Esempio n. 1
0
  def init_fn(rng: Optional[Union[PRNGKey]],
              inputs: Mapping[str, jnp.ndarray],
              batch_axes=(),
              return_initial_output=False,
              **kwargs
  ) -> Tuple[Params, State]:
    """ Initializes your function collecting parameters and state. """
    rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
    with new_custom_context(rng=rng) as ctx:
      # Create the model
      model = create_fun()

      # Load the batch axes for the inputs
      Layer.batch_axes = batch_axes

      key = hk.next_rng_key()

      # Initialize the model
      outputs = model(inputs, key, **kwargs)

      # Unset the batch axes
      Layer.batch_axes = ()

    nonlocal constants
    params, state, constants = ctx.collect_params(), ctx.collect_initial_state(), ctx.collect_constants()

    if return_initial_output:
      return params, state, outputs

    return params, state
Esempio n. 2
0
 def init(
     self,
     rng: tp.Optional[tp.Union[jnp.ndarray, int]],
     *args,
     **kwargs,
 ) -> tp.Tuple[tp.Any, haiku.Params, haiku.State]:
     """Initializes your function collecting parameters and state."""
     rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
     with new_context(rng=rng) as ctx:
         output = self.f(*args, **kwargs)
     return output, ctx.collect_params(), ctx.collect_initial_state()
Esempio n. 3
0
  def apply_fn(params: Optional[Params],
               state: Optional[State],
               rng: Optional[Union[PRNGKey]],
               inputs,
               **kwargs
  ) -> Tuple[Any, State]:
    """ Applies your function injecting parameters and state. """
    params = check_mapping("params", params)
    state = check_mapping("state", state)

    rng = to_prng_sequence(rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR))
    with new_custom_context(params=params, state=state, constants=constants, rng=rng) as ctx:
      model = create_fun()
      key = hk.next_rng_key()
      out = model(inputs, key, **kwargs)
    return out, ctx.collect_state()
Esempio n. 4
0
 def apply(
     self,
     params: tp.Optional[haiku.Params],
     state: tp.Optional[haiku.State],
     rng: tp.Optional[tp.Union[jnp.ndarray, int]],
     *args,
     **kwargs,
 ) -> tp.Tuple[tp.Any, haiku.State]:
     """Applies your function injecting parameters and state."""
     params = check_mapping("params", params)
     state = check_mapping("state", state)
     rng = to_prng_sequence(
         rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR)
     )
     with new_context(params=params, state=state, rng=rng) as ctx:
         out = self.f(*args, **kwargs)
     return out, ctx.collect_state()