def test_vmap_split_rng_without_default(self, require_split_rng): # Tests that when split_rng is passed explicitly the value of # require_split_rng has no impact. x = jnp.arange(2) stateful.vmap.require_split_rng = require_split_rng k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=True)(x) self.assertTrue((k1 != k2).all()) k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=False)(x) self.assertTrue((k1 == k2).all()) stateful.vmap.require_split_rng = True
def lifted_fn(x): outer_defined = Bias(name="inner") def closing_over_fn(x): return outer_defined(x) x = stateful.vmap(closing_over_fn, split_rng=False)(x) return Bias(name="inner")(x)
def test_vmap_no_split_rng(self): key_before = base.next_rng_key() f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=False) x = jnp.arange(4) k1, k2, k3, k4 = f(x) key_after = base.next_rng_key() np.testing.assert_array_equal(k1, k2) np.testing.assert_array_equal(k2, k3) np.testing.assert_array_equal(k3, k4) self.assertFalse(np.array_equal(key_before, k1)) self.assertFalse(np.array_equal(key_after, k1)) self.assertFalse(np.array_equal(key_before, key_after))
def test_vmap_split_rng_with_default(self): with self.assertRaisesRegex(TypeError, "hk.vmap.require_split_rng = False"): # Intentionally missing split_rng arg. stateful.vmap(lambda: None) with self.subTest("require_split_rng=0"): stateful.vmap.require_split_rng = False try: # This call should not trigger an error, even though we are missing the # split_rng argument which appears required (if you look at the function # signature). It only works because require_split_rng is # propagated to vmap via a sneaky decorator. This only exists to support # users who import code that they cannot edit (e.g. from a read only # file system) that is not passing the argument. f = stateful.vmap(base.next_rng_key, axis_size=2) finally: stateful.vmap.require_split_rng = True # Check that split_rng=False was implied. k1, k2 = f() self.assertTrue((k1 == k2).all())
def test_vmap_split_rng(self): key_before = base.next_rng_key() f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=True) x = jnp.arange(4) k1, k2, k3, k4 = f(x) key_after = base.next_rng_key() # Test that none of the keys are equal. named_keys = (("k1", k1), ("k2", k2), ("k3", k3), ("k4", k4), ("key_before", key_before), ("key_after", key_after)) for (a_name, a), (b_name, b) in it.combinations(named_keys, 2): self.assertFalse( np.array_equal(a, b), msg=f"Keys should not be equal, but {a_name} == {b_name}")
def f(x): return stateful.vmap(g)(x)
def test_vmap_in_axes_different_size(self): x = jnp.ones([1, 2]) with self.assertRaisesRegex( ValueError, "vmap got inconsistent sizes for array axes to be mapped"): stateful.vmap(lambda a, b: None, in_axes=(0, 1), split_rng=False)(x, x)
def test_vmap_no_in_axes(self): def fn_name(_): pass with self.assertRaisesRegex( ValueError, "fn_name must have at least one non-None value in in_axes"): stateful.vmap(fn_name, in_axes=None, split_rng=False)
def test_vmap_must_be_called_in_transform(self): f = stateful.vmap(lambda x: x, split_rng=False) with self.assertRaisesRegex(ValueError, "must be used as part of an.*hk.transform"): f(0)
def f(x): return stateful.vmap(g, split_rng=False)(x)
# state to operate unsurprisingly. # pylint: disable=g-long-lambda HK_OVERLOADED_JAX_PURE_EXPECTING_FNS = ( # Just-in-time compilation. ("jit", stateful.jit), # ("make_jaxpr", stateful.make_jaxpr), ("eval_shape", lambda f: (lambda x: [f(x), stateful.eval_shape(f, x)])), ("named_call", stateful.named_call), # Parallelization. # TODO(tomhennigan): Add missing features (e.g. pjit,xmap). # ("pmap", lambda f: stateful.pmap(f, "i")), # Vectorization. ("vmap", lambda f: stateful.vmap(f, split_rng=False)), # Control flow. # TODO(tomhennigan): Enable for associative_scan. # ("associative_scan", lambda f: # (lambda x: jax.lax.associative_scan(f, x))), ("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))), ("fori_loop", lambda f: (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))), # ("map", lambda f: (lambda x: stateful.map(f, x))), ("scan", lambda f: (lambda x: stateful.scan(base_test.identity_carry(f), None, x))), ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))), ("while_loop", lambda f: toggle( f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0, lambda xs: (1, f(xs[1])),