Esempio n. 1
0
 def test_value(self):
   fn1 = lambda: 'fn1'
   fn2 = lambda: 'fn2'
   expected = lambda v: 'fn1' if v else 'fn2'
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
     self.assertEqual(o, expected(v))
Esempio n. 2
0
 def test_tensors(self):
   fn1 = lambda: constant_op.constant(0) - constant_op.constant(1)
   fn2 = lambda: constant_op.constant(0) - constant_op.constant(2)
   expected = lambda v: -1 if v else -2
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
     with self.cached_session():
       self.assertEqual(o.eval(), expected(v))
Esempio n. 3
0
 def test_constant(self):
   fn1 = lambda: constant_op.constant('fn1')
   fn2 = lambda: constant_op.constant('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
     with self.cached_session():
       self.assertEqual(o.eval(), expected(v))
Esempio n. 4
0
 def test_tensors(self):
   fn1 = lambda: constant_op.constant(0) - constant_op.constant(1)
   fn2 = lambda: constant_op.constant(0) - constant_op.constant(2)
   expected = lambda v: -1 if v else -2
   p = array_ops.placeholder(dtypes.bool, [])
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(p, fn1, fn2)
     with self.cached_session():
       self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
Esempio n. 5
0
 def test_value(self):
   fn1 = lambda: ops.convert_to_tensor('fn1')
   fn2 = lambda: ops.convert_to_tensor('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   p = array_ops.placeholder(dtypes.bool, [])
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(p, fn1, fn2)
     with self.cached_session():
       self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
 def test_variable(self):
   fn1 = lambda: variables.Variable('fn1')
   fn2 = lambda: variables.Variable('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
     with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       self.assertEqual(o.eval(), expected(v))
 def test_variable(self):
   fn1 = lambda: variables.Variable('fn1')
   fn2 = lambda: variables.Variable('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   p = array_ops.placeholder(dtypes.bool, [])
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(p, fn1, fn2)
     with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       self.assertEqual(o.eval(feed_dict={p: v}), expected(v))