Made transitive closure to find co-scheduled tasks more efficient

git-svn-id: https://hyracks.googlecode.com/svn/branches/hyracks_dev_next@806 123451ca-8445-de46-9d55-352943316053
diff --git a/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/scheduler/ActivityClusterPlanner.java b/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/scheduler/ActivityClusterPlanner.java
index e5d9f64..c721eca 100644
--- a/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/scheduler/ActivityClusterPlanner.java
+++ b/hyracks-control-cc/src/main/java/edu/uci/ics/hyracks/control/cc/scheduler/ActivityClusterPlanner.java
@@ -127,10 +127,11 @@
                         }
                         Set<TaskId> cluster = taskClusterMap.get(ac1TaskStates[i].getTaskId());
                         for (int j = targetBitmap.nextSetBit(0); j >= 0; j = targetBitmap.nextSetBit(j + 1)) {
-                            cInfoList.add(new Pair<TaskId, ConnectorDescriptorId>(ac2TaskStates[j].getTaskId(), cdId));
+                            TaskId targetTID = ac2TaskStates[j].getTaskId();
+                            cInfoList.add(new Pair<TaskId, ConnectorDescriptorId>(targetTID, cdId));
                             IConnectorPolicy cPolicy = connectorPolicies.get(cdId);
                             if (cPolicy.requiresProducerConsumerCoscheduling()) {
-                                cluster.add(ac2TaskStates[j].getTaskId());
+                                cluster.add(targetTID);
                             }
                         }
                     }
@@ -138,50 +139,7 @@
             }
         }
 
