Beispiel #1
0
 def test_difference_rng(self):
   before = stateful.internal_state()
   base.next_rng_key()
   after = stateful.internal_state()
   diff = stateful.difference(before, after)
   self.assertEmpty(diff.params)
   self.assertEmpty(diff.state)
   self.assertIsNotNone(diff.rng)
Beispiel #2
0
 def test_difference_update_state(self):
   base.get_state("a", [], init=jnp.zeros)
   base.get_state("b", [], init=jnp.zeros)
   before = stateful.internal_state()
   base.set_state("b", jnp.ones([]))
   after = stateful.internal_state()
   diff = stateful.difference(before, after)
   self.assertEmpty(diff.params)
   self.assertEqual(diff.state, {"~": {"a": None,
                                       "b": base.StatePair(0., 1.)}})
   self.assertIsNone(diff.rng)
Beispiel #3
0
 def test_difference_new(self, get_x):
   get_x("a", [], init=jnp.zeros)
   before = stateful.internal_state()
   b = get_x("b", [], init=jnp.zeros)
   after = stateful.internal_state()
   diff = stateful.difference(before, after)
   if get_x == base.get_state:
     self.assertEmpty(diff.params)
     self.assertEqual(diff.state, {"~": {"a": None,
                                         "b": base.StatePair(b, b)}})
   else:
     self.assertEqual(diff.params, {"~": {"a": None, "b": b}})
     self.assertEmpty(diff.state)
   self.assertIsNone(diff.rng)
Beispiel #4
0
 def test_difference_empty(self):
   before = stateful.internal_state()
   after = stateful.internal_state()
   self.assertEmpty(jax.tree_leaves(stateful.difference(before, after)))
Beispiel #5
0
 def stateful_fun(*args, **kwargs) -> Tuple[Any, stateful.InternalState]:
     """Explictly returns the changed Haiku state after fun has been executed."""
     with stateful.temporary_internal_state(state):
         out = fun(*args, **kwargs)
         return out, stateful.difference(state, stateful.internal_state())