[ASTERIXDB-3589][COMP] replace complex join predicate with expressions

Ext-ref: MB-66121

Change-Id: I355943fbf65fa0879b8a1e1827f6a4405997b05b
Reviewed-on: https://asterix-gerrit.ics.uci.edu/c/asterixdb/+/19608
Reviewed-by: <murali.krishna@couchbase.com>
Reviewed-by: <preethampoluparthi@gmail.com>
Reviewed-by: Hussain Towaileb <hussainht@gmail.com>
Tested-by: Jenkins <jenkins@fulliautomatix.ics.uci.edu>
Integration-Tests: Jenkins <jenkins@fulliautomatix.ics.uci.edu>
diff --git a/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/cbo/JoinEnum.java b/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/cbo/JoinEnum.java
index 11a9c95..4be0e66 100644
--- a/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/cbo/JoinEnum.java
+++ b/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/cbo/JoinEnum.java
@@ -19,6 +19,8 @@
 
 package org.apache.asterix.optimizer.rules.cbo;
 
+import static org.apache.asterix.om.functions.BuiltinFunctions.getBuiltinFunctionInfo;
+
 import java.time.LocalDateTime;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -264,8 +266,8 @@
             JoinCondition jc = joinConditions.get(newJoinConditions.get(0));
             return jc.joinCondition;
         }
-        ScalarFunctionCallExpression andExpr = new ScalarFunctionCallExpression(
-                BuiltinFunctions.getBuiltinFunctionInfo(AlgebricksBuiltinFunctions.AND));
+        ScalarFunctionCallExpression andExpr =
+                new ScalarFunctionCallExpression(getBuiltinFunctionInfo(AlgebricksBuiltinFunctions.AND));
 
         for (int joinNum : newJoinConditions) {
             // Need to AND all the expressions.
@@ -284,8 +286,8 @@
             JoinCondition jc = joinConditions.get(newJoinConditions.get(0));
             return jc.joinCondition;
         }
-        ScalarFunctionCallExpression andExpr = new ScalarFunctionCallExpression(
-                BuiltinFunctions.getBuiltinFunctionInfo(AlgebricksBuiltinFunctions.AND));
+        ScalarFunctionCallExpression andExpr =
+                new ScalarFunctionCallExpression(getBuiltinFunctionInfo(AlgebricksBuiltinFunctions.AND));
         for (int joinNum : newJoinConditions) {
             // need to AND all the expressions. skip derived exprs for now.
             JoinCondition jc = joinConditions.get(joinNum);
@@ -315,8 +317,8 @@
             }
             return null;
         }
-        ScalarFunctionCallExpression andExpr = new ScalarFunctionCallExpression(
-                BuiltinFunctions.getBuiltinFunctionInfo(AlgebricksBuiltinFunctions.AND));
+        ScalarFunctionCallExpression andExpr =
+                new ScalarFunctionCallExpression(getBuiltinFunctionInfo(AlgebricksBuiltinFunctions.AND));
 
         // at least one equality predicate needs to be present for a hash join to be possible.
         boolean eqPredFound = false;
