def test_clone_simple(): def f(x, y): a = x * x b = y * y c = a + b return c g = parse(f) cl = GraphCloner(g, clone_constants=True) g2 = cl[g] d1 = set(dfs(g.return_, succ_deeper)) d2 = set(dfs(g2.return_, succ_deeper)) # Both node sets should be disjoint assert d1 & d2 == set() # Without cloning constants cl2 = GraphCloner(g, clone_constants=False) g2 = cl2[g] d1 = set(dfs(g.return_, succ_deeper)) d2 = set(dfs(g2.return_, succ_deeper)) common = d1 & d2 assert all(x.is_constant() for x in common) assert {x.value for x in common} == {P.scalar_add, P.scalar_mul, P.return_}
def _check_toposort(order, root, succ=_succ, incl=_incl): nodes = set(dfs(root, succ, incl)) assert set(order) == nodes for node in nodes: for i in succ(node): if i in order: assert order.index(i) < order.index(node)
def test_dfs_dups(): a = (1, 2, 2) b = (3, 4, 5, 2) c = (b, 6) d = (a, a, c) order = list(dfs(d, _succ)) assert order == [d, a, 1, 2, c, b, 3, 4, 5, 6]
def _successful_inlining(cl, orig, new_params, target): assert cl[orig] is not target assert cl[orig] is orig new_root = cl[orig.output] assert new_root is not orig.output orig_nodes = set(dfs(orig.output, succ_incoming)) new_nodes = set(dfs(new_root, succ_incoming)) for p in new_params: assert p in new_nodes # Clones of orig's nodes should belong to target assert all(cl[node].graph in {target, None} for node in orig_nodes if node.graph is orig) # Clone did not change target assert target.output is THREE
def test_dfs_bad_include(): a = (1, 2) b = (3, 4, 5) c = (b, 6) d = (a, c) def inc(n): return None with pytest.raises(ValueError): list(dfs(d, _succ, inc))
def test_clone_recursive(): def f(x, y): a = x * x b = y * y return f(a, b) g = parse(f) cl = GraphCloner(g, clone_constants=True) g2 = cl[g] d1 = set(dfs(g.return_, succ_deeper)) d2 = set(dfs(g2.return_, succ_deeper)) # Both node sets should be disjoint assert d1 & d2 == set() # Now test inlining cl2 = GraphCloner(clone_constants=True) target = _graph_for_inline() new_params = [ONE, TWO] cl2.add_clone(g, target, new_params) _successful_inlining(cl2, g, new_params, target) # The recursive call still refers to the original graph new_nodes = set(dfs(cl2[g.output], succ_deeper)) assert any(node.value is g for node in new_nodes) # Now test that inlining+total will fail cl2 = GraphCloner(total=True, clone_constants=True) target = _graph_for_inline() new_params = [ONE, TWO] with pytest.raises(Exception): cl2.add_clone(g, target, new_params) cl2[g.output]
def __init__(self, g, labeler=short_labeler, succ=succ_deeper, include=always_include): """Create an Index.""" self.labeler = labeler self._index = defaultdict(set) self._acquire(g) for node in dfs(g.return_, succ, include): self._acquire(node) if node.graph: self._acquire(node.graph)
def graphs_used(self) -> Dict[Graph, Set[Graph]]: """Map each graph to the set of graphs it uses. For each graph, this is the set of graphs that it refers to directly. """ coverage = self.coverage() return { g: { node.value for node in dfs(g.return_, succ_incoming, freevars_boundary(g)) if is_constant_graph(node) } for g in coverage }
def free_variables_direct(self) -> Dict[Graph, Iterable[ANFNode]]: """Return a mapping from each graph to its free variables. The free variables returned are those that the graph refers to directly. Nested graphs are not taken into account, but they are in `free_variables_total`. """ coverage = self.coverage() return { g: [ node for node in dfs(g.return_, succ_incoming, freevars_boundary(g)) if node.graph and node.graph is not g ] for g in coverage }
def free_variables_total(self) -> Dict[Graph, Set[ANFNode]]: """Map each graph to its free variables. This differs from `free_variables_direct` in that it also includes free variables needed by children graphs. Furthermore, graph Constants may figure as free variables. """ parents = self.parents() fvs: Dict[Graph, Set[ANFNode]] = defaultdict(set) for node in dfs(self.root.return_, succ_deeper): for inp in node.inputs: if is_constant_graph(inp): owner = parents[inp.value] else: owner = inp.graph if owner is None: continue g = node.graph while g is not owner: fvs[g].add(inp) g = parents[g] return fvs
def coverage(self) -> Iterable[Graph]: """Return a collection of graphs accessible from the root.""" root: ANFNode = Constant(self.root) nodes = dfs(root, succ_deeper) return set(node.value if is_constant_graph(node) else node.graph for node in nodes) - {None}
def _check_toposort(order, root, succ): nodes = set(dfs(root, succ)) assert len(order) == len(nodes) for node in nodes: for i in succ(node): assert order.index(i) < order.index(node)