Addressed Vinayak's comments on sum and count.

git-svn-id: https://asterixdb.googlecode.com/svn/branches/asterix_fix_agg@614 eaa15691-b419-025a-1212-ee371bd00084
diff --git a/asterix-app/src/test/resources/runtimets/results/aggregate/count_null.adm b/asterix-app/src/test/resources/runtimets/results/aggregate/count_null.adm
index df462fe..51d5f4f 100644
--- a/asterix-app/src/test/resources/runtimets/results/aggregate/count_null.adm
+++ b/asterix-app/src/test/resources/runtimets/results/aggregate/count_null.adm
@@ -1 +1 @@
-{ "count": 2 }
\ No newline at end of file
+{ "count": null }
\ No newline at end of file
diff --git a/asterix-app/src/test/resources/runtimets/results/aggregate/scalar_count_null.adm b/asterix-app/src/test/resources/runtimets/results/aggregate/scalar_count_null.adm
index 4ff1111..1abbc3f 100644
--- a/asterix-app/src/test/resources/runtimets/results/aggregate/scalar_count_null.adm
+++ b/asterix-app/src/test/resources/runtimets/results/aggregate/scalar_count_null.adm
@@ -1,7 +1,7 @@
-4
-4
-4
-4
-4
-4
-4
\ No newline at end of file
+null
+null
+null
+null
+null
+null
+null
\ No newline at end of file
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java
index fdf9325..3de05f8 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java
@@ -26,7 +26,7 @@
             @Override
             public ICopyEvaluator createEvaluator(IDataOutputProvider output) throws AlgebricksException {
                 // The aggregate function will get a SingleFieldFrameTupleReference that points to the result of the ScanCollection.
-                // The list-item will always reside in the first filed (column) of the SingleFieldFrameTupleReference.
+                // The list-item will always reside in the first field (column) of the SingleFieldFrameTupleReference.
                 ICopyEvaluatorFactory[] aggFuncArgs = new ICopyEvaluatorFactory[1];
                 aggFuncArgs[0] = new ColumnAccessEvalFactory(0);
                 // Create aggregate function from this scalar version.
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java
index 0811968..195b391 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java
@@ -31,6 +31,7 @@
     @Override
     public void evaluate(IFrameTupleReference tuple) throws AlgebricksException {
         scanCollection.init(tuple);
+        aggFunc.init();
         while (scanCollection.step()) {
             itemTuple.reset(listItemOut.getByteArray(), 0, listItemOut.getLength());
             aggFunc.step(itemTuple);
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableCountAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableCountAggregateDescriptor.java
index d8f2553..7f42938 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableCountAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableCountAggregateDescriptor.java
@@ -7,26 +7,30 @@
 import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
 import edu.uci.ics.asterix.om.base.AInt32;
 import edu.uci.ics.asterix.om.base.AMutableInt32;
+import edu.uci.ics.asterix.om.base.ANull;
 import edu.uci.ics.asterix.om.functions.IFunctionDescriptor;
 import edu.uci.ics.asterix.om.functions.IFunctionDescriptorFactory;
+import edu.uci.ics.asterix.om.types.ATypeTag;
 import edu.uci.ics.asterix.om.types.BuiltinType;
+import edu.uci.ics.asterix.om.types.EnumDeserializer;
 import edu.uci.ics.asterix.runtime.aggregates.base.AbstractSerializableAggregateFunctionDynamicDescriptor;
 import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
 import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
 import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
 import edu.uci.ics.hyracks.algebricks.runtime.base.ICopySerializableAggregateFunction;
 import edu.uci.ics.hyracks.algebricks.runtime.base.ICopySerializableAggregateFunctionFactory;
 import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
+import edu.uci.ics.hyracks.data.std.util.ArrayBackedValueStorage;
 import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
 
 /**
- * NULLs are also counted.
+ * count(NULL) returns NULL.
  */
 public class SerializableCountAggregateDescriptor extends AbstractSerializableAggregateFunctionDynamicDescriptor {
 
     private static final long serialVersionUID = 1L;
-    public final static FunctionIdentifier FID = new FunctionIdentifier(FunctionConstants.ASTERIX_NS, "count-serial",
-            1);
+    public final static FunctionIdentifier FID = new FunctionIdentifier(FunctionConstants.ASTERIX_NS, "count-serial", 1);
     public static final IFunctionDescriptorFactory FACTORY = new IFunctionDescriptorFactory() {
         public IFunctionDescriptor createFunctionDescriptor() {
             return new SerializableCountAggregateDescriptor();
@@ -39,8 +43,8 @@
     }
 
     @Override
-    public ICopySerializableAggregateFunctionFactory createSerializableAggregateFunctionFactory(final ICopyEvaluatorFactory[] args)
-            throws AlgebricksException {
+    public ICopySerializableAggregateFunctionFactory createSerializableAggregateFunctionFactory(
+            final ICopyEvaluatorFactory[] args) throws AlgebricksException {
         return new ICopySerializableAggregateFunctionFactory() {
             private static final long serialVersionUID = 1L;
 
@@ -52,10 +56,16 @@
                     @SuppressWarnings("unchecked")
                     private ISerializerDeserializer<AInt32> int32Serde = AqlSerializerDeserializerProvider.INSTANCE
                             .getSerializerDeserializer(BuiltinType.AINT32);
+                    @SuppressWarnings("unchecked")
+                    private ISerializerDeserializer<ANull> nullSerde = AqlSerializerDeserializerProvider.INSTANCE
+                            .getSerializerDeserializer(BuiltinType.ANULL);
+                    private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
+                    private ICopyEvaluator eval = args[0].createEvaluator(inputVal);
 
                     @Override
                     public void init(DataOutput state) throws AlgebricksException {
                         try {
+                            state.writeBoolean(false);
                             state.writeInt(0);
                         } catch (IOException e) {
                             throw new AlgebricksException(e);
@@ -65,17 +75,32 @@
                     @Override
                     public void step(IFrameTupleReference tuple, byte[] state, int start, int len)
                             throws AlgebricksException {
-                        int cnt = BufferSerDeUtil.getInt(state, start);
-                        cnt++;
-                        BufferSerDeUtil.writeInt(cnt, state, start);
+                        boolean metNull = BufferSerDeUtil.getBoolean(state, start);
+                        int cnt = BufferSerDeUtil.getInt(state, start + 1);
+                        inputVal.reset();
+                        eval.evaluate(tuple);
+                        ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER
+                                .deserialize(inputVal.getByteArray()[0]);
+                        if (typeTag == ATypeTag.NULL) {
+                            metNull = true;
+                        } else {
+                            cnt++;
+                        }
+                        BufferSerDeUtil.writeBoolean(metNull, state, start);
+                        BufferSerDeUtil.writeInt(cnt, state, start + 1);
                     }
 
                     @Override
                     public void finish(byte[] state, int start, int len, DataOutput out) throws AlgebricksException {
-                        int cnt = BufferSerDeUtil.getInt(state, start);
+                        boolean metNull = BufferSerDeUtil.getBoolean(state, start);
+                        int cnt = BufferSerDeUtil.getInt(state, start + 1);
                         try {
-                            result.setValue(cnt);
-                            int32Serde.serialize(result, out);
+                            if (metNull) {
+                                nullSerde.serialize(ANull.NULL, out);
+                            } else {
+                                result.setValue(cnt);
+                                int32Serde.serialize(result, out);
+                            }
                         } catch (IOException e) {
                             throw new AlgebricksException(e);
                         }
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java
index 4cfde34..c244d89 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java
@@ -40,23 +40,18 @@
     private AMutableInt8 aInt8 = new AMutableInt8((byte) 0);
     @SuppressWarnings("rawtypes")
     private ISerializerDeserializer serde;
-    private final boolean isLocalAgg = false;
-    
+    private final boolean isLocalAgg;
+
     public SerializableSumAggregateFunction(ICopyEvaluatorFactory[] args, boolean isLocalAgg)
             throws AlgebricksException {
         eval = args[0].createEvaluator(inputVal);
+        this.isLocalAgg = isLocalAgg;
     }
-    
+
     @Override
     public void init(DataOutput state) throws AlgebricksException {
         try {
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeBoolean(false);
+            state.writeByte(ATypeTag.SYSTEM_NULL.serialize());
             state.writeDouble(0.0);
         } catch (IOException e) {
             throw new AlgebricksException(e);
@@ -64,145 +59,132 @@
     }
 
     @Override
-    public void step(IFrameTupleReference tuple, byte[] state, int start, int len)
-            throws AlgebricksException {
-        int pos = start;
-        boolean metInt8s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt16s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt32s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt64s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metFloats = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metDoubles = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metNull = BufferSerDeUtil.getBoolean(state, pos++);
-        double sum = BufferSerDeUtil.getDouble(state, pos);
-
+    public void step(IFrameTupleReference tuple, byte[] state, int start, int len) throws AlgebricksException {
+        ATypeTag aggType = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(state[start]);
+        double sum = BufferSerDeUtil.getDouble(state, start + 1);
         inputVal.reset();
         eval.evaluate(tuple);
-        if (inputVal.getLength() > 0) {
-            ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal
-                    .getByteArray()[0]);
-            switch (typeTag) {
-                case INT8: {
-                    metInt8s = true;
-                    byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
+        ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal.getByteArray()[0]);
+        if (typeTag == ATypeTag.NULL) {
+            aggType = ATypeTag.NULL;
+        }
+        if (aggType == ATypeTag.NULL) {
+            return;
+        } else if (aggType == ATypeTag.SYSTEM_NULL) {
+            aggType = typeTag;
+        } else if (typeTag != ATypeTag.SYSTEM_NULL && typeTag != aggType) {
+            throw new AlgebricksException("Unexpected type " + typeTag
+                    + " in sum-aggregation input stream. Expected type " + aggType + ".");
+        }
+        switch (typeTag) {
+            case INT8: {
+                byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case INT16: {
+                short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case INT32: {
+                int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case INT64: {
+                long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case FLOAT: {
+                float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case DOUBLE: {
+                double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case NULL: {
+                aggType = typeTag;
+                break;
+            }
+            case SYSTEM_NULL: {
+                // For global aggregates simply ignore system null here,
+                // but if all input value are system null, then we should return
+                // null in finish().
+                if (isLocalAgg) {
+                    throw new AlgebricksException("Type SYSTEM_NULL encountered in local aggregate.");
                 }
-                case INT16: {
-                    metInt16s = true;
-                    short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case INT32: {
-                    metInt32s = true;
-                    int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case INT64: {
-                    metInt64s = true;
-                    long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case FLOAT: {
-                    metFloats = true;
-                    float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case DOUBLE: {
-                    metDoubles = true;
-                    double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case NULL: {
-                    metNull = true;
-                    break;
-                }
-                case SYSTEM_NULL: {
-                    // For global aggregates simply ignore system null here,
-                    // but if all input value are system null, then we should return
-                    // null in finish().
-                    if (isLocalAgg) {
-                        throw new AlgebricksException("Type SYSTEM_NULL encountered in local aggregate.");
-                    }
-                    break;
-                }
-                default: {
-                    throw new NotImplementedException("Cannot compute SUM for values of type "
-                            + typeTag);
-                }
+                break;
+            }
+            default: {
+                throw new NotImplementedException("Cannot compute SUM for values of type " + typeTag + ".");
             }
         }
-
-        pos = start;
-        BufferSerDeUtil.writeBoolean(metInt8s, state, pos++);
-        BufferSerDeUtil.writeBoolean(metInt16s, state, pos++);
-        BufferSerDeUtil.writeBoolean(metInt32s, state, pos++);
-        BufferSerDeUtil.writeBoolean(metInt64s, state, pos++);
-        BufferSerDeUtil.writeBoolean(metFloats, state, pos++);
-        BufferSerDeUtil.writeBoolean(metDoubles, state, pos++);
-        BufferSerDeUtil.writeBoolean(metNull, state, pos++);
-        BufferSerDeUtil.writeDouble(sum, state, pos);
+        state[start] = aggType.serialize();
+        BufferSerDeUtil.writeDouble(sum, state, start + 1);
     }
 
     @SuppressWarnings("unchecked")
     @Override
     public void finish(byte[] state, int start, int len, DataOutput out) throws AlgebricksException {
-        int pos = start;
-        boolean metInt8s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt16s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt32s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt64s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metFloats = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metDoubles = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metNull = BufferSerDeUtil.getBoolean(state, pos++);
-        double sum = BufferSerDeUtil.getDouble(state, pos);
+        ATypeTag aggType = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(state[start]);
+        double sum = BufferSerDeUtil.getDouble(state, start + 1);
         try {
-            if (metNull) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.ANULL);
-                serde.serialize(ANull.NULL, out);
-            } else if (metDoubles) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.ADOUBLE);
-                aDouble.setValue(sum);
-                serde.serialize(aDouble, out);
-            } else if (metFloats) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.AFLOAT);
-                aFloat.setValue((float) sum);
-                serde.serialize(aFloat, out);
-            } else if (metInt64s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.AINT64);
-                aInt64.setValue((long) sum);
-                serde.serialize(aInt64, out);
-            } else if (metInt32s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.AINT32);
-                aInt32.setValue((int) sum);
-                serde.serialize(aInt32, out);
-            } else if (metInt16s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.AINT16);
-                aInt16.setValue((short) sum);
-                serde.serialize(aInt16, out);
-            } else if (metInt8s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.AINT8);
-                aInt8.setValue((byte) sum);
-                serde.serialize(aInt8, out);
-            } else {
-                // Empty stream. For local agg return system null. For global agg return null.
-                if (isLocalAgg) {
-                    out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
-                } else {
+            switch (aggType) {
+                case INT8: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT8);
+                    aInt8.setValue((byte) sum);
+                    serde.serialize(aInt8, out);
+                    break;
+                }
+                case INT16: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT16);
+                    aInt16.setValue((short) sum);
+                    serde.serialize(aInt16, out);
+                    break;
+                }
+                case INT32: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT32);
+                    aInt32.setValue((int) sum);
+                    serde.serialize(aInt32, out);
+                    break;
+                }
+                case INT64: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT64);
+                    aInt64.setValue((long) sum);
+                    serde.serialize(aInt64, out);
+                    break;
+                }
+                case FLOAT: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AFLOAT);
+                    aFloat.setValue((float) sum);
+                    serde.serialize(aFloat, out);
+                    break;
+                }
+                case DOUBLE: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ADOUBLE);
+                    aDouble.setValue(sum);
+                    serde.serialize(aDouble, out);
+                    break;
+                }
+                case NULL: {
                     serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
                     serde.serialize(ANull.NULL, out);
+                    break;
+                }
+                case SYSTEM_NULL: {
+                    // Empty stream. For local agg return system null. For global agg return null.
+                    if (isLocalAgg) {
+                        out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
+                    } else {
+                        serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
+                        serde.serialize(ANull.NULL, out);
+                    }
+                    break;
                 }
             }
         } catch (IOException e) {
@@ -212,8 +194,7 @@
     }
 
     @Override
-    public void finishPartial(byte[] state, int start, int len, DataOutput out)
-            throws AlgebricksException {
+    public void finishPartial(byte[] state, int start, int len, DataOutput out) throws AlgebricksException {
         finish(state, start, len, out);
     }
 }
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java
index bd0603f..bbd0460 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java
@@ -170,10 +170,10 @@
                         if (count == 0) {
                             GlobalConfig.ASTERIX_LOGGER.fine("AVG aggregate ran over empty input.");
                             try {
-								nullSerde.serialize(ANull.NULL, out);
-							} catch (HyracksDataException e) {
-								throw new AlgebricksException(e);
-							}
+                                nullSerde.serialize(ANull.NULL, out);
+                            } catch (HyracksDataException e) {
+                                throw new AlgebricksException(e);
+                            }
                         } else {
                             try {
                                 if (metNull)
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java
index a264c91..2b214d6 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java
@@ -38,7 +38,7 @@
 
             @Override
             public ICopyAggregateFunction createAggregateFunction(IDataOutputProvider provider) throws AlgebricksException {
-                return new CountAggregateFunction(provider);
+                return new CountAggregateFunction(args, provider);
             }
         };
     }
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java
index 7dfc601..e4d015b 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java
@@ -6,40 +6,69 @@
 import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
 import edu.uci.ics.asterix.om.base.AInt32;
 import edu.uci.ics.asterix.om.base.AMutableInt32;
+import edu.uci.ics.asterix.om.base.ANull;
+import edu.uci.ics.asterix.om.types.ATypeTag;
 import edu.uci.ics.asterix.om.types.BuiltinType;
+import edu.uci.ics.asterix.om.types.EnumDeserializer;
 import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
 import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunction;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
 import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
 import edu.uci.ics.hyracks.data.std.api.IDataOutputProvider;
+import edu.uci.ics.hyracks.data.std.util.ArrayBackedValueStorage;
 import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
 
+/**
+ * count(NULL) returns NULL.
+ */
 public class CountAggregateFunction implements ICopyAggregateFunction {
     private AMutableInt32 result = new AMutableInt32(-1);
     @SuppressWarnings("unchecked")
     private ISerializerDeserializer<AInt32> int32Serde = AqlSerializerDeserializerProvider.INSTANCE
             .getSerializerDeserializer(BuiltinType.AINT32);
+    @SuppressWarnings("unchecked")
+    private ISerializerDeserializer<ANull> nullSerde = AqlSerializerDeserializerProvider.INSTANCE
+            .getSerializerDeserializer(BuiltinType.ANULL);
+    private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
+    private ICopyEvaluator eval;
+    private boolean metNull;
     private int cnt;
     private DataOutput out;
-    
-    public CountAggregateFunction(IDataOutputProvider output) {
+
+    public CountAggregateFunction(ICopyEvaluatorFactory[] args, IDataOutputProvider output) throws AlgebricksException {
+        eval = args[0].createEvaluator(inputVal);
         out = output.getDataOutput();
     }
-    
+
     @Override
     public void init() {
         cnt = 0;
+        metNull = false;
     }
 
     @Override
     public void step(IFrameTupleReference tuple) throws AlgebricksException {
-        cnt++;
+        inputVal.reset();
+        eval.evaluate(tuple);
+        ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal.getByteArray()[0]);
+        // Ignore SYSTEM_NULL.
+        if (typeTag == ATypeTag.NULL) {
+            metNull = true;
+        } else {
+            cnt++;
+        }
     }
 
     @Override
     public void finish() throws AlgebricksException {
         try {
-            result.setValue(cnt);
-            int32Serde.serialize(result, out);
+            if (metNull) {
+                nullSerde.serialize(ANull.NULL, out);
+            } else {
+                result.setValue(cnt);
+                int32Serde.serialize(result, out);
+            }
         } catch (IOException e) {
             throw new AlgebricksException(e);
         }
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/SumAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/SumAggregateFunction.java
index 9ea13ae..4870745 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/SumAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/SumAggregateFunction.java
@@ -34,8 +34,8 @@
     private DataOutput out;
     private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
     private ICopyEvaluator eval;
-    private boolean metInt8s, metInt16s, metInt32s, metInt64s, metFloats, metDoubles, metNull;
     private double sum;
+    private ATypeTag aggType;
     private AMutableDouble aDouble = new AMutableDouble(0);
     private AMutableFloat aFloat = new AMutableFloat(0);
     private AMutableInt64 aInt64 = new AMutableInt64(0);
@@ -56,13 +56,7 @@
 
     @Override
     public void init() {
-        metInt8s = false;
-        metInt16s = false;
-        metInt32s = false;
-        metInt64s = false;
-        metFloats = false;
-        metDoubles = false;
-        metNull = false;
+        aggType = ATypeTag.SYSTEM_NULL;
         sum = 0.0;
     }
 
@@ -70,61 +64,63 @@
     public void step(IFrameTupleReference tuple) throws AlgebricksException {
         inputVal.reset();
         eval.evaluate(tuple);
-        if (inputVal.getLength() > 0) {
-            ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal.getByteArray()[0]);
-            switch (typeTag) {
-                case INT8: {
-                    metInt8s = true;
-                    byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
+        ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal.getByteArray()[0]);
+        if (typeTag == ATypeTag.NULL) {
+            aggType = ATypeTag.NULL;
+        }
+        if (aggType == ATypeTag.NULL) {
+            return;
+        } else if (aggType == ATypeTag.SYSTEM_NULL) {
+            aggType = typeTag;
+        } else if (typeTag != ATypeTag.SYSTEM_NULL && typeTag != aggType) {
+            throw new AlgebricksException("Unexpected type " + typeTag
+                    + " in sum-aggregation input stream. Expected type " + aggType + ".");
+        }
+        switch (typeTag) {
+            case INT8: {
+                byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case INT16: {
+                short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case INT32: {
+                int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case INT64: {
+                long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case FLOAT: {
+                float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case DOUBLE: {
+                double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
+                sum += val;
+                break;
+            }
+            case NULL: {
+                break;
+            }
+            case SYSTEM_NULL: {
+                // For global aggregates simply ignore system null here,
+                // but if all input value are system null, then we should return
+                // null in finish().
+                if (isLocalAgg) {
+                    throw new AlgebricksException("Type SYSTEM_NULL encountered in local aggregate.");
                 }
-                case INT16: {
-                    metInt16s = true;
-                    short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case INT32: {
-                    metInt32s = true;
-                    int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case INT64: {
-                    metInt64s = true;
-                    long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case FLOAT: {
-                    metFloats = true;
-                    float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case DOUBLE: {
-                    metDoubles = true;
-                    double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case NULL: {
-                    metNull = true;
-                    break;
-                }
-                case SYSTEM_NULL: {
-                    // For global aggregates simply ignore system null here,
-                    // but if all input value are system null, then we should return
-                    // null in finish().
-                    if (isLocalAgg) {
-                        throw new AlgebricksException("Type SYSTEM_NULL encountered in local aggregate.");
-                    }
-                    break;
-                }
-                default: {
-                    throw new NotImplementedException("Cannot compute SUM for values of type " + typeTag);
-                }
+                break;
+            }
+            default: {
+                throw new NotImplementedException("Cannot compute SUM for values of type " + typeTag + ".");
             }
         }
     }
@@ -133,40 +129,57 @@
     @Override
     public void finish() throws AlgebricksException {
         try {
-            if (metNull) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
-                serde.serialize(ANull.NULL, out);
-            } else if (metDoubles) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ADOUBLE);
-                aDouble.setValue(sum);
-                serde.serialize(aDouble, out);
-            } else if (metFloats) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AFLOAT);
-                aFloat.setValue((float) sum);
-                serde.serialize(aFloat, out);
-            } else if (metInt64s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT64);
-                aInt64.setValue((long) sum);
-                serde.serialize(aInt64, out);
-            } else if (metInt32s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT32);
-                aInt32.setValue((int) sum);
-                serde.serialize(aInt32, out);
-            } else if (metInt16s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT16);
-                aInt16.setValue((short) sum);
-                serde.serialize(aInt16, out);
-            } else if (metInt8s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT8);
-                aInt8.setValue((byte) sum);
-                serde.serialize(aInt8, out);
-            } else {
-                // Empty stream. For local agg return system null. For global agg return null.
-                if (isLocalAgg) {
-                    out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
-                } else {
+            switch (aggType) {
+                case INT8: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT8);
+                    aInt8.setValue((byte) sum);
+                    serde.serialize(aInt8, out);
+                    break;
+                }
+                case INT16: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT16);
+                    aInt16.setValue((short) sum);
+                    serde.serialize(aInt16, out);
+                    break;
+                }
+                case INT32: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT32);
+                    aInt32.setValue((int) sum);
+                    serde.serialize(aInt32, out);
+                    break;
+                }
+                case INT64: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT64);
+                    aInt64.setValue((long) sum);
+                    serde.serialize(aInt64, out);
+                    break;
+                }
+                case FLOAT: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AFLOAT);
+                    aFloat.setValue((float) sum);
+                    serde.serialize(aFloat, out);
+                    break;
+                }
+                case DOUBLE: {
+                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ADOUBLE);
+                    aDouble.setValue(sum);
+                    serde.serialize(aDouble, out);
+                    break;
+                }
+                case NULL: {
                     serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
                     serde.serialize(ANull.NULL, out);
+                    break;
+                }
+                case SYSTEM_NULL: {
+                    // Empty stream. For local agg return system null. For global agg return null.
+                    if (isLocalAgg) {
+                        out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
+                    } else {
+                        serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
+                        serde.serialize(ANull.NULL, out);
+                    }
+                    break;
                 }
             }
         } catch (IOException e) {