def test_import_var(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(MissingInputError): var0 = MyVariable("var0") # We can't import a new `FunctionGraph` input (i.e. something # without an owner) fg.import_var(var0, "testing") var5 = op2() # We can import variables with owners fg.import_var(var5, "testing") assert var5 in fg.variables assert var5.owner in fg.apply_nodes with pytest.raises(TypeError, match="Computation graph contains.*"): from theano.gof.null_type import NullType fg.import_var(NullType()(), "testing")
def test_multiple(self): # it should replace all occurrences of the pattern x, y, z = inputs() e = op1(op2(x, y), op2(x, y), op2(y, z)) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op2, "1", "2"), (op4, "1")).optimize(g) assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))"
def test_merge_outputs(self): x, y, z = inputs() e1 = op3(op2(x, y)) e2 = op3(op2(x, y)) g = FunctionGraph([x, y, z], [e1, e2]) MergeOptimizer().optimize(g) assert str(g) == "FunctionGraph(*1 -> Op3(Op2(x, y)), *1)"
def test_constant_merging(self): x = MyVariable("x") y = Constant(MyType(), 2, name="y") z = Constant(MyType(), 2, name="z") e = op1(op2(x, y), op2(x, y), op2(x, z)) g = FunctionGraph([x, y, z], [e]) MergeOptimizer().optimize(g) strg = str(g) assert (strg == "FunctionGraph(Op1(*1 -> Op2(x, y), *1, *1))" or strg == "FunctionGraph(Op1(*1 -> Op2(x, z), *1, *1))")
def test_change_input(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var6 = MyVariable2("var6") with pytest.raises(TypeError): fg.change_input("output", 1, var6) with pytest.raises(TypeError): fg.change_input(var5.owner, 1, var6) old_apply_nodes = set(fg.apply_nodes) old_variables = set(fg.variables) old_var5_clients = list(var5.clients) # We're replacing with the same variable, so nothing should happen fg.change_input(var5.owner, 1, var2) assert old_apply_nodes == fg.apply_nodes assert old_variables == fg.variables assert old_var5_clients == var5.clients # Perform a valid `Apply` node input change fg.change_input(var5.owner, 1, var1) assert var5.owner.inputs[1] is var1 assert (var5.owner, 1) not in var2.clients
def test_check_integrity(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(Exception, match="The nodes are .*"): fg.apply_nodes.remove(var5.owner) fg.check_integrity() with pytest.raises(Exception, match="Inconsistent clients.*"): fg.apply_nodes.add(var5.owner) fg.remove_client(var2, (var5.owner, 1)) fg.check_integrity() fg.add_client(var2, (var5.owner, 1)) with pytest.raises(Exception, match="The variables are.*"): fg.variables.remove(var4) fg.check_integrity() fg.variables.add(var4) with pytest.raises(Exception, match="Undeclared input.*"): var6 = MyVariable2("var6") fg.clients[var6] = [(var5.owner, 3)] fg.variables.add(var6) var5.owner.inputs.append(var6) fg.check_integrity() fg.variables.remove(var6) var5.owner.inputs.remove(var6) # TODO: What if the index value is greater than 1? It will throw an # `IndexError`, but that doesn't sound like anything we'd want. with pytest.raises(Exception, match="Inconsistent clients list.*"): fg.add_client(var4, ("output", 1)) fg.check_integrity() fg.remove_client(var4, ("output", 1)) with pytest.raises(Exception, match="Client not in FunctionGraph.*"): fg.add_client(var4, (var6.owner, 0)) fg.check_integrity() fg.remove_client(var4, (var6.owner, 0)) with pytest.raises(Exception, match="Inconsistent clients list.*"): fg.add_client(var4, (var3.owner, 0)) fg.check_integrity()
def test_replace_output(self): # replacing the whole graph x, y, z = inputs() e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).optimize(g) assert str(g) == "[Op4(z, y)]"
def test_replace_subgraph(self): # replacing inside the graph x, y, z = inputs() e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op2, "1", "2"), (op1, "2", "1")).optimize(g) assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))"
def test_replace(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(Exception, match="Cannot replace.*"): var4.fgraph = object() # Trigger a `FunctionGraph` ownership error fg.replace(var4, var1, verbose=True) var4.fgraph = fg with pytest.raises(BadOptimization): var0 = MyVariable2("var0") # The types don't match and one cannot be converted to the other fg.replace(var3, var0) # Test a basic replacement fg.replace_all([(var3, var1)]) assert var3 not in fg.variables assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var1, var2]
def test_no_recurse(self): # if the out pattern is an acceptable in pattern # and that the ignore_newtrees flag is True, # it should do the replacement and stop x, y, z = inputs() e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op2, "1", "2"), (op2, "2", "1"), ign=True).optimize(g) assert str(g) == "[Op1(Op2(y, x), z)]"
def test_pickle(self): var1 = op1() var2 = op2() var3 = op1(var1) var4 = op2(var3, var2) func = FunctionGraph([var1, var2], [var4]) s = pickle.dumps(func) new_func = pickle.loads(s) assert all( type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs)) assert all( type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs)) assert all( type(a.op) is type(b.op) # noqa: E721 for a, b in zip(func.apply_nodes, new_func.apply_nodes)) assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
def test_unification_2(self): x, y, z = inputs() e = op1(op2(x, y), z) # the arguments to op2 are different g = FunctionGraph([x, y, z], [e]) PatternOptimizer( (op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op4, "2", "1"), ).optimize(g) # The replacement should NOT occur assert str(g) == "FunctionGraph(Op1(Op2(x, y), z))"
def test_unification_1(self): x, y, z = inputs() e = op1(op2(x, x), z) # the arguments to op2 are the same g = FunctionGraph([x, y, z], [e]) PatternOptimizer( (op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op4, "2", "1"), ).optimize(g) # So the replacement should occur assert str(g) == "[Op4(z, x)]"
def test_constraints(self): x, y, z = inputs() e = op4(op1(op2(x, y)), op1(op1(x, y))) g = FunctionGraph([x, y, z], [e]) def constraint(r): # Only replacing if the input is an instance of Op2 return r.owner.op == op2 PatternOptimizer( (op1, {"pattern": "1", "constraint": constraint}), (op3, "1") ).optimize(g) assert str(g) == "[Op4(Op3(Op2(x, y)), Op1(Op1(x, y)))]"
def test_validate_inputs(self): var1 = op1() var2 = op2() with pytest.raises(TypeError): FunctionGraph(var1, [var2]) with pytest.raises(TypeError): FunctionGraph([var1], var2) with pytest.raises(ValueError): var3 = op1(var1) FunctionGraph([var3], [var2], clone=False)
def test_match_same_illegal(self): x, y, z = inputs() e = op2(op1(x, x), op1(x, y)) g = FunctionGraph([x, y, z], [e]) def constraint(r): # Only replacing if the input is an instance of Op2 return r.owner.inputs[0] is not r.owner.inputs[1] PatternOptimizer( {"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y") ).optimize(g) assert str(g) == "[Op2(Op1(x, x), Op3(x, y))]"
def test_contains(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) assert var1 in fg assert var3 in fg assert var3.owner in fg assert var5 in fg assert var5.owner in fg
def test_import_node(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var5 = MyVariable("var5") var6 = op2(var5) with pytest.raises(MissingInputError): fg.import_node(var6.owner) var6 = op2(var2) assert not hasattr(var6.owner.tag, "imported_by") fg.import_node(var6.owner) assert hasattr(var6.owner.tag, "imported_by") assert var6 in fg.variables assert var6.owner in fg.apply_nodes assert (var6.owner, 0) in var2.clients
def test_multiple_merges(self): x, y, z = inputs() e1 = op1(x, y) e2 = op2(op3(x), y, z) e = op1(e1, op4(e2, e1), op1(e2)) g = FunctionGraph([x, y, z], [e]) MergeOptimizer().optimize(g) strg = str(g) # note: graph.as_string can only produce the following two possibilities, but if # the implementation was to change there are 6 other acceptable answers. assert ( strg == "[Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2))]" or strg == "[Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1))]" )
def test_replace_bad_state(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(MissingInputError): var0 = MyVariable("var0") # FIXME TODO XXX: This breaks the state of the `FunctionGraph`, # because it doesn't check for validity of the replacement *first*. fg.replace(var1, var0, verbose=True)
def test_replace_circular(self): """`FunctionGraph` allows cycles--for better or worse.""" var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) fg.replace_all([(var3, var4)]) # The following works (and is kind of gross), because `var4` has been # mutated in-place assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var4, var2]
def test_init(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var1) var4 = op2(var3, var2) fg = FunctionGraph([var1, var2], [var3, var4], clone=False) assert fg.inputs == [var1, var2] assert fg.outputs == [var3, var4] assert fg.apply_nodes == {var3.owner, var4.owner} assert fg.update_mapping is None assert fg.check_integrity() is None assert fg.variables == {var1, var2, var3, var4} assert fg.clients(var1) == [(var3.owner, 0)] assert fg.clients(var2) == [(var4.owner, 1)] assert fg.clients(var3) == [(var4.owner, 0), ("output", 0)] assert fg.clients(var4) == [("output", 1)]
def test_pre_greedy_local_optimizer(): empty_fgraph = FunctionGraph([], []) x = MyVariable("x") y = MyVariable("y") c1 = Constant(MyType(), 1, "c1") c2 = Constant(MyType(), 2, "c2") o1 = op2(c1, c2) o3 = op1(c1, y) o2 = op1(o1, c2, x, o3, o1) assert o2.owner.inputs[0].owner is not None assert o2.owner.inputs[4].owner is not None # This should fold `o1`, because it has only `Constant` arguments, and # replace it with the `Constant` result cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], o2) assert cst.owner.inputs[0].owner is None assert cst.owner.inputs[1] is c2 assert cst.owner.inputs[2] is x assert cst.owner.inputs[3] is o3 assert cst.owner.inputs[4] is cst.owner.inputs[0] # We're going to do it again, except this time `o1` is # in the `fgraph`, so it shouldn't be folded fg = FunctionGraph([], [o1], clone=False) o2 = op1(o1, c2, x, o3, o1) cst = pre_greedy_local_optimizer(fg, [constant_folding], o2) assert cst.owner.inputs[0] is o1 assert cst.owner.inputs[4] is cst.owner.inputs[0] # What exactly is this supposed to test? ms = MakeSlice()(1) cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], ms) assert isinstance(cst, SliceConstant) # Make sure constant of slice signature is hashable. assert isinstance(hash(cst.signature()), int)
def test_remove_client(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) assert fg.variables == {var1, var2, var3, var4, var5} assert fg.clients(var2) == [ (var3.owner, 0), (var4.owner, 1), (var5.owner, 1), (var5.owner, 2), ] fg.remove_client(var2, (var4.owner, 1)) assert fg.clients(var2) == [ (var3.owner, 0), (var5.owner, 1), (var5.owner, 2), ] fg.remove_client(var1, (var3.owner, 1)) assert fg.clients(var1) == [] assert var4.owner in fg.apply_nodes # This next `remove_client` should trigger a complete removal of `var4`'s # variables and `Apply` node from the `FunctionGraph`. # # Also, notice that we already removed `var4` from `var2`'s client list # above, so, when we completely remove `var4`, `fg.remove_client` will # attempt to remove `(var4.owner, 1)` from `var2`'s client list again. # This attempt would previously raise a `ValueError` exception, because # the entry was not in the list. fg.remove_client(var4, (var5.owner, 0), reason="testing") assert var4.owner not in fg.apply_nodes assert var4.owner.tag.removed_by == ["testing"] assert not any(o in fg.variables for o in var4.owner.outputs)
def test_replace_test_value(self): var1 = MyVariable("var1") var1.tag.test_value = 1 var2 = MyVariable("var2") var2.tag.test_value = 2 var3 = op1(var2, var1) var4 = op2(var3, var2) var4.tag.test_value = np.array([1, 2]) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var6 = op3() var6.tag.test_value = np.array(0) assert var6.tag.test_value.shape != var4.tag.test_value.shape with pytest.raises(AssertionError, match="The replacement.*"): fg.replace(var4, var6)
def test_pre_constant_merge(): empty_fgraph = FunctionGraph([], []) x = MyVariable("x") y = MyVariable("y") c1 = Constant(MyType(), 1, "c1") c2 = Constant(MyType(), 1, "c1") o1 = op2(c1, x) o2 = op1(o1, y, c2) assert c1 is not c2 res = pre_constant_merge(empty_fgraph, [o2]) assert [o2] == res assert o2.owner.inputs[2] is c1 o2 = op1(o1, y, c2) fg = FunctionGraph([x, y], [o2], clone=False) assert o2.owner in fg.apply_nodes res = pre_constant_merge(fg, [o2]) assert res == [o2] assert o2.owner.inputs[2] is c2 # What is this supposed to test? ms = MakeSlice()(1) res = pre_constant_merge(empty_fgraph, [ms]) assert res == [ms] const_slice = SliceConstant(type=slicetype, data=slice(1, None, 2)) assert isinstance(const_slice, Constant) adv = AdvancedSubtensor()(tt.matrix(), [2, 3], const_slice) res = pre_constant_merge(empty_fgraph, adv) assert res == [adv]
def test_no_merge(self): x, y, z = inputs() e = op1(op3(op2(x, y)), op3(op2(y, x))) g = FunctionGraph([x, y, z], [e]) MergeOptimizer().optimize(g) assert str(g) == "FunctionGraph(Op1(Op3(Op2(x, y)), Op3(Op2(y, x))))"
def test_deep_merge(self): x, y, z = inputs() e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z))) g = FunctionGraph([x, y, z], [e]) MergeOptimizer().optimize(g) assert str(g) == "FunctionGraph(Op1(*1 -> Op3(Op2(x, y), z), Op4(*1)))"
def test_straightforward(self): x, y, z = inputs() e = op1(op2(x, y), op2(x, y), op2(x, z)) g = FunctionGraph([x, y, z], [e]) MergeOptimizer().optimize(g) assert str(g) == "FunctionGraph(Op1(*1 -> Op2(x, y), *1, Op2(x, z)))"
def test_straightforward_2(self): x, y, z = inputs() e = op1(op2(x), op3(y), op4(z)) g = FunctionGraph([x, y, z], [e]) OpSubOptimizer(op3, op4).optimize(g) assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"