Addressed Vinayak's comments on sum and count.
git-svn-id: https://asterixdb.googlecode.com/svn/branches/asterix_fix_agg@614 eaa15691-b419-025a-1212-ee371bd00084
diff --git a/asterix-app/src/test/resources/runtimets/results/aggregate/count_null.adm b/asterix-app/src/test/resources/runtimets/results/aggregate/count_null.adm
index df462fe..51d5f4f 100644
--- a/asterix-app/src/test/resources/runtimets/results/aggregate/count_null.adm
+++ b/asterix-app/src/test/resources/runtimets/results/aggregate/count_null.adm
@@ -1 +1 @@
-{ "count": 2 }
\ No newline at end of file
+{ "count": null }
\ No newline at end of file
diff --git a/asterix-app/src/test/resources/runtimets/results/aggregate/scalar_count_null.adm b/asterix-app/src/test/resources/runtimets/results/aggregate/scalar_count_null.adm
index 4ff1111..1abbc3f 100644
--- a/asterix-app/src/test/resources/runtimets/results/aggregate/scalar_count_null.adm
+++ b/asterix-app/src/test/resources/runtimets/results/aggregate/scalar_count_null.adm
@@ -1,7 +1,7 @@
-4
-4
-4
-4
-4
-4
-4
\ No newline at end of file
+null
+null
+null
+null
+null
+null
+null
\ No newline at end of file
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java
index fdf9325..3de05f8 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/AbstractScalarAggregateDescriptor.java
@@ -26,7 +26,7 @@
@Override
public ICopyEvaluator createEvaluator(IDataOutputProvider output) throws AlgebricksException {
// The aggregate function will get a SingleFieldFrameTupleReference that points to the result of the ScanCollection.
- // The list-item will always reside in the first filed (column) of the SingleFieldFrameTupleReference.
+ // The list-item will always reside in the first field (column) of the SingleFieldFrameTupleReference.
ICopyEvaluatorFactory[] aggFuncArgs = new ICopyEvaluatorFactory[1];
aggFuncArgs[0] = new ColumnAccessEvalFactory(0);
// Create aggregate function from this scalar version.
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java
index 0811968..195b391 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/scalar/GenericScalarAggregateFunction.java
@@ -31,6 +31,7 @@
@Override
public void evaluate(IFrameTupleReference tuple) throws AlgebricksException {
scanCollection.init(tuple);
+ aggFunc.init();
while (scanCollection.step()) {
itemTuple.reset(listItemOut.getByteArray(), 0, listItemOut.getLength());
aggFunc.step(itemTuple);
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableCountAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableCountAggregateDescriptor.java
index d8f2553..7f42938 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableCountAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableCountAggregateDescriptor.java
@@ -7,26 +7,30 @@
import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
import edu.uci.ics.asterix.om.base.AInt32;
import edu.uci.ics.asterix.om.base.AMutableInt32;
+import edu.uci.ics.asterix.om.base.ANull;
import edu.uci.ics.asterix.om.functions.IFunctionDescriptor;
import edu.uci.ics.asterix.om.functions.IFunctionDescriptorFactory;
+import edu.uci.ics.asterix.om.types.ATypeTag;
import edu.uci.ics.asterix.om.types.BuiltinType;
+import edu.uci.ics.asterix.om.types.EnumDeserializer;
import edu.uci.ics.asterix.runtime.aggregates.base.AbstractSerializableAggregateFunctionDynamicDescriptor;
import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopySerializableAggregateFunction;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopySerializableAggregateFunctionFactory;
import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
+import edu.uci.ics.hyracks.data.std.util.ArrayBackedValueStorage;
import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
/**
- * NULLs are also counted.
+ * count(NULL) returns NULL.
*/
public class SerializableCountAggregateDescriptor extends AbstractSerializableAggregateFunctionDynamicDescriptor {
private static final long serialVersionUID = 1L;
- public final static FunctionIdentifier FID = new FunctionIdentifier(FunctionConstants.ASTERIX_NS, "count-serial",
- 1);
+ public final static FunctionIdentifier FID = new FunctionIdentifier(FunctionConstants.ASTERIX_NS, "count-serial", 1);
public static final IFunctionDescriptorFactory FACTORY = new IFunctionDescriptorFactory() {
public IFunctionDescriptor createFunctionDescriptor() {
return new SerializableCountAggregateDescriptor();
@@ -39,8 +43,8 @@
}
@Override
- public ICopySerializableAggregateFunctionFactory createSerializableAggregateFunctionFactory(final ICopyEvaluatorFactory[] args)
- throws AlgebricksException {
+ public ICopySerializableAggregateFunctionFactory createSerializableAggregateFunctionFactory(
+ final ICopyEvaluatorFactory[] args) throws AlgebricksException {
return new ICopySerializableAggregateFunctionFactory() {
private static final long serialVersionUID = 1L;
@@ -52,10 +56,16 @@
@SuppressWarnings("unchecked")
private ISerializerDeserializer<AInt32> int32Serde = AqlSerializerDeserializerProvider.INSTANCE
.getSerializerDeserializer(BuiltinType.AINT32);
+ @SuppressWarnings("unchecked")
+ private ISerializerDeserializer<ANull> nullSerde = AqlSerializerDeserializerProvider.INSTANCE
+ .getSerializerDeserializer(BuiltinType.ANULL);
+ private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
+ private ICopyEvaluator eval = args[0].createEvaluator(inputVal);
@Override
public void init(DataOutput state) throws AlgebricksException {
try {
+ state.writeBoolean(false);
state.writeInt(0);
} catch (IOException e) {
throw new AlgebricksException(e);
@@ -65,17 +75,32 @@
@Override
public void step(IFrameTupleReference tuple, byte[] state, int start, int len)
throws AlgebricksException {
- int cnt = BufferSerDeUtil.getInt(state, start);
- cnt++;
- BufferSerDeUtil.writeInt(cnt, state, start);
+ boolean metNull = BufferSerDeUtil.getBoolean(state, start);
+ int cnt = BufferSerDeUtil.getInt(state, start + 1);
+ inputVal.reset();
+ eval.evaluate(tuple);
+ ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER
+ .deserialize(inputVal.getByteArray()[0]);
+ if (typeTag == ATypeTag.NULL) {
+ metNull = true;
+ } else {
+ cnt++;
+ }
+ BufferSerDeUtil.writeBoolean(metNull, state, start);
+ BufferSerDeUtil.writeInt(cnt, state, start + 1);
}
@Override
public void finish(byte[] state, int start, int len, DataOutput out) throws AlgebricksException {
- int cnt = BufferSerDeUtil.getInt(state, start);
+ boolean metNull = BufferSerDeUtil.getBoolean(state, start);
+ int cnt = BufferSerDeUtil.getInt(state, start + 1);
try {
- result.setValue(cnt);
- int32Serde.serialize(result, out);
+ if (metNull) {
+ nullSerde.serialize(ANull.NULL, out);
+ } else {
+ result.setValue(cnt);
+ int32Serde.serialize(result, out);
+ }
} catch (IOException e) {
throw new AlgebricksException(e);
}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java
index 4cfde34..c244d89 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/serializable/std/SerializableSumAggregateFunction.java
@@ -40,23 +40,18 @@
private AMutableInt8 aInt8 = new AMutableInt8((byte) 0);
@SuppressWarnings("rawtypes")
private ISerializerDeserializer serde;
- private final boolean isLocalAgg = false;
-
+ private final boolean isLocalAgg;
+
public SerializableSumAggregateFunction(ICopyEvaluatorFactory[] args, boolean isLocalAgg)
throws AlgebricksException {
eval = args[0].createEvaluator(inputVal);
+ this.isLocalAgg = isLocalAgg;
}
-
+
@Override
public void init(DataOutput state) throws AlgebricksException {
try {
- state.writeBoolean(false);
- state.writeBoolean(false);
- state.writeBoolean(false);
- state.writeBoolean(false);
- state.writeBoolean(false);
- state.writeBoolean(false);
- state.writeBoolean(false);
+ state.writeByte(ATypeTag.SYSTEM_NULL.serialize());
state.writeDouble(0.0);
} catch (IOException e) {
throw new AlgebricksException(e);
@@ -64,145 +59,132 @@
}
@Override
- public void step(IFrameTupleReference tuple, byte[] state, int start, int len)
- throws AlgebricksException {
- int pos = start;
- boolean metInt8s = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metInt16s = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metInt32s = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metInt64s = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metFloats = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metDoubles = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metNull = BufferSerDeUtil.getBoolean(state, pos++);
- double sum = BufferSerDeUtil.getDouble(state, pos);
-
+ public void step(IFrameTupleReference tuple, byte[] state, int start, int len) throws AlgebricksException {
+ ATypeTag aggType = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(state[start]);
+ double sum = BufferSerDeUtil.getDouble(state, start + 1);
inputVal.reset();
eval.evaluate(tuple);
- if (inputVal.getLength() > 0) {
- ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal
- .getByteArray()[0]);
- switch (typeTag) {
- case INT8: {
- metInt8s = true;
- byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
- sum += val;
- break;
+ ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal.getByteArray()[0]);
+ if (typeTag == ATypeTag.NULL) {
+ aggType = ATypeTag.NULL;
+ }
+ if (aggType == ATypeTag.NULL) {
+ return;
+ } else if (aggType == ATypeTag.SYSTEM_NULL) {
+ aggType = typeTag;
+ } else if (typeTag != ATypeTag.SYSTEM_NULL && typeTag != aggType) {
+ throw new AlgebricksException("Unexpected type " + typeTag
+ + " in sum-aggregation input stream. Expected type " + aggType + ".");
+ }
+ switch (typeTag) {
+ case INT8: {
+ byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case INT16: {
+ short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case INT32: {
+ int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case INT64: {
+ long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case FLOAT: {
+ float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case DOUBLE: {
+ double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case NULL: {
+ aggType = typeTag;
+ break;
+ }
+ case SYSTEM_NULL: {
+ // For global aggregates simply ignore system null here,
+ // but if all input value are system null, then we should return
+ // null in finish().
+ if (isLocalAgg) {
+ throw new AlgebricksException("Type SYSTEM_NULL encountered in local aggregate.");
}
- case INT16: {
- metInt16s = true;
- short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case INT32: {
- metInt32s = true;
- int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case INT64: {
- metInt64s = true;
- long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case FLOAT: {
- metFloats = true;
- float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case DOUBLE: {
- metDoubles = true;
- double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case NULL: {
- metNull = true;
- break;
- }
- case SYSTEM_NULL: {
- // For global aggregates simply ignore system null here,
- // but if all input value are system null, then we should return
- // null in finish().
- if (isLocalAgg) {
- throw new AlgebricksException("Type SYSTEM_NULL encountered in local aggregate.");
- }
- break;
- }
- default: {
- throw new NotImplementedException("Cannot compute SUM for values of type "
- + typeTag);
- }
+ break;
+ }
+ default: {
+ throw new NotImplementedException("Cannot compute SUM for values of type " + typeTag + ".");
}
}
-
- pos = start;
- BufferSerDeUtil.writeBoolean(metInt8s, state, pos++);
- BufferSerDeUtil.writeBoolean(metInt16s, state, pos++);
- BufferSerDeUtil.writeBoolean(metInt32s, state, pos++);
- BufferSerDeUtil.writeBoolean(metInt64s, state, pos++);
- BufferSerDeUtil.writeBoolean(metFloats, state, pos++);
- BufferSerDeUtil.writeBoolean(metDoubles, state, pos++);
- BufferSerDeUtil.writeBoolean(metNull, state, pos++);
- BufferSerDeUtil.writeDouble(sum, state, pos);
+ state[start] = aggType.serialize();
+ BufferSerDeUtil.writeDouble(sum, state, start + 1);
}
@SuppressWarnings("unchecked")
@Override
public void finish(byte[] state, int start, int len, DataOutput out) throws AlgebricksException {
- int pos = start;
- boolean metInt8s = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metInt16s = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metInt32s = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metInt64s = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metFloats = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metDoubles = BufferSerDeUtil.getBoolean(state, pos++);
- boolean metNull = BufferSerDeUtil.getBoolean(state, pos++);
- double sum = BufferSerDeUtil.getDouble(state, pos);
+ ATypeTag aggType = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(state[start]);
+ double sum = BufferSerDeUtil.getDouble(state, start + 1);
try {
- if (metNull) {
- serde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.ANULL);
- serde.serialize(ANull.NULL, out);
- } else if (metDoubles) {
- serde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.ADOUBLE);
- aDouble.setValue(sum);
- serde.serialize(aDouble, out);
- } else if (metFloats) {
- serde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.AFLOAT);
- aFloat.setValue((float) sum);
- serde.serialize(aFloat, out);
- } else if (metInt64s) {
- serde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.AINT64);
- aInt64.setValue((long) sum);
- serde.serialize(aInt64, out);
- } else if (metInt32s) {
- serde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.AINT32);
- aInt32.setValue((int) sum);
- serde.serialize(aInt32, out);
- } else if (metInt16s) {
- serde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.AINT16);
- aInt16.setValue((short) sum);
- serde.serialize(aInt16, out);
- } else if (metInt8s) {
- serde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.AINT8);
- aInt8.setValue((byte) sum);
- serde.serialize(aInt8, out);
- } else {
- // Empty stream. For local agg return system null. For global agg return null.
- if (isLocalAgg) {
- out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
- } else {
+ switch (aggType) {
+ case INT8: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT8);
+ aInt8.setValue((byte) sum);
+ serde.serialize(aInt8, out);
+ break;
+ }
+ case INT16: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT16);
+ aInt16.setValue((short) sum);
+ serde.serialize(aInt16, out);
+ break;
+ }
+ case INT32: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT32);
+ aInt32.setValue((int) sum);
+ serde.serialize(aInt32, out);
+ break;
+ }
+ case INT64: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT64);
+ aInt64.setValue((long) sum);
+ serde.serialize(aInt64, out);
+ break;
+ }
+ case FLOAT: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AFLOAT);
+ aFloat.setValue((float) sum);
+ serde.serialize(aFloat, out);
+ break;
+ }
+ case DOUBLE: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ADOUBLE);
+ aDouble.setValue(sum);
+ serde.serialize(aDouble, out);
+ break;
+ }
+ case NULL: {
serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
serde.serialize(ANull.NULL, out);
+ break;
+ }
+ case SYSTEM_NULL: {
+ // Empty stream. For local agg return system null. For global agg return null.
+ if (isLocalAgg) {
+ out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
+ } else {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
+ serde.serialize(ANull.NULL, out);
+ }
+ break;
}
}
} catch (IOException e) {
@@ -212,8 +194,7 @@
}
@Override
- public void finishPartial(byte[] state, int start, int len, DataOutput out)
- throws AlgebricksException {
+ public void finishPartial(byte[] state, int start, int len, DataOutput out) throws AlgebricksException {
finish(state, start, len, out);
}
}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java
index bd0603f..bbd0460 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java
@@ -170,10 +170,10 @@
if (count == 0) {
GlobalConfig.ASTERIX_LOGGER.fine("AVG aggregate ran over empty input.");
try {
- nullSerde.serialize(ANull.NULL, out);
- } catch (HyracksDataException e) {
- throw new AlgebricksException(e);
- }
+ nullSerde.serialize(ANull.NULL, out);
+ } catch (HyracksDataException e) {
+ throw new AlgebricksException(e);
+ }
} else {
try {
if (metNull)
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java
index a264c91..2b214d6 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateDescriptor.java
@@ -38,7 +38,7 @@
@Override
public ICopyAggregateFunction createAggregateFunction(IDataOutputProvider provider) throws AlgebricksException {
- return new CountAggregateFunction(provider);
+ return new CountAggregateFunction(args, provider);
}
};
}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java
index 7dfc601..e4d015b 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/CountAggregateFunction.java
@@ -6,40 +6,69 @@
import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
import edu.uci.ics.asterix.om.base.AInt32;
import edu.uci.ics.asterix.om.base.AMutableInt32;
+import edu.uci.ics.asterix.om.base.ANull;
+import edu.uci.ics.asterix.om.types.ATypeTag;
import edu.uci.ics.asterix.om.types.BuiltinType;
+import edu.uci.ics.asterix.om.types.EnumDeserializer;
import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunction;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
+import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
import edu.uci.ics.hyracks.data.std.api.IDataOutputProvider;
+import edu.uci.ics.hyracks.data.std.util.ArrayBackedValueStorage;
import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
+/**
+ * count(NULL) returns NULL.
+ */
public class CountAggregateFunction implements ICopyAggregateFunction {
private AMutableInt32 result = new AMutableInt32(-1);
@SuppressWarnings("unchecked")
private ISerializerDeserializer<AInt32> int32Serde = AqlSerializerDeserializerProvider.INSTANCE
.getSerializerDeserializer(BuiltinType.AINT32);
+ @SuppressWarnings("unchecked")
+ private ISerializerDeserializer<ANull> nullSerde = AqlSerializerDeserializerProvider.INSTANCE
+ .getSerializerDeserializer(BuiltinType.ANULL);
+ private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
+ private ICopyEvaluator eval;
+ private boolean metNull;
private int cnt;
private DataOutput out;
-
- public CountAggregateFunction(IDataOutputProvider output) {
+
+ public CountAggregateFunction(ICopyEvaluatorFactory[] args, IDataOutputProvider output) throws AlgebricksException {
+ eval = args[0].createEvaluator(inputVal);
out = output.getDataOutput();
}
-
+
@Override
public void init() {
cnt = 0;
+ metNull = false;
}
@Override
public void step(IFrameTupleReference tuple) throws AlgebricksException {
- cnt++;
+ inputVal.reset();
+ eval.evaluate(tuple);
+ ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal.getByteArray()[0]);
+ // Ignore SYSTEM_NULL.
+ if (typeTag == ATypeTag.NULL) {
+ metNull = true;
+ } else {
+ cnt++;
+ }
}
@Override
public void finish() throws AlgebricksException {
try {
- result.setValue(cnt);
- int32Serde.serialize(result, out);
+ if (metNull) {
+ nullSerde.serialize(ANull.NULL, out);
+ } else {
+ result.setValue(cnt);
+ int32Serde.serialize(result, out);
+ }
} catch (IOException e) {
throw new AlgebricksException(e);
}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/SumAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/SumAggregateFunction.java
index 9ea13ae..4870745 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/SumAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/SumAggregateFunction.java
@@ -34,8 +34,8 @@
private DataOutput out;
private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
private ICopyEvaluator eval;
- private boolean metInt8s, metInt16s, metInt32s, metInt64s, metFloats, metDoubles, metNull;
private double sum;
+ private ATypeTag aggType;
private AMutableDouble aDouble = new AMutableDouble(0);
private AMutableFloat aFloat = new AMutableFloat(0);
private AMutableInt64 aInt64 = new AMutableInt64(0);
@@ -56,13 +56,7 @@
@Override
public void init() {
- metInt8s = false;
- metInt16s = false;
- metInt32s = false;
- metInt64s = false;
- metFloats = false;
- metDoubles = false;
- metNull = false;
+ aggType = ATypeTag.SYSTEM_NULL;
sum = 0.0;
}
@@ -70,61 +64,63 @@
public void step(IFrameTupleReference tuple) throws AlgebricksException {
inputVal.reset();
eval.evaluate(tuple);
- if (inputVal.getLength() > 0) {
- ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal.getByteArray()[0]);
- switch (typeTag) {
- case INT8: {
- metInt8s = true;
- byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
- sum += val;
- break;
+ ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal.getByteArray()[0]);
+ if (typeTag == ATypeTag.NULL) {
+ aggType = ATypeTag.NULL;
+ }
+ if (aggType == ATypeTag.NULL) {
+ return;
+ } else if (aggType == ATypeTag.SYSTEM_NULL) {
+ aggType = typeTag;
+ } else if (typeTag != ATypeTag.SYSTEM_NULL && typeTag != aggType) {
+ throw new AlgebricksException("Unexpected type " + typeTag
+ + " in sum-aggregation input stream. Expected type " + aggType + ".");
+ }
+ switch (typeTag) {
+ case INT8: {
+ byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case INT16: {
+ short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case INT32: {
+ int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case INT64: {
+ long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case FLOAT: {
+ float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case DOUBLE: {
+ double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
+ sum += val;
+ break;
+ }
+ case NULL: {
+ break;
+ }
+ case SYSTEM_NULL: {
+ // For global aggregates simply ignore system null here,
+ // but if all input value are system null, then we should return
+ // null in finish().
+ if (isLocalAgg) {
+ throw new AlgebricksException("Type SYSTEM_NULL encountered in local aggregate.");
}
- case INT16: {
- metInt16s = true;
- short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case INT32: {
- metInt32s = true;
- int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case INT64: {
- metInt64s = true;
- long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case FLOAT: {
- metFloats = true;
- float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case DOUBLE: {
- metDoubles = true;
- double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case NULL: {
- metNull = true;
- break;
- }
- case SYSTEM_NULL: {
- // For global aggregates simply ignore system null here,
- // but if all input value are system null, then we should return
- // null in finish().
- if (isLocalAgg) {
- throw new AlgebricksException("Type SYSTEM_NULL encountered in local aggregate.");
- }
- break;
- }
- default: {
- throw new NotImplementedException("Cannot compute SUM for values of type " + typeTag);
- }
+ break;
+ }
+ default: {
+ throw new NotImplementedException("Cannot compute SUM for values of type " + typeTag + ".");
}
}
}
@@ -133,40 +129,57 @@
@Override
public void finish() throws AlgebricksException {
try {
- if (metNull) {
- serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
- serde.serialize(ANull.NULL, out);
- } else if (metDoubles) {
- serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ADOUBLE);
- aDouble.setValue(sum);
- serde.serialize(aDouble, out);
- } else if (metFloats) {
- serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AFLOAT);
- aFloat.setValue((float) sum);
- serde.serialize(aFloat, out);
- } else if (metInt64s) {
- serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT64);
- aInt64.setValue((long) sum);
- serde.serialize(aInt64, out);
- } else if (metInt32s) {
- serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT32);
- aInt32.setValue((int) sum);
- serde.serialize(aInt32, out);
- } else if (metInt16s) {
- serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT16);
- aInt16.setValue((short) sum);
- serde.serialize(aInt16, out);
- } else if (metInt8s) {
- serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT8);
- aInt8.setValue((byte) sum);
- serde.serialize(aInt8, out);
- } else {
- // Empty stream. For local agg return system null. For global agg return null.
- if (isLocalAgg) {
- out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
- } else {
+ switch (aggType) {
+ case INT8: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT8);
+ aInt8.setValue((byte) sum);
+ serde.serialize(aInt8, out);
+ break;
+ }
+ case INT16: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT16);
+ aInt16.setValue((short) sum);
+ serde.serialize(aInt16, out);
+ break;
+ }
+ case INT32: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT32);
+ aInt32.setValue((int) sum);
+ serde.serialize(aInt32, out);
+ break;
+ }
+ case INT64: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AINT64);
+ aInt64.setValue((long) sum);
+ serde.serialize(aInt64, out);
+ break;
+ }
+ case FLOAT: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.AFLOAT);
+ aFloat.setValue((float) sum);
+ serde.serialize(aFloat, out);
+ break;
+ }
+ case DOUBLE: {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ADOUBLE);
+ aDouble.setValue(sum);
+ serde.serialize(aDouble, out);
+ break;
+ }
+ case NULL: {
serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
serde.serialize(ANull.NULL, out);
+ break;
+ }
+ case SYSTEM_NULL: {
+ // Empty stream. For local agg return system null. For global agg return null.
+ if (isLocalAgg) {
+ out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
+ } else {
+ serde = AqlSerializerDeserializerProvider.INSTANCE.getSerializerDeserializer(BuiltinType.ANULL);
+ serde.serialize(ANull.NULL, out);
+ }
+ break;
}
}
} catch (IOException e) {