def test_literal(): engine = MyEngine() code = CodeSegment(engine) code.binary(x1='a', x2=Literal(2.0), y='a') code.batch(u=Literal(2.0), v='d') d, a, _a = code.compute_with_gradient(['d', 'a', '_a'], {'a' : 1.0}, {'_a': 1.0, '_d' : 0.0}) assert_array_equal(a, 3.0) assert_array_equal(_a, 1.0)
def test_jvp_programme(): engine = MyEngine() code = CodeSegment(engine) code.batch(u='a', v='d') code.unitary(x='d', y='d', factor=2.0) code.batch_with_exarg(u='a', v='e', factor=3.0) jvp = code.get_jvp(init={'a' : 1.0}) d_, e_ = jvp.compute(['d_', 'e_'], {'a_' : 1.0}) assert_array_equal(d_, 4.0) assert_array_equal(e_, 6.0)
def test_copy(): engine = MyEngine() code = CodeSegment(engine) code.batch_with_sub(u='a', v='e') code.unitary(x='a', y='a', factor=3.0) code.unitary(x='a', y='b1', factor=3.0) code.unitary(x='a', y='b2', factor=3.0) code.binary(x1='b1', x2='b2', y='b1') code.unitary(x='b1', y='d', factor=3.0) code.batch(u='b2', v='f') code2 = code.copy()
def test_jvp_vector(): engine = MyEngine() code = CodeSegment(engine) code.batch(u='a', v='d') code.unitary(x='d', y='d', factor=2.0) code.batch_with_exarg(u='a', v='e', factor=3.0) A = numpy.array jvp = code.get_jvp(init={'a' : A([1.0, 1.0])}) d_, e_ = jvp.compute(['d_', 'e_'], {'a_' : A([1.0, 1.0])}) assert_array_equal(d_, [4.0, 4.0]) assert_array_equal(e_, [6.0, 6.0])
def test_programme(): engine = MyEngine() code = CodeSegment(engine) code.batch(u='a', v='d') code.batch_with_exarg(u='a', v='e', factor=3.0) (d, e), tape = code.compute(('d', 'e'), {'a' : 1.0}, return_tape=True) assert_array_equal(d, 2.0) assert_array_equal(e, 6.0) e, d, _a = code.compute_with_gradient(['e', 'd', '_a'], {'a' : 1.0}, {'_d': 1.0, '_e' : 0.0}) assert_array_equal(d, 2.0) assert_array_equal(e, 6.0) assert_array_equal(_a, 2.0) e, d, _a = code.compute_with_gradient(['e', 'd', '_a'], {'a' : 1.0}, {'_d': 0.0, '_e' : 1.0}) assert_array_equal(d, 2.0) assert_array_equal(e, 6.0) assert_array_equal(_a, 6.0)
def test_to_graph(): engine = MyEngine() code = CodeSegment(engine) code.batch_with_sub(u='a', v='e') code.unitary(x='a', y='a', factor=3.0) code.unitary(x='a', y='b1', factor=3.0) code.unitary(x='a', y='b2', factor=3.0) code.binary(x1='b1', x2='b2', y='b1') code.unitary(x='b1', y='d', factor=3.0) code.batch(u='b2', v='f') d, tape = code.compute(('e', 'a', 'f', 'd'), {'a' : 1.0}, return_tape=True) vjp = tape.get_vjp() graph1 = code.to_graph() graph2 = vjp.to_graph()