Added generic scalar aggregate function that wraps around existing non-scalar aggregate functions (this way we should be able to easily add scalar versions of all agg funcs). Implemented scalar count using it.

git-svn-id: https://asterixdb.googlecode.com/svn/branches/asterix_fix_agg@577 eaa15691-b419-025a-1212-ee371bd00084
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/SingleFieldFrameTupleReference.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/SingleFieldFrameTupleReference.java
new file mode 100644
index 0000000..a42a3e9
--- /dev/null
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/SingleFieldFrameTupleReference.java
@@ -0,0 +1,47 @@
+package edu.uci.ics.asterix.runtime.aggregates.base;
+
+import edu.uci.ics.hyracks.api.comm.IFrameTupleAccessor;
+import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
+
+public class SingleFieldFrameTupleReference implements IFrameTupleReference {
+
+    private byte[] fildData;
+    private int start;
+    private int length;
+    
+    public void reset(byte[] fildData, int start, int length) {
+        this.fildData = fildData;
+        this.start = start;
+        this.length = length;
+    }
+    
+    @Override
+    public int getFieldCount() {
+        return 1;
+    }
+
+    @Override
+    public byte[] getFieldData(int fIdx) {
+        return fildData;
+    }
+
+    @Override
+    public int getFieldStart(int fIdx) {
+       return start;
+    }
+
+    @Override
+    public int getFieldLength(int fIdx) {
+       return length;
+    }
+
+    @Override
+    public IFrameTupleAccessor getFrameTupleAccessor() {
+        return null;
+    }
+
+    @Override
+    public int getTupleIndex() {
+        return 0;
+    }
+}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java
new file mode 100644
index 0000000..17f01b9
--- /dev/null
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java
@@ -0,0 +1,40 @@
+package edu.uci.ics.asterix.runtime.aggregates.scalar;
+
+import edu.uci.ics.asterix.om.functions.AsterixBuiltinFunctions;
+import edu.uci.ics.asterix.om.functions.FunctionManagerHolder;
+import edu.uci.ics.asterix.om.functions.IFunctionDescriptor;
+import edu.uci.ics.asterix.om.functions.IFunctionManager;
+import edu.uci.ics.asterix.runtime.aggregates.base.AbstractAggregateFunctionDynamicDescriptor;
+import edu.uci.ics.asterix.runtime.evaluators.base.AbstractScalarFunctionDynamicDescriptor;
+import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
+import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunctionFactory;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
+import edu.uci.ics.hyracks.algebricks.runtime.evaluators.ColumnAccessEvalFactory;
+import edu.uci.ics.hyracks.data.std.api.IDataOutputProvider;
+
+public abstract class AbstractScalarAggregateDescriptor extends AbstractScalarFunctionDynamicDescriptor {
+    private static final long serialVersionUID = 1L;
+
+    @Override
+    public ICopyEvaluatorFactory createEvaluatorFactory(final ICopyEvaluatorFactory[] args) throws AlgebricksException {
+        return new ICopyEvaluatorFactory() {
+            private static final long serialVersionUID = 1L;
+
+            @Override
+            public ICopyEvaluator createEvaluator(IDataOutputProvider output) throws AlgebricksException {
+                // Always access column 0.
+                ICopyEvaluatorFactory[] aggFuncArgs = new ICopyEvaluatorFactory[1];
+                aggFuncArgs[0] = new ColumnAccessEvalFactory(0);
+                // Create aggregate function from this scalar version.
+                FunctionIdentifier fid = AsterixBuiltinFunctions.getAggregateFunction(getIdentifier());
+                IFunctionManager mgr = FunctionManagerHolder.getFunctionManager();
+                IFunctionDescriptor fd = mgr.lookupFunction(fid);
+                AbstractAggregateFunctionDynamicDescriptor aggFuncDesc = (AbstractAggregateFunctionDynamicDescriptor) fd;
+                ICopyAggregateFunctionFactory aggFuncFactory = aggFuncDesc.createAggregateFunctionFactory(aggFuncArgs);
+                return new GenericScalarAggregateFunction(args, aggFuncFactory.createAggregateFunction(output));
+            }
+        };
+    }
+}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java
new file mode 100644
index 0000000..90e62cf
--- /dev/null
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java
@@ -0,0 +1,65 @@
+package edu.uci.ics.asterix.runtime.aggregates.scalar;
+
+import edu.uci.ics.asterix.common.exceptions.AsterixException;
+import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AOrderedListSerializerDeserializer;
+import edu.uci.ics.asterix.om.types.ATypeTag;
+import edu.uci.ics.asterix.om.types.EnumDeserializer;
+import edu.uci.ics.asterix.om.util.NonTaggedFormatUtil;
+import edu.uci.ics.asterix.runtime.aggregates.base.SingleFieldFrameTupleReference;
+import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunction;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
+import edu.uci.ics.hyracks.data.std.util.ArrayBackedValueStorage;
+import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
+
+/**
+ * Implements scalar aggregates via the corresponding
+ * ICopyAggregateFunction which should be passed inside this function.
+ */
+public class GenericScalarAggregateFunction implements ICopyEvaluator {
+
+    private final ArrayBackedValueStorage argOut = new ArrayBackedValueStorage();
+    private final ICopyEvaluator argEval;
+    private final ICopyAggregateFunction aggFunc;
+
+    private final SingleFieldFrameTupleReference itemTuple = new SingleFieldFrameTupleReference();
+    
+    public GenericScalarAggregateFunction(ICopyEvaluatorFactory[] args, ICopyAggregateFunction aggFunc) throws AlgebricksException {
+        this.argEval = args[0].createEvaluator(argOut);
+        this.aggFunc = aggFunc;
+    }
+
+    @Override
+    public void evaluate(IFrameTupleReference tuple) throws AlgebricksException {
+        argOut.reset();
+        argEval.evaluate(tuple);
+        byte[] listBytes = argOut.getByteArray();
+        ATypeTag argType = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(listBytes[0]);        
+        switch (argType) {
+            case UNORDEREDLIST: {
+                break;
+            }
+            case ORDEREDLIST: {
+                try {
+                    aggFunc.init();
+                    ATypeTag itemTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(listBytes[1]);
+                    int numItems = AOrderedListSerializerDeserializer.getNumberOfItems(listBytes, 0);
+                    for (int i = 0; i < numItems; i++) {
+                        int itemOffset = AOrderedListSerializerDeserializer.getItemOffset(listBytes, i);
+                        int itemLength = NonTaggedFormatUtil.getFieldValueLength(listBytes, itemOffset, itemTag, false);
+                        itemTuple.reset(listBytes, itemOffset, itemLength);
+                        aggFunc.step(itemTuple);
+                    }
+                    aggFunc.finish();
+                } catch (AsterixException e) {
+                    throw new AlgebricksException(e);
+                }
+                break;
+            }
+            default: {
+                throw new AlgebricksException("Unsupported type '" + argType + "' as input to scalar aggregate function.");
+            }
+        }        
+    }
+}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/ScalarCountAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/ScalarCountAggregateDescriptor.java
new file mode 100644
index 0000000..6f3baa6
--- /dev/null
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/ScalarCountAggregateDescriptor.java
@@ -0,0 +1,22 @@
+package edu.uci.ics.asterix.runtime.aggregates.scalar;
+
+import edu.uci.ics.asterix.common.functions.FunctionConstants;
+import edu.uci.ics.asterix.om.functions.IFunctionDescriptor;
+import edu.uci.ics.asterix.om.functions.IFunctionDescriptorFactory;
+import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
+
+public class ScalarCountAggregateDescriptor extends AbstractScalarAggregateDescriptor {
+
+    private static final long serialVersionUID = 1L;
+    public final static FunctionIdentifier FID = new FunctionIdentifier(FunctionConstants.ASTERIX_NS, "count", 1);
+    public static final IFunctionDescriptorFactory FACTORY = new IFunctionDescriptorFactory() {
+        public IFunctionDescriptor createFunctionDescriptor() {
+            return new ScalarCountAggregateDescriptor();
+        }
+    };
+
+    @Override
+    public FunctionIdentifier getIdentifier() {
+        return FID;
+    }
+}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java
index 615f543..a264c91 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java
@@ -1,24 +1,15 @@
 package edu.uci.ics.asterix.runtime.aggregates.std;
 
