示例#1
0
def test_error_infer():
    from vmad.core.error import InferError
    with Builder() as m:
        with Builder() as m2:
            a = m.input('a')
            b = m2.input('b')
            with pytest.raises(InferError):
                c = add(x1=a, x2=b)
示例#2
0
def test_model_partial_out():
    with Builder() as m:
        a = m.input('a')
        t1 = add(x1=a, x2=a)
        m.output(c=t1)
        m.output(a=a)

    init = dict(a=3)

    (a, c), tape = m.compute(init=init, vout=['a', 'c'], return_tape=True)
    assert c == 6
    assert a == 3

    vjp = tape.get_vjp()

    # test two outputs individually
    init = dict(_c=1.0, _a=0.0)
    _a = vjp.compute(init=init, vout='_a', monitor=print)
    assert _a == 2.0

    init = dict(_c=0.0, _a=1.0)
    _a = vjp.compute(init=init, vout='_a', monitor=print)
    assert _a == 1.0

    jvp = tape.get_jvp()
    init = dict(a_=1.0)
    a_, c_ = jvp.compute(init=init, vout=['a_', 'c_'], monitor=print)
    assert c_ == 2.0
    assert a_ == 1.0
示例#3
0
def test_model_nested():

    with Builder() as m:
        a = m.input('a')
        b = example(a, 2)
        c = example.forgetful(a, 2)
        m.output(b=b, c=c)

    init = dict(a=1.0)
    (b, c), tape = m.compute(init=init,
                             vout=['b', 'c'],
                             monitor=print,
                             return_tape=True)

    assert b == 4.0
    assert c == 4.0

    vjp = tape.get_vjp()
    init = dict(_b=1.0, _c=1.0)
    _a = vjp.compute(init=init, vout='_a', monitor=print)
    assert _a == 8.0

    jvp = tape.get_jvp()
    init = dict(a_=1.0)
    b_, c_ = jvp.compute(init=init, vout=['b_', 'c_'], monitor=print)
    assert b_ == 4.0
    assert c_ == 4.0
示例#4
0
def test_operator_list_out():
    from numpy.testing import assert_array_equal
    from vmad.core.symbol import ListPlaceholder
    from vmad.core.symbol import List

    with Builder() as m:
        a = m.input('a')

        t = split(x=a, axis=0, args=ListPlaceholder(2))
        assert isinstance(t, List)
        assert len(t) == 2

        m.output(c=t)

    init = dict(a=[[1, 1], [2, 2]])

    c, tape = m.compute(init=init, vout='c', return_tape=True, monitor=print)
    assert_array_equal(c, [[1, 1], [2, 2]])

    vjp = tape.get_vjp()
    init = dict(_c=[[1, 1], [1, 1]])
    _a = vjp.compute(init=init, vout='_a', monitor=print)

    assert_array_equal(_a, [[1, 1], [1, 1]])

    jvp = tape.get_jvp()
    init = dict(a_=[[1, 1], [1, 1]])
    c_ = jvp.compute(init=init, vout='c_', monitor=print)

    assert_array_equal(c_, [[1, 1], [1, 1]])
示例#5
0
def test_operator_default_jvp():
    @operator
    class op:
        ain = 'x'
        aout = 'y1'

        def apl(node, x):
            return dict(y1=x * 2)

        def vjp(node, _y1):
            return dict(_x=_y1 * 2)

    with Builder() as m:
        a = m.input('a')
        t1 = op(x=a)
        m.output(c=t1)

    init = dict(a=3)

    c, tape = m.compute(init=init, vout='c', return_tape=True)
    assert c == 6

    vjp = tape.get_vjp()
    init = dict(_c=1)
    _a = vjp.compute(init=init, vout='_a', monitor=print)
    assert _a == 2

    jvp = tape.get_jvp()
    init = dict(a_=1)
    c_ = jvp.compute(init=init, vout=('c_'), monitor=print)
    assert c_ == 2
示例#6
0
def test_autooperator_bind():

    op1 = example.bind(n=2)
    m = example.build(n=2)

    assert 'n' in op1.hyperargs
    assert 'n' not in example.hyperargs

    with Builder() as m:
        a = m.input('a')
        b = example(a, 2)
        c = op1(a)
        m.output(b=b, c=c)

    init = dict(a=1.0)
    (b, c), tape = m.compute(init=init,
                             vout=['b', 'c'],
                             monitor=print,
                             return_tape=True)

    assert b == c
    assert b == 4.0

    assert example.__bound_model__ is None
    assert op1.__bound_model__ is not None

    op1.build()
