def test_misc_2():
    x, y, z = inputs()
    tv = transpose_view(x)
    e = add_in_place(x, tv)
    g = Env([x,y], [e], False)
    inconsistent(g)
    g.replace(tv, x)
    inconsistent(g)
def test_value_repl_2():
    x, y, z = inputs()
    sy = sigmoid(y)
    e = add_in_place(x, sy)
    g = Env([x,y], [e], False)
    consistent(g)
    g.replace(sy, transpose_view(MyConstant("abc")))
    consistent(g)
def test_usage_loop_through_views_2():
    x, y, z = inputs()
    e0 = transpose_view(transpose_view(sigmoid(x)))
    e = dot(add_in_place(x,y), transpose_view(e0))
    g = Env([x,y,z], [e])
    consistent(g) # because sigmoid can do the copy
    g.replace(e0, x)
    inconsistent(g) # we cut off the path to the sigmoid
def test_indirect_2():
    x, y, z = inputs()
    e0 = transpose_view(x)
    e = dot(sigmoid(add_in_place(x, y)), e0)
    g = Env([x,y,z], [e], False)
    inconsistent(g)
    new_e0 = add(e0, y)
    g.replace(e0, new_e0)
    consistent(g)
def test_repair_destroy_path():
    x, y, z = inputs()
    e1 = transpose_view(transpose_view(x))
    e2 = transpose_view(transpose_view(e1))
    e3 = add_in_place(e2, y)
    e4 = add_in_place(e1, z)
    g = Env([x,y,z], [e3, e4], False)
    inconsistent(g)
    g.replace(e2, transpose_view(x))
    inconsistent(g)
Example #6
0
    def test_replace_bad_state(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(MissingInputError):
            var0 = MyVariable("var0")

            # FIXME TODO XXX: This breaks the state of the `FunctionGraph`,
            # because it doesn't check for validity of the replacement *first*.
            fg.replace(var1, var0, verbose=True)
Example #7
0
def test_misc():
    x, y, z = inputs()
    e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
    g = Env([x,y,z], [e])
    consistent(g)
    chk = g.checkpoint()
    PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g)
    assert str(g) == "[x]"
    new_e = add(x,y)
    g.replace_validate(x, new_e)
    assert str(g) == "[Add(x, y)]"
    g.replace(new_e, dot(add_in_place(x,y), transpose_view(x)))
    assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
    inconsistent(g)
    g.revert(chk)
    consistent(g)
    assert str(g) == "[TransposeView(TransposeView(TransposeView(TransposeView(x))))]"
Example #8
0
 def test_straightforward(self):
     x, y, z = inputs()
     e0 = dot(y, z)
     e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
     g = Env([x, y, z], [e])
     g.extend(NodeFinder())
     assert hasattr(g, 'get_nodes')
     for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
         if not len([x for x in g.get_nodes(type)]) == num:
             raise Exception("Expected: %i times %s" % (num, type))
     new_e0 = add(y, z)
     assert e0.owner in g.get_nodes(dot)
     assert new_e0.owner not in g.get_nodes(add)
     g.replace(e0, new_e0)
     assert e0.owner not in g.get_nodes(dot)
     assert new_e0.owner in g.get_nodes(add)
     for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
         if not len([x for x in g.get_nodes(type)]) == num:
             raise Exception("Expected: %i times %s" % (num, type))
Example #9
0
    def test_replace_test_value(self):

        var1 = MyVariable("var1")
        var1.tag.test_value = 1
        var2 = MyVariable("var2")
        var2.tag.test_value = 2
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var4.tag.test_value = np.array([1, 2])
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        var6 = op3()
        var6.tag.test_value = np.array(0)

        assert var6.tag.test_value.shape != var4.tag.test_value.shape

        with pytest.raises(AssertionError, match="The replacement.*"):
            fg.replace(var4, var6)
