Esempio n. 1
0
 def test_tracer_apply(self):
     lyr = cb.Add()
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c = lyr @ (a, b)
     result = tracer.ApplyExpr(lyr, ('a', 'b'))
     self.assertEqual(c.expr, result)
Esempio n. 2
0
 def test_tracer_index(self):
     lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh())
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     d, e = lyr @ (a, b)
     result0 = tracer.IndexExpr(0, tracer.ApplyExpr(lyr, ('a', 'b')))
     result1 = tracer.IndexExpr(1, tracer.ApplyExpr(lyr, ('a', 'b')))
     self.assertEqual(d.expr, result0)
     self.assertEqual(e.expr, result1)
Esempio n. 3
0
 def test_apply_to_eqn(self):
     lyr = cb.Add()
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c = lyr @ (a, b)
     eqns, outputs = tracer.traces_to_eqns(c)
     result0 = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', ))]
     result1 = ('var0', )
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)
Esempio n. 4
0
 def test_apply_index_to_eqn(self):
     lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh())
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c, d = lyr @ (a, b)
     eqns, outputs = tracer.traces_to_eqns((c, d))
     result0 = [
         tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )),
         tracer.IndexEqn(0, 'var2', 'var0'),
         tracer.IndexEqn(1, 'var2', 'var1')
     ]
     result1 = ('var0', 'var1')
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)
Esempio n. 5
0
 def test_index_to_eqn(self):
     a, b = tracer.Tracer('fake_output', 2)
     eqns, outputs = tracer.traces_to_eqns((a, b))
     result0 = [
         tracer.IndexEqn(0, 'fake_output', 'var0'),
         tracer.IndexEqn(1, 'fake_output', 'var1')
     ]
     result1 = ('var0', 'var1')
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)