def mapped_fun(*args): base.assert_context("vmap") mapped_pure_fun = jax.vmap(pure_fun, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, axis_size=axis_size) state = internal_state() if split_rng: # Need to take a new key and split. num = get_mapped_axis_size(args, in_axes[0]) rng = base.next_rng_keys(num) state = internal_state() # Needed since we mutated internal RNG. saved_rng = state.rng state = InternalState(state.params, state.state, rng) out, state = mapped_pure_fun(args, state) if split_rng: state = InternalState(state.params, state.state, saved_rng) update_internal_state(state) return out
def test_rngs(self): a, b = base.next_rng_keys(2) c, d = base.next_rng_keys(2) for l, r in it.permutations((a, b, c, d), 2): self.assertIsNot(l, r)
class BaseTest(parameterized.TestCase): @test_utils.transform_and_run def test_parameter_reuse(self): w1 = base.get_parameter("w", [], init=jnp.zeros) w2 = base.get_parameter("w", [], init=jnp.zeros) self.assertIs(w1, w2) def test_params(self): with base.new_context() as ctx: w = base.get_parameter("w", [], init=jnp.zeros) self.assertEqual(ctx.collect_params(), {"~": {"w": w}}) @test_utils.transform_and_run def test_naked_get_parameter(self): w1 = base.get_parameter("w", [], init=jnp.zeros) w2 = base.get_parameter("w", [], init=jnp.zeros) self.assertIs(w1, w2) def test_naked_parameter_in_tilde_collection(self): with base.new_context() as ctx: w1 = base.get_parameter("w1", [], init=jnp.zeros) w2 = base.get_parameter("w2", [], init=jnp.ones) self.assertIsNot(w1, w2) self.assertEqual(ctx.collect_params(), {"~": {"w1": w1, "w2": w2}}) @parameterized.parameters(({}, ), ({"~": {}}, )) def test_parameter_in_immutable_ctx(self, params): with base.new_context(params=params): with self.assertRaisesRegex( ValueError, "parameters must be created as part of `init`"): base.get_parameter("w", [], init=jnp.zeros) def test_get_parameter_wrong_shape(self): with base.new_context(): with self.assertRaisesRegex(ValueError, "does not match shape"): base.get_parameter("w", (1, ), init=jnp.zeros) base.get_parameter("w", (2, ), init=jnp.zeros) @parameterized.parameters(base.next_rng_key, lambda: base.next_rng_keys(1)) def test_rng_no_transform(self, f): with self.assertRaisesRegex( ValueError, "must be used as part of an `hk.transform`"): f() @test_utils.transform_and_run def test_rng(self): a = base.next_rng_key() b = base.next_rng_key() self.assertIsNot(a, b) @test_utils.transform_and_run def test_rngs(self): a, b = base.next_rng_keys(2) c, d = base.next_rng_keys(2) for l, r in it.permutations((a, b, c, d), 2): self.assertIsNot(l, r) @test_utils.transform_and_run(seed=None) def test_no_rng(self): with self.assertRaisesRegex(ValueError, "must pass a non-None PRNGKey"): base.next_rng_key() def test_invalid_rng(self): with self.assertRaisesRegex(ValueError, "not a JAX PRNGKey"): base.new_context(rng="nonsense") def test_invalid_rng_none_ignored(self): with base.new_context(rng=None): pass def test_maybe_rng_no_transform(self): with self.assertRaisesRegex( ValueError, "must be used as part of an `hk.transform`"): base.maybe_next_rng_key() @test_utils.transform_and_run(seed=None) def test_maybe_no_rng(self): self.assertIsNone(base.maybe_next_rng_key()) def test_maybe_rng_vs_not(self): """If we have an rng, then next_rng_key() == maybe_next_rng_key().""" rngs = [] maybes = [] @test_utils.transform_and_run def three(): for _ in range(3): rngs.append(base.next_rng_key()) @test_utils.transform_and_run def maybe_three(): for _ in range(3): maybes.append(base.maybe_next_rng_key()) three() maybe_three() self.assertLen(rngs, 6) self.assertTrue(jnp.all(jnp.array(rngs) == jnp.array(maybes))) @parameterized.parameters( (base.get_parameter, base.custom_creator, "collect_params"), (base.get_state, custom_state_creator, "collect_state")) def test_init_custom_creator(self, get_x, custom_x, collect_x): def zeros_creator(next_creator, shape, dtype, init, context): self.assertEqual(context.full_name, "~/w") self.assertEqual(context.module_name, "~") self.assertEqual(context.name, "w") self.assertEqual(shape, []) self.assertEqual(dtype, jnp.float32) self.assertEqual(init, jnp.ones) return next_creator(shape, dtype, jnp.zeros) with base.new_context() as ctx: with custom_x(zeros_creator): get_x("w", [], init=jnp.ones) self.assertEqual( getattr(ctx, collect_x)(), {"~": { "w": jnp.zeros([]) }}) @parameterized.parameters((base.get_parameter, base.custom_creator), (base.get_state, custom_state_creator)) def test_nested_creators(self, get_x, custom_x): log = [] def logging_creator(log_msg): def _logging_creator(next_creator, shape, dtype, init, context): del context log.append(log_msg) return next_creator(shape, dtype, init) return _logging_creator with base.new_context(): with custom_x(logging_creator("a")), \ custom_x(logging_creator("b")), \ custom_x(logging_creator("c")): get_x("w", [], init=jnp.ones) self.assertEqual(log, ["a", "b", "c"]) @parameterized.parameters((base.get_parameter, base.custom_creator, base.custom_getter, "collect_params"), (base.get_state, custom_state_creator, custom_state_getter, "collect_state")) def test_original_dtype(self, get_x, custom_create_x, custom_get_x, collect_x): def dtype_cast_creator(next_creator, shape, dtype, init, context): if context.original_dtype == jnp.bfloat16: dtype = jnp.float32 return next_creator(shape, dtype, init) def dtype_recast_getter(next_getter, value, context): if context.original_dtype == jnp.bfloat16: assert value.dtype == jnp.float32 value = value.astype(jnp.bfloat16) return next_getter(value) with base.new_context() as ctx: with custom_create_x(dtype_cast_creator), \ custom_get_x(dtype_recast_getter): value = get_x("w", [], jnp.bfloat16, jnp.ones) orig_value = jax.tree_leaves(getattr(ctx, collect_x)())[0] assert value.dtype == jnp.bfloat16 assert orig_value.dtype == jnp.float32 @parameterized.parameters((base.get_parameter, base.custom_creator), (base.get_state, custom_state_creator)) def test_original_shape(self, get_x, custom_x): def new_shape_creator(next_creator, shape, dtype, init, context): del shape del context new_shape = (1, 2, 3) return next_creator(new_shape, dtype, init) def original_shape_restorer(next_creator, shape, dtype, init, context): assert shape == (1, 2, 3) return next_creator(context.original_shape, dtype, init) with base.new_context(): with custom_x(new_shape_creator): with custom_x(original_shape_restorer): value = get_x("w", [5], jnp.bfloat16, jnp.ones) assert value.shape == (5, ) @parameterized.parameters( (base.get_parameter, base.custom_getter, "collect_params"), (base.get_state, custom_state_getter, "collect_state")) def test_custom_getter_bf16(self, get_x, custom_x, collect_x): def bf16_getter(next_getter, value, context): del context if value.dtype == jnp.float32: value = value.astype(jnp.bfloat16) return next_getter(value) with base.new_context() as ctx: with custom_x(bf16_getter): f = get_x("f", [], jnp.float32, init=jnp.ones) i = get_x("i", [], jnp.int32, init=jnp.ones) collection = getattr(ctx, collect_x)() self.assertEqual(collection["~"]["f"].dtype, jnp.float32) self.assertEqual(f.dtype, jnp.bfloat16) self.assertEqual(collection["~"]["i"].dtype, jnp.int32) self.assertEqual(i.dtype, jnp.int32) @parameterized.parameters((base.get_parameter, base.custom_getter), (base.get_state, custom_state_getter)) def test_nested_getters(self, get_x, custom_x): log = [] def logging_getter(log_msg, dtype_in, dtype_out): def _logging_getter(next_getter, value, context): del context log.append(log_msg) self.assertEqual(value.dtype, dtype_in) value = value.astype(dtype_out) return next_getter(value) return _logging_getter with base.new_context(): with custom_x(logging_getter("a", jnp.float32, jnp.bfloat16)), \ custom_x(logging_getter("b", jnp.bfloat16, jnp.int32)), \ custom_x(logging_getter("c", jnp.int32, jnp.int8)): w = get_x("w", [], init=jnp.ones) self.assertEqual(w.dtype, jnp.int8) self.assertEqual(log, ["a", "b", "c"]) @parameterized.parameters(*it.permutations([True, False], 2)) def test_creator_types(self, params, state): log = [] def logging_creator(next_creator, shape, dtype, init, context): log.append(context.full_name) return next_creator(shape, dtype, init) with base.new_context(): with base.custom_creator(logging_creator, params=params, state=state): base.get_parameter("params", [], init=jnp.zeros) base.get_state("state", [], init=jnp.zeros) self.assertLen(log, int(params) + int(state)) if params: self.assertIn("~/params", log) if state: self.assertIn("~/state", log) @parameterized.parameters(*it.permutations([True, False], 2)) def test_getter_types(self, params, state): log = [] def logging_getter(next_getter, value, context): log.append(context.full_name) return next_getter(value) with base.new_context(): with base.custom_getter(logging_getter, params=params, state=state): base.get_parameter("params", [], init=jnp.zeros) base.get_state("state", [], init=jnp.zeros) self.assertLen(log, int(params) + int(state)) if params: self.assertIn("~/params", log) if state: self.assertIn("~/state", log) @parameterized.parameters(base.custom_creator, custom_state_creator) def test_creator_requires_context(self, custom_x): def my_creator(next_creator, shape, dtype, init, context): del context return next_creator(shape, dtype, init) with self.assertRaisesRegex( ValueError, "must be used as part of an `hk.transform`"): with custom_x(my_creator): pass @parameterized.parameters(base.custom_getter, custom_state_getter) def test_getter_requires_context(self, custom_x): def my_getter(next_getter, value, context): del context return next_getter(value) with self.assertRaisesRegex( ValueError, "must be used as part of an `hk.transform`"): with custom_x(my_getter): pass def test_get_state_no_init_raises(self): with base.new_context(): with self.assertRaisesRegex(ValueError, "set an init function"): base.get_state("i") with base.new_context(state={"~": {}}): with self.assertRaisesRegex(ValueError, "set an init function"): base.get_state("i") def test_get_state_no_shape_raises(self): with base.new_context(): with self.assertRaisesRegex(ValueError, "provide shape and dtype"): base.get_state("i", init=jnp.zeros) with base.new_context(state={"~": {}}): with self.assertRaisesRegex(ValueError, "provide shape and dtype"): base.get_state("i", init=jnp.zeros) def test_set_then_get(self): with base.new_context() as ctx: base.set_state("i", 1) base.get_state("i") self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}}) for _ in range(10): with ctx: base.set_state("i", 1) y = base.get_state("i") self.assertEqual(y, 1) self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}}) def test_stateful(self): with base.new_context() as ctx: for _ in range(10): count = base.get_state("count", (), jnp.int32, jnp.zeros) base.set_state("count", count + 1) self.assertEqual(ctx.collect_initial_state(), {"~": {"count": 0}}) self.assertEqual(ctx.collect_state(), {"~": {"count": 10}}) def test_new_state_in_apply(self): with base.new_context(params={}, state={}) as ctx: base.set_state("count", 1) self.assertEqual(ctx.collect_initial_state(), {"~": {"count": 1}}) self.assertEqual(ctx.collect_state(), {"~": {"count": 1}}) @parameterized.parameters((42, True), (42, False), (28, True), (28, False)) def test_prng_sequence(self, seed, wrap_seed): # Values using our sequence. key_or_seed = jax.random.PRNGKey(seed) if wrap_seed else seed key_seq = base.PRNGSequence(key_or_seed) seq_v1 = jax.random.normal(next(key_seq), []) seq_v2 = jax.random.normal(next(key_seq), []) # Generate values using manual splitting. key = jax.random.PRNGKey(seed) key, temp_key = jax.random.split(key) raw_v1 = jax.random.normal(temp_key, []) _, temp_key = jax.random.split(key) raw_v2 = jax.random.normal(temp_key, []) self.assertEqual(raw_v1, seq_v1) self.assertEqual(raw_v2, seq_v2) def test_prng_sequence_invalid_input(self): with self.assertRaisesRegex(ValueError, "not a JAX PRNGKey"): base.PRNGSequence("nonsense") def test_prng_sequence_wrong_shape(self): with self.assertRaisesRegex( ValueError, "key did not have expected shape and/or dtype"): base.PRNGSequence(jax.random.split(jax.random.PRNGKey(42), 2)) def test_prng_reserve(self): k = jax.random.PRNGKey(42) s = base.PRNGSequence(k) s.reserve(10) hk_keys = tuple(next(s) for _ in range(10)) jax_keys = tuple(jax.random.split(k, num=11)[1:]) jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys) def test_prng_reserve_twice(self): k = jax.random.PRNGKey(42) s = base.PRNGSequence(k) s.reserve(2) s.reserve(2) hk_keys = tuple(next(s) for _ in range(4)) k, subkey1, subkey2 = tuple(jax.random.split(k, num=3)) _, subkey3, subkey4 = tuple(jax.random.split(k, num=3)) jax_keys = (subkey1, subkey2, subkey3, subkey4) jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys) def test_prng_sequence_split(self): k = jax.random.PRNGKey(42) s = base.PRNGSequence(k) hk_keys = s.take(10) jax_keys = tuple(jax.random.split(k, num=11)[1:]) jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys) @parameterized.parameters(42, 28) def test_with_rng(self, seed): ctx_key = jax.random.PRNGKey(seed * 2 + 1) key = jax.random.PRNGKey(seed) _, next_key = jax.random.split(key) expected_output = jax.random.uniform(next_key, ()) with base.new_context(rng=ctx_key): without_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item() with base.new_context(rng=ctx_key): with base.with_rng(key): with_decorator_out = jax.random.uniform( base.next_rng_key(), ()).item() self.assertNotEqual(without_decorator_out, expected_output) self.assertEqual(with_decorator_out, expected_output) def test_with_rng_no_transform(self): with self.assertRaisesRegex( ValueError, "must be used as part of an `hk.transform`"): with base.with_rng(jax.random.PRNGKey(428)): pass def test_new_context(self): with base.new_context() as ctx: pass self.assertEmpty(ctx.collect_params()) self.assertEmpty(ctx.collect_initial_state()) self.assertEmpty(ctx.collect_state()) def test_context_copies_input(self): before = {"~": {"w": jnp.array(1.)}} with base.new_context(params=before, state=before) as ctx: base.get_parameter("w", [], init=jnp.ones) base.set_state("w", jnp.array(2.)) self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.array(1.)}}) self.assertIsNot(ctx.collect_initial_state(), before) self.assertEqual(ctx.collect_initial_state(), before) self.assertEqual(ctx.collect_state(), {"~": {"w": jnp.array(2.)}}) self.assertEqual(before, {"~": {"w": jnp.array(1.)}}) def test_assert_no_new_parameters(self): with base.new_context(): base.get_parameter("w", [], init=jnp.zeros) with base.assert_no_new_parameters(): # Should not raise, "w" already exists. base.get_parameter("w", [], init=jnp.zeros) with self.assertRaisesRegex(AssertionError, "New parameters were created: .*x"): with base.assert_no_new_parameters(): # Should raise, "x" does not exist. base.get_parameter("x", [], init=jnp.zeros) def test_context_cleanup_after_error(self): with base.new_context(): with self.assertRaisesRegex(ValueError, "expected"): raise ValueError("expected") self.assertEmpty(base.frame_stack)
identity_carry = lambda f: lambda carry, x: (carry, f(x)) ignore_index = lambda f: lambda i, x: f(x) def with_rng_example(): with base.with_rng(jax.random.PRNGKey(42)): pass # Methods in Haiku that mutate internal state. SIDE_EFFECTING_FUNCTIONS = ( ("get_parameter", lambda: base.get_parameter("w", [], init=jnp.zeros)), ("get_state", lambda: base.get_state("w", [], init=jnp.zeros)), ("set_state", lambda: base.set_state("w", 1)), ("next_rng_key", base.next_rng_key), ("next_rng_keys", lambda: base.next_rng_keys(2)), ("reserve_rng_keys", lambda: base.reserve_rng_keys(2)), ("with_rng", with_rng_example), ) # JAX transforms and control flow that need to be aware of Haiku internal # state to operate unsurprisingly. # pylint: disable=g-long-lambda JAX_PURE_EXPECTING_FNS = ( # Just-in-time compilation. ("jit", jax.jit), ("make_jaxpr", jax.make_jaxpr), ("eval_shape", lambda f: (lambda x: jax.eval_shape(f, x))), # Parallelization. # TODO(tomhennigan): Add missing features (e.g. pjit,xmap).