Exemplo n.º 1
    def init_with_output(self,
                         rngs: Union[PRNGKey, RNGSequences],
                         method: Optional[Callable[..., Any]] = None,
                         **kwargs) -> Tuple[Any, FrozenVariableDict]:
        """Initializes a module method with variables and returns output and modified variables.

      rngs: The rngs for the variable collections.
      method: An optional method. If provided, applies this method. If not
              provided, applies the ``__call__`` method.
      `(output, vars)``, where ``vars`` are is a dict of the modified
        if not isinstance(rngs, dict):
            if rngs.shape != (2, ):
                raise errors.InvalidRngError(
                    'RNGs should be of shape (2,) in Module '
                    f'{self.__class__.__name__}, but rngs are: {rngs}')
            rngs = {'params': rngs}
        return self.apply({},
Exemplo n.º 2
  def init_with_output(self,
                       rngs: Union[PRNGKey, RNGSequences],
                       method: Optional[Callable[..., Any]] = None,
                       mutable: CollectionFilter = DenyList("intermediates"),
                       **kwargs) -> Tuple[Any, FrozenVariableDict]:
    """Initializes a module method with variables and returns output and modified variables.

      rngs: The rngs for the variable collections.
      method: An optional method. If provided, applies this method. If not
        provided, applies the ``__call__`` method.
      mutable: Can be bool, str, or list. Specifies which collections should be
        treated as mutable: ``bool``: all/no collections are mutable.
        ``str``: The name of a single mutable collection. ``list``: A
        list of names of mutable collections. By default all collections
        except "intermediates" are mutable.
      `(output, vars)``, where ``vars`` are is a dict of the modified
    if not isinstance(rngs, dict):
      if rngs.shape != (2,):
        raise errors.InvalidRngError(
            'RNGs should be of shape (2,) in Module '
            f'{self.__class__.__name__}, but rngs are: {rngs}')
      rngs = {'params': rngs}
    return self.apply(
        {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
Exemplo n.º 3
 def make_rng(self, name: str) -> PRNGKey:
   """Generates A PRNGKey from a PRNGSequence with name `name`."""
   if not self.has_rng(name):
     raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"')
   self.rng_counters[name] += 1
   return random.fold_in(self.rngs[name], self.rng_counters[name])
Exemplo n.º 4
def bind(variables: VariableDict,
         rngs: Optional[RNGSequences] = None,
         mutable: CollectionFilter = False):
    """Bind variables and rngs to a new ``Scope``.
  bind provides a ``Scope`` instance without transforming a function
  with ``apply``. This is particalary useful for debugging and
  interactive use cases like notebooks where a function would limit
  the ability split up code into different cells.

  a ``Scope`` instance is a stateful object. Note that idiomatic JAX is functional
  and therefore a ``Scope` does not mix well well with vanilla JAX APIs. Therefore,
  we recommend using ``apply`` when code should be reusable and compatible
  across the JAX software ecosystem.
    if not _is_valid_variables(variables):
        raise errors.ApplyScopeInvalidVariablesError()
    if rngs is not None and not _is_valid_rngs(rngs):
        raise errors.InvalidRngError(
            'rngs should be a dictionary mapping strings to `jax.PRNGKey`.')
    new_variables = _unfreeze_variables(variables, mutable)
    return Scope(new_variables, rngs=rngs, mutable=mutable)