def test_difference_rng(self): before = stateful.internal_state() base.next_rng_key() after = stateful.internal_state() diff = stateful.difference(before, after) self.assertEmpty(diff.params) self.assertEmpty(diff.state) self.assertIsNotNone(diff.rng)
def test_difference_update_state(self): base.get_state("a", [], init=jnp.zeros) base.get_state("b", [], init=jnp.zeros) before = stateful.internal_state() base.set_state("b", jnp.ones([])) after = stateful.internal_state() diff = stateful.difference(before, after) self.assertEmpty(diff.params) self.assertEqual(diff.state, {"~": {"a": None, "b": base.StatePair(0., 1.)}}) self.assertIsNone(diff.rng)
def test_difference_new(self, get_x): get_x("a", [], init=jnp.zeros) before = stateful.internal_state() b = get_x("b", [], init=jnp.zeros) after = stateful.internal_state() diff = stateful.difference(before, after) if get_x == base.get_state: self.assertEmpty(diff.params) self.assertEqual(diff.state, {"~": {"a": None, "b": base.StatePair(b, b)}}) else: self.assertEqual(diff.params, {"~": {"a": None, "b": b}}) self.assertEmpty(diff.state) self.assertIsNone(diff.rng)
def test_difference_empty(self): before = stateful.internal_state() after = stateful.internal_state() self.assertEmpty(jax.tree_leaves(stateful.difference(before, after)))
def stateful_fun(*args, **kwargs) -> Tuple[Any, stateful.InternalState]: """Explictly returns the changed Haiku state after fun has been executed.""" with stateful.temporary_internal_state(state): out = fun(*args, **kwargs) return out, stateful.difference(state, stateful.internal_state())