예제 #1
0
 def update(self, state: hk.State):
     """Updates Haiku's internal state to the given state."""
     frame = base.current_frame()
     for mod_name, bundle in state.items():
         if self._name is not None:
             mod_name = f"{self._name}/{mod_name}"
         for name, value in bundle.items():
             initial_pair = base.StatePair(value, value)
             initial = frame.state[mod_name].get(name, initial_pair).initial
             frame.state[mod_name][name] = base.StatePair(initial, value)
예제 #2
0
 def test_difference_update_state(self):
   base.get_state("a", [], init=jnp.zeros)
   base.get_state("b", [], init=jnp.zeros)
   before = stateful.internal_state()
   base.set_state("b", jnp.ones([]))
   after = stateful.internal_state()
   diff = stateful.difference(before, after)
   self.assertEmpty(diff.params)
   self.assertEqual(diff.state, {"~": {"a": None,
                                       "b": base.StatePair(0., 1.)}})
   self.assertIsNone(diff.rng)
예제 #3
0
 def test_difference_new(self, get_x):
   get_x("a", [], init=jnp.zeros)
   before = stateful.internal_state()
   b = get_x("b", [], init=jnp.zeros)
   after = stateful.internal_state()
   diff = stateful.difference(before, after)
   if get_x == base.get_state:
     self.assertEmpty(diff.params)
     self.assertEqual(diff.state, {"~": {"a": None,
                                         "b": base.StatePair(b, b)}})
   else:
     self.assertEqual(diff.params, {"~": {"a": None, "b": b}})
     self.assertEmpty(diff.state)
   self.assertIsNone(diff.rng)
예제 #4
0
def pack_into_dict(src: hk.Params,
                   dst: MutableMapping[str, Any],
                   prefix: str,
                   state: bool = False,
                   check_param_reuse: bool = True):
    """Puts items from src into dst, with an added prefix."""
    for key, value in src.items():
        new_key = f"{prefix}/{key}" if prefix else key
        if check_param_reuse and new_key in dst:
            raise ValueError(
                f"Key '{new_key}' already exists in the destination params. To "
                "prevent accidental parameter re-use during lift, you can't re-use a "
                "parameter already defined in the outer scope.")
        value = dict(value)
        if state:
            value = {k: base.StatePair(v, v) for k, v in value.items()}
        dst[new_key] = value