def test_run_cond_python(self): true_fn = lambda: 2.0 false_fn = lambda: 3.0 self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2.0) self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3.0)
def test_run_cond_tf(self): true_fn = lambda: constant([2.0]) false_fn = lambda: constant([3.0]) with Session() as sess: out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn) self.assertEqual(sess.run(out), 2.0) out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn) self.assertEqual(sess.run(out), 3.0)