Exemple #1
0
  def init_fn(rng: Optional[Union[PRNGKey]],
              inputs: Mapping[str, jnp.ndarray],
              batch_axes=(),
              return_initial_output=False,
              **kwargs
  ) -> Tuple[Params, State]:
    """ Initializes your function collecting parameters and state. """
    rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
    with new_custom_context(rng=rng) as ctx:
      # Create the model
      model = create_fun()

      # Load the batch axes for the inputs
      Layer.batch_axes = batch_axes

      key = hk.next_rng_key()

      # Initialize the model
      outputs = model(inputs, key, **kwargs)

      # Unset the batch axes
      Layer.batch_axes = ()

    nonlocal constants
    params, state, constants = ctx.collect_params(), ctx.collect_initial_state(), ctx.collect_constants()

    if return_initial_output:
      return params, state, outputs

    return params, state
Exemple #2
0
  def apply_fn(params: Optional[Params],
               state: Optional[State],
               rng: Optional[Union[PRNGKey]],
               inputs,
               **kwargs
  ) -> Tuple[Any, State]:
    """ Applies your function injecting parameters and state. """
    params = check_mapping("params", params)
    state = check_mapping("state", state)

    rng = to_prng_sequence(rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR))
    with new_custom_context(params=params, state=state, constants=constants, rng=rng) as ctx:
      model = create_fun()
      key = hk.next_rng_key()
      out = model(inputs, key, **kwargs)
    return out, ctx.collect_state()