def test_run_while_python(self): cond_fn = lambda x, t, s: x > t body_fn = lambda x, t, s: (x * s, t, s) x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 1.0, 0.5]) self.assertEqual(x, 0.75) x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 4.0, 0.5]) self.assertEqual(x, 3.0)
def test_run_while_tf(self): cond_fn = lambda x, t, s: x > t body_fn = lambda x, t, s: (x * s, t, s) with Session() as sess: x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [constant(3.0), 1.0, 0.5]) self.assertEqual(sess.run(x), 0.75) x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [constant(3.0), 4.0, 0.5]) self.assertEqual(sess.run(x), 3.0)