[ASTERIXDB-2843][COMP] Fix type computer for scalar aggregates
- user model changes: no
- storage format changes: no
- interface changes: no
Details:
- Align type computation for scalar aggregates with
regular aggregates
- Add testcase to verify it for all aggregate functions
Change-Id: Iddd8075b490c83cb6f493d02b7bea1eedb4a4129
Reviewed-on: https://asterix-gerrit.ics.uci.edu/c/asterixdb/+/10483
Integration-Tests: Jenkins <jenkins@fulliautomatix.ics.uci.edu>
Tested-by: Jenkins <jenkins@fulliautomatix.ics.uci.edu>
Reviewed-by: Dmitry Lychagin <dmitry.lychagin@couchbase.com>
diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp b/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp
new file mode 100644
index 0000000..932661c9
--- /dev/null
+++ b/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Test that scalar sum() produces correct output type
+ */
+
+select array_sum(array_reverse(lst))
+let lst = (
+ from range(1, 3) r
+ select value int32(to_string(r))
+)
\ No newline at end of file
diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp b/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp
new file mode 100644
index 0000000..361a59b
--- /dev/null
+++ b/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Test that scalar sum() produces correct output type
+ */
+
+select strict_sum(array_reverse(lst))
+let lst = (
+ from range(1, 3) r
+ select value int32(to_string(r))
+)
\ No newline at end of file
diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.adm b/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.adm
new file mode 100644
index 0000000..d9b1127
--- /dev/null
+++ b/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.adm
@@ -0,0 +1 @@
+{ "$1": 6 }
\ No newline at end of file
diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate/sum/scalar_sum_type/scalar_sum_type.1.adm b/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate/sum/scalar_sum_type/scalar_sum_type.1.adm
new file mode 100644
index 0000000..d9b1127
--- /dev/null
+++ b/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate/sum/scalar_sum_type/scalar_sum_type.1.adm
@@ -0,0 +1 @@
+{ "$1": 6 }
\ No newline at end of file
diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/testsuite_sqlpp.xml b/asterixdb/asterix-app/src/test/resources/runtimets/testsuite_sqlpp.xml
index f8164e8..0c3b15b 100644
--- a/asterixdb/asterix-app/src/test/resources/runtimets/testsuite_sqlpp.xml
+++ b/asterixdb/asterix-app/src/test/resources/runtimets/testsuite_sqlpp.xml
@@ -834,6 +834,11 @@
</compilation-unit>
</test-case>
<test-case FilePath="aggregate">
+ <compilation-unit name="sum/scalar_sum_type">
+ <output-dir compare="Text">sum/scalar_sum_type</output-dir>
+ </compilation-unit>
+ </test-case>
+ <test-case FilePath="aggregate">
<compilation-unit name="scalar_var">
<output-dir compare="Text">scalar_var</output-dir>
</compilation-unit>
@@ -2097,6 +2102,11 @@
</compilation-unit>
</test-case>
<test-case FilePath="aggregate-sql">
+ <compilation-unit name="sum/scalar_sum_type">
+ <output-dir compare="Text">sum/scalar_sum_type</output-dir>
+ </compilation-unit>
+ </test-case>
+ <test-case FilePath="aggregate-sql">
<compilation-unit name="scalar_var">
<output-dir compare="Text">scalar_var</output-dir>
</compilation-unit>
diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/functions/BuiltinFunctions.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/functions/BuiltinFunctions.java
index 268c1f6..63573a1 100644
--- a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/functions/BuiltinFunctions.java
+++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/functions/BuiltinFunctions.java
@@ -1832,6 +1832,11 @@
addFunction(NEGINF_IF, DoubleIfTypeComputer.INSTANCE, true);
// Aggregate Functions
+ ScalarVersionOfAggregateResultType scalarNumericSumTypeComputer =
+ new ScalarVersionOfAggregateResultType(NumericSumAggTypeComputer.INSTANCE);
+ ScalarVersionOfAggregateResultType scalarMinMaxTypeComputer =
+ new ScalarVersionOfAggregateResultType(MinMaxAggTypeComputer.INSTANCE);
+
addPrivateFunction(LISTIFY, OrderedListConstructorTypeComputer.INSTANCE, true);
addFunction(SCALAR_ARRAYAGG, ScalarArrayAggTypeComputer.INSTANCE, true);
addFunction(MAX, MinMaxAggTypeComputer.INSTANCE, true);
@@ -1877,7 +1882,7 @@
// SUM
addFunction(SUM, NumericSumAggTypeComputer.INSTANCE, true);
- addFunction(SCALAR_SUM, ScalarVersionOfAggregateResultType.INSTANCE, true);
+ addFunction(SCALAR_SUM, scalarNumericSumTypeComputer, true);
addPrivateFunction(LOCAL_SUM, NumericSumAggTypeComputer.INSTANCE, true);
addPrivateFunction(INTERMEDIATE_SUM, NumericSumAggTypeComputer.INSTANCE, true);
addPrivateFunction(GLOBAL_SUM, NumericSumAggTypeComputer.INSTANCE, true);
@@ -1893,8 +1898,8 @@
addPrivateFunction(SERIAL_INTERMEDIATE_SQL_AVG, LocalAvgTypeComputer.INSTANCE, true);
addFunction(SCALAR_AVG, NullableDoubleTypeComputer.INSTANCE, true);
addFunction(SCALAR_COUNT, AInt64TypeComputer.INSTANCE, true);
- addFunction(SCALAR_MAX, ScalarVersionOfAggregateResultType.INSTANCE, true);
- addFunction(SCALAR_MIN, ScalarVersionOfAggregateResultType.INSTANCE, true);
+ addFunction(SCALAR_MAX, scalarMinMaxTypeComputer, true);
+ addFunction(SCALAR_MIN, scalarMinMaxTypeComputer, true);
addPrivateFunction(INTERMEDIATE_AVG, LocalAvgTypeComputer.INSTANCE, true);
addFunction(SCALAR_STDDEV_SAMP, NullableDoubleTypeComputer.INSTANCE, true);
addPrivateFunction(INTERMEDIATE_STDDEV_SAMP, LocalSingleVarStatisticsTypeComputer.INSTANCE, true);
@@ -1935,7 +1940,7 @@
// SQL SUM
addFunction(SQL_SUM, NumericSumAggTypeComputer.INSTANCE, true);
- addFunction(SCALAR_SQL_SUM, ScalarVersionOfAggregateResultType.INSTANCE, true);
+ addFunction(SCALAR_SQL_SUM, scalarNumericSumTypeComputer, true);
addPrivateFunction(LOCAL_SQL_SUM, NumericSumAggTypeComputer.INSTANCE, true);
addPrivateFunction(INTERMEDIATE_SQL_SUM, NumericSumAggTypeComputer.INSTANCE, true);
addPrivateFunction(GLOBAL_SQL_SUM, NumericSumAggTypeComputer.INSTANCE, true);
@@ -1959,8 +1964,8 @@
addPrivateFunction(GLOBAL_SQL_MIN, MinMaxAggTypeComputer.INSTANCE, true);
addFunction(SCALAR_SQL_AVG, NullableDoubleTypeComputer.INSTANCE, true);
addFunction(SCALAR_SQL_COUNT, AInt64TypeComputer.INSTANCE, true);
- addFunction(SCALAR_SQL_MAX, ScalarVersionOfAggregateResultType.INSTANCE, true);
- addFunction(SCALAR_SQL_MIN, ScalarVersionOfAggregateResultType.INSTANCE, true);
+ addFunction(SCALAR_SQL_MAX, scalarMinMaxTypeComputer, true);
+ addFunction(SCALAR_SQL_MIN, scalarMinMaxTypeComputer, true);
addPrivateFunction(INTERMEDIATE_SQL_AVG, LocalAvgTypeComputer.INSTANCE, true);
addFunction(SQL_STDDEV_SAMP, NullableDoubleTypeComputer.INSTANCE, true);
addPrivateFunction(GLOBAL_SQL_STDDEV_SAMP, NullableDoubleTypeComputer.INSTANCE, true);
@@ -2035,9 +2040,9 @@
addFunction(SCALAR_SQL_COUNT_DISTINCT, AInt64TypeComputer.INSTANCE, true);
addFunction(SUM_DISTINCT, NumericSumAggTypeComputer.INSTANCE, true);
- addFunction(SCALAR_SUM_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true);
+ addFunction(SCALAR_SUM_DISTINCT, scalarNumericSumTypeComputer, true);
addFunction(SQL_SUM_DISTINCT, NumericSumAggTypeComputer.INSTANCE, true);
- addFunction(SCALAR_SQL_SUM_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true);
+ addFunction(SCALAR_SQL_SUM_DISTINCT, scalarNumericSumTypeComputer, true);
addFunction(AVG_DISTINCT, NullableDoubleTypeComputer.INSTANCE, true);
addFunction(SCALAR_AVG_DISTINCT, NullableDoubleTypeComputer.INSTANCE, true);
@@ -2045,14 +2050,14 @@
addFunction(SCALAR_SQL_AVG_DISTINCT, NullableDoubleTypeComputer.INSTANCE, true);
addFunction(MAX_DISTINCT, MinMaxAggTypeComputer.INSTANCE, true);
- addFunction(SCALAR_MAX_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true);
+ addFunction(SCALAR_MAX_DISTINCT, scalarMinMaxTypeComputer, true);
addFunction(SQL_MAX_DISTINCT, MinMaxAggTypeComputer.INSTANCE, true);
- addFunction(SCALAR_SQL_MAX_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true);
+ addFunction(SCALAR_SQL_MAX_DISTINCT, scalarMinMaxTypeComputer, true);
addFunction(MIN_DISTINCT, MinMaxAggTypeComputer.INSTANCE, true);
- addFunction(SCALAR_MIN_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true);
+ addFunction(SCALAR_MIN_DISTINCT, scalarMinMaxTypeComputer, true);
addFunction(SQL_MIN_DISTINCT, MinMaxAggTypeComputer.INSTANCE, true);
- addFunction(SCALAR_SQL_MIN_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true);
+ addFunction(SCALAR_SQL_MIN_DISTINCT, scalarMinMaxTypeComputer, true);
addFunction(STDDEV_SAMP_DISTINCT, NullableDoubleTypeComputer.INSTANCE, true);
addFunction(SCALAR_STDDEV_SAMP_DISTINCT, NullableDoubleTypeComputer.INSTANCE, true);
diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/AggregateResultTypeComputer.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/AggregateResultTypeComputer.java
new file mode 100644
index 0000000..8e663a7
--- /dev/null
+++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/AggregateResultTypeComputer.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.asterix.om.typecomputer.impl;
+
+import org.apache.asterix.om.typecomputer.base.AbstractResultTypeComputer;
+import org.apache.asterix.om.types.IAType;
+import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException;
+import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression;
+import org.apache.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
+import org.apache.hyracks.api.exceptions.SourceLocation;
+
+public abstract class AggregateResultTypeComputer extends AbstractResultTypeComputer {
+ @Override
+ protected void checkArgType(FunctionIdentifier funcId, int argIndex, IAType type, SourceLocation sourceLoc)
+ throws AlgebricksException {
+ super.checkArgType(funcId, argIndex, type, sourceLoc);
+ }
+
+ @Override
+ protected abstract IAType getResultType(ILogicalExpression expr, IAType... strippedInputTypes)
+ throws AlgebricksException;
+}
\ No newline at end of file
diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java
index c34b5ed..fc1eee5 100644
--- a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java
+++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java
@@ -19,7 +19,6 @@
package org.apache.asterix.om.typecomputer.impl;
import org.apache.asterix.dataflow.data.common.ILogicalBinaryComparator;
-import org.apache.asterix.om.typecomputer.base.AbstractResultTypeComputer;
import org.apache.asterix.om.types.ATypeTag;
import org.apache.asterix.om.types.AUnionType;
import org.apache.asterix.om.types.BuiltinType;
@@ -27,7 +26,7 @@
import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression;
-public class MinMaxAggTypeComputer extends AbstractResultTypeComputer {
+public class MinMaxAggTypeComputer extends AggregateResultTypeComputer {
public static final MinMaxAggTypeComputer INSTANCE = new MinMaxAggTypeComputer();
diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/NumericSumAggTypeComputer.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/NumericSumAggTypeComputer.java
index 1c67e56..a4b5e34 100644
--- a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/NumericSumAggTypeComputer.java
+++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/NumericSumAggTypeComputer.java
@@ -18,42 +18,20 @@
*/
package org.apache.asterix.om.typecomputer.impl;
-import org.apache.asterix.om.exceptions.UnsupportedTypeException;
-import org.apache.asterix.om.typecomputer.base.AbstractResultTypeComputer;
import org.apache.asterix.om.types.ATypeTag;
import org.apache.asterix.om.types.AUnionType;
import org.apache.asterix.om.types.BuiltinType;
import org.apache.asterix.om.types.IAType;
import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression;
-import org.apache.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
-import org.apache.hyracks.api.exceptions.SourceLocation;
-public class NumericSumAggTypeComputer extends AbstractResultTypeComputer {
+public class NumericSumAggTypeComputer extends AggregateResultTypeComputer {
public static final NumericSumAggTypeComputer INSTANCE = new NumericSumAggTypeComputer();
private NumericSumAggTypeComputer() {
}
@Override
- protected void checkArgType(FunctionIdentifier funcId, int argIndex, IAType type, SourceLocation sourceLoc)
- throws AlgebricksException {
- ATypeTag tag = type.getTypeTag();
- switch (tag) {
- case DOUBLE:
- case FLOAT:
- case BIGINT:
- case INTEGER:
- case SMALLINT:
- case TINYINT:
- case ANY:
- break;
- default:
- throw new UnsupportedTypeException(sourceLoc, funcId, tag);
- }
- }
-
- @Override
protected IAType getResultType(ILogicalExpression expr, IAType... strippedInputTypes) throws AlgebricksException {
ATypeTag tag = strippedInputTypes[0].getTypeTag();
switch (tag) {
@@ -61,15 +39,12 @@
case SMALLINT:
case INTEGER:
case BIGINT:
- IAType int64Type = BuiltinType.AINT64;
- return AUnionType.createNullableType(int64Type, "AggResult");
+ return AUnionType.createNullableType(BuiltinType.AINT64);
case FLOAT:
case DOUBLE:
- IAType doubleType = BuiltinType.ADOUBLE;
- return AUnionType.createNullableType(doubleType, "AggResult");
+ return AUnionType.createNullableType(BuiltinType.ADOUBLE);
case ANY:
- IAType anyType = strippedInputTypes[0];
- return AUnionType.createNullableType(anyType, "AggResult");
+ return BuiltinType.ANY;
default:
// All other possible cases.
return BuiltinType.ANULL;
diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/ScalarVersionOfAggregateResultType.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/ScalarVersionOfAggregateResultType.java
index 5b90974..fda0285 100644
--- a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/ScalarVersionOfAggregateResultType.java
+++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/ScalarVersionOfAggregateResultType.java
@@ -18,9 +18,7 @@
*/
package org.apache.asterix.om.typecomputer.impl;
-import org.apache.asterix.om.exceptions.TypeMismatchException;
import org.apache.asterix.om.typecomputer.base.AbstractResultTypeComputer;
-import org.apache.asterix.om.types.ATypeTag;
import org.apache.asterix.om.types.AUnionType;
import org.apache.asterix.om.types.AbstractCollectionType;
import org.apache.asterix.om.types.BuiltinType;
@@ -32,32 +30,48 @@
public class ScalarVersionOfAggregateResultType extends AbstractResultTypeComputer {
- public static final ScalarVersionOfAggregateResultType INSTANCE = new ScalarVersionOfAggregateResultType();
+ private final AggregateResultTypeComputer aggResultTypeComputer;
- private ScalarVersionOfAggregateResultType() {
+ public ScalarVersionOfAggregateResultType(AggregateResultTypeComputer aggResultTypeComputer) {
+ this.aggResultTypeComputer = aggResultTypeComputer;
}
@Override
- public void checkArgType(FunctionIdentifier funcId, int argIndex, IAType type, SourceLocation sourceLoc)
+ protected void checkArgType(FunctionIdentifier funcId, int argIndex, IAType argType, SourceLocation sourceLoc)
throws AlgebricksException {
- ATypeTag tag = type.getTypeTag();
- if (tag != ATypeTag.ANY && tag != ATypeTag.ARRAY && tag != ATypeTag.MULTISET) {
- throw new TypeMismatchException(sourceLoc, funcId, argIndex, tag, ATypeTag.ARRAY, ATypeTag.MULTISET);
+ if (argIndex == 0) {
+ switch (argType.getTypeTag()) {
+ case ARRAY:
+ case MULTISET:
+ AbstractCollectionType act = (AbstractCollectionType) argType;
+ aggResultTypeComputer.checkArgType(funcId, argIndex, act.getItemType(), sourceLoc);
+ break;
+ }
}
}
@Override
protected IAType getResultType(ILogicalExpression expr, IAType... strippedInputTypes) throws AlgebricksException {
- ATypeTag tag = strippedInputTypes[0].getTypeTag();
- if (tag == ATypeTag.ANY) {
- return BuiltinType.ANY;
+ IAType argType = strippedInputTypes[0];
+ switch (argType.getTypeTag()) {
+ case ARRAY:
+ case MULTISET:
+ AbstractCollectionType act = (AbstractCollectionType) argType;
+ IAType[] strippedInputTypes2 = strippedInputTypes.clone();
+ strippedInputTypes2[0] = act.getItemType();
+ IAType resultType = aggResultTypeComputer.getResultType(expr, strippedInputTypes2);
+ switch (resultType.getTypeTag()) {
+ case NULL:
+ case MISSING:
+ case ANY:
+ return resultType;
+ case UNION:
+ return AUnionType.createUnknownableType(((AUnionType) resultType).getActualType());
+ default:
+ return AUnionType.createUnknownableType(resultType);
+ }
+ default:
+ return BuiltinType.ANY;
}
- if (tag != ATypeTag.ARRAY && tag != ATypeTag.MULTISET) {
- // this condition being true would've thrown an exception above, no?
- return strippedInputTypes[0];
- }
- AbstractCollectionType act = (AbstractCollectionType) strippedInputTypes[0];
- IAType t = act.getItemType();
- return AUnionType.createUnknownableType(t);
}
}
diff --git a/asterixdb/asterix-runtime/src/test/java/org/apache/asterix/runtime/functions/ScalarAggregateTypeComputerTest.java b/asterixdb/asterix-runtime/src/test/java/org/apache/asterix/runtime/functions/ScalarAggregateTypeComputerTest.java
new file mode 100644
index 0000000..cbde36c
--- /dev/null
+++ b/asterixdb/asterix-runtime/src/test/java/org/apache/asterix/runtime/functions/ScalarAggregateTypeComputerTest.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.asterix.runtime.functions;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+import org.apache.asterix.dataflow.data.common.ExpressionTypeComputer;
+import org.apache.asterix.om.base.ABoolean;
+import org.apache.asterix.om.base.ADate;
+import org.apache.asterix.om.base.ADateTime;
+import org.apache.asterix.om.base.ADayTimeDuration;
+import org.apache.asterix.om.base.ADouble;
+import org.apache.asterix.om.base.ADuration;
+import org.apache.asterix.om.base.AFloat;
+import org.apache.asterix.om.base.AInt16;
+import org.apache.asterix.om.base.AInt32;
+import org.apache.asterix.om.base.AInt64;
+import org.apache.asterix.om.base.AInt8;
+import org.apache.asterix.om.base.AInterval;
+import org.apache.asterix.om.base.AOrderedList;
+import org.apache.asterix.om.base.ARecord;
+import org.apache.asterix.om.base.AString;
+import org.apache.asterix.om.base.ATime;
+import org.apache.asterix.om.base.AUnorderedList;
+import org.apache.asterix.om.base.AYearMonthDuration;
+import org.apache.asterix.om.base.IAObject;
+import org.apache.asterix.om.constants.AsterixConstantValue;
+import org.apache.asterix.om.exceptions.UnsupportedTypeException;
+import org.apache.asterix.om.functions.BuiltinFunctionInfo;
+import org.apache.asterix.om.functions.BuiltinFunctions;
+import org.apache.asterix.om.functions.IFunctionDescriptorFactory;
+import org.apache.asterix.om.types.AOrderedListType;
+import org.apache.asterix.om.types.ARecordType;
+import org.apache.asterix.om.types.ATypeTag;
+import org.apache.asterix.om.types.AUnionType;
+import org.apache.asterix.om.types.AUnorderedListType;
+import org.apache.asterix.om.types.BuiltinType;
+import org.apache.asterix.om.types.IAType;
+import org.apache.commons.lang3.mutable.MutableObject;
+import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException;
+import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression;
+import org.apache.hyracks.algebricks.core.algebra.base.LogicalVariable;
+import org.apache.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
+import org.apache.hyracks.algebricks.core.algebra.expressions.AggregateFunctionCallExpression;
+import org.apache.hyracks.algebricks.core.algebra.expressions.ConstantExpression;
+import org.apache.hyracks.algebricks.core.algebra.expressions.IVariableTypeEnvironment;
+import org.apache.hyracks.algebricks.core.algebra.expressions.ScalarFunctionCallExpression;
+import org.apache.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+/**
+ * Test alignment of type computers between aggregate functions and their scalar versions
+ */
+@RunWith(Parameterized.class)
+public class ScalarAggregateTypeComputerTest {
+
+ private static final IAObject[] ITEMS = {
+ //
+ ABoolean.TRUE,
+ //
+ new AInt8((byte) 0),
+ //
+ new AInt16((short) 0),
+ //
+ new AInt32(0),
+ //
+ new AInt64(0),
+ //
+ new AFloat(0),
+ //
+ new ADouble(0),
+ //
+ new AString(""),
+ //
+ new ADate(0),
+ //
+ new ADateTime(0),
+ //
+ new ATime(0),
+ //
+ new ADuration(0, 0),
+ //
+ new AYearMonthDuration(0),
+ //
+ new ADayTimeDuration(0),
+ //
+ new AInterval(0, 0, ATypeTag.DATETIME.serialize()),
+ //
+ new AOrderedList(AOrderedListType.FULL_OPEN_ORDEREDLIST_TYPE, Collections.singletonList(new AString(""))),
+ //
+ new AUnorderedList(AUnorderedListType.FULLY_OPEN_UNORDEREDLIST_TYPE,
+ Collections.singletonList(new AString(""))),
+ //
+ new ARecord(
+ new ARecordType("record-type", new String[] { "a" }, new IAType[] { BuiltinType.ASTRING }, false),
+ new IAObject[] { new AString("") }) };
+
+ // Test parameters
+ @Parameterized.Parameter
+ public String testName;
+
+ @Parameterized.Parameter(1)
+ public FunctionIdentifier scalarfid;
+
+ @Parameterized.Parameter(2)
+ public FunctionIdentifier aggfid;
+
+ @Parameterized.Parameter(3)
+ public IAObject item;
+
+ @Parameterized.Parameters(name = "ScalarAggregateTypeComputerTest {index}: {0}({3})")
+ public static Collection<Object[]> tests() {
+ List<Object[]> tests = new ArrayList<>();
+
+ FunctionCollection fcoll = FunctionCollection.createDefaultFunctionCollection();
+ for (IFunctionDescriptorFactory fdf : fcoll.getFunctionDescriptorFactories()) {
+ FunctionIdentifier fid = fdf.createFunctionDescriptor().getIdentifier();
+ FunctionIdentifier aggfid = BuiltinFunctions.getAggregateFunction(fid);
+ if (aggfid == null) {
+ continue;
+ }
+ for (IAObject item : ITEMS) {
+ tests.add(new Object[] { fid.getName(), fid, aggfid, item });
+ }
+
+ }
+ return tests;
+ }
+
+ @Test
+ public void test() throws Exception {
+
+ AOrderedListType listType = new AOrderedListType(item.getType(), null);
+ AOrderedList list = new AOrderedList(listType, Collections.singletonList(item));
+ ConstantExpression scalarArgExpr = new ConstantExpression(new AsterixConstantValue(list));
+ BuiltinFunctionInfo scalarfi = BuiltinFunctions.getBuiltinFunctionInfo(scalarfid);
+ ScalarFunctionCallExpression scalarCallExpr =
+ new ScalarFunctionCallExpression(scalarfi, new MutableObject<>(scalarArgExpr));
+ IAType scalarResultType = computeType(scalarCallExpr);
+
+ ConstantExpression aggArgExpr = new ConstantExpression(new AsterixConstantValue(item));
+ BuiltinFunctionInfo aggfi = BuiltinFunctions.getBuiltinFunctionInfo(aggfid);
+ AggregateFunctionCallExpression aggCallExpr = new AggregateFunctionCallExpression(aggfi, false,
+ Collections.singletonList(new MutableObject<>(aggArgExpr)));
+ IAType aggResultType = computeType(aggCallExpr);
+
+ if (!compareResultTypes(scalarResultType, aggResultType)) {
+ Assert.fail(String.format("%s(%s) returns %s != %s(%s) returns %s", scalarfid.getName(), item.getType(),
+ formatResultType(scalarResultType), aggfid.getName(), item.getType(),
+ formatResultType(aggResultType)));
+ }
+ }
+
+ private boolean compareResultTypes(IAType t1, IAType t2) {
+ // null means ERROR
+ if (t1 == null) {
+ // OK if both types are ERROR
+ return t2 == null;
+ } else if (t2 == null) {
+ return false;
+ }
+ boolean t1Union = false, t2Union = false;
+ if (t1.getTypeTag() == ATypeTag.UNION) {
+ t1Union = true;
+ t1 = ((AUnionType) t1).getActualType();
+ }
+ if (t2.getTypeTag() == ATypeTag.UNION) {
+ t2Union = true;
+ t2 = ((AUnionType) t2).getActualType();
+ }
+ return (t1Union == t2Union) && t1.deepEqual(t2);
+ }
+
+ private String formatResultType(IAType t) {
+ return t == null ? "ERROR" : t.toString();
+ }
+
+ private IAType computeType(AbstractFunctionCallExpression callExpr) throws AlgebricksException {
+ try {
+ BuiltinFunctionInfo fi = Objects.requireNonNull((BuiltinFunctionInfo) callExpr.getFunctionInfo());
+ return fi.getResultTypeComputer().computeType(callExpr, EMPTY_TYPE_ENV, null);
+ } catch (UnsupportedTypeException e) {
+ return null;
+ }
+ }
+
+ private static final IVariableTypeEnvironment EMPTY_TYPE_ENV = new IVariableTypeEnvironment() {
+
+ @Override
+ public boolean substituteProducedVariable(LogicalVariable v1, LogicalVariable v2) {
+ throw new IllegalStateException();
+ }
+
+ @Override
+ public void setVarType(LogicalVariable var, Object type) {
+ throw new IllegalStateException();
+ }
+
+ @Override
+ public Object getVarType(LogicalVariable var, List<LogicalVariable> nonNullVariables,
+ List<List<LogicalVariable>> correlatedNullableVariableLists) {
+ throw new IllegalStateException();
+ }
+
+ @Override
+ public Object getVarType(LogicalVariable var) {
+ throw new IllegalStateException();
+ }
+
+ @Override
+ public Object getType(ILogicalExpression expr) throws AlgebricksException {
+ return ExpressionTypeComputer.INSTANCE.getType(expr, null, this);
+ }
+ };
+}