Improved task attempt lookup in job scheduler

git-svn-id: https://hyracks.googlecode.com/svn/branches/hyracks_dev_next@1202 123451ca-8445-de46-9d55-352943316053
diff --git a/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/job/JobRun.java b/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/job/JobRun.java
index a72a866..707ef1f 100644
--- a/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/job/JobRun.java
+++ b/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/job/JobRun.java
@@ -320,7 +320,7 @@
                             attempt.put("end-time", tca.getEndTime());
 
                             JSONArray taskAttempts = new JSONArray();
-                            for (TaskAttempt ta : tca.getTaskAttempts()) {
+                            for (TaskAttempt ta : tca.getTaskAttempts().values()) {
                                 JSONObject taskAttempt = new JSONObject();
                                 taskAttempt.put("task-id", ta.getTaskAttemptId().getTaskId());
                                 taskAttempt.put("task-attempt-id", ta.getTaskAttemptId());
diff --git a/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/job/TaskClusterAttempt.java b/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/job/TaskClusterAttempt.java
index 84848bb..2b74585 100644
--- a/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/job/TaskClusterAttempt.java
+++ b/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/job/TaskClusterAttempt.java
@@ -14,6 +14,10 @@
  */
 package edu.uci.ics.hyracks.control.cc.job;
 
+import java.util.Map;
+
+import edu.uci.ics.hyracks.api.dataflow.TaskId;
+
 public class TaskClusterAttempt {
     public enum TaskClusterStatus {
         RUNNING,
@@ -26,7 +30,7 @@
 
     private final int attempt;
 
-    private TaskAttempt[] taskAttempts;
+    private Map<TaskId, TaskAttempt> taskAttempts;
 
     private TaskClusterStatus status;
 
@@ -47,11 +51,11 @@
         return taskCluster;
     }
 
-    public void setTaskAttempts(TaskAttempt[] taskAttempts) {
+    public void setTaskAttempts(Map<TaskId, TaskAttempt> taskAttempts) {
         this.taskAttempts = taskAttempts;
     }
 
-    public TaskAttempt[] getTaskAttempts() {
+    public Map<TaskId, TaskAttempt> getTaskAttempts() {
         return taskAttempts;
     }
 
@@ -84,7 +88,7 @@
     }
 
     public void initializePendingTaskCounter() {
-        pendingTaskCounter = taskAttempts.length;
+        pendingTaskCounter = taskAttempts.size();
     }
 
     public int getPendingTaskCounter() {
diff --git a/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/scheduler/JobScheduler.java b/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/scheduler/JobScheduler.java
index eada1f5..a4e340c 100644
--- a/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/scheduler/JobScheduler.java
+++ b/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/scheduler/JobScheduler.java
@@ -26,6 +26,7 @@
 import java.util.logging.Logger;
 
 import edu.uci.ics.hyracks.api.application.ICCApplicationContext;
+import edu.uci.ics.hyracks.api.comm.NetworkAddress;
 import edu.uci.ics.hyracks.api.constraints.Constraint;
 import edu.uci.ics.hyracks.api.constraints.IConstraintAcceptor;
 import edu.uci.ics.hyracks.api.constraints.expressions.LValueConstraintExpression;
@@ -47,6 +48,7 @@
 import edu.uci.ics.hyracks.control.cc.ClusterControllerService;
 import edu.uci.ics.hyracks.control.cc.NodeControllerState;
 import edu.uci.ics.hyracks.control.cc.job.ActivityCluster;
+import edu.uci.ics.hyracks.control.cc.job.ActivityPlan;
 import edu.uci.ics.hyracks.control.cc.job.IConnectorDescriptorVisitor;
 import edu.uci.ics.hyracks.control.cc.job.IOperatorDescriptorVisitor;
 import edu.uci.ics.hyracks.control.cc.job.JobRun;
@@ -326,7 +328,7 @@
         List<TaskClusterAttempt> tcAttempts = tc.getAttempts();
         int attempts = tcAttempts.size();
         TaskClusterAttempt tcAttempt = new TaskClusterAttempt(tc, attempts);
-        TaskAttempt[] taskAttempts = new TaskAttempt[tasks.length];
+        Map<TaskId, TaskAttempt> taskAttempts = new HashMap<TaskId, TaskAttempt>();
         Map<TaskId, LValueConstraintExpression> locationMap = new HashMap<TaskId, LValueConstraintExpression>();
         for (int i = 0; i < tasks.length; ++i) {
             Task ts = tasks[i];
@@ -336,57 +338,15 @@
             taskAttempt.setStatus(TaskAttempt.TaskStatus.INITIALIZED, null);
             locationMap.put(tid,
                     new PartitionLocationExpression(tid.getActivityId().getOperatorDescriptorId(), tid.getPartition()));
-            taskAttempts[i] = taskAttempt;
+            taskAttempts.put(tid, taskAttempt);
         }
         tcAttempt.setTaskAttempts(taskAttempts);
         solver.solve(locationMap.values());
         for (int i = 0; i < tasks.length; ++i) {
             Task ts = tasks[i];
             TaskId tid = ts.getTaskId();
-            TaskAttempt taskAttempt = taskAttempts[i];
-            ActivityId aid = tid.getActivityId();
-            Set<ActivityId> blockers = jag.getBlocked2BlockerMap().get(aid);
-            String nodeId = null;
-            if (blockers != null) {
-                for (ActivityId blocker : blockers) {
-                    nodeId = findLocationOfBlocker(jobRun, jag, new TaskId(blocker, tid.getPartition()));
-                    if (nodeId != null) {
-                        break;
-                    }
-                }
-            }
-            Set<String> liveNodes = ccs.getNodeMap().keySet();
-            if (nodeId == null) {
-                LValueConstraintExpression pLocationExpr = locationMap.get(tid);
-                Object location = solver.getValue(pLocationExpr);
-                if (location == null) {
-                    // pick any
-                    nodeId = liveNodes.toArray(new String[liveNodes.size()])[Math.abs(new Random().nextInt())
-                            % liveNodes.size()];
-                } else if (location instanceof String) {
-                    nodeId = (String) location;
-                } else if (location instanceof String[]) {
-                    for (String choice : (String[]) location) {
-                        if (liveNodes.contains(choice)) {
-                            nodeId = choice;
-                            break;
-                        }
-                    }
-                    if (nodeId == null) {
-                        throw new HyracksException("No satisfiable location found for "
-                                + taskAttempt.getTaskAttemptId());
-                    }
-                } else {
-                    throw new HyracksException("Unknown type of value for " + pLocationExpr + ": " + location + "("
-                            + location.getClass() + ")");
-                }
-            }
-            if (nodeId == null) {
-                throw new HyracksException("No satisfiable location found for " + taskAttempt.getTaskAttemptId());
-            }
-            if (!liveNodes.contains(nodeId)) {
-                throw new HyracksException("Node " + nodeId + " not live");
-            }
+            TaskAttempt taskAttempt = taskAttempts.get(tid);
+            String nodeId = assignLocation(jag, locationMap, tid, taskAttempt);
             taskAttempt.setNodeId(nodeId);
             taskAttempt.setStatus(TaskAttempt.TaskStatus.RUNNING, null);
             taskAttempt.setStartTime(System.currentTimeMillis());
@@ -396,32 +356,100 @@
                 taskAttemptMap.put(nodeId, tads);
             }
             ActivityPartitionDetails apd = ts.getActivityPlan().getActivityPartitionDetails();
-            tads.add(new TaskAttemptDescriptor(taskAttempt.getTaskAttemptId(), apd.getPartitionCount(), apd
-                    .getInputPartitionCounts(), apd.getOutputPartitionCounts()));
+            TaskAttemptDescriptor tad = new TaskAttemptDescriptor(taskAttempt.getTaskAttemptId(),
+                    apd.getPartitionCount(), apd.getInputPartitionCounts(), apd.getOutputPartitionCounts());
+            tads.add(tad);
         }
         tcAttempt.initializePendingTaskCounter();
         tcAttempts.add(tcAttempt);
+
+        /* TODO - Further improvement for reducing messages -- not yet complete.
+        for (Map.Entry<String, List<TaskAttemptDescriptor>> e : taskAttemptMap.entrySet()) {
+            List<TaskAttemptDescriptor> tads = e.getValue();
+            for (TaskAttemptDescriptor tad : tads) {
+                TaskId tid = tad.getTaskAttemptId().getTaskId();
+                ActivityId aid = tid.getActivityId();
+                List<IConnectorDescriptor> inConnectors = jag.getActivityInputConnectorDescriptors(aid);
+                int[] inPartitionCounts = tad.getInputPartitionCounts();
+                NetworkAddress[][] partitionLocations = new NetworkAddress[inPartitionCounts.length][];
+                for (int i = 0; i < inPartitionCounts.length; ++i) {
+                    ConnectorDescriptorId cdId = inConnectors.get(i).getConnectorId();
+                    ActivityId producerAid = jag.getProducerActivity(cdId);
+                    partitionLocations[i] = new NetworkAddress[inPartitionCounts[i]];
+                    for (int j = 0; j < inPartitionCounts[i]; ++j) {
+                        TaskId producerTaskId = new TaskId(producerAid, j);
+                        String nodeId = findTaskLocation(producerTaskId);
+                        partitionLocations[i][j] = ccs.getNodeMap().get(nodeId).getDataPort();
+                    }
+                }
+                tad.setInputPartitionLocations(partitionLocations);
+            }
+        }
+        */
+
         tcAttempt.setStatus(TaskClusterAttempt.TaskClusterStatus.RUNNING);
         tcAttempt.setStartTime(System.currentTimeMillis());
         inProgressTaskClusters.add(tc);
     }
 
-    private static String findLocationOfBlocker(JobRun jobRun, JobActivityGraph jag, TaskId tid) {
-        ActivityId blockerAID = tid.getActivityId();
-        ActivityCluster blockerAC = jobRun.getActivityClusterMap().get(blockerAID);
-        Task[] blockerTasks = blockerAC.getPlan().getActivityPlanMap().get(blockerAID).getTasks();
-        List<TaskClusterAttempt> tcAttempts = blockerTasks[tid.getPartition()].getTaskCluster().getAttempts();
+    private String assignLocation(JobActivityGraph jag, Map<TaskId, LValueConstraintExpression> locationMap,
+            TaskId tid, TaskAttempt taskAttempt) throws HyracksException {
+        ActivityId aid = tid.getActivityId();
+        Set<ActivityId> blockers = jag.getBlocked2BlockerMap().get(aid);
+        String nodeId = null;
+        if (blockers != null) {
+            for (ActivityId blocker : blockers) {
+                nodeId = findTaskLocation(new TaskId(blocker, tid.getPartition()));
+                if (nodeId != null) {
+                    break;
+                }
+            }
+        }
+        Set<String> liveNodes = ccs.getNodeMap().keySet();
+        if (nodeId == null) {
+            LValueConstraintExpression pLocationExpr = locationMap.get(tid);
+            Object location = solver.getValue(pLocationExpr);
+            if (location == null) {
+                // pick any
+                nodeId = liveNodes.toArray(new String[liveNodes.size()])[Math.abs(new Random().nextInt())
+                        % liveNodes.size()];
+            } else if (location instanceof String) {
+                nodeId = (String) location;
+            } else if (location instanceof String[]) {
+                for (String choice : (String[]) location) {
+                    if (liveNodes.contains(choice)) {
+                        nodeId = choice;
+                        break;
+                    }
+                }
+                if (nodeId == null) {
+                    throw new HyracksException("No satisfiable location found for " + taskAttempt.getTaskAttemptId());
+                }
+            } else {
+                throw new HyracksException("Unknown type of value for " + pLocationExpr + ": " + location + "("
+                        + location.getClass() + ")");
+            }
+        }
+        if (nodeId == null) {
+            throw new HyracksException("No satisfiable location found for " + taskAttempt.getTaskAttemptId());
+        }
+        if (!liveNodes.contains(nodeId)) {
+            throw new HyracksException("Node " + nodeId + " not live");
+        }
+        return nodeId;
+    }
+
+    private String findTaskLocation(TaskId tid) {
+        ActivityId aid = tid.getActivityId();
+        ActivityCluster ac = jobRun.getActivityClusterMap().get(aid);
+        Task[] tasks = ac.getPlan().getActivityPlanMap().get(aid).getTasks();
+        List<TaskClusterAttempt> tcAttempts = tasks[tid.getPartition()].getTaskCluster().getAttempts();
         if (tcAttempts == null || tcAttempts.isEmpty()) {
             return null;
         }
         TaskClusterAttempt lastTCA = tcAttempts.get(tcAttempts.size() - 1);
-        for (TaskAttempt ta : lastTCA.getTaskAttempts()) {
-            TaskId blockerTID = ta.getTaskAttemptId().getTaskId();
-            if (tid.equals(blockerTID)) {
-                return ta.getNodeId();
-            }
-        }
-        return null;
+        TaskAttempt ta = lastTCA.getTaskAttempts().get(tid);
+        return ta == null ? null : ta.getNodeId();
     }
 
     private static TaskClusterAttempt findLastTaskClusterAttempt(TaskCluster tc) {
@@ -470,7 +498,7 @@
         LOGGER.fine("Aborting task cluster: " + tcAttempt.getAttempt());
         Set<TaskAttemptId> abortTaskIds = new HashSet<TaskAttemptId>();
         Map<String, List<TaskAttemptId>> abortTaskAttemptMap = new HashMap<String, List<TaskAttemptId>>();
-        for (TaskAttempt ta : tcAttempt.getTaskAttempts()) {
+        for (TaskAttempt ta : tcAttempt.getTaskAttempts().values()) {
             TaskAttemptId taId = ta.getTaskAttemptId();
             TaskAttempt.TaskStatus status = ta.getStatus();
             abortTaskIds.add(taId);
@@ -646,7 +674,7 @@
                                 && (lastTaskClusterAttempt.getStatus() == TaskClusterAttempt.TaskClusterStatus.COMPLETED || lastTaskClusterAttempt
                                         .getStatus() == TaskClusterAttempt.TaskClusterStatus.RUNNING)) {
                             boolean abort = false;
-                            for (TaskAttempt ta : lastTaskClusterAttempt.getTaskAttempts()) {
+                            for (TaskAttempt ta : lastTaskClusterAttempt.getTaskAttempts().values()) {
                                 assert (ta.getStatus() == TaskAttempt.TaskStatus.COMPLETED || ta.getStatus() == TaskAttempt.TaskStatus.RUNNING);
                                 if (deadNodes.contains(ta.getNodeId())) {
                                     ta.setStatus(TaskAttempt.TaskStatus.FAILED, "Node " + ta.getNodeId() + " failed");
diff --git a/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/work/AbstractTaskLifecycleWork.java b/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/work/AbstractTaskLifecycleWork.java
index a00f4c4..01f14f3 100644
--- a/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/work/AbstractTaskLifecycleWork.java
+++ b/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/work/AbstractTaskLifecycleWork.java
@@ -60,11 +60,9 @@
                     List<TaskClusterAttempt> taskClusterAttempts = tc.getAttempts();
                     if (taskClusterAttempts != null && taskClusterAttempts.size() > taId.getAttempt()) {
                         TaskClusterAttempt tca = taskClusterAttempts.get(taId.getAttempt());
-                        for (TaskAttempt ta : tca.getTaskAttempts()) {
-                            if (ta.getTaskAttemptId().equals(taId)) {
-                                performEvent(ta);
-                                break;
-                            }
+                        TaskAttempt ta = tca.getTaskAttempts().get(tid);
+                        if (ta != null) {
+                            performEvent(ta);
                         }
                     }
                 }
diff --git a/hyracks-control-common/src/main/java/edu/uci/ics/hyracks/control/common/job/TaskAttemptDescriptor.java b/hyracks-control-common/src/main/java/edu/uci/ics/hyracks/control/common/job/TaskAttemptDescriptor.java
index af89489..f6d1f78 100644
--- a/hyracks-control-common/src/main/java/edu/uci/ics/hyracks/control/common/job/TaskAttemptDescriptor.java
+++ b/hyracks-control-common/src/main/java/edu/uci/ics/hyracks/control/common/job/TaskAttemptDescriptor.java
@@ -17,6 +17,7 @@
 import java.io.Serializable;
 import java.util.Arrays;
 
+import edu.uci.ics.hyracks.api.comm.NetworkAddress;
 import edu.uci.ics.hyracks.api.dataflow.TaskAttemptId;
 
 public class TaskAttemptDescriptor implements Serializable {
@@ -30,6 +31,8 @@
 
     private final int[] nOutputPartitions;
 
+    private NetworkAddress[][] inputPartitionLocations;
+
     public TaskAttemptDescriptor(TaskAttemptId taId, int nPartitions, int[] nInputPartitions, int[] nOutputPartitions) {
         this.taId = taId;
         this.nPartitions = nPartitions;
@@ -53,6 +56,14 @@
         return nOutputPartitions;
     }
 
+    public void setInputPartitionLocations(NetworkAddress[][] inputPartitionLocations) {
+        this.inputPartitionLocations = inputPartitionLocations;
+    }
+
+    public NetworkAddress[][] getInputPartitionLocations() {
+        return inputPartitionLocations;
+    }
+
     @Override
     public String toString() {
         return "TaskAttemptDescriptor[taId = " + taId + ", nPartitions = " + nPartitions + ", nInputPartitions = "