Beispiel #1
0
def test_clone_simple():
    def f(x, y):
        a = x * x
        b = y * y
        c = a + b
        return c

    g = parse(f)

    cl = GraphCloner(g, clone_constants=True)

    g2 = cl[g]

    d1 = set(dfs(g.return_, succ_deeper))
    d2 = set(dfs(g2.return_, succ_deeper))

    # Both node sets should be disjoint
    assert d1 & d2 == set()

    # Without cloning constants
    cl2 = GraphCloner(g, clone_constants=False)

    g2 = cl2[g]

    d1 = set(dfs(g.return_, succ_deeper))
    d2 = set(dfs(g2.return_, succ_deeper))

    common = d1 & d2
    assert all(x.is_constant() for x in common)
    assert {x.value for x in common} == {P.scalar_add, P.scalar_mul, P.return_}
Beispiel #2
0
def _check_toposort(order, root, succ=_succ, incl=_incl):
    nodes = set(dfs(root, succ, incl))
    assert set(order) == nodes
    for node in nodes:
        for i in succ(node):
            if i in order:
                assert order.index(i) < order.index(node)
Beispiel #3
0
def test_dfs_dups():
    a = (1, 2, 2)
    b = (3, 4, 5, 2)
    c = (b, 6)
    d = (a, a, c)

    order = list(dfs(d, _succ))

    assert order == [d, a, 1, 2, c, b, 3, 4, 5, 6]
Beispiel #4
0
def _successful_inlining(cl, orig, new_params, target):
    assert cl[orig] is not target
    assert cl[orig] is orig

    new_root = cl[orig.output]
    assert new_root is not orig.output

    orig_nodes = set(dfs(orig.output, succ_incoming))
    new_nodes = set(dfs(new_root, succ_incoming))

    for p in new_params:
        assert p in new_nodes

    # Clones of orig's nodes should belong to target
    assert all(cl[node].graph in {target, None} for node in orig_nodes
               if node.graph is orig)

    # Clone did not change target
    assert target.output is THREE
Beispiel #5
0
def test_dfs_bad_include():
    a = (1, 2)
    b = (3, 4, 5)
    c = (b, 6)
    d = (a, c)

    def inc(n):
        return None

    with pytest.raises(ValueError):
        list(dfs(d, _succ, inc))
Beispiel #6
0
def test_clone_recursive():
    def f(x, y):
        a = x * x
        b = y * y
        return f(a, b)

    g = parse(f)

    cl = GraphCloner(g, clone_constants=True)

    g2 = cl[g]

    d1 = set(dfs(g.return_, succ_deeper))
    d2 = set(dfs(g2.return_, succ_deeper))

    # Both node sets should be disjoint
    assert d1 & d2 == set()

    # Now test inlining
    cl2 = GraphCloner(clone_constants=True)
    target = _graph_for_inline()
    new_params = [ONE, TWO]
    cl2.add_clone(g, target, new_params)

    _successful_inlining(cl2, g, new_params, target)

    # The recursive call still refers to the original graph
    new_nodes = set(dfs(cl2[g.output], succ_deeper))
    assert any(node.value is g for node in new_nodes)

    # Now test that inlining+total will fail
    cl2 = GraphCloner(total=True, clone_constants=True)
    target = _graph_for_inline()
    new_params = [ONE, TWO]
    with pytest.raises(Exception):
        cl2.add_clone(g, target, new_params)
        cl2[g.output]
Beispiel #7
0
    def __init__(self,
                 g,
                 labeler=short_labeler,
                 succ=succ_deeper,
                 include=always_include):
        """Create an Index."""
        self.labeler = labeler
        self._index = defaultdict(set)

        self._acquire(g)

        for node in dfs(g.return_, succ, include):
            self._acquire(node)
            if node.graph:
                self._acquire(node.graph)
Beispiel #8
0
    def graphs_used(self) -> Dict[Graph, Set[Graph]]:
        """Map each graph to the set of graphs it uses.

        For each graph, this is the set of graphs that it refers to
        directly.
        """
        coverage = self.coverage()
        return {
            g: {
                node.value
                for node in dfs(g.return_, succ_incoming, freevars_boundary(g))
                if is_constant_graph(node)
            }
            for g in coverage
        }
Beispiel #9
0
    def free_variables_direct(self) -> Dict[Graph, Iterable[ANFNode]]:
        """Return a mapping from each graph to its free variables.

        The free variables returned are those that the graph refers
        to directly. Nested graphs are not taken into account, but
        they are in `free_variables_total`.
        """
        coverage = self.coverage()
        return {
            g: [
                node
                for node in dfs(g.return_, succ_incoming, freevars_boundary(g))
                if node.graph and node.graph is not g
            ]
            for g in coverage
        }
Beispiel #10
0
    def free_variables_total(self) -> Dict[Graph, Set[ANFNode]]:
        """Map each graph to its free variables.

        This differs from `free_variables_direct` in that it also
        includes free variables needed by children graphs.
        Furthermore, graph Constants may figure as free variables.
        """
        parents = self.parents()
        fvs: Dict[Graph, Set[ANFNode]] = defaultdict(set)

        for node in dfs(self.root.return_, succ_deeper):
            for inp in node.inputs:
                if is_constant_graph(inp):
                    owner = parents[inp.value]
                else:
                    owner = inp.graph
                if owner is None:
                    continue
                g = node.graph
                while g is not owner:
                    fvs[g].add(inp)
                    g = parents[g]

        return fvs
Beispiel #11
0
 def coverage(self) -> Iterable[Graph]:
     """Return a collection of graphs accessible from the root."""
     root: ANFNode = Constant(self.root)
     nodes = dfs(root, succ_deeper)
     return set(node.value if is_constant_graph(node) else node.graph
                for node in nodes) - {None}
Beispiel #12
0
def _check_toposort(order, root, succ):
    nodes = set(dfs(root, succ))
    assert len(order) == len(nodes)
    for node in nodes:
        for i in succ(node):
            assert order.index(i) < order.index(node)