def wrapper(*args, **kwargs): base.assert_context("optimize_rng_use") # Extract all current state. frame = base.current_frame() params = frame.params or None if params is not None: params = data_structures.to_haiku_dict(params) state = frame.state or None if state is not None: state = base.extract_state(state, initial=True) rng = frame.rng_stack.peek() if rng is not None: rng = rng.internal_state def pure_fun(params, state, rng, *args, **kwargs): with base.new_context(params=params, state=state, rng=rng): return fun(*args, **kwargs) with count_hk_rngs_requested() as rng_count_f: jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs) rng_count = rng_count_f() if rng_count: base.current_frame().rng_stack.peek().reserve(rng_count) return fun(*args, **kwargs)
def __call__(self, *args, **kwargs): frame = base.current_frame() outer_params = frame.params outer_state = frame.state if hk.running_init(): inner_params, inner_state = self._init_fn(*args, **kwargs) # Lift parameters into this transform's params_dict. check_param_reuse = not self._allow_reuse pack_into_dict(inner_params, outer_params, self._prefix_name, check_param_reuse=check_param_reuse) pack_into_dict(inner_state, outer_state, self._prefix_name, state=True, check_param_reuse=check_param_reuse) return inner_params, inner_state else: if self._prefix_name: prefix = f"{self._prefix_name}/" else: prefix = "" inner_params = unpack_from_dict(outer_params, prefix) inner_state = unpack_from_dict(outer_state, prefix) inner_state = base.extract_state(inner_state, initial=False) inner_params = hk.data_structures.to_haiku_dict(inner_params) inner_state = hk.data_structures.to_haiku_dict(inner_state) return inner_params, inner_state
def collect_state(self) -> State: return extract_state(self.__state, initial=False)
def collect_initial_state(self) -> State: return extract_state(self.__state, initial=True)