Beispiel #1
0
 def test_vmap_split_rng_without_default(self, require_split_rng):
   # Tests that when split_rng is passed explicitly the value of
   # require_split_rng has no impact.
   x = jnp.arange(2)
   stateful.vmap.require_split_rng = require_split_rng
   k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=True)(x)
   self.assertTrue((k1 != k2).all())
   k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=False)(x)
   self.assertTrue((k1 == k2).all())
   stateful.vmap.require_split_rng = True
Beispiel #2
0
            def lifted_fn(x):
                outer_defined = Bias(name="inner")

                def closing_over_fn(x):
                    return outer_defined(x)

                x = stateful.vmap(closing_over_fn, split_rng=False)(x)
                return Bias(name="inner")(x)
Beispiel #3
0
 def test_vmap_no_split_rng(self):
   key_before = base.next_rng_key()
   f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=False)
   x = jnp.arange(4)
   k1, k2, k3, k4 = f(x)
   key_after = base.next_rng_key()
   np.testing.assert_array_equal(k1, k2)
   np.testing.assert_array_equal(k2, k3)
   np.testing.assert_array_equal(k3, k4)
   self.assertFalse(np.array_equal(key_before, k1))
   self.assertFalse(np.array_equal(key_after, k1))
   self.assertFalse(np.array_equal(key_before, key_after))
Beispiel #4
0
  def test_vmap_split_rng_with_default(self):
    with self.assertRaisesRegex(TypeError,
                                "hk.vmap.require_split_rng = False"):
      # Intentionally missing split_rng arg.
      stateful.vmap(lambda: None)

    with self.subTest("require_split_rng=0"):
      stateful.vmap.require_split_rng = False
      try:
        # This call should not trigger an error, even though we are missing the
        # split_rng argument which appears required (if you look at the function
        # signature). It only works because require_split_rng is
        # propagated to vmap via a sneaky decorator. This only exists to support
        # users who import code that they cannot edit (e.g. from a read only
        # file system) that is not passing the argument.
        f = stateful.vmap(base.next_rng_key, axis_size=2)
      finally:
        stateful.vmap.require_split_rng = True

    # Check that split_rng=False was implied.
    k1, k2 = f()
    self.assertTrue((k1 == k2).all())
Beispiel #5
0
 def test_vmap_split_rng(self):
   key_before = base.next_rng_key()
   f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=True)
   x = jnp.arange(4)
   k1, k2, k3, k4 = f(x)
   key_after = base.next_rng_key()
   # Test that none of the keys are equal.
   named_keys = (("k1", k1), ("k2", k2), ("k3", k3), ("k4", k4),
                 ("key_before", key_before), ("key_after", key_after))
   for (a_name, a), (b_name, b) in it.combinations(named_keys, 2):
     self.assertFalse(
         np.array_equal(a, b),
         msg=f"Keys should not be equal, but {a_name} == {b_name}")
Beispiel #6
0
 def f(x):
   return stateful.vmap(g)(x)
Beispiel #7
0
 def test_vmap_in_axes_different_size(self):
   x = jnp.ones([1, 2])
   with self.assertRaisesRegex(
       ValueError, "vmap got inconsistent sizes for array axes to be mapped"):
     stateful.vmap(lambda a, b: None, in_axes=(0, 1), split_rng=False)(x, x)
Beispiel #8
0
 def test_vmap_no_in_axes(self):
   def fn_name(_):
     pass
   with self.assertRaisesRegex(
       ValueError, "fn_name must have at least one non-None value in in_axes"):
     stateful.vmap(fn_name, in_axes=None, split_rng=False)
Beispiel #9
0
 def test_vmap_must_be_called_in_transform(self):
   f = stateful.vmap(lambda x: x, split_rng=False)
   with self.assertRaisesRegex(ValueError,
                               "must be used as part of an.*hk.transform"):
     f(0)
Beispiel #10
0
 def f(x):
   return stateful.vmap(g, split_rng=False)(x)
Beispiel #11
0
# state to operate unsurprisingly.
# pylint: disable=g-long-lambda
HK_OVERLOADED_JAX_PURE_EXPECTING_FNS = (
    # Just-in-time compilation.
    ("jit", stateful.jit),

    # ("make_jaxpr", stateful.make_jaxpr),
    ("eval_shape", lambda f: (lambda x: [f(x), stateful.eval_shape(f, x)])),
    ("named_call", stateful.named_call),

    # Parallelization.
    # TODO(tomhennigan): Add missing features (e.g. pjit,xmap).
    # ("pmap", lambda f: stateful.pmap(f, "i")),

    # Vectorization.
    ("vmap", lambda f: stateful.vmap(f, split_rng=False)),

    # Control flow.
    # TODO(tomhennigan): Enable for associative_scan.
    # ("associative_scan", lambda f:
    #  (lambda x: jax.lax.associative_scan(f, x))),
    ("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))),
    ("fori_loop", lambda f:
     (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))),
    # ("map", lambda f: (lambda x: stateful.map(f, x))),
    ("scan", lambda f:
     (lambda x: stateful.scan(base_test.identity_carry(f), None, x))),
    ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))),
    ("while_loop", lambda f: toggle(
        f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0,
                                         lambda xs: (1, f(xs[1])),