Example #1
0
 def test_value(self):
     fn1 = lambda: 'fn1'
     fn2 = lambda: 'fn2'
     expected = lambda v: 'fn1' if v else 'fn2'
     for v in [True, False, 1, 0]:
         o = utils.static_cond(v, fn1, fn2)
         self.assertEqual(o, expected(v))
Example #2
0
 def test_value(self):
   fn1 = lambda: 'fn1'
   fn2 = lambda: 'fn2'
   expected = lambda v: 'fn1' if v else 'fn2'
   for v in [True, False, 1, 0]:
     o = utils.static_cond(v, fn1, fn2)
     self.assertEqual(o, expected(v))
Example #3
0
 def test_tensors(self):
     fn1 = lambda: tf.constant(0) - tf.constant(1)
     fn2 = lambda: tf.constant(0) - tf.constant(2)
     expected = lambda v: -1 if v else -2
     for v in [True, False, 1, 0]:
         o = utils.static_cond(v, fn1, fn2)
         with self.test_session():
             self.assertEqual(o.eval(), expected(v))
Example #4
0
 def test_constant(self):
     fn1 = lambda: tf.constant('fn1')
     fn2 = lambda: tf.constant('fn2')
     expected = lambda v: b'fn1' if v else b'fn2'
     for v in [True, False, 1, 0]:
         o = utils.static_cond(v, fn1, fn2)
         with self.test_session():
             self.assertEqual(o.eval(), expected(v))
Example #5
0
 def test_constant(self):
   fn1 = lambda: constant_op.constant('fn1')
   fn2 = lambda: constant_op.constant('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   for v in [True, False, 1, 0]:
     o = utils.static_cond(v, fn1, fn2)
     with self.test_session():
       self.assertEqual(o.eval(), expected(v))
Example #6
0
 def test_tensors(self):
   fn1 = lambda: constant_op.constant(0) - constant_op.constant(1)
   fn2 = lambda: constant_op.constant(0) - constant_op.constant(2)
   expected = lambda v: -1 if v else -2
   for v in [True, False, 1, 0]:
     o = utils.static_cond(v, fn1, fn2)
     with self.test_session():
       self.assertEqual(o.eval(), expected(v))
Example #7
0
 def test_variable(self):
     fn1 = lambda: tf.Variable('fn1')
     fn2 = lambda: tf.Variable('fn2')
     expected = lambda v: b'fn1' if v else b'fn2'
     for v in [True, False, 1, 0]:
         o = utils.static_cond(v, fn1, fn2)
         with self.test_session() as sess:
             sess.run(tf.global_variables_initializer())
             self.assertEqual(o.eval(), expected(v))
Example #8
0
 def test_variable(self):
   fn1 = lambda: variables.Variable('fn1')
   fn2 = lambda: variables.Variable('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   for v in [True, False, 1, 0]:
     o = utils.static_cond(v, fn1, fn2)
     with self.test_session() as sess:
       sess.run(variables.global_variables_initializer())
       self.assertEqual(o.eval(), expected(v))