示例#1
0
 def test_index_to_eqn(self):
     a, b = tracer.Tracer('fake_output', 2)
     eqns, outputs = tracer.traces_to_eqns((a, b))
     result0 = [
         tracer.IndexEqn(0, 'fake_output', 'var0'),
         tracer.IndexEqn(1, 'fake_output', 'var1')
     ]
     result1 = ('var0', 'var1')
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)
示例#2
0
 def test_apply_to_eqn(self):
     lyr = cb.Add()
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c = lyr @ (a, b)
     eqns, outputs = tracer.traces_to_eqns(c)
     result0 = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', ))]
     result1 = ('var0', )
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)
示例#3
0
 def test_apply_index_to_eqn(self):
     lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh())
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c, d = lyr @ (a, b)
     eqns, outputs = tracer.traces_to_eqns((c, d))
     result0 = [
         tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )),
         tracer.IndexEqn(0, 'var2', 'var0'),
         tracer.IndexEqn(1, 'var2', 'var1')
     ]
     result1 = ('var0', 'var1')
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)