def test_index_to_eqn(self): a, b = tracer.Tracer('fake_output', 2) eqns, outputs = tracer.traces_to_eqns((a, b)) result0 = [ tracer.IndexEqn(0, 'fake_output', 'var0'), tracer.IndexEqn(1, 'fake_output', 'var1') ] result1 = ('var0', 'var1') self.assertEqual(eqns, result0) self.assertEqual(outputs, result1)
def test_apply_to_eqn(self): lyr = cb.Add() a = tracer.Tracer('a') b = tracer.Tracer('b') c = lyr @ (a, b) eqns, outputs = tracer.traces_to_eqns(c) result0 = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', ))] result1 = ('var0', ) self.assertEqual(eqns, result0) self.assertEqual(outputs, result1)
def test_apply_index_to_eqn(self): lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh()) a = tracer.Tracer('a') b = tracer.Tracer('b') c, d = lyr @ (a, b) eqns, outputs = tracer.traces_to_eqns((c, d)) result0 = [ tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )), tracer.IndexEqn(0, 'var2', 'var0'), tracer.IndexEqn(1, 'var2', 'var1') ] result1 = ('var0', 'var1') self.assertEqual(eqns, result0) self.assertEqual(outputs, result1)