[ASTERIXDB-3101][COMP] Optimize pushing assign ops down

- user model changes: no
- storage format changes: no
- interface changes: yes

Details:
One of the things that PushFieldAccessRule attempts to do is
push assign operator down as close as possible to the respective
data scan operator. The assign operator is pushed recursively
through the operators below it one by one until the data scan is
reached. This becomes expensive when there is a large number
of assigns. In the case where all the operators below the assign
operator are other assign operators, the assign operator
could be moved directly above the data-scan skipping all
the intermediate assign operators.

- add default method to IAlgebraicRewriteRule to allow the rules
  to know if they are about to rewrite a nested plan root.

Optimize ExtractCommonExpressionsRule since the current traversal
of operators becomes expensive with a large number of operators.
- optimize ExtractCommonExpressionsRule to work on only roots of
  plans since the implementation descends to children recursively.
  check if the operator was already rewritten after descending to
  the children to allow post order traversal from the root.

Change-Id: I035b72089f973bb08dccf5f9305f8b06da7fc458
Reviewed-on: https://asterix-gerrit.ics.uci.edu/c/asterixdb/+/17316
Integration-Tests: Jenkins <jenkins@fulliautomatix.ics.uci.edu>
Tested-by: Jenkins <jenkins@fulliautomatix.ics.uci.edu>
Reviewed-by: Michael Blow <mblow@apache.org>
(cherry picked from commit 964ff7be6d2c026704001bd00430ae3a78bc66f6)
Reviewed-on: https://asterix-gerrit.ics.uci.edu/c/asterixdb/+/17245
Reviewed-by: Ali Alsuliman <ali.al.solaiman@gmail.com>
diff --git a/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/PushFieldAccessRule.java b/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/PushFieldAccessRule.java
index c82aa33..371f460 100644
--- a/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/PushFieldAccessRule.java
+++ b/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/PushFieldAccessRule.java
@@ -22,6 +22,7 @@
 import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Set;
 
 import org.apache.asterix.algebra.base.OperatorAnnotation;
 import org.apache.asterix.common.config.DatasetConfig.DatasetType;
@@ -176,11 +177,11 @@
         return e1.equals(e2);
     }
 