-import java.io.DataOutput;
-import java.io.IOException;
-
 import edu.uci.ics.asterix.common.functions.FunctionConstants;
-import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
-import edu.uci.ics.asterix.om.base.AInt32;
-import edu.uci.ics.asterix.om.base.AMutableInt32;
 import edu.uci.ics.asterix.om.functions.IFunctionDescriptor;
 import edu.uci.ics.asterix.om.functions.IFunctionDescriptorFactory;
-import edu.uci.ics.asterix.om.types.BuiltinType;
 import edu.uci.ics.asterix.runtime.aggregates.base.AbstractAggregateFunctionDynamicDescriptor;
 import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
 import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
 import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunction;
 import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunctionFactory;
 import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
-import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
 import edu.uci.ics.hyracks.data.std.api.IDataOutputProvider;
-import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
 
 /**
  * NULLs are also counted.
@@ -47,43 +38,8 @@
 
             @Override
             public ICopyAggregateFunction createAggregateFunction(IDataOutputProvider provider) throws AlgebricksException {
-
-                final DataOutput out = provider.getDataOutput();
-
-                return new ICopyAggregateFunction() {
-                    private AMutableInt32 result = new AMutableInt32(-1);
-                    @SuppressWarnings("unchecked")
-                    private ISerializerDeserializer<AInt32> int32Serde = AqlSerializerDeserializerProvider.INSTANCE
-                            .getSerializerDeserializer(BuiltinType.AINT32);
-                    private int cnt;
-
-                    @Override
-                    public void init() {
-                        cnt = 0;
-                    }
-
-                    @Override
-                    public void step(IFrameTupleReference tuple) throws AlgebricksException {
-                        cnt++;
-                    }
-
-                    @Override
-                    public void finish() throws AlgebricksException {
-                        try {
-                            result.setValue(cnt);
-                            int32Serde.serialize(result, out);
-                        } catch (IOException e) {
-                            throw new AlgebricksException(e);
-                        }
-                    }
-
-                    @Override
-                    public void finishPartial() throws AlgebricksException {
-                        finish();
-                    }
-                };
+                return new CountAggregateFunction(provider);
             }
         };
     }
-
 }
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java
new file mode 100644
index 0000000..7dfc601
--- /dev/null
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java
@@ -0,0 +1,52 @@
+package edu.uci.ics.asterix.runtime.aggregates.std;
+
+import java.io.DataOutput;
+import java.io.IOException;
+
+import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
+import edu.uci.ics.asterix.om.base.AInt32;
+import edu.uci.ics.asterix.om.base.AMutableInt32;
+import edu.uci.ics.asterix.om.types.BuiltinType;
+import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunction;
+import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
+import edu.uci.ics.hyracks.data.std.api.IDataOutputProvider;
+import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
+
+public class CountAggregateFunction implements ICopyAggregateFunction {
+    private AMutableInt32 result = new AMutableInt32(-1);
+    @SuppressWarnings("unchecked")
+    private ISerializerDeserializer<AInt32> int32Serde = AqlSerializerDeserializerProvider.INSTANCE
+            .getSerializerDeserializer(BuiltinType.AINT32);
+    private int cnt;
+    private DataOutput out;
+    
+    public CountAggregateFunction(IDataOutputProvider output) {
+        out = output.getDataOutput();
+    }
+    
+    @Override
+    public void init() {
+        cnt = 0;
+    }
+
+    @Override
+    public void step(IFrameTupleReference tuple) throws AlgebricksException {
+        cnt++;
+    }
+
+    @Override
+    public void finish() throws AlgebricksException {
+        try {
+            result.setValue(cnt);
+            int32Serde.serialize(result, out);
+        } catch (IOException e) {
+            throw new AlgebricksException(e);
+        }
+    }
+
+    @Override
+    public void finishPartial() throws AlgebricksException {
+        finish();
+    }
+}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/formats/NonTaggedDataFormat.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/formats/NonTaggedDataFormat.java
index a3b38a9..60e29f9 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/formats/NonTaggedDataFormat.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/formats/NonTaggedDataFormat.java
@@ -41,6 +41,7 @@
 import edu.uci.ics.asterix.om.types.AUnorderedListType;
 import edu.uci.ics.asterix.om.types.IAType;
 import edu.uci.ics.asterix.runtime.aggregates.collections.ListifyAggregateDescriptor;
+import edu.uci.ics.asterix.runtime.aggregates.scalar.ScalarCountAggregateDescriptor;
 import edu.uci.ics.asterix.runtime.aggregates.serializable.std.SerializableAvgAggregateDescriptor;
 import edu.uci.ics.asterix.runtime.aggregates.serializable.std.SerializableCountAggregateDescriptor;
 import edu.uci.ics.asterix.runtime.aggregates.serializable.std.SerializableGlobalAvgAggregateDescriptor;
@@ -320,6 +321,9 @@
         temp.add(SerializableSumAggregateDescriptor.FACTORY);
         temp.add(SerializableLocalSumAggregateDescriptor.FACTORY);
 
+        // scalar aggregates
+        temp.add(ScalarCountAggregateDescriptor.FACTORY);
+        
         // new functions - constructors
         temp.add(ABooleanConstructorDescriptor.FACTORY);
         temp.add(ANullConstructorDescriptor.FACTORY);
@@ -528,7 +532,7 @@
         typeInference(expr, fd, context);
         return fd;
     }
-
+    
     private void typeInference(ILogicalExpression expr, IFunctionDescriptor fd, IVariableTypeEnvironment context)
             throws AlgebricksException {
         if (fd.getIdentifier().equals(AsterixBuiltinFunctions.LISTIFY)) {