def test_dfs_graphs(): g0 = Graph() in0 = Constant(g0) in1 = Constant(1) g0.return_ = in1 value = Apply([in0], Graph()) assert set(dfs(value)) == {value, in0} assert set(dfs(value, follow_graph=True)) == {value, in0, in1}
def test_dfs_variants(): def f(x): z = x * x def g(y): return y + z w = z + 3 q = g(w) return q graph = parse(f) inner_graph_ct, = [x for x in dfs(graph.return_) if x.is_constant_graph()] inner_graph = inner_graph_ct.value inner_ret = inner_graph.return_ deep = _name_nodes(_dfs(inner_ret, succ_deep)) assert deep == set('return scalar_add y z scalar_mul x'.split()) deeper = _name_nodes(_dfs(inner_ret, succ_deeper)) assert deeper == set('return scalar_add y z scalar_mul x w 3 q g'.split()) _bound_fv = freevars_boundary(inner_graph, True) bound_fv = _name_nodes(_dfs(inner_ret, succ_deeper, _bound_fv)) assert bound_fv == set('return scalar_add y z'.split()) _no_fv = freevars_boundary(inner_graph, False) no_fv = _name_nodes(_dfs(inner_ret, succ_deeper, _no_fv)) assert no_fv == set('return scalar_add y'.split()) _excl_root = exclude_from_set([inner_ret]) excl_root = _name_nodes(_dfs(inner_ret, succ_deeper, _excl_root)) assert excl_root == set()
def test_dfs(): in0 = Constant(0) in1 = Constant(1) value = Apply([in0, in1], Graph()) assert next(dfs(value)) == value assert set(dfs(value)) == {value, in0, in1}