def test_error_infer(): from vmad.core.error import InferError with Builder() as m: with Builder() as m2: a = m.input('a') b = m2.input('b') with pytest.raises(InferError): c = add(x1=a, x2=b)
def test_model_partial_out(): with Builder() as m: a = m.input('a') t1 = add(x1=a, x2=a) m.output(c=t1) m.output(a=a) init = dict(a=3) (a, c), tape = m.compute(init=init, vout=['a', 'c'], return_tape=True) assert c == 6 assert a == 3 vjp = tape.get_vjp() # test two outputs individually init = dict(_c=1.0, _a=0.0) _a = vjp.compute(init=init, vout='_a', monitor=print) assert _a == 2.0 init = dict(_c=0.0, _a=1.0) _a = vjp.compute(init=init, vout='_a', monitor=print) assert _a == 1.0 jvp = tape.get_jvp() init = dict(a_=1.0) a_, c_ = jvp.compute(init=init, vout=['a_', 'c_'], monitor=print) assert c_ == 2.0 assert a_ == 1.0
def test_model_nested(): with Builder() as m: a = m.input('a') b = example(a, 2) c = example.forgetful(a, 2) m.output(b=b, c=c) init = dict(a=1.0) (b, c), tape = m.compute(init=init, vout=['b', 'c'], monitor=print, return_tape=True) assert b == 4.0 assert c == 4.0 vjp = tape.get_vjp() init = dict(_b=1.0, _c=1.0) _a = vjp.compute(init=init, vout='_a', monitor=print) assert _a == 8.0 jvp = tape.get_jvp() init = dict(a_=1.0) b_, c_ = jvp.compute(init=init, vout=['b_', 'c_'], monitor=print) assert b_ == 4.0 assert c_ == 4.0
def test_operator_list_out(): from numpy.testing import assert_array_equal from vmad.core.symbol import ListPlaceholder from vmad.core.symbol import List with Builder() as m: a = m.input('a') t = split(x=a, axis=0, args=ListPlaceholder(2)) assert isinstance(t, List) assert len(t) == 2 m.output(c=t) init = dict(a=[[1, 1], [2, 2]]) c, tape = m.compute(init=init, vout='c', return_tape=True, monitor=print) assert_array_equal(c, [[1, 1], [2, 2]]) vjp = tape.get_vjp() init = dict(_c=[[1, 1], [1, 1]]) _a = vjp.compute(init=init, vout='_a', monitor=print) assert_array_equal(_a, [[1, 1], [1, 1]]) jvp = tape.get_jvp() init = dict(a_=[[1, 1], [1, 1]]) c_ = jvp.compute(init=init, vout='c_', monitor=print) assert_array_equal(c_, [[1, 1], [1, 1]])
def test_operator_default_jvp(): @operator class op: ain = 'x' aout = 'y1' def apl(node, x): return dict(y1=x * 2) def vjp(node, _y1): return dict(_x=_y1 * 2) with Builder() as m: a = m.input('a') t1 = op(x=a) m.output(c=t1) init = dict(a=3) c, tape = m.compute(init=init, vout='c', return_tape=True) assert c == 6 vjp = tape.get_vjp() init = dict(_c=1) _a = vjp.compute(init=init, vout='_a', monitor=print) assert _a == 2 jvp = tape.get_jvp() init = dict(a_=1) c_ = jvp.compute(init=init, vout=('c_'), monitor=print) assert c_ == 2
def test_autooperator_bind(): op1 = example.bind(n=2) m = example.build(n=2) assert 'n' in op1.hyperargs assert 'n' not in example.hyperargs with Builder() as m: a = m.input('a') b = example(a, 2) c = op1(a) m.output(b=b, c=c) init = dict(a=1.0) (b, c), tape = m.compute(init=init, vout=['b', 'c'], monitor=print, return_tape=True) assert b == c assert b == 4.0 assert example.__bound_model__ is None assert op1.__bound_model__ is not None op1.build()
def test_error_many_output(): from vmad.core.error import DuplicatedOutput with Builder() as m: a = m.input('a') with pytest.raises(DuplicatedOutput): m.output(a=a) m.output(a=a)
def test_operator_watchpoint(): from vmad.core.stdlib import watchpoint foo = [0] def monitor(x): foo[0] = 1 with Builder() as m: a = m.input('a') b = linalg.add(a, a) watchpoint(a, monitor=monitor) m.output(c=b) init = [('a', 1)] # for node in m: # print('m', node) c, tape = m.compute(init=init, vout='c', return_tape=True) # vjp = tape.get_vjp() # for node in vjp: # print('vjp', node) [c], [_a] = m.compute_with_vjp(init=init, v=[('_c', 1.0)]) assert foo[0] == 1 assert c == 2 assert _a == 2.0 foo[0] = 0 c, c_ = m.compute_with_jvp(vout='c', init=init, v=[('a_', 1.0)]) assert foo[0] == 1 assert c == 2 assert c_ == 2.0
def test_operator_record_extra(): # assert used extra args are recored on the tape @operator class myrecord: ain = 'x' aout = 'y' def apl(node, x, p): return dict(y=x * p, extra=p) def rcd(node, x, p, y, extra): return dict(x=x, u=p, extra=extra) def vjp(node, x, _y, u): return _y * u def jvp(node, x, x_, u): return x_ * u with Builder() as m: a = m.input('a') b = myrecord(x=a, p=2.0) m.output(b=b) init = dict(a=1.0) b, tape = m.compute(init=init, vout='b', monitor=print, return_tape=True) assert b == 2.0 assert 'p' not in tape[0].impl_kwargs assert 'u' in tape[0].impl_kwargs
def test_error_unexpected_output(): from vmad.core.error import ResolveError with Builder() as m: a = m.input('a') m.output(a=a) with pytest.raises(ResolveError): m.compute(vout='a', init={})
def test_error_unexpected_output(): from vmad.core.error import UnexpectedOutput with Builder() as m: a = m.input('a') m.output(a=a) with pytest.raises(UnexpectredOutput): m.compute(vout='b', init=dict(a=1.0))
def test_model_compute_with_jvp(): with Builder() as m: a = m.input('a') t1 = add(x1=a, x2=a) m.output(b=t1) init = [('a', 1)] b, b_ = m.compute_with_jvp(vout='b', init=init, v=[('a_', 1.0)]) assert b == 2.0 assert b_ == 2.0
def test_model_compute_with_vjp(): with Builder() as m: a = m.input('a') t1 = add(x1=a, x2=a) m.output(b=t1) init = [('a', 1)] [b], [_a] = m.compute_with_vjp(init=init, v=[('_b', 1.0)]) assert b == 2.0 assert _a == 2.0
def test_operator_assert_true(): from vmad.core.stdlib import assert_true with Builder() as m: a = m.input('a') assert_true(a, lambda x: isinstance(x, int)) m.output(c=a) with pytest.raises(AssertionError): c = m.compute(vout='c', init=dict(a=1.09)) c = m.compute(vout='c', init=dict(a=1))
def test_model_compute_with_gnDp(): with Builder() as m: a = m.input('a') t1 = mul(x1=a, x2=a) m.output(b=t1) init = [('a', 1)] b, [_a_] = m.compute_with_gnDp( vout='b', init=init, v=[('a_', 1.0)], ) assert b == 1.0 assert _a_ == 4.0
def test_operator_list_with_zero_vjp(): with Builder() as m: value = m.input('value') r = zero_jac([value]) m.output(my_array=r) init = dict(value=1) my_array, tape = m.compute(init=init, vout='my_array', return_tape=True) assert my_array == 3 vjp = tape.get_vjp() _value = vjp.compute(init=dict(_my_array=5), vout='_value') assert _value == 0 jvp = tape.get_jvp() my_array_ = jvp.compute(init=dict(value_=5), vout='my_array_') assert my_array_ == 0
def test_model_many_rewrites(): # this is a nasty model with many variable rewrites. n = 2 with Builder() as m: x = m.input('x') for i in range(2): x = add(x1=x, x2=x) m.output(y=x) init = dict(x=1.0) y, tape = m.compute(init=init, vout='y', return_tape=True) assert y == 4.0 vjp = tape.get_vjp() init = dict(_y = 1.0) _x = vjp.compute(init=init, vout='_x', monitor=print) assert _x == 4.0
def test_autooperator_as_member(): class MyType(object): def __init__(self): self.n = 3 @autooperator('x->y') def example_func(self, x): return dict(y=x * self.n) obj = MyType() with Builder() as m: a = m.input('a') b = obj.example_func(a) m.output(b=b) y = m.compute(vout='b', init=dict(a=1)) assert y == 3 y = obj.example_func.build().compute(vout='y', init=dict(x=1.)) assert y == 3
def test_model_partial(): with Builder() as m: a = m.input('a') t1 = add(x1=a, x2=a) m.output(c=t1) init = dict(a=3) c, tape = m.compute(init=init, vout='c', return_tape=True) assert c == 6 vjp = tape.get_vjp() init = dict(_c=1.0) _a = vjp.compute(init=init, vout='_a', monitor=print) assert _a == 2.0 jvp = tape.get_jvp() init = dict(a_=1.0) c_ = jvp.compute(init=init, vout='c_', monitor=print) assert c_ == 2.0
def test_operator_skip_unused(): with Builder() as m: a = m.input('a') t1 = error(x=a) m.output(c=a) init = dict(a=3) c, tape = m.compute(init=init, vout='c', return_tape=True) assert c == 3 vjp = tape.get_vjp() init = dict(_c=0) _a = vjp.compute(init=init, vout='_a', monitor=print) assert _a == 0 jvp = tape.get_jvp() init = dict(a_=0) c_ = jvp.compute(init=init, vout='c_', monitor=print) assert c_ == 0
def test_operator_assert_isinstance(): from vmad.core.stdlib import assert_isinstance with Builder() as m: a = m.input('a') assert_isinstance(a, int) m.output(c=a) init = [('a', 1)] c, tape = m.compute(init=init, return_tape=True, vout='c') [c], [_a] = m.compute_with_vjp(init=init, v=[('_c', 1.0)]) assert c == 1 assert _a == 1.0 c, c_ = m.compute_with_jvp(vout='c', init=init, v=[('a_', 1.0)]) assert c == 1 assert c_ == 1.0 with pytest.raises(TypeError): c = m.compute(vout='c', init=dict(a=1.09))
def test_operator_zero(): with Builder() as m: a = m.input('a') t1 = error_on_grad(x=a) m.output(c=t1) init = dict(a=3) c, tape = m.compute(init=init, vout='c', return_tape=True) assert c == 3 vjp = tape.get_vjp() init = dict(_c=ZeroGradient) _a = vjp.compute(init=init, vout='_a', monitor=print) assert _a == 0 jvp = tape.get_jvp() init = dict(a_=ZeroGradient) c_ = jvp.compute(init=init, vout='c_', monitor=print) assert c_ == 0
def test_model_attr(): import numpy with Builder() as m: a, b = m.input('a', 'b') d = add(x1=b, x2=1) t1 = add(x1=a, x2=b.eval(lambda b: b.size)) m.output(c=t1) init = dict(a=2, b=numpy.array([2,])) c, tape = m.compute(init=init, vout='c', return_tape=True) assert c == 3 vjp = tape.get_vjp() init = dict(_c=1.0) _a = vjp.compute(init=init, vout='_a', monitor=print) assert _a == 1.0 jvp = tape.get_jvp() init = dict(a_=1.0) c_ = jvp.compute(init=init, vout='c_', monitor=print) assert c_ == 1.0
def test_operator_multi_out(): @operator class op: ain = 'x' # for python 2.x need to use this syntax # to preserve orders aout = 'y1', 'y2' def apl(node, x): return dict(y1=x, y2=2 * x) def vjp(node, _y1, _y2): return dict(_x=_y1 + 2 * _y2) def jvp(node, x_): return dict(y1_=x_, y2_=2 * x_) with Builder() as m: a = m.input('a') t1, t2 = op(x=a) m.output(c=t1, d=t2) init = dict(a=3) (c, d), tape = m.compute(init=init, vout=('c', 'd'), return_tape=True) assert c == 3 assert d == 6 vjp = tape.get_vjp() init = dict(_c=1, _d=1) _a = vjp.compute(init=init, vout='_a', monitor=print) assert _a == 3 jvp = tape.get_jvp() init = dict(a_=1) c_, d_ = jvp.compute(init=init, vout=('c_', 'd_'), monitor=print) assert c_ == 1 assert d_ == 2
def test_operator_defaults(): @operator class with_defaults: ain = 'x' aout = 'y' def apl(node, x, defaults=False): assert defaults == False return x def vjp(node, _y, defaults=False): assert defaults == False return _y def jvp(node, x_, defaults=False): assert defaults == False return x_ with Builder() as m: a = m.input('a') t1 = with_defaults(x=a) m.output(c=t1) init = dict(a=3) c, tape = m.compute(init=init, vout='c', return_tape=True) assert c == 3 vjp = tape.get_vjp() init = dict(_c=1) _a = vjp.compute(init=init, vout='_a', monitor=print) assert _a == 1 jvp = tape.get_jvp() init = dict(a_=1) c_ = jvp.compute(init=init, vout='c_', monitor=print) assert c_ == 1
def test_model_unused(): with Builder() as m: a, b = m.input('a', 'b') # use a twice with a dependency # triggers problem with last_ref in autodiff; # because this line is not executed by the context; # last_ref is not True for the last ref on the tape. d = (a + b) + a m.output(c=1.0) init = dict(a=3, b=4) c, tape = m.compute(init=init, vout='c', return_tape=True) assert c == 1.0 vjp = tape.get_vjp() init = dict(_c=1.0) _a, _b = vjp.compute(init=init, vout=['_a', '_b'], monitor=print) assert _a == 0 assert _b == 0 jvp = tape.get_jvp() init = dict(a_=1.0, b_=1.0) c_ = jvp.compute(init=init, vout='c_', monitor=print) assert c_ == 0
def test_autooperator_precompute2(): op1 = example_func.precompute(n=2, x=1) m = example_func.build(n=2) with Builder() as m: a = m.input('a') b = example_func(a, 2) c = op1(a, n=2) m.output(b=b, c=c) init = dict(a=1.0) (b, c), tape = m.compute(init=init, vout=['b', 'c'], monitor=print, return_tape=True) assert b == c assert b == 4.0 assert example_func.__bound_tape__ is None assert op1.__bound_tape__ is not None op1.build()
def test_operator_list_in(): from numpy.testing import assert_array_equal with Builder() as m: a = m.input('a') t = stack(args=[a, a, a], axis=1) m.output(c=t) init = dict(a=[1, 2]) c, tape = m.compute(init=init, vout='c', return_tape=True) assert_array_equal(c, [[1, 1, 1], [2, 2, 2]]) vjp = tape.get_vjp() init = dict(_c=[[1, 1, 1], [1, 1, 1]]) _a = vjp.compute(init=init, vout='_a', monitor=print) assert_array_equal(_a, [3, 3]) jvp = tape.get_jvp() init = dict(a_=[1, 1]) c_ = jvp.compute(init=init, vout='c_', monitor=print) assert_array_equal(c_, [[1, 1, 1], [1, 1, 1]])
def test_error_missing(): from vmad.core.error import MissingArgument with Builder() as m: a = m.input('a') with pytest.raises(MissingArgument): add(x2=1)
def test_error_overwrite(): from vmad.core.error import OverwritePrecaution with Builder() as m: a = m.input('a') with pytest.raises(OverwritePrecaution): add(x1=a, x2=a, y=a)