def test_misc_2(): x, y, z = inputs() tv = transpose_view(x) e = add_in_place(x, tv) g = Env([x,y], [e], False) inconsistent(g) g.replace(tv, x) inconsistent(g)
def test_value_repl_2(): x, y, z = inputs() sy = sigmoid(y) e = add_in_place(x, sy) g = Env([x,y], [e], False) consistent(g) g.replace(sy, transpose_view(MyConstant("abc"))) consistent(g)
def test_usage_loop_through_views_2(): x, y, z = inputs() e0 = transpose_view(transpose_view(sigmoid(x))) e = dot(add_in_place(x,y), transpose_view(e0)) g = Env([x,y,z], [e]) consistent(g) # because sigmoid can do the copy g.replace(e0, x) inconsistent(g) # we cut off the path to the sigmoid
def test_indirect_2(): x, y, z = inputs() e0 = transpose_view(x) e = dot(sigmoid(add_in_place(x, y)), e0) g = Env([x,y,z], [e], False) inconsistent(g) new_e0 = add(e0, y) g.replace(e0, new_e0) consistent(g)
def test_repair_destroy_path(): x, y, z = inputs() e1 = transpose_view(transpose_view(x)) e2 = transpose_view(transpose_view(e1)) e3 = add_in_place(e2, y) e4 = add_in_place(e1, z) g = Env([x,y,z], [e3, e4], False) inconsistent(g) g.replace(e2, transpose_view(x)) inconsistent(g)
def test_replace_bad_state(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(MissingInputError): var0 = MyVariable("var0") # FIXME TODO XXX: This breaks the state of the `FunctionGraph`, # because it doesn't check for validity of the replacement *first*. fg.replace(var1, var0, verbose=True)
def test_misc(): x, y, z = inputs() e = transpose_view(transpose_view(transpose_view(transpose_view(x)))) g = Env([x,y,z], [e]) consistent(g) chk = g.checkpoint() PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g) assert str(g) == "[x]" new_e = add(x,y) g.replace_validate(x, new_e) assert str(g) == "[Add(x, y)]" g.replace(new_e, dot(add_in_place(x,y), transpose_view(x))) assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]" inconsistent(g) g.revert(chk) consistent(g) assert str(g) == "[TransposeView(TransposeView(TransposeView(TransposeView(x))))]"
def test_straightforward(self): x, y, z = inputs() e0 = dot(y, z) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) g = Env([x, y, z], [e]) g.extend(NodeFinder()) assert hasattr(g, 'get_nodes') for type, num in ((add, 3), (sigmoid, 3), (dot, 2)): if not len([x for x in g.get_nodes(type)]) == num: raise Exception("Expected: %i times %s" % (num, type)) new_e0 = add(y, z) assert e0.owner in g.get_nodes(dot) assert new_e0.owner not in g.get_nodes(add) g.replace(e0, new_e0) assert e0.owner not in g.get_nodes(dot) assert new_e0.owner in g.get_nodes(add) for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): if not len([x for x in g.get_nodes(type)]) == num: raise Exception("Expected: %i times %s" % (num, type))
def test_replace_test_value(self): var1 = MyVariable("var1") var1.tag.test_value = 1 var2 = MyVariable("var2") var2.tag.test_value = 2 var3 = op1(var2, var1) var4 = op2(var3, var2) var4.tag.test_value = np.array([1, 2]) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var6 = op3() var6.tag.test_value = np.array(0) assert var6.tag.test_value.shape != var4.tag.test_value.shape with pytest.raises(AssertionError, match="The replacement.*"): fg.replace(var4, var6)
def test_replace(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(BadOptimization): var0 = MyVariable2("var0") # The types don't match and one cannot be converted to the other fg.replace(var3, var0) # Test a basic replacement fg.replace_all([(var3, var1)]) assert var3 not in fg.variables assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var1, var2]
def test_straightforward(self): x, y, z = inputs() e0 = dot(y, z) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) g = FunctionGraph([x, y, z], [e], clone=False) g.attach_feature(NodeFinder()) assert hasattr(g, 'get_nodes') for type, num in ((add, 3), (sigmoid, 3), (dot, 2)): if not len([x for x in g.get_nodes(type)]) == num: raise Exception("Expected: %i times %s" % (num, type)) new_e0 = add(y, z) assert e0.owner in g.get_nodes(dot) assert new_e0.owner not in g.get_nodes(add) g.replace(e0, new_e0) assert e0.owner not in g.get_nodes(dot) assert new_e0.owner in g.get_nodes(add) for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): if not len([x for x in g.get_nodes(type)]) == num: raise Exception("Expected: %i times %s" % (num, type))
def test_aliased_inputs_replacement(): x, y, z = inputs() tv = transpose_view(x) tvv = transpose_view(tv) sx = sigmoid(x) e = add_in_place(x, tv) g = Env([x, y], [e], False) inconsistent(g) g.replace(tv, sx) consistent(g) g.replace(sx, tv) inconsistent(g) g.replace(tv, tvv) inconsistent(g) g.replace(tv, sx) consistent(g)
def test_straightforward(self): class MyType(Type): def __init__(self, name): self.name = name def __str__(self): return self.name def __repr__(self): return self.name def __eq__(self, other): return isinstance(other, MyType) class MyOp(Op): __props__ = ("nin", "name") def __init__(self, nin, name): self.nin = nin self.name = name def make_node(self, *inputs): def as_variable(x): assert isinstance(x, Variable) return x assert len(inputs) == self.nin inputs = list(map(as_variable, inputs)) for input in inputs: if not isinstance(input.type, MyType): raise Exception("Error 1") outputs = [MyType(self.name + "_R")()] return Apply(self, inputs, outputs) def __str__(self): return self.name sigmoid = MyOp(1, "Sigmoid") add = MyOp(2, "Add") dot = MyOp(2, "Dot") def MyVariable(name): return Variable(MyType(name), None, None) def inputs(): x = MyVariable("x") y = MyVariable("y") z = MyVariable("z") return x, y, z x, y, z = inputs() e0 = dot(y, z) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) g = FunctionGraph([x, y, z], [e], clone=False) g.attach_feature(NodeFinder()) assert hasattr(g, "get_nodes") for type, num in ((add, 3), (sigmoid, 3), (dot, 2)): if not len([t for t in g.get_nodes(type)]) == num: raise Exception("Expected: %i times %s" % (num, type)) new_e0 = add(y, z) assert e0.owner in g.get_nodes(dot) assert new_e0.owner not in g.get_nodes(add) g.replace(e0, new_e0) assert e0.owner not in g.get_nodes(dot) assert new_e0.owner in g.get_nodes(add) for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): if not len([t for t in g.get_nodes(type)]) == num: raise Exception("Expected: %i times %s" % (num, type))