コード例 #1
0
  def test_switch_traces_cases_with_same_id_once(self, n):
    f_witness = []
    g_witness = []

    def f(x):
      f_witness.append(None)
      return x ** 2

    def g(x):
      g_witness.append(None)
      return x ** 2

    stateful.switch(0, [f, g] * n, 2)
    f_hk_call_count = len(f_witness)
    g_hk_call_count = len(g_witness)
    self.assertEqual(f_hk_call_count, 1)
    self.assertEqual(g_hk_call_count, 1)

    # Ensure we are in sync with JAX.
    del f_witness[:], g_witness[:]
    jax.lax.switch(0, [f, g] * n, 2)
    f_jax_call_count = len(f_witness)
    g_jax_call_count = len(g_witness)
    self.assertEqual(f_hk_call_count, f_jax_call_count)
    self.assertEqual(f_hk_call_count, g_jax_call_count)
コード例 #2
0
  def test_stateful_switch_with_rng_use(self):
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    # Test if using different amount of keys in different branches
    # results in error
    def branch_f(i):
      for _ in range(i):
        _ = base.next_rng_key()
      return i

    base.reserve_rng_keys(5)
    branches = [lambda _, i=i: branch_f(i) for i in range(5)]
    self.assertEqual(stateful.switch(3, branches, None), 3)
    self.assertEqual(stateful.switch(0, branches, None), 0)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
コード例 #3
0
 def test_switch_no_transform(self):
   i = jnp.array(2)
   x = jnp.array(42.)
   with self.assertRaises(ValueError, msg="Use jax.switch() instead"):
     stateful.switch(i, [jnp.square] * 3, x)
コード例 #4
0
 def f(i, x):
   mod = SquareModule()
   branches = [mod, lambda x: mod(x + 1), lambda x: mod(x + 2)]
   return stateful.switch(i, branches, x)
コード例 #5
0
ファイル: stateful_test.py プロジェクト: deepmind/dm-haiku
    # ("pmap", lambda f: stateful.pmap(f, "i")),

    # Vectorization.
    ("vmap", lambda f: stateful.vmap(f, split_rng=False)),

    # Control flow.
    # TODO(tomhennigan): Enable for associative_scan.
    # ("associative_scan", lambda f:
    #  (lambda x: jax.lax.associative_scan(f, x))),
    ("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))),
    ("fori_loop", lambda f:
     (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))),
    # ("map", lambda f: (lambda x: stateful.map(f, x))),
    ("scan", lambda f:
     (lambda x: stateful.scan(base_test.identity_carry(f), None, x))),
    ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))),
    ("while_loop", lambda f: toggle(
        f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0,
                                         lambda xs: (1, f(xs[1])),
                                         (0, x)))),

    # Automatic differentiation.
    # TODO(tomhennigan): Add missing features (e.g. custom_vjp, custom_jvp).

    ("grad", lambda f: stateful.grad(lambda x: f(x).sum())),
    ("value_and_grad", lambda f: stateful.value_and_grad(lambda x: f(x).sum())),
    ("checkpoint", stateful.remat),
)
# pylint: enable=g-long-lambda