示例#7
0
def test_error_many_output():
    from vmad.core.error import DuplicatedOutput
    with Builder() as m:
        a = m.input('a')
        with pytest.raises(DuplicatedOutput):
            m.output(a=a)
            m.output(a=a)
示例#8
0
def test_operator_watchpoint():
    from vmad.core.stdlib import watchpoint
    foo = [0]
    def monitor(x):
        foo[0] = 1

    with Builder() as m:
        a = m.input('a')
        b = linalg.add(a, a)
        watchpoint(a, monitor=monitor)
        m.output(c=b)

    init = [('a', 1)]
#    for node in m:
#        print('m', node)
    c, tape = m.compute(init=init, vout='c', return_tape=True)
#    vjp = tape.get_vjp()
#    for node in vjp:
#        print('vjp', node)

    [c], [_a] = m.compute_with_vjp(init=init, v=[('_c', 1.0)])
    assert foo[0] == 1
    assert c == 2
    assert _a == 2.0

    foo[0] = 0
    c, c_ = m.compute_with_jvp(vout='c', init=init, v=[('a_', 1.0)])
    assert foo[0] == 1
    assert c == 2
    assert c_ == 2.0
示例#9
0
def test_operator_record_extra():
    # assert used extra args are recored on the tape
    @operator
    class myrecord:
        ain = 'x'
        aout = 'y'

        def apl(node, x, p):
            return dict(y=x * p, extra=p)

        def rcd(node, x, p, y, extra):
            return dict(x=x, u=p, extra=extra)

        def vjp(node, x, _y, u):
            return _y * u

        def jvp(node, x, x_, u):
            return x_ * u

    with Builder() as m:
        a = m.input('a')
        b = myrecord(x=a, p=2.0)
        m.output(b=b)

    init = dict(a=1.0)
    b, tape = m.compute(init=init, vout='b', monitor=print, return_tape=True)

    assert b == 2.0
    assert 'p' not in tape[0].impl_kwargs
    assert 'u' in tape[0].impl_kwargs
示例#10
0
def test_error_unexpected_output():
    from vmad.core.error import ResolveError
    with Builder() as m:
        a = m.input('a')
        m.output(a=a)

    with pytest.raises(ResolveError):
        m.compute(vout='a', init={})
示例#11
0
def test_error_unexpected_output():
    from vmad.core.error import UnexpectedOutput
    with Builder() as m:
        a = m.input('a')
        m.output(a=a)

    with pytest.raises(UnexpectredOutput):
        m.compute(vout='b', init=dict(a=1.0))
示例#12
0
def test_model_compute_with_jvp():
    with Builder() as m:
        a = m.input('a')
        t1 = add(x1=a, x2=a)
        m.output(b=t1)

    init = [('a', 1)]
    b, b_ = m.compute_with_jvp(vout='b', init=init, v=[('a_', 1.0)])
    assert b == 2.0
    assert b_ == 2.0
示例#13
0
def test_model_compute_with_vjp():
    with Builder() as m:
        a = m.input('a')
        t1 = add(x1=a, x2=a)
        m.output(b=t1)

    init = [('a', 1)]
    [b], [_a] = m.compute_with_vjp(init=init, v=[('_b', 1.0)])
    assert b == 2.0
    assert _a == 2.0
示例#14
0
def test_operator_assert_true():
    from vmad.core.stdlib import assert_true

    with Builder() as m:
        a = m.input('a')
        assert_true(a, lambda x: isinstance(x, int))
        m.output(c=a)

    with pytest.raises(AssertionError):
        c = m.compute(vout='c', init=dict(a=1.09))

    c = m.compute(vout='c', init=dict(a=1))
示例#15
0
def test_model_compute_with_gnDp():
    with Builder() as m:
        a = m.input('a')
        t1 = mul(x1=a, x2=a)
        m.output(b=t1)

    init = [('a', 1)]
    b, [_a_] = m.compute_with_gnDp(
                            vout='b',
                            init=init,
                            v=[('a_', 1.0)],
                            )
    assert b == 1.0
    assert _a_ == 4.0
示例#16
0
def test_operator_list_with_zero_vjp():
    with Builder() as m:
        value = m.input('value')
        r = zero_jac([value])
        m.output(my_array=r)

    init = dict(value=1)
    my_array, tape = m.compute(init=init, vout='my_array', return_tape=True)
    assert my_array == 3

    vjp = tape.get_vjp()
    _value = vjp.compute(init=dict(_my_array=5), vout='_value')
    assert _value == 0

    jvp = tape.get_jvp()
    my_array_ = jvp.compute(init=dict(value_=5), vout='my_array_')
    assert my_array_ == 0
