def init_with_output(self, rngs: Union[PRNGKey, RNGSequences], *args, method: Optional[Callable[..., Any]] = None, **kwargs) -> Tuple[Any, FrozenVariableDict]: """Initializes a module method with variables and returns output and modified variables. Args: rngs: The rngs for the variable collections. method: An optional method. If provided, applies this method. If not provided, applies the ``__call__`` method. Returns: `(output, vars)``, where ``vars`` are is a dict of the modified collections. """ if not isinstance(rngs, dict): if rngs.shape != (2, ): raise errors.InvalidRngError( 'RNGs should be of shape (2,) in Module ' f'{self.__class__.__name__}, but rngs are: {rngs}') rngs = {'params': rngs} return self.apply({}, *args, rngs=rngs, method=method, mutable=True, **kwargs)
def init_with_output(self, rngs: Union[PRNGKey, RNGSequences], *args, method: Optional[Callable[..., Any]] = None, mutable: CollectionFilter = DenyList("intermediates"), **kwargs) -> Tuple[Any, FrozenVariableDict]: """Initializes a module method with variables and returns output and modified variables. Args: rngs: The rngs for the variable collections. method: An optional method. If provided, applies this method. If not provided, applies the ``__call__`` method. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. By default all collections except "intermediates" are mutable. Returns: `(output, vars)``, where ``vars`` are is a dict of the modified collections. """ if not isinstance(rngs, dict): if rngs.shape != (2,): raise errors.InvalidRngError( 'RNGs should be of shape (2,) in Module ' f'{self.__class__.__name__}, but rngs are: {rngs}') rngs = {'params': rngs} return self.apply( {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
def make_rng(self, name: str) -> PRNGKey: """Generates A PRNGKey from a PRNGSequence with name `name`.""" if not self.has_rng(name): raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"') self._check_valid() self._validate_trace_level() self.rng_counters[name] += 1 return random.fold_in(self.rngs[name], self.rng_counters[name])
def bind(variables: VariableDict, rngs: Optional[RNGSequences] = None, mutable: CollectionFilter = False): """Bind variables and rngs to a new ``Scope``. bind provides a ``Scope`` instance without transforming a function with ``apply``. This is particalary useful for debugging and interactive use cases like notebooks where a function would limit the ability split up code into different cells. a ``Scope`` instance is a stateful object. Note that idiomatic JAX is functional and therefore a ``Scope` does not mix well well with vanilla JAX APIs. Therefore, we recommend using ``apply`` when code should be reusable and compatible across the JAX software ecosystem. """ if not _is_valid_variables(variables): raise errors.ApplyScopeInvalidVariablesError() if rngs is not None and not _is_valid_rngs(rngs): raise errors.InvalidRngError( 'rngs should be a dictionary mapping strings to `jax.PRNGKey`.') new_variables = _unfreeze_variables(variables, mutable) return Scope(new_variables, rngs=rngs, mutable=mutable)