Added initial SumAccumulator implementation to clean up aggregates.

git-svn-id: https://asterixdb.googlecode.com/svn/branches/asterix_fix_agg@566 eaa15691-b419-025a-1212-ee371bd00084
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/IAccumulator.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/IAccumulator.java
new file mode 100644
index 0000000..4f4d32c
--- /dev/null
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/IAccumulator.java
@@ -0,0 +1,16 @@
+package edu.uci.ics.asterix.runtime.aggregates.base;
+
+import java.io.DataOutput;
+import java.io.IOException;
+
+import edu.uci.ics.hyracks.data.std.api.IMutableValueStorage;
+import edu.uci.ics.hyracks.data.std.api.IValueReference;
+
+public interface IAccumulator {
+    public void init(IMutableValueStorage state, IValueReference defaultValue) throws IOException;
+
+    public void step(IMutableValueStorage state, IValueReference value) throws IOException;
+
+    // TODO: Second param was initially an IPointable.
+    public void finish(IMutableValueStorage state, DataOutput out) throws IOException;
+}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/SumAccumulator.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/SumAccumulator.java
new file mode 100644
index 0000000..e33ead9
--- /dev/null
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/SumAccumulator.java
@@ -0,0 +1,164 @@
+package edu.uci.ics.asterix.runtime.aggregates.base;
+
+import java.io.DataOutput;
+import java.io.IOException;
+
+import edu.uci.ics.asterix.dataflow.data.nontagged.serde.ADoubleSerializerDeserializer;
+import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AFloatSerializerDeserializer;
+import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt16SerializerDeserializer;
+import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt32SerializerDeserializer;
+import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt64SerializerDeserializer;
+import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt8SerializerDeserializer;
+import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
+import edu.uci.ics.asterix.om.base.AMutableDouble;
+import edu.uci.ics.asterix.om.base.AMutableFloat;
+import edu.uci.ics.asterix.om.base.AMutableInt16;
+import edu.uci.ics.asterix.om.base.AMutableInt32;
+import edu.uci.ics.asterix.om.base.AMutableInt64;
+import edu.uci.ics.asterix.om.base.AMutableInt8;
+import edu.uci.ics.asterix.om.base.ANull;
+import edu.uci.ics.asterix.om.types.ATypeTag;
+import edu.uci.ics.asterix.om.types.BuiltinType;
+import edu.uci.ics.asterix.om.types.EnumDeserializer;
+import edu.uci.ics.asterix.runtime.aggregates.serializable.std.BufferSerDeUtil;
+import edu.uci.ics.hyracks.algebricks.common.exceptions.NotImplementedException;
+import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
+import edu.uci.ics.hyracks.data.std.api.IMutableValueStorage;
+import edu.uci.ics.hyracks.data.std.api.IValueReference;
+
+public class SumAccumulator implements IAccumulator {
+    private static final int SUM_OFF = 0;
+    // TODO: Let's encode this in a single byte.
+    private static final int MET_INT8_OFF = 8;
+    private static final int MET_INT16_OFF = 9;
+    private static final int MET_INT32_OFF = 10;
+    private static final int MET_INT64_OFF = 11;
+    private static final int MET_FLOAT_OFF = 12;
+    private static final int MET_DOUBLE_OFF = 13;
+    private static final int MET_NULL_OFF = 14;
+        
+    private AMutableInt8 aInt8 = new AMutableInt8((byte) 0);
+    private AMutableInt16 aInt16 = new AMutableInt16((short) 0);
+    private AMutableInt32 aInt32 = new AMutableInt32(0);
+    private AMutableInt64 aInt64 = new AMutableInt64(0);
+    private AMutableFloat aFloat = new AMutableFloat(0);
+    private AMutableDouble aDouble = new AMutableDouble(0);
+    @SuppressWarnings("rawtypes")
+    private ISerializerDeserializer serde;
+    
+    private IValueReference defaultValue;
+    
+    @Override
+    public void init(IMutableValueStorage state, IValueReference defaultValue) throws IOException {
+        // Set initial value.
+        state.getDataOutput().writeDouble(0);
+        // Initialize met flags to false.
+        state.getDataOutput().write((byte) 0);
+        state.getDataOutput().write((byte) 0);
+        state.getDataOutput().write((byte) 0);
+        state.getDataOutput().write((byte) 0);
+        state.getDataOutput().write((byte) 0);
+        state.getDataOutput().write((byte) 0);
+        state.getDataOutput().write((byte) 0);
+        // Remember default value.
+        this.defaultValue = defaultValue;
+    }
+
+    @Override
+    public void step(IMutableValueStorage state, IValueReference value) {
+        byte[] valueBytes = value.getByteArray();
+        int stateStartOff = state.getStartOffset();
+        ATypeTag valueTypeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(valueBytes[0]);
+        double sum = BufferSerDeUtil.getDouble(state.getByteArray(), stateStartOff + SUM_OFF);
+        switch (valueTypeTag) {
+            case INT8: {
+                state.getByteArray()[stateStartOff + MET_INT8_OFF] = 1;
+                sum += AInt8SerializerDeserializer.getByte(valueBytes, 1);
+                break;
+            }
+            case INT16: {
+                state.getByteArray()[stateStartOff + MET_INT16_OFF] = 1;
+                sum += AInt16SerializerDeserializer.getShort(valueBytes, 1);
+                break;
+            }
+            case INT32: {
+                state.getByteArray()[stateStartOff + MET_INT32_OFF] = 1;
+                sum += AInt32SerializerDeserializer.getInt(valueBytes, 1);
+                break;
+            }
+            case INT64: {
+                state.getByteArray()[stateStartOff + MET_INT64_OFF] = 1;
+                sum += AInt64SerializerDeserializer.getLong(valueBytes, 1);
+                break;
+            }
+            case FLOAT: {
+                state.getByteArray()[stateStartOff + MET_FLOAT_OFF] = 1;
+                sum += AFloatSerializerDeserializer.getFloat(valueBytes, 1);
+                break;
+            }
+            case DOUBLE: {
+                state.getByteArray()[stateStartOff + MET_DOUBLE_OFF] = 1;
+                sum += ADoubleSerializerDeserializer.getDouble(valueBytes, 1);
+                break;
+            }
+            case NULL: {
+                state.getByteArray()[stateStartOff + MET_NULL_OFF] = 1;
+                break;
+            }
+            case SYSTEM_NULL: {
+                // Ignore.
+                break;
+            }
+            default: {
+                throw new NotImplementedException("Cannot compute SUM for values of type "
+                        + valueTypeTag);
+            }
+        }
+        BufferSerDeUtil.writeDouble(sum, state.getByteArray(), stateStartOff + SUM_OFF);
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public void finish(IMutableValueStorage state, DataOutput out) throws IOException {
+        byte[] stateBytes = state.getByteArray();
+        int stateStartOff = state.getStartOffset();        
+        double sum = BufferSerDeUtil.getDouble(stateBytes, stateStartOff + SUM_OFF);
+        if (stateBytes[stateStartOff + MET_NULL_OFF] == 1) {
+            serde = AqlSerializerDeserializerProvider.INSTANCE
+                    .getSerializerDeserializer(BuiltinType.ANULL);
+            serde.serialize(ANull.NULL, out);
+        } else if (stateBytes[stateStartOff + MET_DOUBLE_OFF] == 1) {
+            serde = AqlSerializerDeserializerProvider.INSTANCE
+                    .getSerializerDeserializer(BuiltinType.ADOUBLE);
+            aDouble.setValue(sum);
+            serde.serialize(aDouble, out);
+        } else if (stateBytes[stateStartOff + MET_FLOAT_OFF] == 1) {
+            serde = AqlSerializerDeserializerProvider.INSTANCE
+                    .getSerializerDeserializer(BuiltinType.AFLOAT);
+            aFloat.setValue((float) sum);
+            serde.serialize(aFloat, out);
+        } else if (stateBytes[stateStartOff + MET_INT64_OFF] == 1) {
+            serde = AqlSerializerDeserializerProvider.INSTANCE
+                    .getSerializerDeserializer(BuiltinType.AINT64);
+            aInt64.setValue((long) sum);
+            serde.serialize(aInt64, out);
+        } else if (stateBytes[stateStartOff + MET_INT32_OFF] == 1) {
+            serde = AqlSerializerDeserializerProvider.INSTANCE
+                    .getSerializerDeserializer(BuiltinType.AINT32);
+            aInt32.setValue((int) sum);
+            serde.serialize(aInt32, out);
+        } else if (stateBytes[stateStartOff + MET_INT16_OFF] == 1) {
+            serde = AqlSerializerDeserializerProvider.INSTANCE
+                    .getSerializerDeserializer(BuiltinType.AINT16);
+            aInt16.setValue((short) sum);
+            serde.serialize(aInt16, out);
+        } else if (stateBytes[stateStartOff + MET_INT8_OFF] == 1) {
+            serde = AqlSerializerDeserializerProvider.INSTANCE
+                    .getSerializerDeserializer(BuiltinType.AINT8);
+            aInt8.setValue((byte) sum);
+            serde.serialize(aInt8, out);
+        } else {
+            out.write(defaultValue.getByteArray(), defaultValue.getStartOffset(), defaultValue.getLength());
+        }
+    }
+}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/WrappingMutableValueStorage.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/WrappingMutableValueStorage.java
new file mode 100644
index 0000000..cbfa41a
--- /dev/null
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/base/WrappingMutableValueStorage.java
@@ -0,0 +1,51 @@
+package edu.uci.ics.asterix.runtime.aggregates.base;
+
+import java.io.DataOutput;
+
+import edu.uci.ics.hyracks.data.std.api.IMutableValueStorage;
+
+public class WrappingMutableValueStorage implements IMutableValueStorage {
+
+    private byte[] bytes;
+    private int start;
+    private int length;
+    private DataOutput dataOutput;
+    
+    @Override
+    public byte[] getByteArray() {
+        return bytes;
+    }
+
+    @Override
+    public int getStartOffset() {
+        return start;
+    }
+
+    @Override
+    public int getLength() {
+       return length;
+    }
+
+    @Override
+    public DataOutput getDataOutput() {
+       return dataOutput;
+    }
+
+    @Override
+    public void reset() {
+        dataOutput = null;
+        bytes = null;
+        start = -1;
+        length = -1;
+    }
+    
+    public void wrap(DataOutput dataOutput) {
+        this.dataOutput = dataOutput;
+    }
+    
+    public void wrap(byte[] bytes, int start, int length) {
+        this.bytes = bytes;
+        this.start = start;
+        this.length = length;
+    }
+}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java
index 4cfde34..eed89fc 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java
@@ -3,212 +3,68 @@
 import java.io.DataOutput;
 import java.io.IOException;
 
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.ADoubleSerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AFloatSerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt16SerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt32SerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt64SerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt8SerializerDeserializer;
-import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
-import edu.uci.ics.asterix.om.base.AMutableDouble;
-import edu.uci.ics.asterix.om.base.AMutableFloat;
-import edu.uci.ics.asterix.om.base.AMutableInt16;
-import edu.uci.ics.asterix.om.base.AMutableInt32;
-import edu.uci.ics.asterix.om.base.AMutableInt64;
-import edu.uci.ics.asterix.om.base.AMutableInt8;
-import edu.uci.ics.asterix.om.base.ANull;
 import edu.uci.ics.asterix.om.types.ATypeTag;
-import edu.uci.ics.asterix.om.types.BuiltinType;
-import edu.uci.ics.asterix.om.types.EnumDeserializer;
+import edu.uci.ics.asterix.runtime.aggregates.base.IAccumulator;
+import edu.uci.ics.asterix.runtime.aggregates.base.SumAccumulator;
+import edu.uci.ics.asterix.runtime.aggregates.base.WrappingMutableValueStorage;
 import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
-import edu.uci.ics.hyracks.algebricks.common.exceptions.NotImplementedException;
 import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
 import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
 import edu.uci.ics.hyracks.algebricks.runtime.base.ICopySerializableAggregateFunction;
-import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
 import edu.uci.ics.hyracks.data.std.util.ArrayBackedValueStorage;
 import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
 
 public class SerializableSumAggregateFunction implements ICopySerializableAggregateFunction {
     private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
+    private ArrayBackedValueStorage defaultVal = new ArrayBackedValueStorage();
+    private WrappingMutableValueStorage stateWrapper = new WrappingMutableValueStorage();
     private ICopyEvaluator eval;
-    private AMutableDouble aDouble = new AMutableDouble(0);
-    private AMutableFloat aFloat = new AMutableFloat(0);
-    private AMutableInt64 aInt64 = new AMutableInt64(0);
-    private AMutableInt32 aInt32 = new AMutableInt32(0);
-    private AMutableInt16 aInt16 = new AMutableInt16((short) 0);
-    private AMutableInt8 aInt8 = new AMutableInt8((byte) 0);
-    @SuppressWarnings("rawtypes")
-    private ISerializerDeserializer serde;
-    private final boolean isLocalAgg = false;
+    private IAccumulator accumulator = new SumAccumulator();
     
     public SerializableSumAggregateFunction(ICopyEvaluatorFactory[] args, boolean isLocalAgg)
             throws AlgebricksException {
         eval = args[0].createEvaluator(inputVal);
+        try {
+            if (isLocalAgg) {
+                defaultVal.getDataOutput().writeByte(ATypeTag.SYSTEM_NULL.serialize());
+            } else {
+                defaultVal.getDataOutput().writeByte(ATypeTag.NULL.serialize());
+            }
+        } catch (IOException e) {
+            throw new AlgebricksException(e);
+        }
     }
     
     @Override
     public void init(DataOutput state) throws AlgebricksException {
-        try {
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeBoolean(false);
-            state.writeDouble(0.0);
+        try {            
+            stateWrapper.wrap(state);
+            accumulator.init(stateWrapper, defaultVal);
         } catch (IOException e) {
             throw new AlgebricksException(e);
         }
     }
 
     @Override
-    public void step(IFrameTupleReference tuple, byte[] state, int start, int len)
-            throws AlgebricksException {
-        int pos = start;
-        boolean metInt8s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt16s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt32s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt64s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metFloats = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metDoubles = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metNull = BufferSerDeUtil.getBoolean(state, pos++);
-        double sum = BufferSerDeUtil.getDouble(state, pos);
-
+    public void step(IFrameTupleReference tuple, byte[] state, int start, int len) throws AlgebricksException {
         inputVal.reset();
         eval.evaluate(tuple);
-        if (inputVal.getLength() > 0) {
-            ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal
-                    .getByteArray()[0]);
-            switch (typeTag) {
-                case INT8: {
-                    metInt8s = true;
-                    byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case INT16: {
-                    metInt16s = true;
-                    short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case INT32: {
-                    metInt32s = true;
-                    int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case INT64: {
-                    metInt64s = true;
-                    long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case FLOAT: {
-                    metFloats = true;
-                    float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case DOUBLE: {
-                    metDoubles = true;
-                    double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
-                    sum += val;
-                    break;
-                }
-                case NULL: {
-                    metNull = true;
-                    break;
-                }
-                case SYSTEM_NULL: {
-                    // For global aggregates simply ignore system null here,
-                    // but if all input value are system null, then we should return
-                    // null in finish().
-                    if (isLocalAgg) {
-                        throw new AlgebricksException("Type SYSTEM_NULL encountered in local aggregate.");
-                    }
-                    break;
-                }
-                default: {
-                    throw new NotImplementedException("Cannot compute SUM for values of type "
-                            + typeTag);
-                }
-            }
-        }
-
-        pos = start;
-        BufferSerDeUtil.writeBoolean(metInt8s, state, pos++);
-        BufferSerDeUtil.writeBoolean(metInt16s, state, pos++);
-        BufferSerDeUtil.writeBoolean(metInt32s, state, pos++);
-        BufferSerDeUtil.writeBoolean(metInt64s, state, pos++);
-        BufferSerDeUtil.writeBoolean(metFloats, state, pos++);
-        BufferSerDeUtil.writeBoolean(metDoubles, state, pos++);
-        BufferSerDeUtil.writeBoolean(metNull, state, pos++);
-        BufferSerDeUtil.writeDouble(sum, state, pos);
-    }
-
-    @SuppressWarnings("unchecked")
-    @Override
-    public void finish(byte[] state, int start, int len, DataOutput out) throws AlgebricksException {
-        int pos = start;
-        boolean metInt8s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt16s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt32s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metInt64s = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metFloats = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metDoubles = BufferSerDeUtil.getBoolean(state, pos++);
-        boolean metNull = BufferSerDeUtil.getBoolean(state, pos++);
-        double sum = BufferSerDeUtil.getDouble(state, pos);
+        stateWrapper.wrap(state, start, len);
         try {
-            if (metNull) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.ANULL);
-                serde.serialize(ANull.NULL, out);
-            } else if (metDoubles) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.ADOUBLE);
-                aDouble.setValue(sum);
-                serde.serialize(aDouble, out);
-            } else if (metFloats) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.AFLOAT);
-                aFloat.setValue((float) sum);
-                serde.serialize(aFloat, out);
-            } else if (metInt64s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.AINT64);
-                aInt64.setValue((long) sum);
-                serde.serialize(aInt64, out);
-            } else if (metInt32s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.AINT32);
-                aInt32.setValue((int) sum);
-                serde.serialize(aInt32, out);
-            } else if (metInt16s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.AINT16);
-                aInt16.setValue((short) sum);
-                serde.serialize(aInt16, out);
-            } else if (metInt8s) {
-                serde = AqlSerializerDeserializerProvider.INSTANCE
-                        .getSerializerDeserializer(BuiltinType.AINT8);
-                aInt8.setValue((byte) sum);
-                serde.serialize(aInt8, out);
-            } else {
-                // Empty stream. For local agg return system null. For global agg return null.
-                if (isLocalAgg) {
-                    out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
-                } else {
-                    serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
-                    serde.serialize(ANull.NULL, out);
-                }
-            }
+            accumulator.step(stateWrapper, inputVal);
         } catch (IOException e) {
             throw new AlgebricksException(e);
         }
+    }
 
+    @Override
+    public void finish(byte[] state, int start, int len, DataOutput out) throws AlgebricksException {
+        try {
+            stateWrapper.wrap(state, start, len);
+            accumulator.finish(stateWrapper, out);
+        } catch (IOException e) {
+            throw new AlgebricksException(e);
+        }
     }
 
     @Override