示例#17
0
def test_model_many_rewrites():
    # this is a nasty model with many variable rewrites.
    n = 2
    with Builder() as m:
        x = m.input('x')
        for i in range(2):
            x = add(x1=x, x2=x)

        m.output(y=x)

    init = dict(x=1.0)
    y, tape = m.compute(init=init, vout='y', return_tape=True)
    assert y == 4.0

    vjp = tape.get_vjp()
    init = dict(_y = 1.0)
    _x = vjp.compute(init=init, vout='_x', monitor=print)
    assert _x == 4.0
示例#18
0
def test_autooperator_as_member():
    class MyType(object):
        def __init__(self):
            self.n = 3

        @autooperator('x->y')
        def example_func(self, x):
            return dict(y=x * self.n)

    obj = MyType()
    with Builder() as m:
        a = m.input('a')
        b = obj.example_func(a)
        m.output(b=b)

    y = m.compute(vout='b', init=dict(a=1))
    assert y == 3

    y = obj.example_func.build().compute(vout='y', init=dict(x=1.))
    assert y == 3
示例#19
0
def test_model_partial():
    with Builder() as m:
        a = m.input('a')
        t1 = add(x1=a, x2=a)
        m.output(c=t1)

    init = dict(a=3)

    c, tape = m.compute(init=init, vout='c', return_tape=True)
    assert c == 6

    vjp = tape.get_vjp()
    init = dict(_c=1.0)
    _a = vjp.compute(init=init, vout='_a', monitor=print)
    assert _a == 2.0

    jvp = tape.get_jvp()
    init = dict(a_=1.0)
    c_ = jvp.compute(init=init, vout='c_', monitor=print)
    assert c_ == 2.0
示例#20
0
def test_operator_skip_unused():

    with Builder() as m:
        a = m.input('a')
        t1 = error(x=a)
        m.output(c=a)

    init = dict(a=3)

    c, tape = m.compute(init=init, vout='c', return_tape=True)
    assert c == 3

    vjp = tape.get_vjp()
    init = dict(_c=0)
    _a = vjp.compute(init=init, vout='_a', monitor=print)
    assert _a == 0

    jvp = tape.get_jvp()
    init = dict(a_=0)
    c_ = jvp.compute(init=init, vout='c_', monitor=print)
    assert c_ == 0
示例#21
0
def test_operator_assert_isinstance():
    from vmad.core.stdlib import assert_isinstance

    with Builder() as m:
        a = m.input('a')
        assert_isinstance(a, int)
        m.output(c=a)

    init = [('a', 1)]
    c, tape = m.compute(init=init, return_tape=True, vout='c')

    [c], [_a] = m.compute_with_vjp(init=init, v=[('_c', 1.0)])
    assert c == 1
    assert _a == 1.0

    c, c_ = m.compute_with_jvp(vout='c', init=init, v=[('a_', 1.0)])
    assert c == 1
    assert c_ == 1.0

    with pytest.raises(TypeError):
        c = m.compute(vout='c', init=dict(a=1.09))
示例#22
0
def test_operator_zero():

    with Builder() as m:
        a = m.input('a')
        t1 = error_on_grad(x=a)
        m.output(c=t1)

    init = dict(a=3)

    c, tape = m.compute(init=init, vout='c', return_tape=True)
    assert c == 3

    vjp = tape.get_vjp()
    init = dict(_c=ZeroGradient)
    _a = vjp.compute(init=init, vout='_a', monitor=print)
    assert _a == 0

    jvp = tape.get_jvp()
    init = dict(a_=ZeroGradient)
    c_ = jvp.compute(init=init, vout='c_', monitor=print)
    assert c_ == 0
示例#23
0
def test_model_attr():
    import numpy
    with Builder() as m:
        a, b = m.input('a', 'b')
        d = add(x1=b, x2=1)
        t1 = add(x1=a, x2=b.eval(lambda b: b.size))
        m.output(c=t1)

    init = dict(a=2, b=numpy.array([2,]))

    c, tape = m.compute(init=init, vout='c', return_tape=True)
    assert c == 3

    vjp = tape.get_vjp()
    init = dict(_c=1.0)
    _a = vjp.compute(init=init, vout='_a', monitor=print)
    assert _a == 1.0

    jvp = tape.get_jvp()
    init = dict(a_=1.0)
    c_ = jvp.compute(init=init, vout='c_', monitor=print)
    assert c_ == 1.0
