def test_neighbors(self): """Test root().""" # Create [{1, 2}, {2}, set()]. a_graph = DirectedGraph() a_graph.connect(0, 1) a_graph.connect(0, 2) a_graph.connect(1, 2) # Check neighbors on existing vertices. self.assertEqual(a_graph.neighbors(0), {1, 2}) self.assertEqual(a_graph.neighbors(1), {2}) self.assertEqual(a_graph.neighbors(2), set()) # Check neighbors on non-existing vertices. self.assertEqual(a_graph.neighbors(3), set()) self.assertEqual(a_graph.neighbors(4), set())
def test_linked_list(self): """Test public methods on a linked list.""" # Build a linked list: # 0 -> 1 -> 2 a_graph = DirectedGraph() a_graph.connect(0, 1) a_graph.connect(1, 2) # Sort. a_sorter = TopologicalSort(a_graph) sorted_vertices = a_sorter.sort() # Only one correct sequence. self.assertEqual(sorted_vertices, (2, 1, 0)) # Check the result using a graph.Reachability object. a_checker = Reachability(a_graph) self.assertTrue(self._sorted(sorted_vertices, a_checker))
def test_binary_tree(self): """Test public methods on a binary tree.""" # Build a binary tree: # 0 # / \ # 1 2 a_graph = DirectedGraph() a_graph.connect(0, 1) a_graph.connect(0, 2) # Sort. a_sorter = TopologicalSort(a_graph) sorted_vertices = a_sorter.sort() # Either of the two sequences is correct. self.assertTrue(sorted_vertices in {(2, 1, 0), (1, 2, 0)}) # Check the result using a graph.Reachability object. a_checker = Reachability(a_graph) self.assertTrue(self._sorted(sorted_vertices, a_checker))
def test_binary_tree(self): """Test public methods on a binary tree.""" # Build a binary tree: # 0 # / \ # 1 2 a_graph = DirectedGraph() a_graph.connect(0, 1) a_graph.connect(0, 2) checker = Reachability(a_graph) # All the vertices are reachable from the root. self.assertTrue(checker.has_path(0, 1)) self.assertTrue(checker.has_path(0, 2)) # Vertices in another subtree are NOT reachable. self.assertFalse(checker.has_path(1, 2)) self.assertFalse(checker.has_path(2, 1)) # Parent is NOT reachable. self.assertFalse(checker.has_path(1, 0)) self.assertFalse(checker.has_path(2, 0))
def test_implicit_adding_by_connecting(self): """Test connect(), which implicitly calls add().""" # Create []. a_graph = DirectedGraph() # Implicitly add 0, 1, 2 to [], then connect 1 with 2, # which makes [set(), {2}, set()]. a_graph.connect(1, 2) self.assertEqual(a_graph.n_vertices(), 3) # Trivially connected vertices. self.assertEqual(a_graph.connected(0, 0), True) self.assertEqual(a_graph.connected(1, 1), True) self.assertEqual(a_graph.connected(2, 2), True) # Uni-directional connection created by connect(). self.assertEqual(a_graph.connected(1, 2), True) self.assertEqual(a_graph.connected(2, 1), False) # Disconnected vertices. self.assertEqual(a_graph.connected(0, 1), False) self.assertEqual(a_graph.connected(1, 0), False) self.assertEqual(a_graph.connected(0, 2), False) self.assertEqual(a_graph.connected(2, 0), False)
def test_linked_list(self): """Test public methods on a linked list.""" # Build a linked list: # 0 -> 1 -> 2 a_graph = DirectedGraph() a_graph.connect(0, 1) a_graph.connect(1, 2) checker = Reachability(a_graph) # A vertex is always reachable from/to itself. self.assertTrue(checker.has_path(0, 0)) self.assertTrue(checker.has_path(1, 1)) self.assertTrue(checker.has_path(2, 2)) # A downstream vertex is reachable from an upstream vertex. self.assertTrue(checker.has_path(0, 1)) self.assertTrue(checker.has_path(1, 2)) self.assertTrue(checker.has_path(0, 2)) # A upstream vertex is NOT reachable from an downstream vertex. self.assertFalse(checker.has_path(1, 0)) self.assertFalse(checker.has_path(2, 0)) self.assertFalse(checker.has_path(2, 1))
def setUp(self): dag = DirectedGraph() dag.add_node(0) dag.add_node(1) dag.add_node(2) dag.add_node(3) dag.add_node(4) dag.add_node(5) dag.connect(0, 1) dag.connect(0, 2) dag.connect(0, 3) dag.connect(3, 4) dag.connect(4, 5) self.dag = dag cyclic = DirectedGraph() cyclic.add_node(0) cyclic.add_node(1) cyclic.add_node(2) cyclic.add_node(3) cyclic.add_node(4) cyclic.connect(0, 1) cyclic.connect(1, 2) cyclic.connect(1, 3) cyclic.connect(2, 3) cyclic.connect(2, 4) cyclic.connect(4, 0) self.cyclic = cyclic
class Scheduler: """A scheduler supporting O(1) adding and O(N) scheduling.""" def __init__(self): self._task_to_id = dict() self._id_to_task = list() self._graph = DirectedGraph() self._union = UnionFind() def add_a_task(self, task): """Add a new task. Do nothing, if the task has already been added. """ if task not in self._task_to_id: i_task = self.n_tasks() self._task_to_id[task] = i_task self._id_to_task.append(task) assert len(self._id_to_task) == len(self._task_to_id) assert task == self._id_to_task[self._task_to_id[task]] def add_tasks(self, tasks): """Add multiple tasks.""" for task in tasks: self.add_a_task(task) def n_tasks(self): """Return the number of tasks being added.""" return len(self._task_to_id) def add_a_prerequisite(self, task, prerequisite): """Add a prerequisite for a task. Automatically add a new task, if any of the two is new. Do nothing, if the prerequisite has already been added. """ self.add_a_task(task) self.add_a_task(prerequisite) i_task = self._task_to_id[task] i_prerequisite = self._task_to_id[prerequisite] self._graph.connect(i_task, i_prerequisite) self._union.connect(i_task, i_prerequisite) def add_prerequisites(self, task, prerequisites): """Add multiple prerequisites for a task.""" for prerequisite in prerequisites: self.add_a_prerequisite(task, prerequisite) def schedule(self): """Return the tasks in topologically sorted order.""" sorted_tasks = TopologicalSort(self._graph).sort() # Make immutable copies. scheduled_tasks = set() for a_component in self._to_components(sorted_tasks): scheduled_tasks.add(tuple(a_component)) return scheduled_tasks def _to_components(self, sorted_tasks): root_to_component = dict() for i_task in sorted_tasks: i_root = self._union.root(i_task) if i_root not in root_to_component: root_to_component[i_root] = list() task = self._id_to_task[i_task] root_to_component[i_root].append(task) return root_to_component.values()