The first phase of refactoring the average aggregate function.
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AbstractLocalAvgAggregateFunction.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AbstractLocalAvgAggregateFunction.java
index 72da44d..e950ae7 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AbstractLocalAvgAggregateFunction.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AbstractLocalAvgAggregateFunction.java
@@ -59,9 +59,9 @@
private DataOutput out;
private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
private ICopyEvaluator eval;
+ private ATypeTag aggType;
private double sum;
private long count;
- private ATypeTag aggType;
private ArrayBackedValueStorage avgBytes = new ArrayBackedValueStorage();
private ByteArrayAccessibleOutputStream sumBytes = new ByteArrayAccessibleOutputStream();
@@ -71,7 +71,6 @@
private ICopyEvaluator evalSum = new AccessibleByteArrayEval(avgBytes.getDataOutput(), sumBytes);
private ICopyEvaluator evalCount = new AccessibleByteArrayEval(avgBytes.getDataOutput(), countBytes);
private ClosedRecordConstructorEval recordEval;
-
@SuppressWarnings("unchecked")
private ISerializerDeserializer<ADouble> doubleSerde = AqlSerializerDeserializerProvider.INSTANCE
.getSerializerDeserializer(BuiltinType.ADOUBLE);
@@ -95,7 +94,7 @@
ARecordType tmpRecType;
try {
tmpRecType = new ARecordType(null, new String[] { "sum", "count" }, new IAType[] {
- new AUnionType(unionList, "OptionalDouble"), BuiltinType.AINT64 }, true);
+ new AUnionType(unionList, "OptionalDouble"), BuiltinType.AINT64 }, false);
} catch (AsterixException e) {
throw new AlgebricksException(e);
}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java
index 7815606..3ff836a 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/AvgAggregateDescriptor.java
@@ -15,51 +15,16 @@
package edu.uci.ics.asterix.runtime.aggregates.std;
-import java.io.DataOutput;
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
-import edu.uci.ics.asterix.common.config.GlobalConfig;
-import edu.uci.ics.asterix.common.exceptions.AsterixException;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.ADoubleSerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AFloatSerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt16SerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt32SerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt64SerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt8SerializerDeserializer;
-import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
-import edu.uci.ics.asterix.om.base.ADouble;
-import edu.uci.ics.asterix.om.base.AInt64;
-import edu.uci.ics.asterix.om.base.AMutableDouble;
-import edu.uci.ics.asterix.om.base.AMutableInt64;
-import edu.uci.ics.asterix.om.base.ANull;
import edu.uci.ics.asterix.om.functions.AsterixBuiltinFunctions;
import edu.uci.ics.asterix.om.functions.IFunctionDescriptor;
import edu.uci.ics.asterix.om.functions.IFunctionDescriptorFactory;
-import edu.uci.ics.asterix.om.types.ARecordType;
-import edu.uci.ics.asterix.om.types.ATypeTag;
-import edu.uci.ics.asterix.om.types.AUnionType;
-import edu.uci.ics.asterix.om.types.BuiltinType;
-import edu.uci.ics.asterix.om.types.EnumDeserializer;
-import edu.uci.ics.asterix.om.types.IAType;
-import edu.uci.ics.asterix.om.types.hierachy.ATypeHierarchy;
import edu.uci.ics.asterix.runtime.aggregates.base.AbstractAggregateFunctionDynamicDescriptor;
-import edu.uci.ics.asterix.runtime.evaluators.common.AccessibleByteArrayEval;
-import edu.uci.ics.asterix.runtime.evaluators.common.ClosedRecordConstructorEvalFactory.ClosedRecordConstructorEval;
import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
-import edu.uci.ics.hyracks.algebricks.common.exceptions.NotImplementedException;
import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunction;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunctionFactory;
-import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
-import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
import edu.uci.ics.hyracks.data.std.api.IDataOutputProvider;
-import edu.uci.ics.hyracks.data.std.util.ArrayBackedValueStorage;
-import edu.uci.ics.hyracks.data.std.util.ByteArrayAccessibleOutputStream;
-import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
public class AvgAggregateDescriptor extends AbstractAggregateFunctionDynamicDescriptor {
@@ -78,165 +43,13 @@
@Override
public ICopyAggregateFunctionFactory createAggregateFunctionFactory(final ICopyEvaluatorFactory[] args)
throws AlgebricksException {
- List<IAType> unionList = new ArrayList<IAType>();
- unionList.add(BuiltinType.ANULL);
- unionList.add(BuiltinType.ADOUBLE);
- ARecordType tmpRecType;
- try {
- tmpRecType = new ARecordType(null, new String[] { "sum", "count" }, new IAType[] {
- new AUnionType(unionList, "OptionalDouble"), BuiltinType.AINT64 }, true);
- } catch (AsterixException e) {
- throw new AlgebricksException(e);
- }
-
- final ARecordType recType = tmpRecType;
-
return new ICopyAggregateFunctionFactory() {
private static final long serialVersionUID = 1L;
@Override
public ICopyAggregateFunction createAggregateFunction(final IDataOutputProvider provider)
throws AlgebricksException {
-
- return new ICopyAggregateFunction() {
-
- private DataOutput out = provider.getDataOutput();
- private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
- private ICopyEvaluator eval = args[0].createEvaluator(inputVal);
- private double sum;
- private long count;
- private ATypeTag aggType;
- private AMutableDouble aDouble = new AMutableDouble(0);
- private AMutableInt64 aInt64 = new AMutableInt64(0);
-
- private ArrayBackedValueStorage avgBytes = new ArrayBackedValueStorage();
- private ByteArrayAccessibleOutputStream sumBytes = new ByteArrayAccessibleOutputStream();
- private DataOutput sumBytesOutput = new DataOutputStream(sumBytes);
- private ByteArrayAccessibleOutputStream countBytes = new ByteArrayAccessibleOutputStream();
- private DataOutput countBytesOutput = new DataOutputStream(countBytes);
- private ICopyEvaluator evalSum = new AccessibleByteArrayEval(avgBytes.getDataOutput(), sumBytes);
- private ICopyEvaluator evalCount = new AccessibleByteArrayEval(avgBytes.getDataOutput(), countBytes);
- private ClosedRecordConstructorEval recordEval = new ClosedRecordConstructorEval(recType,
- new ICopyEvaluator[] { evalSum, evalCount }, avgBytes, out);
-
- @SuppressWarnings("unchecked")
- private ISerializerDeserializer<ADouble> doubleSerde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.ADOUBLE);
- @SuppressWarnings("unchecked")
- private ISerializerDeserializer<AInt64> intSerde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.AINT64);
- @SuppressWarnings("unchecked")
- private ISerializerDeserializer<ANull> nullSerde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.ANULL);
-
- @Override
- public void init() throws AlgebricksException {
- aggType = ATypeTag.SYSTEM_NULL;
- sum = 0.0;
- count = 0;
- }
-
- @Override
- public void step(IFrameTupleReference tuple) throws AlgebricksException {
- inputVal.reset();
- eval.evaluate(tuple);
- ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER
- .deserialize(inputVal.getByteArray()[0]);
- if (typeTag == ATypeTag.NULL || aggType == ATypeTag.NULL) {
- aggType = ATypeTag.NULL;
- return;
- } else if (aggType == ATypeTag.SYSTEM_NULL) {
- aggType = typeTag;
- } else if (typeTag != ATypeTag.SYSTEM_NULL && !ATypeHierarchy.isCompatible(typeTag, aggType)) {
- throw new AlgebricksException("Unexpected type " + typeTag
- + " in aggregation input stream. Expected type " + aggType + ".");
- } else if (ATypeHierarchy.canPromote(aggType, typeTag)) {
- aggType = typeTag;
- }
-
- if (typeTag != ATypeTag.SYSTEM_NULL) {
- ++count;
- }
- switch (typeTag) {
- case INT8: {
- byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case INT16: {
- short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case INT32: {
- int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case INT64: {
- long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case FLOAT: {
- float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case DOUBLE: {
- double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case NULL: {
- break;
- }
- default: {
- throw new NotImplementedException("Cannot compute AVG for values of type " + typeTag);
- }
- }
- }
-
- @Override
- public void finish() throws AlgebricksException {
- try {
- if (count == 0 || aggType == ATypeTag.NULL) {
- nullSerde.serialize(ANull.NULL, out);
- } else {
- aDouble.setValue(sum / count);
- doubleSerde.serialize(aDouble, out);
- }
- } catch (IOException e) {
- throw new AlgebricksException(e);
- }
- }
-
- @Override
- public void finishPartial() throws AlgebricksException {
- if (count == 0) {
- if (GlobalConfig.DEBUG) {
- GlobalConfig.ASTERIX_LOGGER.finest("AVG aggregate ran over empty input.");
- }
- } else {
- try {
- if (aggType == ATypeTag.NULL) {
- sumBytes.reset();
- nullSerde.serialize(ANull.NULL, sumBytesOutput);
- } else {
- sumBytes.reset();
- aDouble.setValue(sum);
- doubleSerde.serialize(aDouble, sumBytesOutput);
- }
- countBytes.reset();
- aInt64.setValue(count);
- intSerde.serialize(aInt64, countBytesOutput);
- recordEval.evaluate(null);
- } catch (IOException e) {
- throw new AlgebricksException(e);
- }
- }
- }
- };
+ return new AvgAggregateFunction(args, provider);
}
};
}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/GlobalAvgAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/GlobalAvgAggregateDescriptor.java
index eb5d6bd..230e06b 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/GlobalAvgAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/GlobalAvgAggregateDescriptor.java
@@ -15,45 +15,16 @@
package edu.uci.ics.asterix.runtime.aggregates.std;
-import java.io.DataOutput;
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
-import edu.uci.ics.asterix.common.exceptions.AsterixException;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.ADoubleSerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt64SerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.ARecordSerializerDeserializer;
-import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
-import edu.uci.ics.asterix.om.base.ADouble;
-import edu.uci.ics.asterix.om.base.AInt64;
-import edu.uci.ics.asterix.om.base.AMutableDouble;
-import edu.uci.ics.asterix.om.base.AMutableInt64;
-import edu.uci.ics.asterix.om.base.ANull;
import edu.uci.ics.asterix.om.functions.AsterixBuiltinFunctions;
import edu.uci.ics.asterix.om.functions.IFunctionDescriptor;
import edu.uci.ics.asterix.om.functions.IFunctionDescriptorFactory;
-import edu.uci.ics.asterix.om.types.ARecordType;
-import edu.uci.ics.asterix.om.types.ATypeTag;
-import edu.uci.ics.asterix.om.types.AUnionType;
-import edu.uci.ics.asterix.om.types.BuiltinType;
-import edu.uci.ics.asterix.om.types.EnumDeserializer;
-import edu.uci.ics.asterix.om.types.IAType;
import edu.uci.ics.asterix.runtime.aggregates.base.AbstractAggregateFunctionDynamicDescriptor;
-import edu.uci.ics.asterix.runtime.evaluators.common.AccessibleByteArrayEval;
-import edu.uci.ics.asterix.runtime.evaluators.common.ClosedRecordConstructorEvalFactory.ClosedRecordConstructorEval;
import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunction;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunctionFactory;
-import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
-import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
import edu.uci.ics.hyracks.data.std.api.IDataOutputProvider;
-import edu.uci.ics.hyracks.data.std.util.ArrayBackedValueStorage;
-import edu.uci.ics.hyracks.data.std.util.ByteArrayAccessibleOutputStream;
-import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
public class GlobalAvgAggregateDescriptor extends AbstractAggregateFunctionDynamicDescriptor {
@@ -73,142 +44,13 @@
@Override
public ICopyAggregateFunctionFactory createAggregateFunctionFactory(final ICopyEvaluatorFactory[] args)
throws AlgebricksException {
- List<IAType> unionList = new ArrayList<IAType>();
- unionList.add(BuiltinType.ANULL);
- unionList.add(BuiltinType.ADOUBLE);
- ARecordType tmpRecType;
- try {
- tmpRecType = new ARecordType(null, new String[] { "sum", "count" }, new IAType[] {
- new AUnionType(unionList, "OptionalDouble"), BuiltinType.AINT64 }, false);
- } catch (AsterixException e) {
- throw new AlgebricksException(e);
- }
-
- final ARecordType recType = tmpRecType;
-
return new ICopyAggregateFunctionFactory() {
private static final long serialVersionUID = 1L;
@Override
public ICopyAggregateFunction createAggregateFunction(final IDataOutputProvider provider)
throws AlgebricksException {
-
- return new ICopyAggregateFunction() {
-
- private DataOutput out = provider.getDataOutput();
- private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
- private ICopyEvaluator eval = args[0].createEvaluator(inputVal);
- private double globalSum;
- private long globalCount;
- private AMutableDouble aDouble = new AMutableDouble(0);
- private AMutableInt64 aInt64 = new AMutableInt64(0);
-
- private ArrayBackedValueStorage avgBytes = new ArrayBackedValueStorage();
- private ByteArrayAccessibleOutputStream sumBytes = new ByteArrayAccessibleOutputStream();
- private DataOutput sumBytesOutput = new DataOutputStream(sumBytes);
- private ByteArrayAccessibleOutputStream countBytes = new ByteArrayAccessibleOutputStream();
- private DataOutput countBytesOutput = new DataOutputStream(countBytes);
- private ICopyEvaluator evalSum = new AccessibleByteArrayEval(avgBytes.getDataOutput(), sumBytes);
- private ICopyEvaluator evalCount = new AccessibleByteArrayEval(avgBytes.getDataOutput(), countBytes);
- private ClosedRecordConstructorEval recordEval = new ClosedRecordConstructorEval(recType,
- new ICopyEvaluator[] { evalSum, evalCount }, avgBytes, out);
-
- @SuppressWarnings("unchecked")
- private ISerializerDeserializer<AInt64> longSerde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.AINT64);
- @SuppressWarnings("unchecked")
- private ISerializerDeserializer<ADouble> doubleSerde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.ADOUBLE);
- @SuppressWarnings("unchecked")
- private ISerializerDeserializer<ANull> nullSerde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.ANULL);
- private boolean metNull;
-
- @Override
- public void init() {
- globalSum = 0.0;
- globalCount = 0;
- metNull = false;
- }
-
- @Override
- public void step(IFrameTupleReference tuple) throws AlgebricksException {
- inputVal.reset();
- eval.evaluate(tuple);
- byte[] serBytes = inputVal.getByteArray();
- ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(serBytes[0]);
- switch (typeTag) {
- case NULL: {
- metNull = true;
- break;
- }
- case SYSTEM_NULL: {
- // Ignore and return.
- return;
- }
- case RECORD: {
- // Expected.
- break;
- }
- default: {
- throw new AlgebricksException("Global-Avg is not defined for values of type "
- + EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(serBytes[0]));
- }
- }
-
- // The record length helps us determine whether the input record fields are nullable.
- int recordLength = ARecordSerializerDeserializer.getRecordLength(serBytes, 1);
- int nullBitmapSize = 1;
- if (recordLength == 29) {
- nullBitmapSize = 0;
- }
- int offset1 = ARecordSerializerDeserializer.getFieldOffsetById(serBytes, 0, nullBitmapSize,
- false);
- if (offset1 == 0) // the sum is null
- metNull = true;
- else
- globalSum += ADoubleSerializerDeserializer.getDouble(serBytes, offset1);
- int offset2 = ARecordSerializerDeserializer.getFieldOffsetById(serBytes, 1, nullBitmapSize,
- false);
- if (offset2 != 0) // the count is not null
- globalCount += AInt64SerializerDeserializer.getLong(serBytes, offset2);
-
- }
-
- @Override
- public void finish() throws AlgebricksException {
- try {
- if (globalCount == 0 || metNull)
- nullSerde.serialize(ANull.NULL, out);
- else {
- aDouble.setValue(globalSum / globalCount);
- doubleSerde.serialize(aDouble, out);
- }
- } catch (IOException e) {
- throw new AlgebricksException(e);
- }
- }
-
- @Override
- public void finishPartial() throws AlgebricksException {
- try {
- if (metNull || globalCount == 0) {
- sumBytes.reset();
- nullSerde.serialize(ANull.NULL, sumBytesOutput);
- } else {
- sumBytes.reset();
- aDouble.setValue(globalSum);
- doubleSerde.serialize(aDouble, sumBytesOutput);
- }
- countBytes.reset();
- aInt64.setValue(globalCount);
- longSerde.serialize(aInt64, countBytesOutput);
- recordEval.evaluate(null);
- } catch (IOException e) {
- throw new AlgebricksException(e);
- }
- }
- };
+ return new GlobalAvgAggregateFunction(args, provider);
}
};
}
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/LocalAvgAggregateDescriptor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/LocalAvgAggregateDescriptor.java
index 09a659c..fd98764 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/LocalAvgAggregateDescriptor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/aggregates/std/LocalAvgAggregateDescriptor.java
@@ -15,50 +15,16 @@
package edu.uci.ics.asterix.runtime.aggregates.std;
-import java.io.DataOutput;
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
-import edu.uci.ics.asterix.common.exceptions.AsterixException;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.ADoubleSerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AFloatSerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt16SerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt32SerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt64SerializerDeserializer;
-import edu.uci.ics.asterix.dataflow.data.nontagged.serde.AInt8SerializerDeserializer;
-import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
-import edu.uci.ics.asterix.om.base.ADouble;
-import edu.uci.ics.asterix.om.base.AInt64;
-import edu.uci.ics.asterix.om.base.AMutableDouble;
-import edu.uci.ics.asterix.om.base.AMutableInt64;
-import edu.uci.ics.asterix.om.base.ANull;
import edu.uci.ics.asterix.om.functions.AsterixBuiltinFunctions;
import edu.uci.ics.asterix.om.functions.IFunctionDescriptor;
import edu.uci.ics.asterix.om.functions.IFunctionDescriptorFactory;
-import edu.uci.ics.asterix.om.types.ARecordType;
-import edu.uci.ics.asterix.om.types.ATypeTag;
-import edu.uci.ics.asterix.om.types.AUnionType;
-import edu.uci.ics.asterix.om.types.BuiltinType;
-import edu.uci.ics.asterix.om.types.EnumDeserializer;
-import edu.uci.ics.asterix.om.types.IAType;
-import edu.uci.ics.asterix.om.types.hierachy.ATypeHierarchy;
import edu.uci.ics.asterix.runtime.aggregates.base.AbstractAggregateFunctionDynamicDescriptor;
-import edu.uci.ics.asterix.runtime.evaluators.common.AccessibleByteArrayEval;
-import edu.uci.ics.asterix.runtime.evaluators.common.ClosedRecordConstructorEvalFactory.ClosedRecordConstructorEval;
import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
-import edu.uci.ics.hyracks.algebricks.common.exceptions.NotImplementedException;
import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunction;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyAggregateFunctionFactory;
-import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
-import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
import edu.uci.ics.hyracks.data.std.api.IDataOutputProvider;
-import edu.uci.ics.hyracks.data.std.util.ArrayBackedValueStorage;
-import edu.uci.ics.hyracks.data.std.util.ByteArrayAccessibleOutputStream;
-import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
public class LocalAvgAggregateDescriptor extends AbstractAggregateFunctionDynamicDescriptor {
@@ -77,157 +43,13 @@
@Override
public ICopyAggregateFunctionFactory createAggregateFunctionFactory(final ICopyEvaluatorFactory[] args)
throws AlgebricksException {
-
return new ICopyAggregateFunctionFactory() {
private static final long serialVersionUID = 1L;
@Override
public ICopyAggregateFunction createAggregateFunction(final IDataOutputProvider provider)
throws AlgebricksException {
-
- List<IAType> unionList = new ArrayList<IAType>();
- unionList.add(BuiltinType.ANULL);
- unionList.add(BuiltinType.ADOUBLE);
- ARecordType tmpRecType;
- try {
- tmpRecType = new ARecordType(null, new String[] { "sum", "count" }, new IAType[] {
- new AUnionType(unionList, "OptionalDouble"), BuiltinType.AINT64 }, false);
- } catch (AsterixException e) {
- throw new AlgebricksException(e);
- }
-
- final ARecordType recType = tmpRecType;
-
- return new ICopyAggregateFunction() {
-
- private DataOutput out = provider.getDataOutput();
- private ArrayBackedValueStorage inputVal = new ArrayBackedValueStorage();
- private ICopyEvaluator eval = args[0].createEvaluator(inputVal);
- private ATypeTag aggType;
- private double sum;
- private long count;
-
- private ArrayBackedValueStorage avgBytes = new ArrayBackedValueStorage();
- private ByteArrayAccessibleOutputStream sumBytes = new ByteArrayAccessibleOutputStream();
- private DataOutput sumBytesOutput = new DataOutputStream(sumBytes);
- private ByteArrayAccessibleOutputStream countBytes = new ByteArrayAccessibleOutputStream();
- private DataOutput countBytesOutput = new DataOutputStream(countBytes);
- private ICopyEvaluator evalSum = new AccessibleByteArrayEval(avgBytes.getDataOutput(), sumBytes);
- private ICopyEvaluator evalCount = new AccessibleByteArrayEval(avgBytes.getDataOutput(), countBytes);
- private ClosedRecordConstructorEval recordEval = new ClosedRecordConstructorEval(recType,
- new ICopyEvaluator[] { evalSum, evalCount }, avgBytes, out);
- @SuppressWarnings("unchecked")
- private ISerializerDeserializer<ADouble> doubleSerde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.ADOUBLE);
- @SuppressWarnings("unchecked")
- private ISerializerDeserializer<AInt64> longSerde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.AINT64);
- @SuppressWarnings("unchecked")
- private ISerializerDeserializer<ANull> nullSerde = AqlSerializerDeserializerProvider.INSTANCE
- .getSerializerDeserializer(BuiltinType.ANULL);
- private AMutableDouble aDouble = new AMutableDouble(0);
- private AMutableInt64 aInt64 = new AMutableInt64(0);
-
- @Override
- public void init() {
- aggType = ATypeTag.SYSTEM_NULL;
- sum = 0.0;
- count = 0;
- }
-
- @Override
- public void step(IFrameTupleReference tuple) throws AlgebricksException {
- inputVal.reset();
- eval.evaluate(tuple);
- ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER
- .deserialize(inputVal.getByteArray()[0]);
- if (typeTag == ATypeTag.NULL || aggType == ATypeTag.NULL) {
- aggType = ATypeTag.NULL;
- return;
- } else if (aggType == ATypeTag.SYSTEM_NULL) {
- aggType = typeTag;
- } else if (typeTag != ATypeTag.SYSTEM_NULL && !ATypeHierarchy.isCompatible(typeTag, aggType)) {
- throw new AlgebricksException("Unexpected type " + typeTag
- + " in aggregation input stream. Expected type " + aggType + ".");
- } else if (ATypeHierarchy.canPromote(aggType, typeTag)) {
- aggType = typeTag;
- }
-
- if (typeTag != ATypeTag.SYSTEM_NULL) {
- ++count;
- }
-
- switch (typeTag) {
- case INT8: {
- byte val = AInt8SerializerDeserializer.getByte(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case INT16: {
- short val = AInt16SerializerDeserializer.getShort(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case INT32: {
- int val = AInt32SerializerDeserializer.getInt(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case INT64: {
- long val = AInt64SerializerDeserializer.getLong(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case FLOAT: {
- float val = AFloatSerializerDeserializer.getFloat(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case DOUBLE: {
- double val = ADoubleSerializerDeserializer.getDouble(inputVal.getByteArray(), 1);
- sum += val;
- break;
- }
- case NULL: {
- break;
- }
- default: {
- throw new NotImplementedException("Cannot compute LOCAL-AVG for values of type "
- + typeTag);
- }
- }
- inputVal.reset();
- }
-
- @Override
- public void finish() throws AlgebricksException {
- try {
- if (count == 0 && aggType != ATypeTag.NULL) {
- out.writeByte(ATypeTag.SYSTEM_NULL.serialize());
- return;
- }
- if (aggType == ATypeTag.NULL) {
- sumBytes.reset();
- nullSerde.serialize(ANull.NULL, sumBytesOutput);
- } else {
- sumBytes.reset();
- aDouble.setValue(sum);
- doubleSerde.serialize(aDouble, sumBytesOutput);
- }
- countBytes.reset();
- aInt64.setValue(count);
- longSerde.serialize(aInt64, countBytesOutput);
- recordEval.evaluate(null);
- } catch (IOException e) {
- throw new AlgebricksException(e);
- }
- }
-
- @Override
- public void finishPartial() throws AlgebricksException {
- finish();
- }
- };
+ return new LocalAvgAggregateFunction(args, provider);
}
};
}