示例#24
0
def test_operator_multi_out():
    @operator
    class op:
        ain = 'x'
        # for python 2.x need to use this syntax
        # to preserve orders
        aout = 'y1', 'y2'

        def apl(node, x):
            return dict(y1=x, y2=2 * x)

        def vjp(node, _y1, _y2):
            return dict(_x=_y1 + 2 * _y2)

        def jvp(node, x_):
            return dict(y1_=x_, y2_=2 * x_)

    with Builder() as m:
        a = m.input('a')
        t1, t2 = op(x=a)
        m.output(c=t1, d=t2)

    init = dict(a=3)

    (c, d), tape = m.compute(init=init, vout=('c', 'd'), return_tape=True)
    assert c == 3
    assert d == 6

    vjp = tape.get_vjp()
    init = dict(_c=1, _d=1)
    _a = vjp.compute(init=init, vout='_a', monitor=print)
    assert _a == 3

    jvp = tape.get_jvp()
    init = dict(a_=1)
    c_, d_ = jvp.compute(init=init, vout=('c_', 'd_'), monitor=print)
    assert c_ == 1
    assert d_ == 2
示例#25
0
def test_operator_defaults():
    @operator
    class with_defaults:
        ain = 'x'
        aout = 'y'

        def apl(node, x, defaults=False):
            assert defaults == False
            return x

        def vjp(node, _y, defaults=False):
            assert defaults == False
            return _y

        def jvp(node, x_, defaults=False):
            assert defaults == False
            return x_

    with Builder() as m:
        a = m.input('a')
        t1 = with_defaults(x=a)
        m.output(c=t1)

    init = dict(a=3)

    c, tape = m.compute(init=init, vout='c', return_tape=True)
    assert c == 3

    vjp = tape.get_vjp()
    init = dict(_c=1)
    _a = vjp.compute(init=init, vout='_a', monitor=print)
    assert _a == 1

    jvp = tape.get_jvp()
    init = dict(a_=1)
    c_ = jvp.compute(init=init, vout='c_', monitor=print)
    assert c_ == 1
示例#26
0
def test_model_unused():
    with Builder() as m:
        a, b = m.input('a', 'b')
        # use a twice with a dependency
        # triggers problem with last_ref in autodiff;
        # because this line is not executed by the context;
        # last_ref is not True for the last ref on the tape.
        d = (a + b) + a
        m.output(c=1.0)

    init = dict(a=3, b=4)
    c, tape = m.compute(init=init, vout='c', return_tape=True)
    assert c == 1.0

    vjp = tape.get_vjp()
    init = dict(_c=1.0)
    _a, _b = vjp.compute(init=init, vout=['_a', '_b'], monitor=print)
    assert _a == 0
    assert _b == 0

    jvp = tape.get_jvp()
    init = dict(a_=1.0, b_=1.0)
    c_ = jvp.compute(init=init, vout='c_', monitor=print)
    assert c_ == 0
示例#27
0
def test_autooperator_precompute2():

    op1 = example_func.precompute(n=2, x=1)
    m = example_func.build(n=2)

    with Builder() as m:
        a = m.input('a')
        b = example_func(a, 2)
        c = op1(a, n=2)
        m.output(b=b, c=c)

    init = dict(a=1.0)
    (b, c), tape = m.compute(init=init,
                             vout=['b', 'c'],
                             monitor=print,
                             return_tape=True)

    assert b == c
    assert b == 4.0

    assert example_func.__bound_tape__ is None
    assert op1.__bound_tape__ is not None

    op1.build()
示例#28
0
def test_operator_list_in():
    from numpy.testing import assert_array_equal

    with Builder() as m:
        a = m.input('a')
        t = stack(args=[a, a, a], axis=1)
        m.output(c=t)

    init = dict(a=[1, 2])

    c, tape = m.compute(init=init, vout='c', return_tape=True)
    assert_array_equal(c, [[1, 1, 1], [2, 2, 2]])

    vjp = tape.get_vjp()
    init = dict(_c=[[1, 1, 1], [1, 1, 1]])
    _a = vjp.compute(init=init, vout='_a', monitor=print)

    assert_array_equal(_a, [3, 3])

    jvp = tape.get_jvp()
    init = dict(a_=[1, 1])
    c_ = jvp.compute(init=init, vout='c_', monitor=print)

    assert_array_equal(c_, [[1, 1, 1], [1, 1, 1]])
示例#29
0
def test_error_missing():
    from vmad.core.error import MissingArgument
    with Builder() as m:
        a = m.input('a')
        with pytest.raises(MissingArgument):
            add(x2=1)
示例#30
0
def test_error_overwrite():
    from vmad.core.error import OverwritePrecaution
    with Builder() as m:
        a = m.input('a')
        with pytest.raises(OverwritePrecaution):
            add(x1=a, x2=a, y=a)