def test_difference_rng(self): before = stateful.internal_state() base.next_rng_key() after = stateful.internal_state() diff = stateful.difference(before, after) self.assertEmpty(diff.params) self.assertEmpty(diff.state) self.assertIsNotNone(diff.rng)
def test_difference_update_state(self): base.get_state("a", [], init=jnp.zeros) base.get_state("b", [], init=jnp.zeros) before = stateful.internal_state() base.set_state("b", jnp.ones([])) after = stateful.internal_state() diff = stateful.difference(before, after) self.assertEmpty(diff.params) self.assertEqual(diff.state, {"~": {"a": None, "b": base.StatePair(0., 1.)}}) self.assertIsNone(diff.rng)
def test_difference_new(self, get_x): get_x("a", [], init=jnp.zeros) before = stateful.internal_state() b = get_x("b", [], init=jnp.zeros) after = stateful.internal_state() diff = stateful.difference(before, after) if get_x == base.get_state: self.assertEmpty(diff.params) self.assertEqual(diff.state, {"~": {"a": None, "b": base.StatePair(b, b)}}) else: self.assertEqual(diff.params, {"~": {"a": None, "b": b}}) self.assertEmpty(diff.state) self.assertIsNone(diff.rng)
def wrapper(*args, **kwargs): if base.inside_transform(): # fun might be stateful, in which case we need to explicitly thread # state in and out of fun to preserve fun as functionally pure. state = stateful.internal_state() named_f = _named_call(statefulify(fun, state), name=name) out, state = named_f(*args, **kwargs) stateful.update_internal_state(state) else: out = _named_call(fun, name=name)(*args, **kwargs) return out
def test_difference_empty(self): before = stateful.internal_state() after = stateful.internal_state() self.assertEmpty(jax.tree_leaves(stateful.difference(before, after)))
def stateful_fun(*args, **kwargs) -> Tuple[Any, stateful.InternalState]: """Explictly returns the changed Haiku state after fun has been executed.""" with stateful.temporary_internal_state(state): out = fun(*args, **kwargs) return out, stateful.internal_state()
def test_temporary_state_resets_names(self): with stateful.temporary_internal_state(stateful.internal_state()): mod1 = module.Module(name="foo") mod2 = module.Module(name="foo") self.assertEqual(mod1.module_name, "foo") self.assertEqual(mod2.module_name, "foo")