def test_Dimshuffle_lift_restrictions(): rng = shared(np.random.RandomState(1233532), borrow=False) x = normal(tt.arange(2).reshape((2, )), 100, size=(2, 2, 2), rng=rng) y = x.dimshuffle(1, 0, 2) # The non-`Dimshuffle` client depends on the RNG state, so we can't # perform the lift z = x - y fg = FunctionGraph([rng], [z], clone=False) _ = EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) dimshuffle_node = fg.outputs[0].owner.inputs[1].owner assert dimshuffle_node == y.owner assert isinstance(dimshuffle_node.op, DimShuffle) assert dimshuffle_node.inputs[0].owner.op == normal # The non-`Dimshuffle` client doesn't depend on the RNG state, so we can # perform the lift z = tt.ones(x.shape) - y fg = FunctionGraph([rng], [z], clone=False) EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) rv_node = fg.outputs[0].owner.inputs[1].owner assert rv_node.op == normal assert isinstance(rv_node.inputs[-1].owner.op, DimShuffle) assert isinstance(rv_node.inputs[-2].owner.op, DimShuffle)
def test_Subtensor_lift_restrictions(): rng = shared(np.random.RandomState(1233532), borrow=False) std = tt.vector("std") std.tag.test_value = np.array([1e-5, 2e-5, 3e-5], dtype=config.floatX) x = normal(tt.arange(2), tt.ones(2), rng=rng) y = x[1] # The non-`Subtensor` client depends on the RNG state, so we can't perform # the lift z = x - y fg = FunctionGraph([rng], [z], clone=False) _ = EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner assert subtensor_node == y.owner assert isinstance(subtensor_node.op, Subtensor) assert subtensor_node.inputs[0].owner.op == normal # The non-`Subtensor` client doesn't depend on the RNG state, so we can # perform the lift z = tt.ones(x.shape) - x[1] fg = FunctionGraph([rng], [z], clone=False) EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner assert rv_node.op == normal assert isinstance(rv_node.inputs[-1].owner.op, Subtensor) assert isinstance(rv_node.inputs[-2].owner.op, Subtensor)
def test_merge_with_weird_eq(): # numpy arrays don't compare equal like other python objects # SCALAR CASE x = tt.constant(np.asarray(1), name="x") y = tt.constant(np.asarray(1), name="y") g = FunctionGraph([x, y], [x + y]) MergeOptimizer().optimize(g) assert len(g.apply_nodes) == 1 node = list(g.apply_nodes)[0] assert len(node.inputs) == 2 assert node.inputs[0] is node.inputs[1] # NONSCALAR CASE # This was created to test TensorConstantSignature x = tt.constant(np.ones(5), name="x") y = tt.constant(np.ones(5), name="y") g = FunctionGraph([x, y], [x + y]) MergeOptimizer().optimize(g) assert len(g.apply_nodes) == 1 node = list(g.apply_nodes)[0] assert len(node.inputs) == 2 assert node.inputs[0] is node.inputs[1]
def test_validate_inputs(self): var1 = op1() var2 = op2() with pytest.raises(TypeError): FunctionGraph(var1, [var2]) with pytest.raises(TypeError): FunctionGraph([var1], var2) with pytest.raises(ValueError): var3 = op1(var1) FunctionGraph([var3], [var2], clone=False)
def test_replace(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(Exception, match="Cannot replace.*"): var4.fgraph = object() # Trigger a `FunctionGraph` ownership error fg.replace(var4, var1, verbose=True) var4.fgraph = fg with pytest.raises(BadOptimization): var0 = MyVariable2("var0") # The types don't match and one cannot be converted to the other fg.replace(var3, var0) # Test a basic replacement fg.replace_all([(var3, var1)]) assert var3 not in fg.variables assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var1, var2]
def test_change_input(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var6 = MyVariable2("var6") with pytest.raises(TypeError): fg.change_input("output", 1, var6) with pytest.raises(TypeError): fg.change_input(var5.owner, 1, var6) old_apply_nodes = set(fg.apply_nodes) old_variables = set(fg.variables) old_var5_clients = list(var5.clients) # We're replacing with the same variable, so nothing should happen fg.change_input(var5.owner, 1, var2) assert old_apply_nodes == fg.apply_nodes assert old_variables == fg.variables assert old_var5_clients == var5.clients # Perform a valid `Apply` node input change fg.change_input(var5.owner, 1, var1) assert var5.owner.inputs[1] is var1 assert (var5.owner, 1) not in var2.clients
def test_import_var(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(MissingInputError): var0 = MyVariable("var0") # We can't import a new `FunctionGraph` input (i.e. something # without an owner) fg.import_var(var0, "testing") var5 = op2() # We can import variables with owners fg.import_var(var5, "testing") assert var5 in fg.variables assert var5.owner in fg.apply_nodes with pytest.raises(TypeError, match="Computation graph contains.*"): from theano.gof.null_type import NullType fg.import_var(NullType()(), "testing")
def test_replace_output(self): # replacing the whole graph x, y, z = inputs() e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).optimize(g) assert str(g) == "[Op4(z, y)]"
def est_both_assert_merge_2_reverse(self): # Test case "test_both_assert_merge_2" but in reverse order x1 = T.matrix('x1') x2 = T.matrix('x2') x3 = T.matrix('x3') e = T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all())) +\ T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) g = FunctionGraph([x1, x2, x3], [e]) MergeOptimizer().optimize(g) strg = theano.printing.debugprint(g, file='str') strref = '''Elemwise{add,no_inplace} [@A] '' 7 |dot [@B] '' 6 | |Assert{msg='Theano Assert failed!'} [@C] '' 5 | | |x1 [@D] | | |All [@E] '' 3 | | |Elemwise{gt,no_inplace} [@F] '' 1 | | |x1 [@D] | | |x3 [@G] | |Assert{msg='Theano Assert failed!'} [@H] '' 4 | |x2 [@I] | |All [@J] '' 2 | |Elemwise{gt,no_inplace} [@K] '' 0 | |x2 [@I] | |x3 [@G] |dot [@B] '' 6 ''' print(strg) assert strg == strref, (strg, strref)
def test_multiple(self): # it should replace all occurrences of the pattern x, y, z = inputs() e = op1(op2(x, y), op2(x, y), op2(y, z)) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op2, "1", "2"), (op4, "1")).optimize(g) assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))"
def test_multi(self): x, y, z = inputs() e0 = op1(x, y) e = op3(op4(e0), e0) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g) assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
def est_both_assert_merge_2(self): # Merge two nodes, both have assert on different node x1 = T.matrix('x1') x2 = T.matrix('x2') x3 = T.matrix('x3') e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\ T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all())) g = FunctionGraph([x1, x2, x3], [e]) MergeOptimizer().optimize(g) strg = theano.printing.debugprint(g, file='str') strref = '''Elemwise{add,no_inplace} [id A] '' 7 |dot [id B] '' 6 | |Assert{msg='Theano Assert failed!'} [id C] '' 5 | | |x1 [id D] | | |All [id E] '' 3 | | |Elemwise{gt,no_inplace} [id F] '' 1 | | |x1 [id D] | | |x3 [id G] | |Assert{msg='Theano Assert failed!'} [id H] '' 4 | |x2 [id I] | |All [id J] '' 2 | |Elemwise{gt,no_inplace} [id K] '' 0 | |x2 [id I] | |x3 [id G] |dot [id B] '' 6 ''' # print(strg) assert strg == strref, (strg, strref)
def test_nested_odd(self): x, y, z = inputs() e = op1(op1(op1(op1(op1(x))))) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, (op1, '1')), '1').optimize(g) assert str(g) == "[Op1(x)]"
def test_nested_out_pattern(self): x, y, z = inputs() e = op1(x, y) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, '1', '2'), (op4, (op1, '1'), (op2, '2'), (op3, '1', '2'))).optimize(g) assert str(g) == "[Op4(Op1(x), Op2(y), Op3(x, y))]"
def test_replace_subgraph(self): # replacing inside the graph x, y, z = inputs() e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op2, "1", "2"), (op1, "2", "1")).optimize(g) assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))"
def test_match_same(self): x, y, z = inputs() e = op1(x, x) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, 'x', 'y'), (op3, 'x', 'y')).optimize(g) assert str(g) == "[Op3(x, x)]"
def test_merge_outputs(self): x, y, z = inputs() e1 = op3(op2(x, y)) e2 = op3(op2(x, y)) g = FunctionGraph([x, y, z], [e1, e2]) MergeOptimizer().optimize(g) assert str(g) == "[*1 -> Op3(Op2(x, y)), *1]"
def est_both_assert_merge_2_reverse(self): # Test case "test_both_assert_merge_2" but in reverse order x1 = tt.matrix("x1") x2 = tt.matrix("x2") x3 = tt.matrix("x3") e = tt.dot(x1, tt.opt.assert_op(x2, (x2 > x3).all())) + tt.dot( tt.opt.assert_op(x1, (x1 > x3).all()), x2) g = FunctionGraph([x1, x2, x3], [e]) MergeOptimizer().optimize(g) strg = theano.printing.debugprint(g, file="str") strref = """Elemwise{add,no_inplace} [id A] '' 7 |dot [id B] '' 6 | |Assert{msg='Theano Assert failed!'} [id C] '' 5 | | |x1 [id D] | | |All [id E] '' 3 | | |Elemwise{gt,no_inplace} [id F] '' 1 | | |x1 [id D] | | |x3 [id G] | |Assert{msg='Theano Assert failed!'} [id H] '' 4 | |x2 [id I] | |All [id J] '' 2 | |Elemwise{gt,no_inplace} [id K] '' 0 | |x2 [id I] | |x3 [id G] |dot [id B] '' 6 """ print(strg) assert strg == strref, (strg, strref)
def test_check_integrity(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(Exception, match="The nodes are .*"): fg.apply_nodes.remove(var5.owner) fg.check_integrity() with pytest.raises(Exception, match="Inconsistent clients.*"): fg.apply_nodes.add(var5.owner) fg.remove_client(var2, (var5.owner, 1)) fg.check_integrity() fg.add_client(var2, (var5.owner, 1)) with pytest.raises(Exception, match="The variables are.*"): fg.variables.remove(var4) fg.check_integrity() fg.variables.add(var4) with pytest.raises(Exception, match="Undeclared input.*"): var6 = MyVariable2("var6") fg.clients[var6] = [(var5.owner, 3)] fg.variables.add(var6) var5.owner.inputs.append(var6) fg.check_integrity() fg.variables.remove(var6) var5.owner.inputs.remove(var6) # TODO: What if the index value is greater than 1? It will throw an # `IndexError`, but that doesn't sound like anything we'd want. with pytest.raises(Exception, match="Inconsistent clients list.*"): fg.add_client(var4, ("output", 1)) fg.check_integrity() fg.remove_client(var4, ("output", 1)) with pytest.raises(Exception, match="Client not in FunctionGraph.*"): fg.add_client(var4, (var6.owner, 0)) fg.check_integrity() fg.remove_client(var4, (var6.owner, 0)) with pytest.raises(Exception, match="Inconsistent clients list.*"): fg.add_client(var4, (var3.owner, 0)) fg.check_integrity()
def test_expand(self): x, y, z = inputs() e = op1(op1(op1(x))) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, '1'), (op2, (op1, '1')), ign=True).optimize(g) assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]"
def Env(inputs, outputs, validate=True): e = FunctionGraph(inputs, outputs, clone=False) e.attach_feature(destroyhandler.DestroyHandler()) e.attach_feature(ReplaceValidate()) if validate: e.validate() return e
def test_unification_2(self): x, y, z = inputs() e = op1(op2(x, y), z) # the arguments to op2 are different g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern (op4, '2', '1')).optimize(g) # The replacement should NOT occur assert str(g) == "[Op1(Op2(x, y), z)]"
def test_constant_unification(self): x = Constant(MyType(), 2, name="x") y = MyVariable("y") z = Constant(MyType(), 2, name="z") e = op1(op1(x, y), y) g = FunctionGraph([y], [e]) PatternOptimizer((op1, z, "1"), (op2, "1", z)).optimize(g) assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))"
def test_nested_even(self): # regardless of the order in which we optimize, this # should work x, y, z = inputs() e = op1(op1(op1(op1(x)))) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, (op1, "1")), "1").optimize(g) assert str(g) == "FunctionGraph(x)"
def test_constant_unification(self): x = Constant(MyType(), 2, name='x') y = MyVariable('y') z = Constant(MyType(), 2, name='z') e = op1(op1(x, y), y) g = FunctionGraph([y], [e]) PatternOptimizer((op1, z, '1'), (op2, '1', z)).optimize(g) assert str(g) == "[Op1(Op2(y, z), y)]"
def test_no_recurse(self): # if the out pattern is an acceptable in pattern # and that the ignore_newtrees flag is True, # it should do the replacement and stop x, y, z = inputs() e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op2, "1", "2"), (op2, "2", "1"), ign=True).optimize(g) assert str(g) == "[Op1(Op2(y, x), z)]"
def test_eq(self): # replacing the whole graph x, y, z = inputs() e = op1(op_y(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).optimize(g) str_g = str(g) assert str_g == "FunctionGraph(Op4(z, y))"
def test_eq(self): # replacing the whole graph x, y, z = inputs() e = op1(op_y(x, y), z) g = FunctionGraph([x, y, z], [e]) PatternOptimizer((op1, (op_z, '1', '2'), '3'), (op4, '3', '2')).optimize(g) str_g = str(g) assert str_g == "[Op4(z, y)]"
def test_ambiguous(self): # this test should always work with TopoOptimizer and the # ignore_newtrees flag set to False. Behavior with ignore_newtrees # = True or with other NavigatorOptimizers may differ. x, y, z = inputs() e = op1(op1(op1(op1(op1(x))))) g = FunctionGraph([x, y, z], [e]) TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).optimize(g) assert str(g) == "[Op1(x)]"
def test_constant_merging(self): x = MyVariable('x') y = Constant(MyType(), 2, name='y') z = Constant(MyType(), 2, name='z') e = op1(op2(x, y), op2(x, y), op2(x, z)) g = FunctionGraph([x, y, z], [e]) MergeOptimizer().optimize(g) strg = str(g) assert strg == "[Op1(*1 -> Op2(x, y), *1, *1)]" \ or strg == "[Op1(*1 -> Op2(x, z), *1, *1)]"