Recursively checking for agg funcs in an assign op.
git-svn-id: https://asterixdb.googlecode.com/svn/branches/asterix_fix_agg@502 eaa15691-b419-025a-1212-ee371bd00084
diff --git a/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java b/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java
index 5770f4b..0b96934 100644
--- a/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java
+++ b/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java
@@ -14,6 +14,8 @@
*/
package edu.uci.ics.asterix.optimizer.rules;
+import java.util.Collection;
+import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
@@ -29,6 +31,7 @@
import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalOperatorTag;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalVariable;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
+import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression.FunctionKind;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AggregateFunctionCallExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.ConstantExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.VariableReferenceExpression;
@@ -54,7 +57,7 @@
@Override
public boolean rewritePost(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
throws AlgebricksException {
- // Pattern to match: assign <-- aggregate.
+ // Pattern to match: assign <-- aggregate <-- !(group-by)
AbstractLogicalOperator op = (AbstractLogicalOperator) opRef.getValue();
if (op.getOperatorTag() != LogicalOperatorTag.ASSIGN) {
return false;
@@ -64,7 +67,13 @@
if (op2.getOperatorTag() != LogicalOperatorTag.AGGREGATE) {
return false;
}
-
+ // If there's a group by below the agg, then we want to have the agg pushed into the group by.
+ Mutable<ILogicalOperator> opRef3 = op2.getInputs().get(0);
+ AbstractLogicalOperator op3 = (AbstractLogicalOperator) opRef3.getValue();
+ if (op3.getOperatorTag() == LogicalOperatorTag.GROUP) {
+ return false;
+ }
+
AssignOperator assignOp = (AssignOperator) op;
AggregateOperator aggOp = (AggregateOperator) op2;
if (aggOp.getVariables().size() != 1) {
@@ -77,38 +86,12 @@
return false;
}
- Mutable<ILogicalExpression> srcAssignExprRef = null;
- FunctionIdentifier aggFuncIdent = null;
- List<Mutable<ILogicalExpression>> assignExprRefs = assignOp.getExpressions();
- for (Mutable<ILogicalExpression> assignExprRef : assignExprRefs) {
- // Continue if assignExprRef is not an aggregate function.
- ILogicalExpression assignExpr = assignExprRef.getValue();
- if (assignExpr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
- continue;
- }
- AbstractFunctionCallExpression assignFuncExpr = (AbstractFunctionCallExpression) assignExpr;
- aggFuncIdent = AsterixBuiltinFunctions.getAggregateFunction(assignFuncExpr.getFunctionIdentifier());
- if (aggFuncIdent == null) {
- continue;
- }
- // Make sure this is the expr that uses aggVar.
- List<Mutable<ILogicalExpression>> aggFuncArgRefs = assignFuncExpr.getArguments();
- for (Mutable<ILogicalExpression> aggFuncExprRef : aggFuncArgRefs) {
- ILogicalExpression aggFuncArgExpr = aggFuncExprRef.getValue();
- if (aggFuncArgExpr.getExpressionTag() != LogicalExpressionTag.VARIABLE) {
- continue;
- }
- VariableReferenceExpression varRefExpr = (VariableReferenceExpression) aggFuncArgExpr;
- if (varRefExpr.getVariableReference() != aggVar) {
- continue;
- }
- srcAssignExprRef = assignExprRef;
- break;
- }
- }
+ Mutable<ILogicalExpression> srcAssignExprRef = fingAggFuncExprRef(assignOp.getExpressions(), aggVar);
if (srcAssignExprRef == null) {
- return false;
+ return false;
}
+ AbstractFunctionCallExpression assignFuncExpr = (AbstractFunctionCallExpression) srcAssignExprRef.getValue();
+ FunctionIdentifier aggFuncIdent = AsterixBuiltinFunctions.getAggregateFunction(assignFuncExpr.getFunctionIdentifier());
// Push the agg func into the agg op.
AbstractFunctionCallExpression aggOpExpr = (AbstractFunctionCallExpression) aggOp.getExpressions().get(0).getValue();
@@ -134,4 +117,26 @@
return true;
}
+
+ private Mutable<ILogicalExpression> fingAggFuncExprRef(List<Mutable<ILogicalExpression>> exprRefs, LogicalVariable aggVar) {
+ for (Mutable<ILogicalExpression> exprRef : exprRefs) {
+ ILogicalExpression expr = exprRef.getValue();
+ if (expr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
+ continue;
+ }
+ AbstractFunctionCallExpression funcExpr = (AbstractFunctionCallExpression) expr;
+ FunctionIdentifier funcIdent = AsterixBuiltinFunctions.getAggregateFunction(funcExpr.getFunctionIdentifier());
+ if (funcIdent == null) {
+ // Recursively look in func args.
+ return fingAggFuncExprRef(funcExpr.getArguments(), aggVar);
+ }
+ // Check if this is the expr that uses aggVar.
+ Collection<LogicalVariable> usedVars = new HashSet<LogicalVariable>();
+ funcExpr.getUsedVariables(usedVars);
+ if (usedVars.contains(aggVar)) {
+ return exprRef;
+ }
+ }
+ return null;
+ }
}