Esempio n. 1
0
def test_dfs_variants():
    def f(x):
        z = x * x

        def g(y):
            return y + z

        w = z + 3
        q = g(w)
        return q

    graph = parse(f)
    inner_graph_ct, = [x for x in dfs(graph.return_) if is_constant_graph(x)]
    inner_graph = inner_graph_ct.value

    inner_ret = inner_graph.return_

    deep = _name_nodes(_dfs(inner_ret, succ_deep))
    assert deep == set('. return add y z mul x'.split())

    deeper = _name_nodes(_dfs(inner_ret, succ_deeper))
    assert deeper == set('. return add y z mul x w 3 q g'.split())

    _bound_fv = freevars_boundary(inner_graph, True)
    bound_fv = _name_nodes(_dfs(inner_ret, succ_deeper, _bound_fv))
    assert bound_fv == set('. return add y z'.split())

    _no_fv = freevars_boundary(inner_graph, False)
    no_fv = _name_nodes(_dfs(inner_ret, succ_deeper, _no_fv))
    assert no_fv == set('. return add y'.split())

    _excl_root = exclude_from_set([inner_ret])
    excl_root = _name_nodes(_dfs(inner_ret, succ_deeper, _excl_root))
    assert excl_root == set()
Esempio n. 2
0
def test_disconnect():
    def f(x):
        a = x * x
        _b = a + x  # Not connected to any output # noqa
        c = a * a
        d = c * c  # Connected to g's output

        def g(y):
            return d * y

        return g(c)

    g = Parser(ENV, f).parse(False)

    # Include {None} to get Constants (makes it easier to compare to live)
    cov = {None} | accessible_graphs(g)

    live = _name_nodes(_dfs(g.return_, succ_deep))
    assert live == set('x a mul c return . g d y'.split())

    total = _name_nodes(_dfs(g.return_, succ_bidirectional(cov)))
    assert total == set('x a mul c return . g d y _b add'.split())

    destroy_disconnected_nodes(g)

    total2 = _name_nodes(_dfs(g.return_, succ_bidirectional(cov)))
    assert total2 == live
Esempio n. 3
0
def destroy_disconnected_nodes(root: Graph) -> None:
    """Remove dead nodes that belong to the graphs accessible from root.

    The `uses` set of a node may keep alive some nodes that are not connected
    to the output of a graph (e.g. `_, x = pair`, where `_` is unused). These
    nodes are removed by this function.
    """
    # We restrict ourselves to graphs accessible from root, otherwise we may
    # accidentally destroy nodes from other graphs that are users of the
    # constants we use.
    cov = accessible_graphs(root)
    live = dfs(root.return_, True)
    total = _dfs(root.return_, succ_bidirectional(cov))
    dead = set(total) - set(live)
    for node in dead:
        node.inputs.clear()  # type: ignore
Esempio n. 4
0
def dfs(root: ANFNode, follow_graph: bool = False) -> Iterable[ANFNode]:
    """Perform a depth-first search."""
    return _dfs(root, succ_deep if follow_graph else succ_incoming)