Added PushAggFuncIntoStandaloneAggregateRule

git-svn-id: https://asterixdb.googlecode.com/svn/branches/asterix_fix_agg@499 eaa15691-b419-025a-1212-ee371bd00084
diff --git a/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java b/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java
new file mode 100644
index 0000000..65372e7
--- /dev/null
+++ b/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/PushAggFuncIntoStandaloneAggregateRule.java
@@ -0,0 +1,120 @@
+/*
+ * Copyright 2009-2010 by The Regents of the University of California
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * you may obtain a copy of the License from
+ * 
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.uci.ics.asterix.optimizer.rules;
+
+import java.util.LinkedList;
+import java.util.List;
+
+import org.apache.commons.lang3.mutable.Mutable;
+
+import edu.uci.ics.asterix.om.functions.AsterixBuiltinFunctions;
+import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
+import edu.uci.ics.hyracks.algebricks.core.algebra.base.ILogicalExpression;
+import edu.uci.ics.hyracks.algebricks.core.algebra.base.ILogicalOperator;
+import edu.uci.ics.hyracks.algebricks.core.algebra.base.IOptimizationContext;
+import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalExpressionTag;
+import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalOperatorTag;
+import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalVariable;
+import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
+import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AggregateFunctionCallExpression;
+import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.VariableReferenceExpression;
+import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
+import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.AbstractLogicalOperator;
+import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.AggregateOperator;
+import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.AssignOperator;
+import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.visitors.VariableUtilities;
+import edu.uci.ics.hyracks.algebricks.core.rewriter.base.IAlgebraicRewriteRule;
+
+/**
+ * When aggregates appear w/o group-by, a default group by a constant is
+ * introduced.
+ */
+
+public class PushAggFuncIntoStandaloneAggregateRule implements IAlgebraicRewriteRule {
+
+    @Override
+    public boolean rewritePre(Mutable<ILogicalOperator> opRef, IOptimizationContext context) throws AlgebricksException {
+        return false;
+    }
+
+    @Override
+    public boolean rewritePost(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
+            throws AlgebricksException {
+        // Pattern to match: assign <-- aggregate.
+        AbstractLogicalOperator op = (AbstractLogicalOperator) opRef.getValue();
+        if (op.getOperatorTag() != LogicalOperatorTag.ASSIGN) {
+            return false;
+        }
+        Mutable<ILogicalOperator> opRef2 = op.getInputs().get(0);
+        AbstractLogicalOperator op2 = (AbstractLogicalOperator) opRef2.getValue();
+        if (op2.getOperatorTag() != LogicalOperatorTag.AGGREGATE) {
+            return false;
+        }
+
+        AssignOperator assignOp = (AssignOperator) op;
+        AggregateOperator aggOp = (AggregateOperator) op2;
+        if (aggOp.getVariables().size() != 1) {
+            return false;
+        }
+        LogicalVariable aggVar = aggOp.getVariables().get(0);
+        List<LogicalVariable> used = new LinkedList<LogicalVariable>();
+        VariableUtilities.getUsedVariables(assignOp, used);
+        if (!used.contains(aggVar)) {
+            return false;
+        }
+        
+        Mutable<ILogicalExpression> srcAssignExprRef = null;        
+        FunctionIdentifier aggFuncIdent = null;
+        List<Mutable<ILogicalExpression>> assignExprRefs = assignOp.getExpressions();
+        for (Mutable<ILogicalExpression> assignExprRef : assignExprRefs) {
+            // Continue if assignExprRef is not an aggregate function.
+            ILogicalExpression assignExpr = assignExprRef.getValue();
+            if (assignExpr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
+                continue;
+            }
+            AbstractFunctionCallExpression assignFuncExpr = (AbstractFunctionCallExpression) assignExpr;
+            aggFuncIdent = AsterixBuiltinFunctions.getAggregateFunction(assignFuncExpr.getFunctionIdentifier());
+            if (aggFuncIdent == null) {
+                continue;
+            }
+            // Make sure this is the expr that uses aggVar.
+            List<Mutable<ILogicalExpression>> aggFuncArgRefs = assignFuncExpr.getArguments();
+            for (Mutable<ILogicalExpression> aggFuncExprRef : aggFuncArgRefs) {
+                ILogicalExpression aggFuncArgExpr = aggFuncExprRef.getValue();
+                if (aggFuncArgExpr.getExpressionTag() != LogicalExpressionTag.VARIABLE) {
+                    continue;
+                }
+                VariableReferenceExpression varRefExpr = (VariableReferenceExpression) aggFuncArgExpr;
+                if (varRefExpr.getVariableReference() != aggVar) {
+                    continue;
+                }
+                srcAssignExprRef = assignExprRef;
+                break;
+            }
+        }
+        if (srcAssignExprRef == null) {
+            return false;
+        }
+        
+        // Push the agg func into the agg op.                
+        AbstractFunctionCallExpression aggOpExpr = (AbstractFunctionCallExpression) aggOp.getExpressions().get(0).getValue();
+        aggOp.getExpressions().get(0).setValue(new AggregateFunctionCallExpression(AsterixBuiltinFunctions.getAsterixFunctionInfo(aggFuncIdent), false, aggOpExpr.getArguments().get(0)));
+        
+        // The assign now just "renames" the variable to make sure the upstream plan still works.
+        srcAssignExprRef.setValue(new VariableReferenceExpression(aggVar));
+
+        return true;
+    }
+}