support recursive cast of record type in static casting

git-svn-id: https://asterixdb.googlecode.com/svn/branches/asterix_opentype@318 eaa15691-b419-025a-1212-ee371bd00084
diff --git a/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/IntroduceStaticTypeCastRule.java b/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/IntroduceStaticTypeCastRule.java
index fc0a2e8..f76d6a7 100644
--- a/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/IntroduceStaticTypeCastRule.java
+++ b/asterix-algebra/src/main/java/edu/uci/ics/asterix/optimizer/rules/IntroduceStaticTypeCastRule.java
@@ -36,6 +36,10 @@
 
 public class IntroduceStaticTypeCastRule implements IAlgebraicRewriteRule {
 
+    // nested open field rec type
+    private static ARecordType nestedOpenRecType = new ARecordType("nested-open", new String[] {}, new IAType[] {},
+            true);
+
     @Override
     public boolean rewritePre(Mutable<ILogicalOperator> opRef, IOptimizationContext context) throws AlgebricksException {
         return false;
@@ -84,13 +88,13 @@
 
         AbstractLogicalOperator currentOperator = oldAssignOperator;
         List<LogicalVariable> producedVariables = new ArrayList<LogicalVariable>();
-        boolean changed = false;
 
         /**
          * find the assign operator for the "input record" to the insert_delete
          * operator
          */
         do {
+            context.addToDontApplySet(this, currentOperator);
             if (currentOperator.getOperatorTag() == LogicalOperatorTag.ASSIGN) {
                 producedVariables.clear();
                 VariableUtilities.getProducedVariables(currentOperator, producedVariables);
@@ -105,21 +109,9 @@
                     ILogicalExpression expr = expressionPointers.get(position).getValue();
                     if (expr.getExpressionTag() == LogicalExpressionTag.FUNCTION_CALL) {
                         ScalarFunctionCallExpression funcExpr = (ScalarFunctionCallExpression) expr;
-                        changed = TypeComputerUtilities.setRequiredAndInputTypes(funcExpr, requiredRecordType,
-                                inputRecordType);
-                        changed &= !requiredRecordType.equals(inputRecordType);
-                        if (changed) {
-                            staticTypeCast(funcExpr, requiredRecordType, inputRecordType);
-                            List<Mutable<ILogicalExpression>> args = funcExpr.getArguments();
-                            int openPartStart = requiredRecordType.getFieldTypes().length * 2;
-                            for (int j = openPartStart; j < args.size(); j++) {
-                                ILogicalExpression arg = args.get(j).getValue();
-                                if (arg.getExpressionTag() == LogicalExpressionTag.FUNCTION_CALL) {
-                                    AbstractFunctionCallExpression argFunc = (AbstractFunctionCallExpression) arg;
-                                    TypeComputerUtilities.setOpenType(argFunc, true);
-                                }
-                            }
-                        }
+                        if(TypeComputerUtilities.getRequiredType(funcExpr)!=null)
+                            return false;
+                        rewriteFuncExpr(funcExpr, requiredRecordType, inputRecordType);
                     }
                     context.computeAndSetTypeEnvironmentForOperator(originalAssign);
                 }
@@ -129,7 +121,24 @@
             else
                 break;
         } while (currentOperator != null);
-        return changed;
+        return true;
+    }
+
+    private void rewriteFuncExpr(ScalarFunctionCallExpression funcExpr, ARecordType requiredRecordType,
+            ARecordType inputRecordType) {
+        TypeComputerUtilities.setRequiredAndInputTypes(funcExpr, requiredRecordType, inputRecordType);
+        staticTypeCast(funcExpr, requiredRecordType, inputRecordType);
+        List<Mutable<ILogicalExpression>> args = funcExpr.getArguments();
+        int openPartStart = requiredRecordType.getFieldTypes().length * 2;
+        if (requiredRecordType.isOpen()) {
+            for (int j = openPartStart; j < args.size(); j++) {
+                ILogicalExpression arg = args.get(j).getValue();
+                if (arg.getExpressionTag() == LogicalExpressionTag.FUNCTION_CALL) {
+                    AbstractFunctionCallExpression argFunc = (AbstractFunctionCallExpression) arg;
+                    TypeComputerUtilities.setOpenType(argFunc, true);
+                }
+            }
+        }
     }
 
     private void staticTypeCast(ScalarFunctionCallExpression func, ARecordType reqType, ARecordType inputType) {
@@ -163,6 +172,12 @@
                         fieldPermutation[j] = i;
                         openFields[i] = false;
                         matched = true;
+
+                        if (fieldType.getTypeTag() == ATypeTag.RECORD) {
+                            ScalarFunctionCallExpression scalarFunc = (ScalarFunctionCallExpression) func
+                                    .getArguments().get(2 * i + 1).getValue();
+                            rewriteFuncExpr(scalarFunc, (ARecordType) reqFieldType, (ARecordType) fieldType);
+                        }
                         break;
                     }
 
@@ -171,13 +186,33 @@
                             && NonTaggedFormatUtil.isOptionalField((AUnionType) reqFieldType)) {
                         IAType itemType = ((AUnionType) reqFieldType).getUnionList().get(
                                 NonTaggedFormatUtil.OPTIONAL_TYPE_INDEX_IN_UNION_LIST);
+                        reqFieldType = itemType;
                         if (fieldType.equals(BuiltinType.ANULL) || fieldType.equals(itemType)) {
                             fieldPermutation[j] = i;
                             openFields[i] = false;
                             matched = true;
+
+                            // rewrite record expr
+                            if (reqFieldType.getTypeTag() == ATypeTag.RECORD
+                                    && fieldType.getTypeTag() == ATypeTag.RECORD) {
+                                ScalarFunctionCallExpression scalarFunc = (ScalarFunctionCallExpression) func
+                                        .getArguments().get(2 * i + 1).getValue();
+                                rewriteFuncExpr(scalarFunc, (ARecordType) reqFieldType, (ARecordType) fieldType);
+                            }
                             break;
                         }
                     }
