def test_value(self): fn1 = lambda: 'fn1' fn2 = lambda: 'fn2' expected = lambda v: 'fn1' if v else 'fn2' for v in [True, False, 1, 0]: o = utils.static_cond(v, fn1, fn2) self.assertEqual(o, expected(v))
def test_constant(self): fn1 = lambda: constant_op.constant('fn1') fn2 = lambda: constant_op.constant('fn2') expected = lambda v: b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.static_cond(v, fn1, fn2) with self.cached_session(): self.assertEqual(o.eval(), expected(v))
def test_tensors(self): fn1 = lambda: constant_op.constant(0) - constant_op.constant(1) fn2 = lambda: constant_op.constant(0) - constant_op.constant(2) expected = lambda v: -1 if v else -2 for v in [True, False, 1, 0]: o = utils.static_cond(v, fn1, fn2) with self.cached_session(): self.assertEqual(o.eval(), expected(v))
def test_variable(self): fn1 = lambda: variables.Variable('fn1') fn2 = lambda: variables.Variable('fn2') expected = lambda v: b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.static_cond(v, fn1, fn2) with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(o.eval(), expected(v))