def test_value(self): def fn1(): return 'fn1' def fn2(): return 'fn2' def expected(v): return 'fn1' if v else 'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) self.assertEqual(o, expected(v))
def test_tensors(self): def fn1(): return constant_op.constant(0) - constant_op.constant(1) def fn2(): return constant_op.constant(0) - constant_op.constant(2) def expected(v): return -1 if v else -2 for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) with self.cached_session(): self.assertEqual(o.eval(), expected(v))
def test_constant(self): def fn1(): return constant_op.constant('fn1') def fn2(): return constant_op.constant('fn2') def expected(v): return b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) with self.cached_session(): self.assertEqual(o.eval(), expected(v))
def test_tensors(self): def fn1(): return constant_op.constant(0) - constant_op.constant(1) def fn2(): return constant_op.constant(0) - constant_op.constant(2) def expected(v): return -1 if v else -2 p = array_ops.placeholder(dtypes.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) with self.cached_session(): self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_value(self): def fn1(): return ops.convert_to_tensor('fn1') def fn2(): return ops.convert_to_tensor('fn2') def expected(v): return b'fn1' if v else b'fn2' p = array_ops.placeholder(dtypes.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) with self.cached_session(): self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_variable(self): with tf.Graph().as_default(): def fn1(): return variables.Variable('fn1') def fn2(): return variables.Variable('fn2') def expected(v): return b'fn1' if v else b'fn2' p = array_ops.placeholder(dtypes.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_variable(self): with tf.Graph().as_default(): def fn1(): return variables.Variable('fn1') def fn2(): return variables.Variable('fn2') def expected(v): return b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(o.eval(), expected(v))