+
+                    // match the record field: need cast
+                    if (reqFieldType.getTypeTag() == ATypeTag.RECORD && fieldType.getTypeTag() == ATypeTag.RECORD) {
+                        ScalarFunctionCallExpression scalarFunc = (ScalarFunctionCallExpression) func.getArguments()
+                                .get(2 * i + 1).getValue();
+                        rewriteFuncExpr(scalarFunc, (ARecordType) reqFieldType, (ARecordType) fieldType);
+                        fieldPermutation[j] = i;
+                        openFields[i] = false;
+                        matched = true;
+                        break;
+                    }
                 }
             }
             if (matched)
@@ -196,7 +231,7 @@
                 String fieldName = inputFieldNames[j];
                 IAType fieldType = inputFieldTypes[j];
                 if (fieldName.equals(reqFieldName)) {
-                    if (fieldType.equals(reqFieldType)) {
+                    if (!openFields[j]) {
                         matched = true;
                         break;
                     }
@@ -250,7 +285,12 @@
         for (int i = 0; i < openFields.length; i++) {
             if (openFields[i]) {
                 arguments.add(argumentsClone.get(2 * i));
-                arguments.add(argumentsClone.get(2 * i + 1));
+                Mutable<ILogicalExpression> fExprRef = argumentsClone.get(2 * i + 1);
+                if (inputFieldTypes[i].getTypeTag() == ATypeTag.RECORD) {
+                    ScalarFunctionCallExpression funcExpr = (ScalarFunctionCallExpression) fExprRef.getValue();
+                    rewriteFuncExpr(funcExpr, nestedOpenRecType, (ARecordType) inputFieldTypes[i]);
+                }
+                arguments.add(fExprRef);
             }
         }
     }
diff --git a/asterix-app/src/test/resources/runtimets/queries/nestrecords/nestrecord.aql b/asterix-app/src/test/resources/runtimets/queries/nestrecords/nestrecord.aql
new file mode 100644
index 0000000..42d93a8
--- /dev/null
+++ b/asterix-app/src/test/resources/runtimets/queries/nestrecords/nestrecord.aql
@@ -0,0 +1,41 @@
+/* 
+ * Test case Name  : opentype-closed-optional.aql
+ * Description     : verify that closed type can have optional fields
+ * Expected Result : Success
+ */
+
+drop dataverse testdv2 if exists;
+create dataverse testdv2;
+use dataverse testdv2;
+
+
+create type AddressType as open{
+  street: string,
+  city: string
+}
+
+create type testtype as closed {
+  name: string,
+  id: string,
+  address: AddressType?
+}
+
+create dataset testds(testtype) partitioned by key id;
+
+insert into dataset testds (
+{ "id": "001", "name": "Person One", "address": {"street": "3019 DBH",  "city": "Irvine", "zip": 92697} }
+);
+
+insert into dataset testds (
+{ "id": "002", "name": "Person Two" }
+);
+
+insert into dataset testds (
+{ "id": "003", "name": "Person Three", "address": {"street": "2019 DBH",  "city": "Irvine"} }
+);
+
+write output to nc1:"rttest/nestrecords_nestrecord.adm";
+
+for $d in dataset("testds") 
+order by $d.id
+return $d
diff --git a/asterix-app/src/test/resources/runtimets/results/nestrecords/nestrecord.adm b/asterix-app/src/test/resources/runtimets/results/nestrecords/nestrecord.adm
new file mode 100644
index 0000000..3a67eae
--- /dev/null
+++ b/asterix-app/src/test/resources/runtimets/results/nestrecords/nestrecord.adm
@@ -0,0 +1,3 @@
+{ "name": "Person One", "id": "001", "address": { "street": "3019 DBH", "city": "Irvine", "zip": 92697 } }
+{ "name": "Person Two", "id": "002", "address": null }
+{ "name": "Person Three", "id": "003", "address": { "street": "2019 DBH", "city": "Irvine" } }
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/accessors/ARecordAccessor.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/accessors/ARecordAccessor.java
index 6bd798d..e05c67a 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/accessors/ARecordAccessor.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/accessors/ARecordAccessor.java
@@ -124,6 +124,7 @@
     private void reset() {
         typeBos.setByteArray(typeBuffer, closedPartTypeInfoSize);
         dataBos.setByteArray(dataBuffer, 0);
+        //reset the allocator
         allocator.reset();
 
         //clean up the returned containers
diff --git a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/accessors/cast/ARecordCaster.java b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/accessors/cast/ARecordCaster.java
index 8073c91..a3f31bb 100644
--- a/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/accessors/cast/ARecordCaster.java
+++ b/asterix-runtime/src/main/java/edu/uci/ics/asterix/runtime/accessors/cast/ARecordCaster.java
@@ -22,6 +22,7 @@
 import java.util.List;
 
 import edu.uci.ics.asterix.builders.RecordBuilder;
+import edu.uci.ics.asterix.common.exceptions.AsterixException;
 import edu.uci.ics.asterix.dataflow.data.nontagged.AqlNullWriterFactory;
 import edu.uci.ics.asterix.om.types.ARecordType;
 import edu.uci.ics.asterix.om.types.ATypeTag;
@@ -32,6 +33,7 @@
 import edu.uci.ics.asterix.runtime.accessors.ARecordAccessor;
 import edu.uci.ics.asterix.runtime.accessors.base.IBinaryAccessor;
 import edu.uci.ics.asterix.runtime.util.ResettableByteArrayOutputStream;
+import edu.uci.ics.hyracks.algebricks.common.utils.Triple;
 import edu.uci.ics.hyracks.api.dataflow.value.IBinaryComparator;
 import edu.uci.ics.hyracks.api.dataflow.value.INullWriter;
 import edu.uci.ics.hyracks.data.std.accessors.PointableBinaryComparatorFactory;
@@ -69,6 +71,10 @@
     private ResettableByteArrayOutputStream outputBos = new ResettableByteArrayOutputStream();
     private DataOutputStream outputDos = new DataOutputStream(outputBos);
 
+    private IBinaryAccessor fieldTempReference = AFlatValueAccessor.FACTORY.createElement(null);
+    private Triple<IBinaryAccessor, IAType, Boolean> nestedVisitorArg = new Triple<IBinaryAccessor, IAType, Boolean>(
+            fieldTempReference, null, null);
+
     public ARecordCaster() {
         try {
             bos.setByteArray(buffer, 0);
@@ -87,7 +93,7 @@
     }
 
     public void castRecord(ARecordAccessor recordAccessor, IBinaryAccessor resultAccessor, ARecordType reqType,
-            ACastVisitor visitor) throws IOException {
+            ACastVisitor visitor) throws IOException, AsterixException {
         List<IBinaryAccessor> fieldNames = recordAccessor.getFieldNames();
         List<IBinaryAccessor> fieldTypeTags = recordAccessor.getFieldTypeTags();
         List<IBinaryAccessor> fieldValues = recordAccessor.getFieldValues();
@@ -215,7 +221,8 @@
     }
 
     private void writeOutput(List<IBinaryAccessor> fieldNames, List<IBinaryAccessor> fieldTypeTags,
-            List<IBinaryAccessor> fieldValues, DataOutput output, ACastVisitor visitor) throws IOException {
+            List<IBinaryAccessor> fieldValues, DataOutput output, ACastVisitor visitor) throws IOException,
+            AsterixException {
         // reset the states of the record builder
         recBuilder.reset(cachedReqType);
         recBuilder.init();
@@ -229,7 +236,12 @@
             } else {
                 field = nullReference;
             }
-            recBuilder.addField(i, field);
+            IAType fType = cachedReqType.getFieldTypes()[i];
+            nestedVisitorArg.second = fType;
+            
+            //recursively casting, the result of casting can always be thought as flat
+            field.accept(visitor, nestedVisitorArg);
+            recBuilder.addField(i, nestedVisitorArg.first);
         }
 
         // write the open part