def test_remove_input(self): var0 = MyVariable("var0") var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = MyVariable("var3") var4 = MyVariable("var4") op1_out = op1(var1, var0) out0 = op2(op1_out, var2) out1 = op1(var3, var4) out1.name = "out1" out2 = op1(out1, var0) out2.name = "out2" out3 = out1 fg = FunctionGraph( [var0, var1, var2, var3, var4], [out0, out1, out2, out3], clone=False, ) fg.remove_input(4) fg.check_integrity() assert fg.inputs == [var0, var1, var2, var3] assert fg.outputs == [out0]
def test_remove_output_3(self): var0 = MyVariable("var0") var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = MyVariable("var3") var4 = MyVariable("var4") var5 = MyVariable("var5") var6 = MyVariable("var6") op1_out = op1(var1, var0) out0 = op2(op1_out, var2) out1 = op1(var3, var4) out1.name = "out1" out2 = op1(op1_out, var5) out2.name = "out2" out3 = op1(var3, var6) out3.name = "out3" out4 = op1_out out5 = var3 fg = FunctionGraph( [var0, var1, var2, var3, var4, var5, var6], [out0, out1, out2, out3, out4, out5], clone=False, ) fg.remove_output(1) fg.check_integrity() assert fg.inputs == [var0, var1, var2, var3, var4, var5, var6] assert fg.outputs == [out0, out2, out3, out4, out5] assert out1 not in fg.clients
def test_init(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var1) var4 = op2(var3, var2) fg = FunctionGraph([var1, var2], [var3, var4], clone=False) assert fg.inputs == [var1, var2] assert fg.outputs == [var3, var4] assert fg.apply_nodes == {var3.owner, var4.owner} assert fg.update_mapping is None assert fg.check_integrity() is None assert fg.variables == {var1, var2, var3, var4} assert fg.get_clients(var1) == [(var3.owner, 0)] assert fg.get_clients(var2) == [(var4.owner, 1)] assert fg.get_clients(var3) == [("output", 0), (var4.owner, 0)] assert fg.get_clients(var4) == [("output", 1)] varC = MyConstant("varC") var5 = op1(var1, varC) fg = FunctionGraph(outputs=[var3, var4, var5], clone=False) assert fg.inputs == [var1, var2] memo = {} fg = FunctionGraph(outputs=[var3, var4], clone=True, memo=memo) assert memo[var1].type == var1.type assert memo[var1].name == var1.name assert memo[var2].type == var2.type assert memo[var2].name == var2.name assert var3 in memo assert var4 in memo
def test_constant(self): x = Constant(MyType(), 2, name="x") y = MyVariable("y") z = Constant(MyType(), 2, name="z") e = op1(op1(x, y), y) g = FunctionGraph([y], [e]) PatternOptimizer((op1, z, "1"), (op2, "1", z)).optimize(g) assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))"
def test_nested_even(self): # regardless of the order in which we optimize, this # should work x, y, z = inputs() e = op1(op1(op1(op1(x)))) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, (op1, "1")), "1").optimize(g) assert str(g) == "FunctionGraph(x)"
def test_ambiguous(self): # this test should always work with TopoOptimizer and the # ignore_newtrees flag set to False. Behavior with ignore_newtrees # = True or with other NavigatorOptimizers may differ. x, y, z = inputs() e = op1(op1(op1(op1(op1(x))))) g = FunctionGraph([x, y, z], [e]) TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).optimize(g) assert str(g) == "FunctionGraph(Op1(x))"
def test_validate_inputs(self): var1 = op1() var2 = op2() with pytest.raises(TypeError): FunctionGraph(var1, [var2]) with pytest.raises(TypeError): FunctionGraph([var1], var2) with pytest.raises(ValueError): var3 = op1(var1) FunctionGraph([var3], [var2], clone=False)
def test_match_same_illegal(self): x, y, z = inputs() e = op2(op1(x, x), op1(x, y)) g = FunctionGraph([x, y, z], [e]) def constraint(r): # Only replacing if the input is an instance of Op2 return r.owner.inputs[0] is not r.owner.inputs[1] PatternOptimizer({ "pattern": (op1, "x", "y"), "constraint": constraint }, (op3, "x", "y")).optimize(g) assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
def test_constraints(self): x, y, z = inputs() e = op4(op1(op2(x, y)), op1(op1(x, y))) g = FunctionGraph([x, y, z], [e]) def constraint(r): # Only replacing if the input is an instance of Op2 return r.owner.op == op2 PatternOptimizer((op1, { "pattern": "1", "constraint": constraint }), (op3, "1")).optimize(g) assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))"
def test_replace_subgraph(self): # replacing inside the graph x, y, z = inputs() e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op2, "1", "2"), (op1, "2", "1")).optimize(g) assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))"
def test_multiple(self): # it should replace all occurrences of the pattern x, y, z = inputs() e = op1(op2(x, y), op2(x, y), op2(y, z)) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op2, "1", "2"), (op4, "1")).optimize(g) assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))"
def test_multiple_merges(self): x, y, z = inputs() e1 = op1(x, y) e2 = op2(op3(x), y, z) e = op1(e1, op4(e2, e1), op1(e2)) g = FunctionGraph([x, y, z], [e]) MergeOptimizer().optimize(g) strg = str(g) # note: graph.as_string can only produce the following two possibilities, but if # the implementation was to change there are 6 other acceptable answers. assert ( strg == "FunctionGraph(Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2)))" or strg == "FunctionGraph(Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1)))" )
def test_check_integrity(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(Exception, match="The nodes are .*"): fg.apply_nodes.remove(var5.owner) fg.check_integrity() with pytest.raises(Exception, match="Inconsistent clients.*"): fg.apply_nodes.add(var5.owner) fg.remove_client(var2, (var5.owner, 1)) fg.check_integrity() fg.add_client(var2, (var5.owner, 1)) with pytest.raises(Exception, match="The variables are.*"): fg.variables.remove(var4) fg.check_integrity() fg.variables.add(var4) with pytest.raises(Exception, match="Undeclared input.*"): var6 = MyVariable2("var6") fg.clients[var6] = [(var5.owner, 3)] fg.variables.add(var6) var5.owner.inputs.append(var6) fg.check_integrity() fg.variables.remove(var6) var5.owner.inputs.remove(var6) # TODO: What if the index value is greater than 1? It will throw an # `IndexError`, but that doesn't sound like anything we'd want. with pytest.raises(Exception, match="Inconsistent clients list.*"): fg.add_client(var4, ("output", 1)) fg.check_integrity() fg.remove_client(var4, ("output", 1)) with pytest.raises(Exception, match="Client not in FunctionGraph.*"): fg.add_client(var4, (var6.owner, 0)) fg.check_integrity() fg.remove_client(var4, (var6.owner, 0)) with pytest.raises(Exception, match="Inconsistent clients list.*"): fg.add_client(var4, (var3.owner, 0)) fg.check_integrity()
def test_change_input(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var6 = MyVariable2("var6") with pytest.raises(TypeError): fg.change_input("output", 1, var6) with pytest.raises(TypeError): fg.change_input(var5.owner, 1, var6) old_apply_nodes = set(fg.apply_nodes) old_variables = set(fg.variables) old_var5_clients = list(fg.get_clients(var5)) # We're replacing with the same variable, so nothing should happen fg.change_input(var5.owner, 1, var2) assert old_apply_nodes == fg.apply_nodes assert old_variables == fg.variables assert old_var5_clients == fg.get_clients(var5) # Perform a valid `Apply` node input change fg.change_input(var5.owner, 1, var1) assert var5.owner.inputs[1] is var1 assert (var5.owner, 1) not in fg.get_clients(var2)
def test_import_var(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var0 = MyVariable("var0") with pytest.raises(MissingInputError): # We can't import a new `FunctionGraph` input (i.e. something # without an owner), at least not without setting `import_missing` fg.import_var(var0, "testing") fg.import_var(var0, import_missing=True) assert var0 in fg.inputs var5 = op2() # We can import variables with owners fg.import_var(var5, "testing") assert var5 in fg.variables assert var5.owner in fg.apply_nodes with pytest.raises(TypeError, match="Computation graph contains.*"): from aesara.graph.null_type import NullType fg.import_var(NullType()(), "testing")
def test_import_node(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var8 = MyVariable("var8") var6 = op2(var8) with pytest.raises(MissingInputError): fg.import_node(var6.owner) assert var8 not in fg.variables fg.import_node(var6.owner, import_missing=True) assert var8 in fg.inputs assert var6.owner in fg.apply_nodes var7 = op2(var2) assert not hasattr(var7.owner.tag, "imported_by") fg.import_node(var7.owner) assert hasattr(var7.owner.tag, "imported_by") assert var7 in fg.variables assert var7.owner in fg.apply_nodes assert (var7.owner, 0) in fg.get_clients(var2)
def test_verbose(self, capsys): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) fg = FunctionGraph([var1, var2], [var3], clone=False) rv_feature = ReplaceValidate() fg.attach_feature(rv_feature) rv_feature.replace_all_validate(fg, [(var3, var1)], reason="test-reason", verbose=True) capres = capsys.readouterr() assert capres.err == "" assert "optimizer: rewrite test-reason replaces Op1.0 with var1" in capres.out class TestFeature(Feature): def validate(self, *args): raise Exception() fg.attach_feature(TestFeature()) with pytest.raises(Exception): rv_feature.replace_all_validate(fg, [(var3, var1)], reason="test-reason", verbose=True) capres = capsys.readouterr() assert "optimizer: validate failed on node Op1.0" in capres.out
def test_remove_in_and_out(self): var1 = MyVariable("var1") var2 = MyVariable("var2") op1_out = op1(var2, var1) op2_out = op2(op1_out, var2) op3_out = op3(op2_out, var2, var2) fg = FunctionGraph([var1, var2], [op1_out, op3_out], clone=False) # Remove an output fg.remove_output(1) fg.check_integrity() assert fg.outputs == [op1_out] assert op3_out not in fg.clients assert not any(op3_out.owner in clients for clients in sum(fg.clients.values(), [])) # Remove an input fg.remove_input(0) fg.check_integrity() assert var1 not in fg.variables assert fg.inputs == [var2] assert fg.outputs == [] assert not any(op1_out.owner in clients for clients in sum(fg.clients.values(), []))
def test_remove_node(self): var1 = MyVariable("var1") var2 = MyVariable("var2") node1_out = op1(var1) node2_out = op2(var2, node1_out) node3_out = op3(node2_out) fg = FunctionGraph([var1, var2], [node3_out], clone=False) fg.remove_node(node3_out.owner) fg.check_integrity() assert not fg.apply_nodes fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False) fg.remove_node(node3_out.owner) fg.check_integrity() assert fg.apply_nodes == {node1_out.owner, node2_out.owner} fg = FunctionGraph([var1, var2], [node2_out, node3_out], clone=False) fg.remove_node(node2_out.owner) fg.check_integrity() assert not fg.apply_nodes
def test_multi(self): x, y, z = inputs() e0 = op1(x, y) e = op3(op4(e0), e0) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g) assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
def test_patternsub_values_eq_approx(out_pattern, tracks): # PatternSub would fail when `values_eq_approx` and `get_nodes` were specified x = MyVariable("x") e = op1(x) fg = FunctionGraph([x], [e], clone=False) opt = EquilibriumOptimizer( [ PatternSub( (op1, "x"), out_pattern, tracks=[op1] if tracks else (), get_nodes=(lambda fgraph, node: [node]) if tracks else None, values_eq_approx=values_eq_approx_always_true, ) ], max_use_ratio=1, ) opt.optimize(fg) output = fg.outputs[0] if isinstance(out_pattern, tuple): assert output.owner.op == op2 assert output.tag.values_eq_approx is values_eq_approx_always_true elif out_pattern == "x": assert output is x assert output.tag.values_eq_approx is values_eq_approx_always_true else: # The replacement types do not match, so the substitution should've # failed assert output is e
def test_optimizer_verbose(self, capsys): x = MyVariable("x") y = MyVariable("y") o1 = op1(x, y) fgraph = FunctionGraph([x, y], [o1], clone=False) @local_optimizer(None) def local_opt_1(fgraph, node): if node.inputs[0] == x: res = op2(y, *node.inputs[1:]) return [res] @local_optimizer(None) def local_opt_2(fgraph, node): if node.inputs[0] == y: res = op2(x, *node.inputs[1:]) return [res] opt_group = LocalOptGroup(local_opt_1, local_opt_2) with config.change_flags(optimizer_verbose=True): (new_res, ) = opt_group.transform(fgraph, o1.owner) _ = opt_group.transform(fgraph, new_res.owner) capres = capsys.readouterr() assert capres.err == "" assert ( "optimizer: rewrite local_opt_1 replaces Op1(x, y) with [Op2.0]" in capres.out) assert ( "optimizer: rewrite local_opt_2 replaces Op2(y, y) with [Op2.0]" in capres.out)
def test_nested_out_pattern(self): x, y, z = inputs() e = op1(x, y) g = FunctionGraph([x, y, z], [e]) PatternOptimizer( (op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2"))).optimize(g) assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))"
def test_replace_output(self): # replacing the whole graph x, y, z = inputs() e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).optimize(g) assert str(g) == "FunctionGraph(Op4(z, y))"
def test_2(self): x, y, z = map(MyVariable, "xyz") e = op1(op1(op3(x, y))) g = FunctionGraph([x, y, z], [e]) # print g opt = EquilibriumOptimizer( [ PatternSub((op1, (op2, "x", "y")), (op4, "x", "y")), PatternSub((op3, "x", "y"), (op4, "x", "y")), PatternSub((op4, "x", "y"), (op5, "x", "y")), PatternSub((op5, "x", "y"), (op6, "x", "y")), PatternSub((op6, "x", "y"), (op2, "x", "y")), ], max_use_ratio=10, ) opt.optimize(g) assert str(g) == "FunctionGraph(Op2(x, y))"
def test_deep_merge(self): x, y, z = inputs() e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z))) g = FunctionGraph([x, y, z], [e], clone=False) MergeOptimizer().optimize(g) out_var = g.outputs[0] var_1, var_2 = out_var.owner.inputs assert var_2.owner.inputs[0] is var_1
def test_allow_multiple_clients(self): x, y, z = inputs() e0 = op1(x, y) # `e0` has multiple clients (i.e. the `op4` and `op3` nodes) e = op3(op4(e0), e0) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g) assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
def test_pickle(self): var1 = op1() var2 = op2() var3 = op1(var1) var4 = op2(var3, var2) func = FunctionGraph([var1, var2], [var4]) s = pickle.dumps(func) new_func = pickle.loads(s) assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs)) assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs)) assert all( type(a.op) is type(b.op) # noqa: E721 for a, b in zip(func.apply_nodes, new_func.apply_nodes) ) assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
def test_straightforward(self): x, y, z = inputs() e = op1(op2(x, y), op2(x, y), op2(x, z)) g = FunctionGraph([x, y, z], [e], clone=False) MergeOptimizer().optimize(g) out_var = g.outputs[0] var_1, var_2, var_3 = out_var.owner.inputs assert var_1 is var_2 assert var_1 is not var_3
def test_eq(self): # replacing the whole graph x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op_y(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).optimize(g) str_g = str(g) assert str_g == "FunctionGraph(Op4(z, y))"