def test_cond_traces_branches_with_same_id_once(self): witness = [] def f(x): witness.append(None) return x ** 2 stateful.cond(False, f, f, 0) hk_call_count = len(witness) self.assertEqual(hk_call_count, 1) # Ensure we are in sync with JAX. del witness[:] jax.lax.cond(False, f, f, 0) jax_call_count = len(witness) self.assertEqual(hk_call_count, jax_call_count)
def test_cond_two_args(self): a, b = stateful.cond(True, lambda a, b: (b, a), lambda a, b: (a, b), 2, 1) self.assertEqual(a, 1) self.assertEqual(b, 2)
def test_cond_three_args(self): a, b, c = stateful.cond(True, lambda a, b, c: (c, b, a), lambda a, b, c: (a, b, c), 3, 2, 1) self.assertEqual(a, 1) self.assertEqual(b, 2) self.assertEqual(c, 3)
def test_stateful_cond_with_rng_use(self): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 # Test if using different amount of keys in different branches # results in error def true_branch(x): _ = base.next_rng_key() return x def false_branch(x): _ = base.next_rng_key() _ = base.next_rng_key() return x base.reserve_rng_keys(5) _ = stateful.cond(True, true_branch, false_branch, 0) _ = stateful.cond(False, true_branch, false_branch, 0) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
def test_cond_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.cond() instead"): stateful.cond(x == 2, x, lambda x: x**2, x, lambda x: (x + 1)**2)
def f(x): mod = SquareModule() return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1))
def f(x): mod = SquareModule() if single_arg: return stateful.cond(x == 2, mod, lambda x: mod(x + 1), x) else: return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1))
def test_cond_operand_kwarg_and_operands(self): with self.assertRaisesRegex(ValueError, "cannot.*pass.*positionally"): stateful.cond(True, lambda x: x + 5, lambda x: x + 4, 1, operand=1)
def test_cond_operand_kwarg(self): x = stateful.cond(True, lambda x: x + 5, lambda x: x + 4, operand=1) self.assertEqual(x, 6)
def test_cond_no_args(self): x = stateful.cond(True, lambda: 5, lambda: 4) self.assertEqual(x, 5)
# ("make_jaxpr", stateful.make_jaxpr), ("eval_shape", lambda f: (lambda x: [f(x), stateful.eval_shape(f, x)])), ("named_call", stateful.named_call), # Parallelization. # TODO(tomhennigan): Add missing features (e.g. pjit,xmap). # ("pmap", lambda f: stateful.pmap(f, "i")), # Vectorization. ("vmap", lambda f: stateful.vmap(f, split_rng=False)), # Control flow. # TODO(tomhennigan): Enable for associative_scan. # ("associative_scan", lambda f: # (lambda x: jax.lax.associative_scan(f, x))), ("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))), ("fori_loop", lambda f: (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))), # ("map", lambda f: (lambda x: stateful.map(f, x))), ("scan", lambda f: (lambda x: stateful.scan(base_test.identity_carry(f), None, x))), ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))), ("while_loop", lambda f: toggle( f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0, lambda xs: (1, f(xs[1])), (0, x)))), # Automatic differentiation. # TODO(tomhennigan): Add missing features (e.g. custom_vjp, custom_jvp). ("grad", lambda f: stateful.grad(lambda x: f(x).sum())),