def f(xs): m = CountingModule() def sf(c, x): self.assertEqual(c, ()) return c, m(x) _, ys = stateful.scan(sf, (), xs) return ys
def test_stateful_scan_with_rng_use(self, iteration_count): # TODO(lenamartens): remove when default changes to > 1. tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 def body_fun(c, x): for _ in range(10): _ = base.next_rng_key() return c, x base.reserve_rng_keys(5) _ = stateful.scan(body_fun, (), (), length=iteration_count) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
def test_scan_no_transform(self): xs = jnp.arange(3) with self.assertRaises(ValueError, msg="Use jax.scan() instead"): stateful.scan(lambda c, x: (c, x), (), xs)
def model(x, *, allow_reuse): return stateful.scan(Outer(allow_reuse), (), x)
# TODO(tomhennigan): Add missing features (e.g. pjit,xmap). # ("pmap", lambda f: stateful.pmap(f, "i")), # Vectorization. ("vmap", lambda f: stateful.vmap(f, split_rng=False)), # Control flow. # TODO(tomhennigan): Enable for associative_scan. # ("associative_scan", lambda f: # (lambda x: jax.lax.associative_scan(f, x))), ("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))), ("fori_loop", lambda f: (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))), # ("map", lambda f: (lambda x: stateful.map(f, x))), ("scan", lambda f: (lambda x: stateful.scan(base_test.identity_carry(f), None, x))), ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))), ("while_loop", lambda f: toggle( f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0, lambda xs: (1, f(xs[1])), (0, x)))), # Automatic differentiation. # TODO(tomhennigan): Add missing features (e.g. custom_vjp, custom_jvp). ("grad", lambda f: stateful.grad(lambda x: f(x).sum())), ("value_and_grad", lambda f: stateful.value_and_grad(lambda x: f(x).sum())), ("checkpoint", stateful.remat), ) # pylint: enable=g-long-lambda