Example #1
0
  def wrapper(*args, **kwargs):
    base.assert_context("optimize_rng_use")

    # Extract all current state.
    frame = base.current_frame()
    params = frame.params or None
    if params is not None:
      params = data_structures.to_haiku_dict(params)
    state = frame.state or None
    if state is not None:
      state = base.extract_state(state, initial=True)
    rng = frame.rng_stack.peek()
    if rng is not None:
      rng = rng.internal_state

    def pure_fun(params, state, rng, *args, **kwargs):
      with base.new_context(params=params, state=state, rng=rng):
        return fun(*args, **kwargs)

    with count_hk_rngs_requested() as rng_count_f:
      jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs)
    rng_count = rng_count_f()

    if rng_count:
      base.current_frame().rng_stack.peek().reserve(rng_count)
    return fun(*args, **kwargs)
Example #2
0
 def __call__(self, *args, **kwargs):
     frame = base.current_frame()
     outer_params = frame.params
     outer_state = frame.state
     if hk.running_init():
         inner_params, inner_state = self._init_fn(*args, **kwargs)
         # Lift parameters into this transform's params_dict.
         check_param_reuse = not self._allow_reuse
         pack_into_dict(inner_params,
                        outer_params,
                        self._prefix_name,
                        check_param_reuse=check_param_reuse)
         pack_into_dict(inner_state,
                        outer_state,
                        self._prefix_name,
                        state=True,
                        check_param_reuse=check_param_reuse)
         return inner_params, inner_state
     else:
         if self._prefix_name:
             prefix = f"{self._prefix_name}/"
         else:
             prefix = ""
         inner_params = unpack_from_dict(outer_params, prefix)
         inner_state = unpack_from_dict(outer_state, prefix)
         inner_state = base.extract_state(inner_state, initial=False)
         inner_params = hk.data_structures.to_haiku_dict(inner_params)
         inner_state = hk.data_structures.to_haiku_dict(inner_state)
         return inner_params, inner_state
Example #3
0
 def collect_state(self) -> State:
     return extract_state(self.__state, initial=False)
Example #4
0
 def collect_initial_state(self) -> State:
     return extract_state(self.__state, initial=True)