-    private boolean pushDownFieldAccessRec(Mutable<ILogicalOperator> opRef, IOptimizationContext context,
+    private boolean pushDownFieldAccessRec(Mutable<ILogicalOperator> assignOpRef, IOptimizationContext context,
             String finalAnnot) throws AlgebricksException {
-        AssignOperator assignOp = (AssignOperator) opRef.getValue();
-        Mutable<ILogicalOperator> opRef2 = assignOp.getInputs().get(0);
-        AbstractLogicalOperator inputOp = (AbstractLogicalOperator) opRef2.getValue();
+        AssignOperator assignOp = (AssignOperator) assignOpRef.getValue();
+        Mutable<ILogicalOperator> inputOpRef = assignOp.getInputs().get(0);
+        AbstractLogicalOperator inputOp = (AbstractLogicalOperator) inputOpRef.getValue();
         // If it's not an indexed field, it is pushed so that scan can be rewritten into index search.
         if (inputOp.getOperatorTag() == LogicalOperatorTag.PROJECT
                 || context.checkAndAddToAlreadyCompared(assignOp, inputOp)
@@ -196,24 +197,31 @@
             return false;
         }
         if (testAndModifyRedundantOp(assignOp, inputOp)) {
-            pushDownFieldAccessRec(opRef2, context, finalAnnot);
+            pushDownFieldAccessRec(inputOpRef, context, finalAnnot);
             return true;
         }
-        HashSet<LogicalVariable> usedInAccess = new HashSet<>();
+        Set<LogicalVariable> usedInAccess = new HashSet<>();
         VariableUtilities.getUsedVariables(assignOp, usedInAccess);
-
-        HashSet<LogicalVariable> produced2 = new HashSet<>();
+        if (usedInAccess.isEmpty()) {
+            return false;
+        }
+        Set<LogicalVariable> produced = new HashSet<>();
+        ILogicalOperator dataScanOp =
+                getDataScanOp(assignOpRef, assignOp, inputOpRef, inputOp, usedInAccess, produced, context);
+        if (dataScanOp != null) {
+            // in this case, we don't need to keep pushing the assign op through all the assign operators below it since
+            // this is unnecessary. we just need to try replacing field access by the primary key if it refers to one
+            return rewriteFieldAccessToPK(context, finalAnnot, assignOp, dataScanOp);
+        }
+        produced.clear();
         if (inputOp.getOperatorTag() == LogicalOperatorTag.GROUP) {
-            VariableUtilities.getLiveVariables(inputOp, produced2);
+            VariableUtilities.getLiveVariables(inputOp, produced);
         } else {
-            VariableUtilities.getProducedVariables(inputOp, produced2);
+            VariableUtilities.getProducedVariables(inputOp, produced);
         }
         boolean pushItDown = false;
         HashSet<LogicalVariable> inter = new HashSet<>(usedInAccess);
-        if (inter.isEmpty()) { // ground value
-            return false;
-        }
-        inter.retainAll(produced2);
+        inter.retainAll(produced);
         if (inter.isEmpty()) {
             pushItDown = true;
         } else if (inputOp.getOperatorTag() == LogicalOperatorTag.GROUP) {
@@ -254,18 +262,18 @@
             if (inputOp.getOperatorTag() == LogicalOperatorTag.NESTEDTUPLESOURCE) {
                 Mutable<ILogicalOperator> childOfSubplan =
                         ((NestedTupleSourceOperator) inputOp).getDataSourceReference().getValue().getInputs().get(0);
-                pushAccessDown(opRef, inputOp, childOfSubplan, context, finalAnnot);
+                pushAccessDown(assignOpRef, inputOp, childOfSubplan, context, finalAnnot);
                 return true;
             }
             if (inputOp.getInputs().size() == 1 && !inputOp.hasNestedPlans()) {
-                pushAccessDown(opRef, inputOp, inputOp.getInputs().get(0), context, finalAnnot);
+                pushAccessDown(assignOpRef, inputOp, inputOp.getInputs().get(0), context, finalAnnot);
                 return true;
             } else {
                 for (Mutable<ILogicalOperator> inp : inputOp.getInputs()) {
                     HashSet<LogicalVariable> v2 = new HashSet<>();
                     VariableUtilities.getLiveVariables(inp.getValue(), v2);
                     if (v2.containsAll(usedInAccess)) {
-                        pushAccessDown(opRef, inputOp, inp, context, finalAnnot);
+                        pushAccessDown(assignOpRef, inputOp, inp, context, finalAnnot);
                         return true;
                     }
                 }
@@ -277,7 +285,7 @@
                         HashSet<LogicalVariable> v2 = new HashSet<>();
                         VariableUtilities.getLiveVariables(root.getValue(), v2);
                         if (v2.containsAll(usedInAccess)) {
-                            pushAccessDown(opRef, inputOp, root, context, finalAnnot);
+                            pushAccessDown(assignOpRef, inputOp, root, context, finalAnnot);
                             return true;
                         }
                     }
@@ -286,73 +294,124 @@
             return false;
         } else {
             // check if the accessed field is one of the partitioning key fields. If yes, we can equate the 2 variables
-            if (inputOp.getOperatorTag() == LogicalOperatorTag.DATASOURCESCAN) {
-                DataSourceScanOperator scan = (DataSourceScanOperator) inputOp;
-                IDataSource<DataSourceId> dataSource = (IDataSource<DataSourceId>) scan.getDataSource();
-                byte dsType = ((DataSource) dataSource).getDatasourceType();
-                if (dsType != DataSource.Type.INTERNAL_DATASET && dsType != DataSource.Type.EXTERNAL_DATASET) {
-                    return false;
+            return rewriteFieldAccessToPK(context, finalAnnot, assignOp, inputOp);
+        }
+    }
+
+    /**
+     * Tries to rewrite field access to its equivalent PK. For example, a data scan operator of dataset "ds" produces
+     * the following variables: $PK1, $PK2,.., $ds, ($meta_var?). Given field access: $$ds.getField("id") and given that
+     * the field "id" is one of the primary keys of ds, the field access $$ds.getField("id") is replaced by the primary
+     * key variable (one of the $PKs).
+     * @return true if the field access in the assign operator was replaced by the primary key variable.
+     */
+    private boolean rewriteFieldAccessToPK(IOptimizationContext context, String finalAnnot, AssignOperator assignOp,
+            ILogicalOperator inputOp) throws AlgebricksException {
+        if (inputOp.getOperatorTag() == LogicalOperatorTag.DATASOURCESCAN) {
+            DataSourceScanOperator scan = (DataSourceScanOperator) inputOp;
+            IDataSource<DataSourceId> dataSource = (IDataSource<DataSourceId>) scan.getDataSource();
+            byte dsType = ((DataSource) dataSource).getDatasourceType();
+            if (dsType != DataSource.Type.INTERNAL_DATASET && dsType != DataSource.Type.EXTERNAL_DATASET) {
+                return false;
+            }
+            DataSourceId asid = dataSource.getId();
+            MetadataProvider mp = (MetadataProvider) context.getMetadataProvider();
+            Dataset dataset = mp.findDataset(asid.getDataverseName(), asid.getDatasourceName());
+            if (dataset == null) {
+                throw new CompilationException(ErrorCode.UNKNOWN_DATASET_IN_DATAVERSE, scan.getSourceLocation(),
+                        asid.getDatasourceName(), asid.getDataverseName());
+            }
+            if (dataset.getDatasetType() != DatasetType.INTERNAL) {
+                setAsFinal(assignOp, context, finalAnnot);
+                return false;
+            }
+
+            List<LogicalVariable> allVars = scan.getVariables();
+            LogicalVariable dataRecVarInScan = ((DataSource) dataSource).getDataRecordVariable(allVars);
+            LogicalVariable metaRecVarInScan = ((DataSource) dataSource).getMetaVariable(allVars);
+
+            // data part
+            String dataTypeName = dataset.getItemTypeName();
+            IAType dataType = mp.findType(dataset.getItemTypeDataverseName(), dataTypeName);
+            if (dataType.getTypeTag() != ATypeTag.OBJECT) {
+                return false;
+            }
+            ARecordType dataRecType = (ARecordType) dataType;
+            Pair<ILogicalExpression, List<String>> fieldPathAndVar = getFieldExpression(assignOp, dataRecType);
+            ILogicalExpression targetRecVar = fieldPathAndVar.first;
+            List<String> targetFieldPath = fieldPathAndVar.second;
+            boolean rewrite = false;
+            boolean fieldFromMeta = false;
+            if (sameRecords(targetRecVar, dataRecVarInScan)) {
+                rewrite = true;
+            } else {
+                // check meta part
+                IAType metaType = mp.findMetaType(dataset); // could be null
+                if (metaType != null && metaType.getTypeTag() == ATypeTag.OBJECT) {
+                    fieldPathAndVar = getFieldExpression(assignOp, (ARecordType) metaType);
+                    targetRecVar = fieldPathAndVar.first;
+                    targetFieldPath = fieldPathAndVar.second;
+                    if (sameRecords(targetRecVar, metaRecVarInScan)) {
+                        rewrite = true;
+                        fieldFromMeta = true;
+                    }
                 }
-                DataSourceId asid = dataSource.getId();
-                MetadataProvider mp = (MetadataProvider) context.getMetadataProvider();
-                Dataset dataset = mp.findDataset(asid.getDataverseName(), asid.getDatasourceName());
-                if (dataset == null) {
-                    throw new CompilationException(ErrorCode.UNKNOWN_DATASET_IN_DATAVERSE, scan.getSourceLocation(),
-                            asid.getDatasourceName(), asid.getDataverseName());
-                }
-                if (dataset.getDatasetType() != DatasetType.INTERNAL) {
+            }
+
+            if (rewrite) {
+                int p = DatasetUtil.getPositionOfPartitioningKeyField(dataset, targetFieldPath, fieldFromMeta);
+                if (p < 0) { // not one of the partitioning fields
                     setAsFinal(assignOp, context, finalAnnot);
                     return false;
                 }
-
-                List<LogicalVariable> allVars = scan.getVariables();
-                LogicalVariable dataRecVarInScan = ((DataSource) dataSource).getDataRecordVariable(allVars);
-                LogicalVariable metaRecVarInScan = ((DataSource) dataSource).getMetaVariable(allVars);
-
-                // data part
-                String dataTypeName = dataset.getItemTypeName();
-                IAType dataType = mp.findType(dataset.getItemTypeDataverseName(), dataTypeName);
-                if (dataType.getTypeTag() != ATypeTag.OBJECT) {
-                    return false;
-                }
-                ARecordType dataRecType = (ARecordType) dataType;
-                Pair<ILogicalExpression, List<String>> fieldPathAndVar = getFieldExpression(assignOp, dataRecType);
-                ILogicalExpression targetRecVar = fieldPathAndVar.first;
-                List<String> targetFieldPath = fieldPathAndVar.second;
-                boolean rewrite = false;
-                boolean fieldFromMeta = false;
-                if (sameRecords(targetRecVar, dataRecVarInScan)) {
-                    rewrite = true;
-                } else {
-                    // check meta part
-                    IAType metaType = mp.findMetaType(dataset); // could be null
-                    if (metaType != null && metaType.getTypeTag() == ATypeTag.OBJECT) {
-                        fieldPathAndVar = getFieldExpression(assignOp, (ARecordType) metaType);
-                        targetRecVar = fieldPathAndVar.first;
-                        targetFieldPath = fieldPathAndVar.second;
-                        if (sameRecords(targetRecVar, metaRecVarInScan)) {
-                            rewrite = true;
-                            fieldFromMeta = true;
-                        }
-                    }
-                }
-
-                if (rewrite) {
-                    int p = DatasetUtil.getPositionOfPartitioningKeyField(dataset, targetFieldPath, fieldFromMeta);
-                    if (p < 0) { // not one of the partitioning fields
-                        setAsFinal(assignOp, context, finalAnnot);
-                        return false;
-                    }
-                    LogicalVariable keyVar = scan.getVariables().get(p);
-                    VariableReferenceExpression keyVarRef = new VariableReferenceExpression(keyVar);
-                    keyVarRef.setSourceLocation(targetRecVar.getSourceLocation());
-                    assignOp.getExpressions().get(0).setValue(keyVarRef);
-                    return true;
-                }
+                LogicalVariable keyVar = scan.getVariables().get(p);
+                VariableReferenceExpression keyVarRef = new VariableReferenceExpression(keyVar);
+                keyVarRef.setSourceLocation(targetRecVar.getSourceLocation());
+                assignOp.getExpressions().get(0).setValue(keyVarRef);
+                return true;
             }
-            setAsFinal(assignOp, context, finalAnnot);
-            return false;
         }
+        setAsFinal(assignOp, context, finalAnnot);
+        return false;
+    }
+
+    /**
+     * Looks for a data scan operator where the data scan operator is below only assign operators. Then, if
+     * applicable, the assign operator is moved down and placed above the data-scan.
+     *
+     * @return the data scan operator if it exists below multiple assign operators only and the assign operator is now
+     * above the data-scan.
+     */
+    private ILogicalOperator getDataScanOp(Mutable<ILogicalOperator> assignOpRef, AssignOperator assignOp,
+            Mutable<ILogicalOperator> assignInputRef, ILogicalOperator assignInput, Set<LogicalVariable> usedInAssign,
+            Set<LogicalVariable> producedByInput, IOptimizationContext context) throws AlgebricksException {
+        ILogicalOperator firstInput = assignInput;
+        while (assignInput.getOperatorTag() == LogicalOperatorTag.ASSIGN) {
+            if (isRedundantAssign(assignOp, assignInput)) {
+                return null;
+            }
+            assignInputRef = assignInput.getInputs().get(0);
+            assignInput = assignInputRef.getValue();
+        }
+        if (assignInput.getOperatorTag() != LogicalOperatorTag.DATASOURCESCAN) {
+            return null;
+        }
+        VariableUtilities.getProducedVariables(assignInput, producedByInput);
+        if (!producedByInput.containsAll(usedInAssign)) {
+            return null;
+        }
+        if (firstInput == assignInput) {
+            // the input to the assign operator is already a data-scan
+            return assignInput;
+        }
+        // move the assign op down, place it above the data-scan
+        assignOpRef.setValue(firstInput);
+        List<Mutable<ILogicalOperator>> assignInputs = assignOp.getInputs();
+        assignInputs.get(0).setValue(assignInput);
+        assignInputRef.setValue(assignOp);
+        context.computeAndSetTypeEnvironmentForOperator(assignOp);
+        context.computeAndSetTypeEnvironmentForOperator(firstInput);
+        return assignInput;
     }
 
     /**
@@ -398,12 +457,9 @@
     }
 
     private boolean testAndModifyRedundantOp(AssignOperator access, AbstractLogicalOperator op2) {
-        if (op2.getOperatorTag() != LogicalOperatorTag.ASSIGN) {
-            return false;
-        }
-        AssignOperator a2 = (AssignOperator) op2;
-        ILogicalExpression accessExpr0 = getFirstExpr(access);
-        if (accessExpr0.equals(getFirstExpr(a2))) {
+        if (isRedundantAssign(access, op2)) {
+            AssignOperator a2 = (AssignOperator) op2;
+            ILogicalExpression accessExpr0 = getFirstExpr(access);
             VariableReferenceExpression varRef = new VariableReferenceExpression(a2.getVariables().get(0));
             varRef.setSourceLocation(accessExpr0.getSourceLocation());
             access.getExpressions().get(0).setValue(varRef);
@@ -413,6 +469,14 @@
         }
     }
 
+    private static boolean isRedundantAssign(AssignOperator assignOp, ILogicalOperator inputOp) {
+        if (inputOp.getOperatorTag() != LogicalOperatorTag.ASSIGN) {
+            return false;
+        }
+        ILogicalExpression assignOpExpr = getFirstExpr(assignOp);
+        return assignOpExpr.equals(getFirstExpr((AssignOperator) inputOp));
+    }
+
     // indirect recursivity with pushDownFieldAccessRec
     private void pushAccessDown(Mutable<ILogicalOperator> fldAccessOpRef, ILogicalOperator op2,
             Mutable<ILogicalOperator> inputOfOp2, IOptimizationContext context, String finalAnnot)
@@ -429,8 +493,7 @@
         pushDownFieldAccessRec(inputOfOp2, context, finalAnnot);
     }
 
-    private ILogicalExpression getFirstExpr(AssignOperator assign) {
+    private static ILogicalExpression getFirstExpr(AssignOperator assign) {
         return assign.getExpressions().get(0).getValue();
     }
-
 }
diff --git a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFirstRuleCheckFixpointRuleController.java b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFirstRuleCheckFixpointRuleController.java
index 29c178a..79ec0fa 100644
--- a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFirstRuleCheckFixpointRuleController.java
+++ b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFirstRuleCheckFixpointRuleController.java
@@ -72,7 +72,7 @@
         do {
             anyChange = false;
             for (int i = 0; i < rules.size(); i++) {
-                boolean ruleFired = rewriteOperatorRef(root, rules.get(i), true, fullDfs);
+                boolean ruleFired = rewriteOperatorRef(root, rules.get(i), true, fullDfs, false);
                 // If the first rule returns false in the first iteration, stops applying the rules at all.
                 if (!firstRuleChecked && i == 0 && !ruleFired) {
                     return ruleFired;
diff --git a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFixpointRuleController.java b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFixpointRuleController.java
index 1fef33e..bbe281d 100644
--- a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFixpointRuleController.java
+++ b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFixpointRuleController.java
@@ -49,7 +49,7 @@
         do {
             anyChange = false;
             for (IAlgebraicRewriteRule rule : ruleCollection) {
-                boolean ruleFired = rewriteOperatorRef(root, rule, true, fullDfs);
+                boolean ruleFired = rewriteOperatorRef(root, rule, true, fullDfs, false);
                 if (ruleFired) {
                     anyChange = true;
                     anyRuleFired = true;
diff --git a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialOnceRuleController.java b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialOnceRuleController.java
index bcbc207..1090fe1 100644
--- a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialOnceRuleController.java
+++ b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialOnceRuleController.java
@@ -40,7 +40,7 @@
             throws AlgebricksException {
         boolean fired = false;
         for (IAlgebraicRewriteRule rule : rules) {
-            if (rewriteOperatorRef(root, rule, enterNestedPlans, true)) {
+            if (rewriteOperatorRef(root, rule, enterNestedPlans, true, false)) {
                 fired = true;
             }
         }
diff --git a/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/AbstractRuleController.java b/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/AbstractRuleController.java
index 0261106..9a47b8a 100644
--- a/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/AbstractRuleController.java
+++ b/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/AbstractRuleController.java
@@ -67,14 +67,15 @@
      */
     protected boolean rewriteOperatorRef(Mutable<ILogicalOperator> opRef, IAlgebraicRewriteRule rule)
             throws AlgebricksException {
-        return rewriteOperatorRef(opRef, rule, true, false);
+        return rewriteOperatorRef(opRef, rule, true, false, false);
     }
 
     protected boolean rewriteOperatorRef(Mutable<ILogicalOperator> opRef, IAlgebraicRewriteRule rule,
-            boolean enterNestedPlans, boolean fullDFS) throws AlgebricksException {
+            boolean enterNestedPlans, boolean fullDFS, boolean enteredNestedPlanRoot) throws AlgebricksException {
 
         String preBeforePlan = getPlanString(opRef);
         sanityCheckBeforeRewrite(rule, opRef);
+        rule.enteredNestedPlan(enteredNestedPlanRoot);
         if (rule.rewritePre(opRef, context)) {
             String preAfterPlan = getPlanString(opRef);
             printRuleApplication(rule, "fired", preBeforePlan, preAfterPlan);
@@ -88,7 +89,7 @@
         AbstractLogicalOperator op = (AbstractLogicalOperator) opRef.getValue();
 
         for (Mutable<ILogicalOperator> inp : op.getInputs()) {
-            if (rewriteOperatorRef(inp, rule, enterNestedPlans, fullDFS)) {
+            if (rewriteOperatorRef(inp, rule, enterNestedPlans, fullDFS, false)) {
                 rewritten = true;
                 if (!fullDFS) {
                     break;
@@ -100,7 +101,7 @@
             AbstractOperatorWithNestedPlans o2 = (AbstractOperatorWithNestedPlans) op;
             for (ILogicalPlan p : o2.getNestedPlans()) {
                 for (Mutable<ILogicalOperator> r : p.getRoots()) {
-                    if (rewriteOperatorRef(r, rule, enterNestedPlans, fullDFS)) {
+                    if (rewriteOperatorRef(r, rule, enterNestedPlans, fullDFS, true)) {
                         rewritten = true;
                         if (!fullDFS) {
                             break;
diff --git a/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/IAlgebraicRewriteRule.java b/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/IAlgebraicRewriteRule.java
index 128c372..33bc4a9 100644
--- a/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/IAlgebraicRewriteRule.java
+++ b/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/IAlgebraicRewriteRule.java
@@ -54,4 +54,13 @@
             throws AlgebricksException {
         return false;
     }
+
+    /**
+     * Called before calling {@link #rewritePre} to designate if the {@code opRef} is a nested plan root.
+     *
+     * @param enteredNestedPlanRoot whether the operator to be rewritten is a nested plan root.
+     */
+    default void enteredNestedPlan(boolean enteredNestedPlanRoot) {
+        // no op
+    }
 }
diff --git a/hyracks-fullstack/algebricks/algebricks-rewriter/src/main/java/org/apache/hyracks/algebricks/rewriter/rules/ExtractCommonExpressionsRule.java b/hyracks-fullstack/algebricks/algebricks-rewriter/src/main/java/org/apache/hyracks/algebricks/rewriter/rules/ExtractCommonExpressionsRule.java
index 9420498..e2ba557 100644
--- a/hyracks-fullstack/algebricks/algebricks-rewriter/src/main/java/org/apache/hyracks/algebricks/rewriter/rules/ExtractCommonExpressionsRule.java
+++ b/hyracks-fullstack/algebricks/algebricks-rewriter/src/main/java/org/apache/hyracks/algebricks/rewriter/rules/ExtractCommonExpressionsRule.java
@@ -77,17 +77,17 @@
  */
 public class ExtractCommonExpressionsRule implements IAlgebraicRewriteRule {
 
-    private final List<ILogicalExpression> originalAssignExprs = new ArrayList<ILogicalExpression>();
+    private final List<ILogicalExpression> originalAssignExprs = new ArrayList<>();
 
     private final CommonExpressionSubstitutionVisitor substVisitor = new CommonExpressionSubstitutionVisitor();
-    private final Map<ILogicalExpression, ExprEquivalenceClass> exprEqClassMap =
-            new HashMap<ILogicalExpression, ExprEquivalenceClass>();
+    private final Map<ILogicalExpression, ExprEquivalenceClass> exprEqClassMap = new HashMap<>();
 
     private final List<LogicalVariable> tmpLiveVars = new ArrayList<>();
     private final List<LogicalVariable> tmpProducedVars = new ArrayList<>();
+    private boolean enteredNestedPlan = false;
 
     // Set of operators for which common subexpression elimination should not be performed.
-    private static final Set<LogicalOperatorTag> ignoreOps = new HashSet<LogicalOperatorTag>(6);
+    private static final Set<LogicalOperatorTag> ignoreOps = new HashSet<>(6);
 
     static {
         ignoreOps.add(LogicalOperatorTag.UNNEST);
@@ -100,6 +100,11 @@
     }
 
     @Override
+    public void enteredNestedPlan(boolean enteredNestedPlanRoot) {
+        this.enteredNestedPlan = enteredNestedPlanRoot;
+    }
+
+    @Override
     public boolean rewritePost(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
             throws AlgebricksException {
         return false;
@@ -108,6 +113,14 @@
     @Override
     public boolean rewritePre(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
             throws AlgebricksException {
+        ILogicalOperator op = opRef.getValue();
+        if (enteredNestedPlan) {
+            enteredNestedPlan = false;
+        } else if (op.getOperatorTag() != LogicalOperatorTag.DISTRIBUTE_RESULT
+                && op.getOperatorTag() != LogicalOperatorTag.SINK
+                && op.getOperatorTag() != LogicalOperatorTag.DELEGATE_OPERATOR) {
+            return false;
+        }
         exprEqClassMap.clear();
         substVisitor.setContext(context);
         boolean modified = removeCommonExpressions(opRef, context);
@@ -155,9 +168,6 @@
     private boolean removeCommonExpressions(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
             throws AlgebricksException {
         AbstractLogicalOperator op = (AbstractLogicalOperator) opRef.getValue();
-        if (context.checkIfInDontApplySet(this, opRef.getValue())) {
-            return false;
-        }
 
         boolean modified = false;
         // Recurse into children.
@@ -166,7 +176,9 @@
                 modified = true;
             }
         }
-
+        if (context.checkIfInDontApplySet(this, opRef.getValue())) {
+            return modified;
+        }
         // TODO: Deal with replicate properly. Currently, we just clear the expr equivalence map,
         // since we want to avoid incorrect expression replacement
         // (the resulting new variables should be assigned live below a replicate/split).