def test_switch_traces_cases_with_same_id_once(self, n): f_witness = [] g_witness = [] def f(x): f_witness.append(None) return x ** 2 def g(x): g_witness.append(None) return x ** 2 stateful.switch(0, [f, g] * n, 2) f_hk_call_count = len(f_witness) g_hk_call_count = len(g_witness) self.assertEqual(f_hk_call_count, 1) self.assertEqual(g_hk_call_count, 1) # Ensure we are in sync with JAX. del f_witness[:], g_witness[:] jax.lax.switch(0, [f, g] * n, 2) f_jax_call_count = len(f_witness) g_jax_call_count = len(g_witness) self.assertEqual(f_hk_call_count, f_jax_call_count) self.assertEqual(f_hk_call_count, g_jax_call_count)
def test_stateful_switch_with_rng_use(self): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 # Test if using different amount of keys in different branches # results in error def branch_f(i): for _ in range(i): _ = base.next_rng_key() return i base.reserve_rng_keys(5) branches = [lambda _, i=i: branch_f(i) for i in range(5)] self.assertEqual(stateful.switch(3, branches, None), 3) self.assertEqual(stateful.switch(0, branches, None), 0) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
def test_switch_no_transform(self): i = jnp.array(2) x = jnp.array(42.) with self.assertRaises(ValueError, msg="Use jax.switch() instead"): stateful.switch(i, [jnp.square] * 3, x)
def f(i, x): mod = SquareModule() branches = [mod, lambda x: mod(x + 1), lambda x: mod(x + 2)] return stateful.switch(i, branches, x)
# ("pmap", lambda f: stateful.pmap(f, "i")), # Vectorization. ("vmap", lambda f: stateful.vmap(f, split_rng=False)), # Control flow. # TODO(tomhennigan): Enable for associative_scan. # ("associative_scan", lambda f: # (lambda x: jax.lax.associative_scan(f, x))), ("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))), ("fori_loop", lambda f: (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))), # ("map", lambda f: (lambda x: stateful.map(f, x))), ("scan", lambda f: (lambda x: stateful.scan(base_test.identity_carry(f), None, x))), ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))), ("while_loop", lambda f: toggle( f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0, lambda xs: (1, f(xs[1])), (0, x)))), # Automatic differentiation. # TODO(tomhennigan): Add missing features (e.g. custom_vjp, custom_jvp). ("grad", lambda f: stateful.grad(lambda x: f(x).sum())), ("value_and_grad", lambda f: stateful.value_and_grad(lambda x: f(x).sum())), ("checkpoint", stateful.remat), ) # pylint: enable=g-long-lambda