def build_graph_from_outputs(outputs: Iterable[DataPlaceholder]) -> DiGraph: """Builds a graph by backtracking from a sets of outputs. It does so by backtracking recursively in depth-first fashion, jumping from outputs to steps in tandem until hitting a step with no inputs (an InputStep). It builds the graph including the targets, i.e. the graph at fit time. Parameters ---------- outputs Outputs (data placeholders) from where the backtrack to build the graph starts. Returns ------- graph The built graph. """ graph = DiGraph() # Add nodes (steps) def collect_steps_from(output): parent_step = output.step if parent_step in graph: return graph.add_node(parent_step) for input in parent_step.inputs: collect_steps_from(input) for target in parent_step.targets: collect_steps_from(target) for output in outputs: collect_steps_from(output) # Add edges (data) for step in graph: for input in step.inputs: graph.add_edge(input.step, step, input) for target in step.targets: graph.add_edge(target.step, step, target) # Check for any nodes (steps) with duplicated names duplicated_names = find_duplicated_items([step.name for step in graph]) if duplicated_names: raise RuntimeError( "A graph cannot contain steps with duplicated names. " "Found the following duplicates:\n" "{}".format(duplicated_names)) return graph
def test_in_degree(): graph = DiGraph() nodes = range(4) for node in nodes: graph.add_node(node) graph.add_edge(0, 1) graph.add_edge(1, 3) graph.add_edge(2, 3) assert [0, 1, 0, 2] == [graph.in_degree(node) for node in nodes]
def test_topological_sort_cyclic_graph(): graph = DiGraph() for node in [0, 1, 2]: graph.add_node(node) graph.add_edge(0, 1) graph.add_edge(1, 2) graph.add_edge(2, 0) with pytest.raises(CyclicDiGraphError): graph.topological_sort()
def test_clear(): graph = DiGraph() graph.add_node(0) graph.add_node(1) graph.add_edge(0, 1, "foo") assert 0 in graph assert 1 in graph assert [(0, 1, {"foo"})] == list(graph.edges) graph.clear() assert 0 not in graph assert 1 not in graph assert [] == list(graph) assert [] == list(graph.edges)
def test_topological_sort(): # Example randomly generated with # https://www.cs.usfca.edu/~galles/visualization/TopoSortDFS.html graph = DiGraph() for node in range(8): graph.add_node(node) graph.add_edge(0, 2) graph.add_edge(0, 3) graph.add_edge(2, 4) graph.add_edge(2, 6) graph.add_edge(4, 7) graph.add_edge(6, 7) graph.add_edge(3, 5) graph.add_edge(1, 5) graph.add_edge(3, 7) assert [1, 0, 3, 5, 2, 6, 4, 7] == graph.topological_sort()
def test_edges(): graph = DiGraph() graph.add_node("A") graph.add_node("B") graph.add_node("C") graph.add_edge("A", "B", 123) graph.add_edge("A", "C", 456) # Cannot make sets of sets to compare and assert # so we use lists and do brute-force comparison def equal(x, y): y = list(y) try: for elem in x: y.remove(elem) except ValueError: return False return not y assert equal([("A", "B", {123}), ("A", "C", {456})], list(graph.edges))
def test_get_edge_data(): graph = DiGraph() graph.add_node("A") graph.add_node("B") graph.add_edge("A", "B") assert set() == graph.get_edge_data("A", "B") graph.add_edge("A", "B", 123) assert {123} == graph.get_edge_data("A", "B") graph.add_edge("A", "B", 456, 789) assert {123, 456, 789} == graph.get_edge_data("A", "B")
def build_graph_from_outputs(outputs: Iterable[DataPlaceholder]) -> DiGraph: """Builds a graph by backtracking from a sets of outputs. It does so by backtracking recursively in depth-first fashion, jumping from outputs to steps in tandem until hitting a step with no inputs (an InputStep). It builds the graph including the targets, i.e. the graph at fit time. Parameters ---------- outputs Outputs (data placeholders) from where the backtrack to build the graph starts. Returns ------- graph The built graph. """ graph = DiGraph() # Add nodes (a node represents a step at a given port) def collect_nodes_from(output): parent_node = output.node if parent_node in graph: return graph.add_node(parent_node) for input in parent_node.inputs: collect_nodes_from(input) for target in parent_node.targets: collect_nodes_from(target) for output in outputs: collect_nodes_from(output) # Add edges (data) for node in graph: for input in node.inputs: graph.add_edge(input.node, node, input) for target in node.targets: graph.add_edge(target.node, node, target) # Check that there are no steps with the same name steps_seen = {} # type: Dict[str, Step] duplicated_names = [] for node in graph: step_name = node.step.name if step_name not in steps_seen: steps_seen[step_name] = node.step elif node.step is not steps_seen[step_name]: duplicated_names.append(step_name) if duplicated_names: raise RuntimeError( "A graph cannot contain steps with duplicated names. " "Found the following duplicates:\n" "{}".format(duplicated_names) ) return graph
def test_ancestors(): graph = DiGraph() with pytest.raises(NodeNotFoundError): graph.ancestors(0) # Case 1: # +--> [1] --> [3] --> [5] # | ^ # [0] | # | | # +--> [2] --> [4] -----+ for node in range(6): graph.add_node(node) graph.add_edge(0, 1) graph.add_edge(1, 3) graph.add_edge(3, 5) graph.add_edge(0, 2) graph.add_edge(2, 4) graph.add_edge(4, 5) assert set() == graph.ancestors(0) assert {0, 2} == graph.ancestors(4) assert {0, 1, 2, 3, 4} == graph.ancestors(5) # Case 2: # [0] --> [1] --> [4] # ^ ^ # | | # + ------+ | # | | # [2] --> [3] -----+ graph = DiGraph() for node in range(5): graph.add_node(node) graph.add_edge(0, 1) graph.add_edge(1, 4) graph.add_edge(2, 1) graph.add_edge(2, 3) graph.add_edge(3, 4) assert set() == graph.ancestors(0) assert {0, 2} == graph.ancestors(1) assert {2} == graph.ancestors(3) assert {0, 1, 2, 3} == graph.ancestors(4)
def test_can_add_same_edge(): graph = DiGraph() graph.add_node("A") graph.add_node("B") graph.add_edge("A", "B") graph.add_edge("A", "B")
def test_add_edge_with_nonexistent_node(): graph = DiGraph() graph.add_node("A") with pytest.raises(NodeNotFoundError): graph.add_edge("A", "B")
def test_add_edge(): graph = DiGraph() graph.add_node("A") graph.add_node("B") graph.add_edge("A", "B") assert "B" in graph.successors("A") and "A" in graph.predecessors("B")