def init_by_shape(cls, _rng, input_specs, *args, name=None, **kwargs): """Initialize the module parameters. This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs. Example:: input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0), input_specs=[(input_shape, jnp.float32)]) Args: _rng: the random number generator used to initialize parameters. input_specs: an iterable of (shape, dtype) pairs specifying the inputs *args: arguments passed to the module's apply function name: name of this module. **kwargs: keyword arguments passed to the module's apply function Returns: A pair consisting of the model output and the initialized parameters """ stochastic_rng = None try: stochastic_rng = stochastic.make_rng() except ValueError: # Either there is no stochastic scope or the current # scope is invalid due to another jax transformation. # In both cases we should not try to lift the stochastic # scope into the lazy evaluation pass def lazy_init(*inputs): def init_fn(): return cls.init(_rng, *(inputs + args), name=name, **kwargs) if stochastic_rng is not None: # Create a new stochastic scope inside the lazy evaluation # this way we can use a stochastic scope in combination # with init_by_shape. with stochastic.stochastic(stochastic_rng): return init_fn() else: return init_fn() return jax_utils.partial_eval_by_shape(lazy_init, input_specs)
def init_by_shape(cls, rng, input_specs, *args, name=None, **kwargs): """Initialize the module parameters. This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs. Args: rng: the random number generator used to initialize parameters. input_specs: an iterable of (shape, dtype) pairs specifying the inputs *args: arguments passed to the module's apply function name: name of this module. **kwargs: keyword arguments passed to the module's apply function Returns: A pair consisting of the model output and the initialized parameters """ def lazy_init(*inputs): return cls.init(rng, *(inputs + args), name=name, **kwargs) return jax_utils.partial_eval_by_shape(lazy_init, input_specs)