Example #10
0
    def test_replace(self):

        var1 = MyVariable("var1")
        var2 = MyVariable("var2")
        var3 = op1(var2, var1)
        var4 = op2(var3, var2)
        var5 = op3(var4, var2, var2)
        fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

        with pytest.raises(BadOptimization):
            var0 = MyVariable2("var0")
            # The types don't match and one cannot be converted to the other
            fg.replace(var3, var0)

        # Test a basic replacement
        fg.replace_all([(var3, var1)])
        assert var3 not in fg.variables
        assert fg.apply_nodes == {var4.owner, var5.owner}
        assert var4.owner.inputs == [var1, var2]
Example #11
0
    def test_straightforward(self):
        x, y, z = inputs()
        e0 = dot(y, z)
        e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
        g = FunctionGraph([x, y, z], [e], clone=False)
        g.attach_feature(NodeFinder())

        assert hasattr(g, 'get_nodes')
        for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
            if not len([x for x in g.get_nodes(type)]) == num:
                raise Exception("Expected: %i times %s" % (num, type))
        new_e0 = add(y, z)
        assert e0.owner in g.get_nodes(dot)
        assert new_e0.owner not in g.get_nodes(add)
        g.replace(e0, new_e0)
        assert e0.owner not in g.get_nodes(dot)
        assert new_e0.owner in g.get_nodes(add)
        for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
            if not len([x for x in g.get_nodes(type)]) == num:
                raise Exception("Expected: %i times %s" % (num, type))
Example #12
0
def test_aliased_inputs_replacement():
    x, y, z = inputs()
    tv = transpose_view(x)
    tvv = transpose_view(tv)
    sx = sigmoid(x)
    e = add_in_place(x, tv)
    g = Env([x, y], [e], False)
    inconsistent(g)
    g.replace(tv, sx)
    consistent(g)
    g.replace(sx, tv)
    inconsistent(g)
    g.replace(tv, tvv)
    inconsistent(g)
    g.replace(tv, sx)
    consistent(g)
Example #13
0
    def test_straightforward(self):
        class MyType(Type):
            def __init__(self, name):
                self.name = name

            def __str__(self):
                return self.name

            def __repr__(self):
                return self.name

            def __eq__(self, other):
                return isinstance(other, MyType)

        class MyOp(Op):

            __props__ = ("nin", "name")

            def __init__(self, nin, name):
                self.nin = nin
                self.name = name

            def make_node(self, *inputs):
                def as_variable(x):
                    assert isinstance(x, Variable)
                    return x

                assert len(inputs) == self.nin
                inputs = list(map(as_variable, inputs))
                for input in inputs:
                    if not isinstance(input.type, MyType):
                        raise Exception("Error 1")
                outputs = [MyType(self.name + "_R")()]
                return Apply(self, inputs, outputs)

            def __str__(self):
                return self.name

        sigmoid = MyOp(1, "Sigmoid")
        add = MyOp(2, "Add")
        dot = MyOp(2, "Dot")

        def MyVariable(name):
            return Variable(MyType(name), None, None)

        def inputs():
            x = MyVariable("x")
            y = MyVariable("y")
            z = MyVariable("z")
            return x, y, z

        x, y, z = inputs()
        e0 = dot(y, z)
        e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
        g = FunctionGraph([x, y, z], [e], clone=False)
        g.attach_feature(NodeFinder())

        assert hasattr(g, "get_nodes")
        for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
            if not len([t for t in g.get_nodes(type)]) == num:
                raise Exception("Expected: %i times %s" % (num, type))
        new_e0 = add(y, z)
        assert e0.owner in g.get_nodes(dot)
        assert new_e0.owner not in g.get_nodes(add)
        g.replace(e0, new_e0)
        assert e0.owner not in g.get_nodes(dot)
        assert new_e0.owner in g.get_nodes(add)
        for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
            if not len([t for t in g.get_nodes(type)]) == num:
                raise Exception("Expected: %i times %s" % (num, type))