Ejemplo n.º 1
0
 def f(xs):
   m = CountingModule()
   def sf(c, x):
     self.assertEqual(c, ())
     return c, m(x)
   _, ys = stateful.scan(sf, (), xs)
   return ys
Ejemplo n.º 2
0
 def test_stateful_scan_with_rng_use(self, iteration_count):
   # TODO(lenamartens): remove when default changes to > 1.
   tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
   base.DEFAULT_PRNG_RESERVE_SIZE = 64
   def body_fun(c, x):
     for _ in range(10):
       _ = base.next_rng_key()
     return c, x
   base.reserve_rng_keys(5)
   _ = stateful.scan(body_fun, (), (), length=iteration_count)
   base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
Ejemplo n.º 3
0
 def test_scan_no_transform(self):
   xs = jnp.arange(3)
   with self.assertRaises(ValueError, msg="Use jax.scan() instead"):
     stateful.scan(lambda c, x: (c, x), (), xs)
Ejemplo n.º 4
0
 def model(x, *, allow_reuse):
     return stateful.scan(Outer(allow_reuse), (), x)
Ejemplo n.º 5
0
    # TODO(tomhennigan): Add missing features (e.g. pjit,xmap).
    # ("pmap", lambda f: stateful.pmap(f, "i")),

    # Vectorization.
    ("vmap", lambda f: stateful.vmap(f, split_rng=False)),

    # Control flow.
    # TODO(tomhennigan): Enable for associative_scan.
    # ("associative_scan", lambda f:
    #  (lambda x: jax.lax.associative_scan(f, x))),
    ("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))),
    ("fori_loop", lambda f:
     (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))),
    # ("map", lambda f: (lambda x: stateful.map(f, x))),
    ("scan", lambda f:
     (lambda x: stateful.scan(base_test.identity_carry(f), None, x))),
    ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))),
    ("while_loop", lambda f: toggle(
        f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0,
                                         lambda xs: (1, f(xs[1])),
                                         (0, x)))),

    # Automatic differentiation.
    # TODO(tomhennigan): Add missing features (e.g. custom_vjp, custom_jvp).

    ("grad", lambda f: stateful.grad(lambda x: f(x).sum())),
    ("value_and_grad", lambda f: stateful.value_and_grad(lambda x: f(x).sum())),
    ("checkpoint", stateful.remat),
)
# pylint: enable=g-long-lambda