Example #1
0
  def test_cond_traces_branches_with_same_id_once(self):
    witness = []
    def f(x):
      witness.append(None)
      return x ** 2

    stateful.cond(False, f, f, 0)
    hk_call_count = len(witness)
    self.assertEqual(hk_call_count, 1)

    # Ensure we are in sync with JAX.
    del witness[:]
    jax.lax.cond(False, f, f, 0)
    jax_call_count = len(witness)
    self.assertEqual(hk_call_count, jax_call_count)
Example #2
0
 def test_cond_two_args(self):
   a, b = stateful.cond(True,
                        lambda a, b: (b, a),
                        lambda a, b: (a, b),
                        2, 1)
   self.assertEqual(a, 1)
   self.assertEqual(b, 2)
Example #3
0
 def test_cond_three_args(self):
   a, b, c = stateful.cond(True,
                           lambda a, b, c: (c, b, a),
                           lambda a, b, c: (a, b, c),
                           3, 2, 1)
   self.assertEqual(a, 1)
   self.assertEqual(b, 2)
   self.assertEqual(c, 3)
Example #4
0
  def test_stateful_cond_with_rng_use(self):
    tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE
    base.DEFAULT_PRNG_RESERVE_SIZE = 64
    # Test if using different amount of keys in different branches
    # results in error
    def true_branch(x):
      _ = base.next_rng_key()
      return x

    def false_branch(x):
      _ = base.next_rng_key()
      _ = base.next_rng_key()
      return x

    base.reserve_rng_keys(5)
    _ = stateful.cond(True, true_branch, false_branch, 0)
    _ = stateful.cond(False, true_branch, false_branch, 0)
    base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default
Example #5
0
 def test_cond_no_transform(self):
   x = jnp.array(3.)
   with self.assertRaises(ValueError, msg="Use jax.cond() instead"):
     stateful.cond(x == 2, x, lambda x: x**2, x, lambda x: (x + 1)**2)
Example #6
0
 def f(x):
   mod = SquareModule()
   return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1))
Example #7
0
 def f(x):
   mod = SquareModule()
   if single_arg:
     return stateful.cond(x == 2, mod, lambda x: mod(x + 1), x)
   else:
     return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1))
Example #8
0
 def test_cond_operand_kwarg_and_operands(self):
   with self.assertRaisesRegex(ValueError, "cannot.*pass.*positionally"):
     stateful.cond(True, lambda x: x + 5, lambda x: x + 4, 1, operand=1)
Example #9
0
 def test_cond_operand_kwarg(self):
   x = stateful.cond(True, lambda x: x + 5, lambda x: x + 4, operand=1)
   self.assertEqual(x, 6)
Example #10
0
 def test_cond_no_args(self):
   x = stateful.cond(True, lambda: 5, lambda: 4)
   self.assertEqual(x, 5)
Example #11
0
    # ("make_jaxpr", stateful.make_jaxpr),
    ("eval_shape", lambda f: (lambda x: [f(x), stateful.eval_shape(f, x)])),
    ("named_call", stateful.named_call),

    # Parallelization.
    # TODO(tomhennigan): Add missing features (e.g. pjit,xmap).
    # ("pmap", lambda f: stateful.pmap(f, "i")),

    # Vectorization.
    ("vmap", lambda f: stateful.vmap(f, split_rng=False)),

    # Control flow.
    # TODO(tomhennigan): Enable for associative_scan.
    # ("associative_scan", lambda f:
    #  (lambda x: jax.lax.associative_scan(f, x))),
    ("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))),
    ("fori_loop", lambda f:
     (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))),
    # ("map", lambda f: (lambda x: stateful.map(f, x))),
    ("scan", lambda f:
     (lambda x: stateful.scan(base_test.identity_carry(f), None, x))),
    ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))),
    ("while_loop", lambda f: toggle(
        f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0,
                                         lambda xs: (1, f(xs[1])),
                                         (0, x)))),

    # Automatic differentiation.
    # TODO(tomhennigan): Add missing features (e.g. custom_vjp, custom_jvp).

    ("grad", lambda f: stateful.grad(lambda x: f(x).sum())),