@@ -496,16 +498,43 @@
         List<LogicalVariable> usedVars = new ArrayList<>();
         List<AssignOperator> erase = new ArrayList<>();
         for (JoinCondition jc : joinConditions) {
-            usedVars.clear();
             ILogicalExpression expr = jc.joinCondition;
+            AbstractFunctionCallExpression aexpr = (AbstractFunctionCallExpression) expr;
+            usedVars.clear();
             expr.getUsedVariables(usedVars);
-            for (AssignOperator aOp : assignOps) {
+            boolean fixed = false;
+            for (AssignOperator aOp : assignOps) { // These assignOps are internal assignOps (found between join nodes)
                 for (int i = 0; i < aOp.getVariables().size(); i++) {
                     if (usedVars.contains(aOp.getVariables().get(i))) {
                         OperatorManipulationUtil.replaceVarWithExpr((AbstractFunctionCallExpression) expr,
                                 aOp.getVariables().get(i), aOp.getExpressions().get(i).getValue());
                         jc.joinCondition = expr;
                         erase.add(aOp);
+                        fixed = true;
+                    }
+                }
+            }
+            if (!fixed) {
+                // now comes the hard part. Need to look thru all the assigns in the leafInputs
+                for (ILogicalOperator op : leafInputs) {
+                    while (op.getOperatorTag() != LogicalOperatorTag.EMPTYTUPLESOURCE) {
+                        if (op.getOperatorTag() == LogicalOperatorTag.ASSIGN) {
+                            AssignOperator aOp = (AssignOperator) op;
+                            ILogicalExpression a = aOp.getExpressions().get(0).getValue();
+                            usedVars.clear();
+                            a.getUsedVariables(usedVars);
+                            if (usedVars.size() > 1) {
+                                for (int i = 0; i < aOp.getVariables().size(); i++) {
+                                    if (usedVars.contains(aOp.getVariables().get(i))) {
+                                        OperatorManipulationUtil.replaceVarWithExpr(
+                                                (AbstractFunctionCallExpression) expr, aOp.getVariables().get(i),
+                                                aOp.getExpressions().get(i).getValue());
+                                        jc.joinCondition = expr;
+                                    }
+                                }
+                            }
+                        }
+                        op = op.getInputs().get(0).getValue();
                     }
                 }
             }
diff --git a/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/cbo/Stats.java b/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/cbo/Stats.java
index 32c70cb..4f5688a 100644
--- a/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/cbo/Stats.java
+++ b/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/cbo/Stats.java
@@ -119,7 +119,7 @@
         List<LogicalVariable> exprUsedVars = new ArrayList<>();
         joinExpr.getUsedVariables(exprUsedVars);
 
-        if (jc.numLeafInputs != 2) {
+        if ((jc.numLeafInputs != 2) || (exprUsedVars.size() <= 1)) {
             // we can only deal with binary joins. More checks should be in place as well such as R.a op S.a
             return 1.0;
         }
diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/results_cbo/column/filter/subplan/subplan.042.plan b/asterixdb/asterix-app/src/test/resources/runtimets/results_cbo/column/filter/subplan/subplan.042.plan
index 6ba2a43..ef307c5 100644
--- a/asterixdb/asterix-app/src/test/resources/runtimets/results_cbo/column/filter/subplan/subplan.042.plan
+++ b/asterixdb/asterix-app/src/test/resources/runtimets/results_cbo/column/filter/subplan/subplan.042.plan
@@ -1,16 +1,16 @@
-distribute result [$$70] [cardinality: 3.08, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
+distribute result [$$70] [cardinality: 6.0, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
 -- DISTRIBUTE_RESULT  |UNPARTITIONED|
-  exchange [cardinality: 3.08, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
+  exchange [cardinality: 6.0, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
   -- ONE_TO_ONE_EXCHANGE  |UNPARTITIONED|
-    aggregate [$$70] <- [agg-sql-sum($$76)] [cardinality: 3.08, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
+    aggregate [$$70] <- [agg-sql-sum($$76)] [cardinality: 6.0, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
     -- AGGREGATE  |UNPARTITIONED|
-      exchange [cardinality: 3.08, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
+      exchange [cardinality: 6.0, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
       -- RANDOM_MERGE_EXCHANGE  |PARTITIONED|
-        aggregate [$$76] <- [agg-sql-count(1)] [cardinality: 3.08, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
+        aggregate [$$76] <- [agg-sql-count(1)] [cardinality: 6.0, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
         -- AGGREGATE  |PARTITIONED|
-          exchange [cardinality: 3.08, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
+          exchange [cardinality: 6.0, doc-size: 11.0, op-cost: 0.0, total-cost: 28.6]
           -- ONE_TO_ONE_EXCHANGE  |PARTITIONED|
-            join (or(eq($$71, "7"), neq($$69, 0))) [cardinality: 3.08, doc-size: 11.0, op-cost: 12.6, total-cost: 28.6]
+            join (or(eq($$71, "7"), neq($$69, 0))) [cardinality: 6.0, doc-size: 11.0, op-cost: 12.6, total-cost: 28.6]
             -- NESTED_LOOP  |PARTITIONED|
               exchange [cardinality: 6.0, doc-size: 9.0, op-cost: 0.0, total-cost: 6.0]
               -- ONE_TO_ONE_EXCHANGE  |PARTITIONED|