Recursively checking for agg funcs in an assign op.

git-svn-id: https://asterixdb.googlecode.com/svn/branches/asterix_fix_agg@502 eaa15691-b419-025a-1212-ee371bd00084
diff --git a/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java b/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java
index 5770f4b..0b96934 100644
--- a/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java
+++ b/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java
@@ -14,6 +14,8 @@
  */
 package edu.uci.ics.asterix.optimizer.rules;
 
+import java.util.Collection;
+import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
 
@@ -29,6 +31,7 @@
 import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalOperatorTag;
 import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalVariable;
 import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
+import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression.FunctionKind;
 import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AggregateFunctionCallExpression;
 import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.ConstantExpression;
 import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.VariableReferenceExpression;
@@ -54,7 +57,7 @@
     @Override
     public boolean rewritePost(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
             throws AlgebricksException {
-        // Pattern to match: assign <-- aggregate.
+        // Pattern to match: assign <-- aggregate <-- !(group-by)
         AbstractLogicalOperator op = (AbstractLogicalOperator) opRef.getValue();
         if (op.getOperatorTag() != LogicalOperatorTag.ASSIGN) {
             return false;
@@ -64,7 +67,13 @@
         if (op2.getOperatorTag() != LogicalOperatorTag.AGGREGATE) {
             return false;
         }
-
+        // If there's a group by below the agg, then we want to have the agg pushed into the group by.
+        Mutable<ILogicalOperator> opRef3 = op2.getInputs().get(0);
+        AbstractLogicalOperator op3 = (AbstractLogicalOperator) opRef3.getValue();
+        if (op3.getOperatorTag() == LogicalOperatorTag.GROUP) {
+            return false;
+        }
+        
         AssignOperator assignOp = (AssignOperator) op;
         AggregateOperator aggOp = (AggregateOperator) op2;
         if (aggOp.getVariables().size() != 1) {
@@ -77,38 +86,12 @@
             return false;
         }
         
-        Mutable<ILogicalExpression> srcAssignExprRef = null;        
-        FunctionIdentifier aggFuncIdent = null;
-        List<Mutable<ILogicalExpression>> assignExprRefs = assignOp.getExpressions();
-        for (Mutable<ILogicalExpression> assignExprRef : assignExprRefs) {
-            // Continue if assignExprRef is not an aggregate function.
-            ILogicalExpression assignExpr = assignExprRef.getValue();
-            if (assignExpr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
-                continue;
-            }
-            AbstractFunctionCallExpression assignFuncExpr = (AbstractFunctionCallExpression) assignExpr;
-            aggFuncIdent = AsterixBuiltinFunctions.getAggregateFunction(assignFuncExpr.getFunctionIdentifier());
-            if (aggFuncIdent == null) {
-                continue;
-            }
-            // Make sure this is the expr that uses aggVar.
-            List<Mutable<ILogicalExpression>> aggFuncArgRefs = assignFuncExpr.getArguments();
-            for (Mutable<ILogicalExpression> aggFuncExprRef : aggFuncArgRefs) {
-                ILogicalExpression aggFuncArgExpr = aggFuncExprRef.getValue();
-                if (aggFuncArgExpr.getExpressionTag() != LogicalExpressionTag.VARIABLE) {
-                    continue;
-                }
-                VariableReferenceExpression varRefExpr = (VariableReferenceExpression) aggFuncArgExpr;
-                if (varRefExpr.getVariableReference() != aggVar) {
-                    continue;
-                }
-                srcAssignExprRef = assignExprRef;
-                break;
-            }
-        }
+        Mutable<ILogicalExpression> srcAssignExprRef = fingAggFuncExprRef(assignOp.getExpressions(), aggVar);
         if (srcAssignExprRef == null) {
-            return false;
+        	return false;
         }
+        AbstractFunctionCallExpression assignFuncExpr = (AbstractFunctionCallExpression) srcAssignExprRef.getValue();
+        FunctionIdentifier aggFuncIdent = AsterixBuiltinFunctions.getAggregateFunction(assignFuncExpr.getFunctionIdentifier());
         
         // Push the agg func into the agg op.                
         AbstractFunctionCallExpression aggOpExpr = (AbstractFunctionCallExpression) aggOp.getExpressions().get(0).getValue();
@@ -134,4 +117,26 @@
         
         return true;
     }
+    
+    private Mutable<ILogicalExpression> fingAggFuncExprRef(List<Mutable<ILogicalExpression>> exprRefs, LogicalVariable aggVar) {
+    	for (Mutable<ILogicalExpression> exprRef : exprRefs) {
+            ILogicalExpression expr = exprRef.getValue();
+            if (expr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
+                continue;
+            }
+            AbstractFunctionCallExpression funcExpr = (AbstractFunctionCallExpression) expr;
+            FunctionIdentifier funcIdent = AsterixBuiltinFunctions.getAggregateFunction(funcExpr.getFunctionIdentifier());
+            if (funcIdent == null) {
+            	// Recursively look in func args.
+            	return fingAggFuncExprRef(funcExpr.getArguments(), aggVar);
+            }
+            // Check if this is the expr that uses aggVar.
+            Collection<LogicalVariable> usedVars = new HashSet<LogicalVariable>();
+            funcExpr.getUsedVariables(usedVars);
+            if (usedVars.contains(aggVar)) {
+            	return exprRef;
+            }
+    	}
+    	return null;
+    }
 }