def test_context_reuse_same_instance(self): params = {"parent_module/~/child_module": {"w": jnp.array(2.)}, "parent_module/~/child_module_1": {"w": jnp.array(3.)}, "parent_module_1/~/child_module": {"w": jnp.array(4.)}, "parent_module_1/~/child_module_1": {"w": jnp.array(5.)}} with base.new_context(params=params) as ctx: mod1 = ParentModule() mod2 = ParentModule() self.assertEqual(mod1.module_name, "parent_module") self.assertEqual(mod2.module_name, "parent_module_1") for parent, (c1, c2) in ((mod1, (2., 3.)), (mod2, (4., 5.))): self.assertEqual(parent.child1(), c1) self.assertEqual(parent.child2(), c2) with ctx: for parent, (c1, c2) in ((mod1, (2., 3.)), (mod2, (4., 5.))): self.assertEqual(parent.child1(), c1) self.assertEqual(parent.child2(), c2) # Creating a new context should not be a problem. with base.new_context(params=ctx.collect_params()) as ctx: mod1 = ParentModule() mod2 = ParentModule() self.assertEqual(mod1.module_name, "parent_module") self.assertEqual(mod2.module_name, "parent_module_1") for parent, (c1, c2) in ((mod1, (2., 3.)), (mod2, (4., 5.))): self.assertEqual(parent.child1(), c1) self.assertEqual(parent.child2(), c2)
def test_get_state_no_init_raises(self): with base.new_context(): with self.assertRaisesRegex(ValueError, "set an init function"): base.get_state("i") with base.new_context(state={"~": {}}): with self.assertRaisesRegex(ValueError, "set an init function"): base.get_state("i")
def test_get_state_no_shape_raises(self): with base.new_context(): with self.assertRaisesRegex(ValueError, "provide shape and dtype"): base.get_state("i", init=jnp.zeros) with base.new_context(state={"~": {}}): with self.assertRaisesRegex(ValueError, "provide shape and dtype"): base.get_state("i", init=jnp.zeros)
def test_with_rng(self, seed): ctx_key = jax.random.PRNGKey(seed * 2 + 1) key = jax.random.PRNGKey(seed) _, next_key = jax.random.split(key) expected_output = jax.random.uniform(next_key, ()) with base.new_context(rng=ctx_key): without_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item() with base.new_context(rng=ctx_key): with base.with_rng(key): with_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item() self.assertNotEqual(without_decorator_out, expected_output) self.assertEqual(with_decorator_out, expected_output)
def test_do_not_store(self): def my_creator(next_creator, shape, dtype, init, context): del next_creator, shape, dtype, init, context return base.DO_NOT_STORE def my_getter(next_getter, value, context): assert value is base.DO_NOT_STORE return next_getter( context.original_init(context.original_shape, context.original_dtype)) def my_setter(next_setter, value, context): del next_setter, value, context return base.DO_NOT_STORE with base.new_context() as ctx: with base.custom_creator(my_creator, state=True), \ base.custom_getter(my_getter, state=True), \ base.custom_setter(my_setter): self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1) self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1) base.set_state("s2", jnp.ones([])) self.assertEmpty(ctx.collect_params()) self.assertEmpty(ctx.collect_state())
def test_setter_tree(self): witness = [] x = {"a": jnp.ones([]), "b": jnp.zeros([123])} y = jax.tree_map(lambda x: x + 1, x) def my_setter(next_setter, value, ctx): self.assertIs(value, x) self.assertEqual(ctx.original_shape, {"a": (), "b": (123, )}) self.assertEqual(ctx.original_dtype, { "a": jnp.float32, "b": jnp.float32 }) self.assertEqual(ctx.full_name, "~/x") self.assertEqual(ctx.name, "x") self.assertIsNone(ctx.module) witness.append(None) del next_setter return y with base.new_context(): with base.custom_setter(my_setter): base.set_state("x", x) x = base.get_state("x") self.assertIs(x, y) self.assertNotEmpty(witness)
def test_unable_to_mutate_name(self): def mutates_name(next_creator, name, shape, dtype, init): next_creator(name + "_foo", shape, dtype, init) with base.new_context(), base.custom_creator(mutates_name): with self.assertRaisesRegex(ValueError, "Modifying .*name.* not supported"): base.get_parameter("w", [], init=jnp.ones)
def test_stateful(self): with base.new_context() as ctx: for _ in range(10): count = base.get_state("count", (), jnp.int32, jnp.zeros) base.set_state("count", count + 1) self.assertEqual(ctx.collect_initial_state(), {"~": {"count": 0}}) self.assertEqual(ctx.collect_state(), {"~": {"count": 10}})
def test_context_copies_input(self): before = {"~": {"w": jnp.array(1.)}} with base.new_context(params=before, state=before) as ctx: base.get_parameter("w", [], init=jnp.ones) base.set_state("w", jnp.array(2.)) self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.array(1.)}}) self.assertIsNot(ctx.collect_initial_state(), before) self.assertEqual(ctx.collect_initial_state(), before) self.assertEqual(ctx.collect_state(), {"~": {"w": jnp.array(2.)}}) self.assertEqual(before, {"~": {"w": jnp.array(1.)}})
def init_fn( rng: Optional[Union[PRNGKey, PRNGSeed]], *args, **kwargs, ) -> Tuple[Params, State]: """Initializes your function collecting parameters and state.""" rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR) with base.new_context(rng=rng) as ctx: f(*args, **kwargs) return ctx.collect_params(), ctx.collect_initial_state()
def init( self, rng: tp.Optional[tp.Union[jnp.ndarray, int]], *args, **kwargs, ) -> tp.Tuple[tp.Any, haiku.Params, haiku.State]: """Initializes your function collecting parameters and state.""" rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR) with new_context(rng=rng) as ctx: output = self.f(*args, **kwargs) return output, ctx.collect_params(), ctx.collect_initial_state()
def test_assert_no_new_parameters(self): with base.new_context(): base.get_parameter("w", [], init=jnp.zeros) with base.assert_no_new_parameters(): # Should not raise, "w" already exists. base.get_parameter("w", [], init=jnp.zeros) with self.assertRaisesRegex(AssertionError, "New parameters were created: .*x"): with base.assert_no_new_parameters(): # Should raise, "x" does not exist. base.get_parameter("x", [], init=jnp.zeros)
def apply_fn( params: Params, state: State, rng: Optional[Union[PRNGKey, PRNGSeed]], *args, **kwargs, ) -> Tuple[Any, State]: """Applies your function injecting parameters and state.""" rng = to_prng_sequence( rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR)) with base.new_context(params=params, state=state, rng=rng) as ctx: out = f(*args, **kwargs) return out, ctx.collect_state()
def test_init_custom_creator(self, get_x, custom_x, collect_x): def zeros_creator(next_creator, shape, dtype, init, context): self.assertEqual(context.full_name, "~/w") self.assertEqual(shape, []) self.assertEqual(dtype, jnp.float32) self.assertEqual(init, jnp.ones) return next_creator(shape, dtype, jnp.zeros) with base.new_context() as ctx: with custom_x(zeros_creator): get_x("w", [], init=jnp.ones) self.assertEqual(getattr(ctx, collect_x)(), {"~": {"w": jnp.zeros([])}})
def test_set_then_get(self): with base.new_context() as ctx: base.set_state("i", 1) base.get_state("i") self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}}) for _ in range(10): with ctx: base.set_state("i", 1) y = base.get_state("i") self.assertEqual(y, 1) self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}})
def test_init_custom_creator(self): def zeros_creator(next_creator, name, shape, dtype, init): self.assertEqual(name, "~/w") self.assertEqual(shape, []) self.assertEqual(dtype, jnp.float32) self.assertEqual(init, jnp.ones) return next_creator(name, shape, dtype, jnp.zeros) with base.new_context() as ctx: with base.custom_creator(zeros_creator): base.get_parameter("w", [], init=jnp.ones) self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.zeros([])}})
def init_fn( rng: Optional[Union[PRNGKey, int]], *args, **kwargs, ) -> Tuple[hk.Params, hk.State]: """Initializes your function collecting parameters and state.""" rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR) with base.new_context(rng=rng) as ctx: try: f(*args, **kwargs) except jax.errors.UnexpectedTracerError as e: raise jax.errors.UnexpectedTracerError( unexpected_tracer_hint) from e return ctx.collect_params(), ctx.collect_initial_state()
def apply_fn( params: Optional[hk.Params], state: Optional[hk.State], rng: Optional[Union[PRNGKey, int]], *args, **kwargs, ) -> Tuple[Any, hk.State]: """Applies your function injecting parameters and state.""" params = check_mapping("params", params) state = check_mapping("state", state) rng = to_prng_sequence( rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR)) with base.new_context(params=params, state=state, rng=rng) as ctx: out = f(*args, **kwargs) return out, ctx.collect_state()
def test_original_shape(self, get_x, custom_x): def new_shape_creator(next_creator, shape, dtype, init, context): del shape del context new_shape = (1, 2, 3) return next_creator(new_shape, dtype, init) def original_shape_restorer(next_creator, shape, dtype, init, context): assert shape == (1, 2, 3) return next_creator(context.original_shape, dtype, init) with base.new_context(): with custom_x(new_shape_creator): with custom_x(original_shape_restorer): value = get_x("w", [5], jnp.bfloat16, jnp.ones) assert value.shape == (5, )
def test_getter_types(self, params, state): log = [] def logging_getter(next_getter, value, context): log.append(context.full_name) return next_getter(value) with base.new_context(): with base.custom_getter(logging_getter, params=params, state=state): base.get_parameter("params", [], init=jnp.zeros) base.get_state("state", [], init=jnp.zeros) self.assertLen(log, int(params) + int(state)) if params: self.assertIn("~/params", log) if state: self.assertIn("~/state", log)
def test_custom_getter_bf16(self, get_x, custom_x, collect_x): def bf16_getter(next_getter, value, context): del context if value.dtype == jnp.float32: value = value.astype(jnp.bfloat16) return next_getter(value) with base.new_context() as ctx: with custom_x(bf16_getter): f = get_x("f", [], jnp.float32, init=jnp.ones) i = get_x("i", [], jnp.int32, init=jnp.ones) collection = getattr(ctx, collect_x)() self.assertEqual(collection["~"]["f"].dtype, jnp.float32) self.assertEqual(f.dtype, jnp.bfloat16) self.assertEqual(collection["~"]["i"].dtype, jnp.int32) self.assertEqual(i.dtype, jnp.int32)
def test_nested_creators(self): log = [] def logging_creator(log_msg): def _logging_creator(next_creator, name, shape, dtype, init): log.append(log_msg) return next_creator(name, shape, dtype, init) return _logging_creator with base.new_context(): with base.custom_creator(logging_creator("a")), \ base.custom_creator(logging_creator("b")), \ base.custom_creator(logging_creator("c")): base.get_parameter("w", [], init=jnp.ones) self.assertEqual(log, ["a", "b", "c"])
def test_original_shape(self): def new_shape_creator(next_creator, shape, dtype, init, context): del shape del context new_shape = (1, 2, 3) return next_creator(new_shape, dtype, init) def original_shape_restorer(next_creator, shape, dtype, init, context): assert shape == (1, 2, 3) return next_creator(context.original_shape, dtype, init) with base.new_context(): with base.custom_creator(new_shape_creator): with base.custom_creator(original_shape_restorer): param = base.get_parameter("w", [5], jnp.bfloat16, jnp.ones) assert param.shape == (5, )
def apply( self, params: tp.Optional[haiku.Params], state: tp.Optional[haiku.State], rng: tp.Optional[tp.Union[jnp.ndarray, int]], *args, **kwargs, ) -> tp.Tuple[tp.Any, haiku.State]: """Applies your function injecting parameters and state.""" params = check_mapping("params", params) state = check_mapping("state", state) rng = to_prng_sequence( rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR) ) with new_context(params=params, state=state, rng=rng) as ctx: out = self.f(*args, **kwargs) return out, ctx.collect_state()
def test_nested_creators(self, get_x, custom_x): log = [] def logging_creator(log_msg): def _logging_creator(next_creator, shape, dtype, init, context): del context log.append(log_msg) return next_creator(shape, dtype, init) return _logging_creator with base.new_context(): with custom_x(logging_creator("a")), \ custom_x(logging_creator("b")), \ custom_x(logging_creator("c")): get_x("w", [], init=jnp.ones) self.assertEqual(log, ["a", "b", "c"])
def test_custom_getter_bf16(self): def bf16_getter(next_getter, value, context): del context if value.dtype == jnp.float32: value = value.astype(jnp.bfloat16) return next_getter(value) with base.new_context() as ctx: with base.custom_getter(bf16_getter): f = base.get_parameter("f", [], jnp.float32, init=jnp.ones) i = base.get_parameter("i", [], jnp.int32, init=jnp.ones) params = ctx.collect_params() self.assertEqual(params["~"]["f"].dtype, jnp.float32) self.assertEqual(f.dtype, jnp.bfloat16) self.assertEqual(params["~"]["i"].dtype, jnp.int32) self.assertEqual(i.dtype, jnp.int32)
def test_dataclass(self, name): with base.new_context() as ctx: output_sizes = [300, 100, 10] if name is None: mlp = DataMLP(output_sizes) else: mlp = DataMLP(output_sizes, name="mlp") mlp(jnp.ones([1, 28 * 28])) params = ctx.collect_params() if name is None: module_names = ["data_mlp/linear", "data_mlp/linear_1", "data_mlp/linear_2"] else: module_names = ["mlp/linear", "mlp/linear_1", "mlp/linear_2"] self.assertEqual(list(params.keys()), module_names) for module_name, output_size in zip(module_names, output_sizes): self.assertEqual(params[module_name]["w"].shape[-1], output_size) self.assertEqual(params[module_name]["b"].shape[-1], output_size)
def test_nested_getters(self, get_x, custom_x): log = [] def logging_getter(log_msg, dtype_in, dtype_out): def _logging_getter(next_getter, value, context): del context log.append(log_msg) self.assertEqual(value.dtype, dtype_in) value = value.astype(dtype_out) return next_getter(value) return _logging_getter with base.new_context(): with custom_x(logging_getter("a", jnp.float32, jnp.bfloat16)), \ custom_x(logging_getter("b", jnp.bfloat16, jnp.int32)), \ custom_x(logging_getter("c", jnp.int32, jnp.int8)): w = get_x("w", [], init=jnp.ones) self.assertEqual(w.dtype, jnp.int8) self.assertEqual(log, ["a", "b", "c"])
def test_original_dtype(self): def dtype_cast_creator(next_creator, shape, dtype, init, context): if context.original_dtype == jnp.bfloat16: dtype = jnp.float32 return next_creator(shape, dtype, init) def dtype_recast_getter(next_getter, value, context): if context.original_dtype == jnp.bfloat16: assert value.dtype == jnp.float32 value = value.astype(jnp.bfloat16) return next_getter(value) with base.new_context() as ctx: with base.custom_creator(dtype_cast_creator), \ base.custom_getter(dtype_recast_getter): param = base.get_parameter("w", [], jnp.bfloat16, jnp.ones) orig_param = jax.tree_leaves(ctx.collect_params())[0] assert param.dtype == jnp.bfloat16 assert orig_param.dtype == jnp.float32
def test_original_dtype(self, get_x, custom_create_x, custom_get_x, collect_x): def dtype_cast_creator(next_creator, shape, dtype, init, context): if context.original_dtype == jnp.bfloat16: dtype = jnp.float32 return next_creator(shape, dtype, init) def dtype_recast_getter(next_getter, value, context): if context.original_dtype == jnp.bfloat16: assert value.dtype == jnp.float32 value = value.astype(jnp.bfloat16) return next_getter(value) with base.new_context() as ctx: with custom_create_x(dtype_cast_creator), \ custom_get_x(dtype_recast_getter): value = get_x("w", [], jnp.bfloat16, jnp.ones) orig_value = jax.tree_leaves(getattr(ctx, collect_x)())[0] assert value.dtype == jnp.bfloat16 assert orig_value.dtype == jnp.float32