def test_stateful_fori_with_rng_use(self, iteration_count): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 def body_fun(_, x): for _ in range(10): _ = base.next_rng_key() return x base.reserve_rng_keys(5) _ = stateful.fori_loop(0, iteration_count, body_fun, 1) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
def test_stateful_scan_with_rng_use(self, iteration_count): # TODO(lenamartens): remove when default changes to > 1. tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 def body_fun(c, x): for _ in range(10): _ = base.next_rng_key() return c, x base.reserve_rng_keys(5) _ = stateful.scan(body_fun, (), (), length=iteration_count) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
def test_stateful_switch_with_rng_use(self): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 # Test if using different amount of keys in different branches # results in error def branch_f(i): for _ in range(i): _ = base.next_rng_key() return i base.reserve_rng_keys(5) branches = [lambda _, i=i: branch_f(i) for i in range(5)] self.assertEqual(stateful.switch(3, branches, None), 3) self.assertEqual(stateful.switch(0, branches, None), 0) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
def test_stateful_cond_with_rng_use(self): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 # Test if using different amount of keys in different branches # results in error def true_branch(x): _ = base.next_rng_key() return x def false_branch(x): _ = base.next_rng_key() _ = base.next_rng_key() return x base.reserve_rng_keys(5) _ = stateful.cond(True, true_branch, false_branch, 0) _ = stateful.cond(False, true_branch, false_branch, 0) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
ignore_index = lambda f: lambda i, x: f(x) def with_rng_example(): with base.with_rng(jax.random.PRNGKey(42)): pass # Methods in Haiku that mutate internal state. SIDE_EFFECTING_FUNCTIONS = ( ("get_parameter", lambda: base.get_parameter("w", [], init=jnp.zeros)), ("get_state", lambda: base.get_state("w", [], init=jnp.zeros)), ("set_state", lambda: base.set_state("w", 1)), ("next_rng_key", base.next_rng_key), ("next_rng_keys", lambda: base.next_rng_keys(2)), ("reserve_rng_keys", lambda: base.reserve_rng_keys(2)), ("with_rng", with_rng_example), ) # JAX transforms and control flow that need to be aware of Haiku internal # state to operate unsurprisingly. # pylint: disable=g-long-lambda JAX_PURE_EXPECTING_FNS = ( # Just-in-time compilation. ("jit", jax.jit), ("make_jaxpr", jax.make_jaxpr), ("eval_shape", lambda f: (lambda x: jax.eval_shape(f, x))), # Parallelization. # TODO(tomhennigan): Add missing features (e.g. pjit,xmap). ("pmap", lambda f: jax.pmap(f, "i")),