def test_dfs_variants(): def f(x): z = x * x def g(y): return y + z w = z + 3 q = g(w) return q graph = parse(f) inner_graph_ct, = [x for x in dfs(graph.return_) if is_constant_graph(x)] inner_graph = inner_graph_ct.value inner_ret = inner_graph.return_ deep = _name_nodes(_dfs(inner_ret, succ_deep)) assert deep == set('. return add y z mul x'.split()) deeper = _name_nodes(_dfs(inner_ret, succ_deeper)) assert deeper == set('. return add y z mul x w 3 q g'.split()) _bound_fv = freevars_boundary(inner_graph, True) bound_fv = _name_nodes(_dfs(inner_ret, succ_deeper, _bound_fv)) assert bound_fv == set('. return add y z'.split()) _no_fv = freevars_boundary(inner_graph, False) no_fv = _name_nodes(_dfs(inner_ret, succ_deeper, _no_fv)) assert no_fv == set('. return add y'.split()) _excl_root = exclude_from_set([inner_ret]) excl_root = _name_nodes(_dfs(inner_ret, succ_deeper, _excl_root)) assert excl_root == set()
def test_disconnect(): def f(x): a = x * x _b = a + x # Not connected to any output # noqa c = a * a d = c * c # Connected to g's output def g(y): return d * y return g(c) g = Parser(ENV, f).parse(False) # Include {None} to get Constants (makes it easier to compare to live) cov = {None} | accessible_graphs(g) live = _name_nodes(_dfs(g.return_, succ_deep)) assert live == set('x a mul c return . g d y'.split()) total = _name_nodes(_dfs(g.return_, succ_bidirectional(cov))) assert total == set('x a mul c return . g d y _b add'.split()) destroy_disconnected_nodes(g) total2 = _name_nodes(_dfs(g.return_, succ_bidirectional(cov))) assert total2 == live
def destroy_disconnected_nodes(root: Graph) -> None: """Remove dead nodes that belong to the graphs accessible from root. The `uses` set of a node may keep alive some nodes that are not connected to the output of a graph (e.g. `_, x = pair`, where `_` is unused). These nodes are removed by this function. """ # We restrict ourselves to graphs accessible from root, otherwise we may # accidentally destroy nodes from other graphs that are users of the # constants we use. cov = accessible_graphs(root) live = dfs(root.return_, True) total = _dfs(root.return_, succ_bidirectional(cov)) dead = set(total) - set(live) for node in dead: node.inputs.clear() # type: ignore
def dfs(root: ANFNode, follow_graph: bool = False) -> Iterable[ANFNode]: """Perform a depth-first search.""" return _dfs(root, succ_deep if follow_graph else succ_incoming)