-        boolean done = false;
-        while (!done) {
-            done = true;
-            Set<TaskId> set = new HashSet<TaskId>();
-            Set<TaskId> oldSet = null;
-            for (Map.Entry<TaskId, Set<TaskId>> e : taskClusterMap.entrySet()) {
-                set.clear();
-                oldSet = e.getValue();
-                set.addAll(e.getValue());
-                for (TaskId tid : e.getValue()) {
-                    set.addAll(taskClusterMap.get(tid));
-                }
-                for (TaskId tid : set) {
-                    Set<TaskId> targetSet = taskClusterMap.get(tid);
-                    if (!targetSet.equals(set)) {
-                        done = false;
-                        break;
-                    }
-                }
-                if (!done) {
-                    break;
-                }
-            }
-            for (TaskId tid : oldSet) {
-                taskClusterMap.put(tid, set);
-            }
-        }
-
-        Set<Set<TaskId>> clusters = new HashSet<Set<TaskId>>(taskClusterMap.values());
-        Set<TaskCluster> tcSet = new HashSet<TaskCluster>();
-        int counter = 0;
-        for (Set<TaskId> cluster : clusters) {
-            Set<Task> taskStates = new HashSet<Task>();
-            for (TaskId tid : cluster) {
-                taskStates.add(activityPlanMap.get(tid.getActivityId()).getTasks()[tid.getPartition()]);
-            }
-            TaskCluster tc = new TaskCluster(new TaskClusterId(ac.getActivityClusterId(), counter++), ac,
-                    taskStates.toArray(new Task[taskStates.size()]));
-            tcSet.add(tc);
-            for (TaskId tid : cluster) {
-                activityPlanMap.get(tid.getActivityId()).getTasks()[tid.getPartition()].setTaskCluster(tc);
-            }
-        }
-        TaskCluster[] taskClusters = tcSet.toArray(new TaskCluster[tcSet.size()]);
+        TaskCluster[] taskClusters = buildTaskClusters(ac, activityPlanMap, taskClusterMap);
 
         for (TaskCluster tc : taskClusters) {
             Set<TaskCluster> tcDependencyTaskClusters = tc.getDependencyTaskClusters();
@@ -213,8 +171,8 @@
 
         if (LOGGER.isLoggable(Level.INFO)) {
             LOGGER.info("Plan for " + ac);
-            LOGGER.info("Built " + tcSet.size() + " Task Clusters");
-            for (TaskCluster tc : tcSet) {
+            LOGGER.info("Built " + taskClusters.length + " Task Clusters");
+            for (TaskCluster tc : taskClusters) {
                 LOGGER.info("Tasks: " + Arrays.toString(tc.getTasks()));
             }
         }
@@ -222,6 +180,80 @@
         ac.setPlan(new ActivityClusterPlan(taskClusters, activityPlanMap));
     }
 
+    private TaskCluster[] buildTaskClusters(ActivityCluster ac, Map<ActivityId, ActivityPlan> activityPlanMap,
+            Map<TaskId, Set<TaskId>> taskClusterMap) {
+        /*
+         * taskClusterMap contains for every TID x, x -> { coscheduled consumer TIDs U x }
+         * We compute the transitive closure of this relation to find the largest set of
+         * tasks that need to be co-scheduled
+         */
+        int counter = 0;
+        TaskId[] ordinalList = new TaskId[taskClusterMap.size()];
+        Map<TaskId, Integer> ordinalMap = new HashMap<TaskId, Integer>();
+        for (TaskId tid : taskClusterMap.keySet()) {
+            ordinalList[counter] = tid;
+            ordinalMap.put(tid, counter);
+            ++counter;
+        }
+
+        int n = ordinalList.length;
+        BitSet[] paths = new BitSet[n];
+        for (Map.Entry<TaskId, Set<TaskId>> e : taskClusterMap.entrySet()) {
+            int i = ordinalMap.get(e.getKey());
+            BitSet bsi = paths[i];
+            if (bsi == null) {
+                bsi = new BitSet(n);
+                paths[i] = bsi;
+            }
+            for (TaskId ttid : e.getValue()) {
+                int j = ordinalMap.get(ttid);
+                paths[i].set(j);
+                BitSet bsj = paths[j];
+                if (bsj == null) {
+                    bsj = new BitSet(n);
+                    paths[j] = bsj;
+                }
+                bsj.set(i);
+            }
+        }
+        for (int k = 0; k < n; ++k) {
+            for (int i = paths[k].nextSetBit(0); i >= 0; i = paths[k].nextSetBit(i + 1)) {
+                for (int j = paths[i].nextClearBit(0); j < n && j >= 0; j = paths[i].nextClearBit(j + 1)) {
+                    paths[i].set(j, paths[k].get(j));
+                    paths[j].set(i, paths[i].get(j));
+                }
+            }
+        }
+        BitSet pending = new BitSet(n);
+        pending.set(0, n);
+        List<List<TaskId>> clusters = new ArrayList<List<TaskId>>();
+        for (int i = pending.nextSetBit(0); i >= 0; i = pending.nextSetBit(i)) {
+            List<TaskId> cluster = new ArrayList<TaskId>();
+            for (int j = paths[i].nextSetBit(0); j >= 0; j = paths[i].nextSetBit(j + 1)) {
+                cluster.add(ordinalList[j]);
+                pending.clear(j);
+            }
+            clusters.add(cluster);
+        }
+
+        List<TaskCluster> tcSet = new ArrayList<TaskCluster>();
+        counter = 0;
+        for (List<TaskId> cluster : clusters) {
+            List<Task> taskStates = new ArrayList<Task>();
+            for (TaskId tid : cluster) {
+                taskStates.add(activityPlanMap.get(tid.getActivityId()).getTasks()[tid.getPartition()]);
+            }
+            TaskCluster tc = new TaskCluster(new TaskClusterId(ac.getActivityClusterId(), counter++), ac,
+                    taskStates.toArray(new Task[taskStates.size()]));
+            tcSet.add(tc);
+            for (TaskId tid : cluster) {
+                activityPlanMap.get(tid.getActivityId()).getTasks()[tid.getPartition()].setTaskCluster(tc);
+            }
+        }
+        TaskCluster[] taskClusters = tcSet.toArray(new TaskCluster[tcSet.size()]);
+        return taskClusters;
+    }
+
     private TaskCluster getTaskCluster(TaskId tid) {
         ActivityCluster ac = scheduler.getJobRun().getActivityClusterMap().get(tid.getActivityId());
         ActivityClusterPlan acp = ac.getPlan();