def init_fn(rng: Optional[Union[PRNGKey]], inputs: Mapping[str, jnp.ndarray], batch_axes=(), return_initial_output=False, **kwargs ) -> Tuple[Params, State]: """ Initializes your function collecting parameters and state. """ rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR) with new_custom_context(rng=rng) as ctx: # Create the model model = create_fun() # Load the batch axes for the inputs Layer.batch_axes = batch_axes key = hk.next_rng_key() # Initialize the model outputs = model(inputs, key, **kwargs) # Unset the batch axes Layer.batch_axes = () nonlocal constants params, state, constants = ctx.collect_params(), ctx.collect_initial_state(), ctx.collect_constants() if return_initial_output: return params, state, outputs return params, state
def apply_fn(params: Optional[Params], state: Optional[State], rng: Optional[Union[PRNGKey]], inputs, **kwargs ) -> Tuple[Any, State]: """ Applies your function injecting parameters and state. """ params = check_mapping("params", params) state = check_mapping("state", state) rng = to_prng_sequence(rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR)) with new_custom_context(params=params, state=state, constants=constants, rng=rng) as ctx: model = create_fun() key = hk.next_rng_key() out = model(inputs, key, **kwargs) return out, ctx.collect_state()