Finished the AVG abstraction.
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AbstractAvgAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AbstractAvgAggregateFunction.java
index 5c724a4..30a0f71 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AbstractAvgAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AbstractAvgAggregateFunction.java
@@ -28,6 +28,7 @@
import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt32SerializerDeserializer;
import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt64SerializerDeserializer;
import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt8SerializerDeserializer;
+import edu.uci.ics.asterix.dataflow.data.nontagged.serde.ARecordSerializerDeserializer;
import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
import edu.uci.ics.asterix.om.base.ADouble;
import edu.uci.ics.asterix.om.base.AInt64;
@@ -60,11 +61,12 @@
private DataOutput out;
private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
private ICopyEvaluator eval;
+ private ATypeTag aggType;
private double sum;
private long count;
- private ATypeTag aggType;
private AMutableDouble aDouble = new AMutableDouble(0);
private AMutableInt64 aInt64 = new AMutableInt64(0);
+ private final boolean isLocalAgg;
private ArrayBackedValueStorage avgBytes = new ArrayBackedValueStorage();
private ByteArrayAccessibleOutputStream sumBytes = new ByteArrayAccessibleOutputStream();
@@ -85,18 +87,24 @@
private ISerializerDeserializer<ANull> nullSerde = AqlSerializerDeserializerProvider.INSTANCE
.getSerializerDeserializer(BuiltinType.ANULL);
- public AbstractAvgAggregateFunction(ICopyEvaluatorFactory[] args, IDataOutputProvider output)
+ public AbstractAvgAggregateFunction(ICopyEvaluatorFactory[] args, IDataOutputProvider output, boolean isLocalAgg)
throws AlgebricksException {
eval = args[0].createEvaluator(inputVal);
out = output.getDataOutput();
+ this.isLocalAgg = isLocalAgg;
List<IAType> unionList = new ArrayList<IAType>();
unionList.add(BuiltinType.ANULL);
unionList.add(BuiltinType.ADOUBLE);
ARecordType tmpRecType;
try {
- tmpRecType = new ARecordType(null, new String[] { "sum", "count" }, new IAType[] {
- new AUnionType(unionList, "OptionalDouble"), BuiltinType.AINT64 }, true);
+ if (isLocalAgg) {
+ tmpRecType = new ARecordType(null, new String[] { "sum", "count" }, new IAType[] {
+ new AUnionType(unionList, "OptionalDouble"), BuiltinType.AINT64 }, false);
+ } else {
+ tmpRecType = new ARecordType(null, new String[] { "sum", "count" }, new IAType[] {
+ new AUnionType(unionList, "OptionalDouble"), BuiltinType.AINT64 }, true);
+ }
} catch (AsterixException e) {
throw new AlgebricksException(e);
}
@@ -117,10 +125,21 @@
public void step(IFrameTupleReference tuple) throws AlgebricksException {
inputVal.reset();
eval.evaluate(tuple);
- ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(inputVal.getByteArray()[0]);
- if (typeTag == ATypeTag.NULL || aggType == ATypeTag.NULL) {
+ byte[] serBytes = inputVal.getByteArray();
+ ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(serBytes[0]);
+ if (typeTag == ATypeTag.NULL) {
aggType = ATypeTag.NULL;
return;
+ } else if (aggType == ATypeTag.NULL) {
+ return;
+ } else if (typeTag == ATypeTag.RECORD) {
+ // Global aggregate
+ if (isLocalAgg) {
+ throw new AlgebricksException("Record type can not be processed by in a local-avg operation.");
+ } else if (typeTag == ATypeTag.SYSTEM_NULL) {
+ // ignore
+ return;
+ }
} else if (aggType == ATypeTag.SYSTEM_NULL) {
aggType = typeTag;
} else if (typeTag != ATypeTag.SYSTEM_NULL && !ATypeHierarchy.isCompatible(typeTag, aggType)) {
@@ -130,7 +149,7 @@
aggType = typeTag;
}
- if (typeTag != ATypeTag.SYSTEM_NULL) {
+ if (typeTag != ATypeTag.SYSTEM_NULL && typeTag != ATypeTag.RECORD) {
++count;
}
@@ -168,34 +187,45 @@
case NULL: {
break;
}
+ case SYSTEM_NULL: {
+ if (isLocalAgg) {
+ throw new AlgebricksException("SYSTEM_NULL can not be processed by in a local-avg operation.");
+ }
+ break;
+ }
+ case RECORD: {
+ // Expected for global aggregate.
+ // The record length helps us determine whether the input record fields are nullable.
+ int recordLength = ARecordSerializerDeserializer.getRecordLength(serBytes, 1);
+ int nullBitmapSize = 1;
+ if (recordLength == 29) {
+ nullBitmapSize = 0;
+ }
+ int offset1 = ARecordSerializerDeserializer.getFieldOffsetById(serBytes, 0, nullBitmapSize, false);
+ if (offset1 == 0) // the sum is null
+ aggType = ATypeTag.NULL;
+ else
+ sum += ADoubleSerializerDeserializer.getDouble(serBytes, offset1);
+ int offset2 = ARecordSerializerDeserializer.getFieldOffsetById(serBytes, 1, nullBitmapSize, false);
+ if (offset2 != 0) // the count is not null
+ count += AInt64SerializerDeserializer.getLong(serBytes, offset2);
+ break;
+ }
default: {
throw new NotImplementedException("Cannot compute AVG for values of type " + typeTag);
}
}
+ inputVal.reset();
}
@Override
public void finish() throws AlgebricksException {
try {
- if (count == 0 || aggType == ATypeTag.NULL) {
- nullSerde.serialize(ANull.NULL, out);
- } else {
- aDouble.setValue(sum / count);
- doubleSerde.serialize(aDouble, out);
- }
- } catch (IOException e) {
- throw new AlgebricksException(e);
- }
- }
-
- @Override
- public void finishPartial() throws AlgebricksException {
- if (count == 0) {
- if (GlobalConfig.DEBUG) {
- GlobalConfig.ASTERIX_LOGGER.finest("AVG aggregate ran over empty input.");
- }
- } else {
- try {
+ if (isLocalAgg) {
+ if (count == 0 && aggType != ATypeTag.NULL) {
+ out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
+ return;
+ }
if (aggType == ATypeTag.NULL) {
sumBytes.reset();
nullSerde.serialize(ANull.NULL, sumBytesOutput);
@@ -208,10 +238,46 @@
aInt64.setValue(count);
intSerde.serialize(aInt64, countBytesOutput);
recordEval.evaluate(null);
- } catch (IOException e) {
- throw new AlgebricksException(e);
+ } else {
+ if (count == 0 || aggType == ATypeTag.NULL) {
+ nullSerde.serialize(ANull.NULL, out);
+ } else {
+ aDouble.setValue(sum / count);
+ doubleSerde.serialize(aDouble, out);
+ }
}
+ } catch (IOException e) {
+ throw new AlgebricksException(e);
}
}
+ @Override
+ public void finishPartial() throws AlgebricksException {
+ if (isLocalAgg) {
+ finish();
+ } else {
+ if (count == 0) {
+ if (GlobalConfig.DEBUG) {
+ GlobalConfig.ASTERIX_LOGGER.finest("AVG aggregate ran over empty input.");
+ }
+ } else {
+ try {
+ if (count == 0 || aggType == ATypeTag.NULL) {
+ sumBytes.reset();
+ nullSerde.serialize(ANull.NULL, sumBytesOutput);
+ } else {
+ sumBytes.reset();
+ aDouble.setValue(sum);
+ doubleSerde.serialize(aDouble, sumBytesOutput);
+ }
+ countBytes.reset();
+ aInt64.setValue(count);
+ intSerde.serialize(aInt64, countBytesOutput);
+ recordEval.evaluate(null);
+ } catch (IOException e) {
+ throw new AlgebricksException(e);
+ }
+ }
+ }
+ }
}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateFunction.java
index 5fd1e9c..4fc9e37 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateFunction.java
@@ -22,7 +22,7 @@
public class AvgAggregateFunction extends AbstractAvgAggregateFunction {
public AvgAggregateFunction(ICopyEvaluatorFactory[] args, IDataOutputProvider output) throws AlgebricksException {
- super(args, output);
+ super(args, output, false);
}
}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/GlobalAvgAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/GlobalAvgAggregateFunction.java
index 47f87ff..de5185e 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/GlobalAvgAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/GlobalAvgAggregateFunction.java
@@ -19,10 +19,10 @@
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
import edu.uci.ics.hyracks.data.std.api.IDataOutputProvider;
-public class GlobalAvgAggregateFunction extends AbstractGlobalAvgAggregateFunction {
+public class GlobalAvgAggregateFunction extends AbstractAvgAggregateFunction {
public GlobalAvgAggregateFunction(ICopyEvaluatorFactory[] args, IDataOutputProvider output)
throws AlgebricksException {
- super(args, output);
+ super(args, output, false);
}
}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/LocalAvgAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/LocalAvgAggregateFunction.java
index 49bbd2f..9f89c5a 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/LocalAvgAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/LocalAvgAggregateFunction.java
@@ -19,10 +19,10 @@
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
import edu.uci.ics.hyracks.data.std.api.IDataOutputProvider;
-public class LocalAvgAggregateFunction extends AbstractLocalAvgAggregateFunction {
+public class LocalAvgAggregateFunction extends AbstractAvgAggregateFunction {
public LocalAvgAggregateFunction(ICopyEvaluatorFactory[] args, IDataOutputProvider output)
throws AlgebricksException {
- super(args, output);
+ super(args, output, true);
}
}