Exemple #1
0
 def test_stateful_fori_with_rng_use(self, iteration_count):
   tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
   base.DEFAULT_PRNG_RESERVE_SIZE = 64
   def body_fun(_, x):
     for _ in range(10):
       _ = base.next_rng_key()
     return x
   base.reserve_rng_keys(5)
   _ = stateful.fori_loop(0, iteration_count, body_fun, 1)
   base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
Exemple #2
0
 def test_stateful_scan_with_rng_use(self, iteration_count):
   # TODO(lenamartens): remove when default changes to > 1.
   tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
   base.DEFAULT_PRNG_RESERVE_SIZE = 64
   def body_fun(c, x):
     for _ in range(10):
       _ = base.next_rng_key()
     return c, x
   base.reserve_rng_keys(5)
   _ = stateful.scan(body_fun, (), (), length=iteration_count)
   base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
Exemple #3
0
  def test_stateful_switch_with_rng_use(self):
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    # Test if using different amount of keys in different branches
    # results in error
    def branch_f(i):
      for _ in range(i):
        _ = base.next_rng_key()
      return i

    base.reserve_rng_keys(5)
    branches = [lambda _, i=i: branch_f(i) for i in range(5)]
    self.assertEqual(stateful.switch(3, branches, None), 3)
    self.assertEqual(stateful.switch(0, branches, None), 0)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
Exemple #4
0
  def test_stateful_cond_with_rng_use(self):
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    # Test if using different amount of keys in different branches
    # results in error
    def true_branch(x):
      _ = base.next_rng_key()
      return x

    def false_branch(x):
      _ = base.next_rng_key()
      _ = base.next_rng_key()
      return x

    base.reserve_rng_keys(5)
    _ = stateful.cond(True, true_branch, false_branch, 0)
    _ = stateful.cond(False, true_branch, false_branch, 0)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
Exemple #5
0
ignore_index = lambda f: lambda i, x: f(x)


def with_rng_example():
    with base.with_rng(jax.random.PRNGKey(42)):
        pass


# Methods in Haiku that mutate internal state.
SIDE_EFFECTING_FUNCTIONS = (
    ("get_parameter", lambda: base.get_parameter("w", [], init=jnp.zeros)),
    ("get_state", lambda: base.get_state("w", [], init=jnp.zeros)),
    ("set_state", lambda: base.set_state("w", 1)),
    ("next_rng_key", base.next_rng_key),
    ("next_rng_keys", lambda: base.next_rng_keys(2)),
    ("reserve_rng_keys", lambda: base.reserve_rng_keys(2)),
    ("with_rng", with_rng_example),
)

# JAX transforms and control flow that need to be aware of Haiku internal
# state to operate unsurprisingly.
# pylint: disable=g-long-lambda
JAX_PURE_EXPECTING_FNS = (
    # Just-in-time compilation.
    ("jit", jax.jit),
    ("make_jaxpr", jax.make_jaxpr),
    ("eval_shape", lambda f: (lambda x: jax.eval_shape(f, x))),

    # Parallelization.
    # TODO(tomhennigan): Add missing features (e.g. pjit,xmap).
    ("pmap", lambda f: jax.pmap(f, "i")),