def test_do_not_store(self): def my_creator(next_creator, shape, dtype, init, context): del next_creator, shape, dtype, init, context return base.DO_NOT_STORE def my_getter(next_getter, value, context): assert value is base.DO_NOT_STORE return next_getter( context.original_init(context.original_shape, context.original_dtype)) def my_setter(next_setter, value, context): del next_setter, value, context return base.DO_NOT_STORE with base.new_context() as ctx: with base.custom_creator(my_creator, state=True), \ base.custom_getter(my_getter, state=True), \ base.custom_setter(my_setter): self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1) self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1) base.set_state("s2", jnp.ones([])) self.assertEmpty(ctx.collect_params()) self.assertEmpty(ctx.collect_state())
def test_nested_creators(self): log = [] def logging_creator(log_msg): def _logging_creator(next_creator, name, shape, dtype, init): log.append(log_msg) return next_creator(name, shape, dtype, init) return _logging_creator with base.new_context(): with base.custom_creator(logging_creator("a")), \ base.custom_creator(logging_creator("b")), \ base.custom_creator(logging_creator("c")): base.get_parameter("w", [], init=jnp.ones) self.assertEqual(log, ["a", "b", "c"])
def test_unable_to_mutate_name(self): def mutates_name(next_creator, name, shape, dtype, init): next_creator(name + "_foo", shape, dtype, init) with base.new_context(), base.custom_creator(mutates_name): with self.assertRaisesRegex(ValueError, "Modifying .*name.* not supported"): base.get_parameter("w", [], init=jnp.ones)
def test_original_shape(self): def new_shape_creator(next_creator, shape, dtype, init, context): del shape del context new_shape = (1, 2, 3) return next_creator(new_shape, dtype, init) def original_shape_restorer(next_creator, shape, dtype, init, context): assert shape == (1, 2, 3) return next_creator(context.original_shape, dtype, init) with base.new_context(): with base.custom_creator(new_shape_creator): with base.custom_creator(original_shape_restorer): param = base.get_parameter("w", [5], jnp.bfloat16, jnp.ones) assert param.shape == (5, )
def test_nested_creators(self): log = [] def logging_creator(log_msg): def _logging_creator(next_creator, name, shape, dtype, init): log.append(log_msg) return next_creator(name, shape, dtype, init) return _logging_creator init_fn, _ = base.transform( lambda: base.get_parameter("w", [], init=jnp.ones)) a, b, c = map(logging_creator, ["a", "b", "c"]) with base.custom_creator(a), base.custom_creator(b), base.custom_creator(c): init_fn(None) self.assertEqual(log, ["a", "b", "c"])
def test_creator_requires_context(self): def my_creator(next_creator, shape, dtype, init, context): del context return next_creator(shape, dtype, init) with self.assertRaisesRegex( ValueError, "must be used as part of an `hk.transform`"): with base.custom_creator(my_creator): pass
def test_unable_to_mutate_name(self): def mutates_name(next_creator, name, shape, dtype, init): next_creator(name + "_foo", shape, dtype, init) init_fn, _ = base.transform( lambda: base.get_parameter("w", [], init=jnp.ones)) with self.assertRaisesRegex(ValueError, "Modifying .*name.* not supported"): with base.custom_creator(mutates_name): init_fn(None)
def test_init_custom_creator(self): def zeros_creator(next_creator, name, shape, dtype, init): self.assertEqual(name, "~/w") self.assertEqual(shape, []) self.assertEqual(dtype, jnp.float32) self.assertEqual(init, jnp.ones) return next_creator(name, shape, dtype, jnp.zeros) with base.new_context() as ctx: with base.custom_creator(zeros_creator): base.get_parameter("w", [], init=jnp.ones) self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.zeros([])}})
def test_init_custom_creator(self): def zeros_creator(next_creator, name, shape, dtype, init): self.assertEqual(name, "~/w") self.assertEqual(shape, []) self.assertEqual(dtype, jnp.float32) self.assertEqual(init, jnp.ones) return next_creator(name, shape, dtype, jnp.zeros) init_fn, _ = base.transform( lambda: base.get_parameter("w", [], init=jnp.ones)) with base.custom_creator(zeros_creator): params = init_fn(None) self.assertEqual(params, {"~": {"w": jnp.zeros([])}})
def test_creator_types(self, params, state): log = [] def logging_creator(next_creator, shape, dtype, init, context): log.append(context.full_name) return next_creator(shape, dtype, init) with base.new_context(): with base.custom_creator(logging_creator, params=params, state=state): base.get_parameter("params", [], init=jnp.zeros) base.get_state("state", [], init=jnp.zeros) self.assertLen(log, int(params) + int(state)) if params: self.assertIn("~/params", log) if state: self.assertIn("~/state", log)
def test_original_dtype(self): def dtype_cast_creator(next_creator, shape, dtype, init, context): if context.original_dtype == jnp.bfloat16: dtype = jnp.float32 return next_creator(shape, dtype, init) def dtype_recast_getter(next_getter, value, context): if context.original_dtype == jnp.bfloat16: assert value.dtype == jnp.float32 value = value.astype(jnp.bfloat16) return next_getter(value) with base.new_context() as ctx: with base.custom_creator(dtype_cast_creator), \ base.custom_getter(dtype_recast_getter): param = base.get_parameter("w", [], jnp.bfloat16, jnp.ones) orig_param = jax.tree_leaves(ctx.collect_params())[0] assert param.dtype == jnp.bfloat16 assert orig_param.dtype == jnp.float32
def net(): with base.custom_creator(counting_creator): return MultipleForwardMethods()()
def f(): a, b, c = map(logging_creator, ["a", "b", "c"]) with base.custom_creator(a), \ base.custom_creator(b), \ base.custom_creator(c): return base.get_parameter("w", [], init=jnp.ones)
def net(): with base.custom_creator(counting_creator): for i in range(4): base.get_parameter("w{}".format(i), [], init=jnp.zeros)
def f(): with base.custom_creator(zeros_creator): return base.get_parameter("w", [], init=jnp.ones)