Esempio n. 1
0
    def init_by_shape(cls, _rng, input_specs, *args, name=None, **kwargs):
        """Initialize the module parameters.

    This method will initialize the module parameters without computation.
    Initializer functions can depend on the shape but not the value of inputs.
    
    Example::

      input_shape = (batch_size, image_size, image_size, 3)
      model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),
                                      input_specs=[(input_shape, jnp.float32)])

    Args:
      _rng: the random number generator used to initialize parameters.
      input_specs: an iterable of (shape, dtype) pairs specifying the inputs
      *args: arguments passed to the module's apply function
      name: name of this module.
      **kwargs: keyword arguments passed to the module's apply function
    Returns:
      A pair consisting of the model output and the initialized parameters
    """
        stochastic_rng = None
        try:
            stochastic_rng = stochastic.make_rng()
        except ValueError:
            # Either there is no stochastic scope or the current
            # scope is invalid due to another jax transformation.
            # In both cases we should not try to lift the stochastic
            # scope into the lazy evaluation
            pass

        def lazy_init(*inputs):
            def init_fn():
                return cls.init(_rng, *(inputs + args), name=name, **kwargs)

            if stochastic_rng is not None:
                # Create a new stochastic scope inside the lazy evaluation
                # this way we can use a stochastic scope in combination
                # with init_by_shape.
                with stochastic.stochastic(stochastic_rng):
                    return init_fn()
            else:
                return init_fn()

        return jax_utils.partial_eval_by_shape(lazy_init, input_specs)
Esempio n. 2
0
    def init_by_shape(cls, rng, input_specs, *args, name=None, **kwargs):
        """Initialize the module parameters.

    This method will initialize the module parameters without computation.
    Initializer functions can depend on the shape but not the value of inputs.

    Args:
      rng: the random number generator used to initialize parameters.
      input_specs: an iterable of (shape, dtype) pairs specifying the inputs
      *args: arguments passed to the module's apply function
      name: name of this module.
      **kwargs: keyword arguments passed to the module's apply function
    Returns:
      A pair consisting of the model output and the initialized parameters
    """
        def lazy_init(*inputs):
            return cls.init(rng, *(inputs + args), name=name, **kwargs)

        return jax_utils.partial_eval_by_shape(lazy_init, input_specs)