Esempio n. 1
0
    def init_with_output(self,
                         rngs: Union[PRNGKey, RNGSequences],
                         *args,
                         method: Optional[Callable[..., Any]] = None,
                         **kwargs) -> Tuple[Any, FrozenVariableDict]:
        """Initializes a module method with variables and returns output and modified variables.

    Args:
      rngs: The rngs for the variable collections.
      method: An optional method. If provided, applies this method. If not
              provided, applies the ``__call__`` method.
    Returns:
      `(output, vars)``, where ``vars`` are is a dict of the modified
      collections.
    """
        if not isinstance(rngs, dict):
            if rngs.shape != (2, ):
                raise errors.InvalidRngError(
                    'RNGs should be of shape (2,) in Module '
                    f'{self.__class__.__name__}, but rngs are: {rngs}')
            rngs = {'params': rngs}
        return self.apply({},
                          *args,
                          rngs=rngs,
                          method=method,
                          mutable=True,
                          **kwargs)
Esempio n. 2
0
  def init_with_output(self,
                       rngs: Union[PRNGKey, RNGSequences],
                       *args,
                       method: Optional[Callable[..., Any]] = None,
                       mutable: CollectionFilter = DenyList("intermediates"),
                       **kwargs) -> Tuple[Any, FrozenVariableDict]:
    """Initializes a module method with variables and returns output and modified variables.

    Args:
      rngs: The rngs for the variable collections.
      method: An optional method. If provided, applies this method. If not
        provided, applies the ``__call__`` method.
      mutable: Can be bool, str, or list. Specifies which collections should be
        treated as mutable: ``bool``: all/no collections are mutable.
        ``str``: The name of a single mutable collection. ``list``: A
        list of names of mutable collections. By default all collections
        except "intermediates" are mutable.
    Returns:
      `(output, vars)``, where ``vars`` are is a dict of the modified
      collections.
    """
    if not isinstance(rngs, dict):
      if rngs.shape != (2,):
        raise errors.InvalidRngError(
            'RNGs should be of shape (2,) in Module '
            f'{self.__class__.__name__}, but rngs are: {rngs}')
      rngs = {'params': rngs}
    return self.apply(
        {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
Esempio n. 3
0
 def make_rng(self, name: str) -> PRNGKey:
   """Generates A PRNGKey from a PRNGSequence with name `name`."""
   if not self.has_rng(name):
     raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"')
   self._check_valid()
   self._validate_trace_level()
   self.rng_counters[name] += 1
   return random.fold_in(self.rngs[name], self.rng_counters[name])
Esempio n. 4
0
def bind(variables: VariableDict,
         rngs: Optional[RNGSequences] = None,
         mutable: CollectionFilter = False):
    """Bind variables and rngs to a new ``Scope``.
  
  bind provides a ``Scope`` instance without transforming a function
  with ``apply``. This is particalary useful for debugging and
  interactive use cases like notebooks where a function would limit
  the ability split up code into different cells.

  a ``Scope`` instance is a stateful object. Note that idiomatic JAX is functional
  and therefore a ``Scope` does not mix well well with vanilla JAX APIs. Therefore,
  we recommend using ``apply`` when code should be reusable and compatible
  across the JAX software ecosystem.
  """
    if not _is_valid_variables(variables):
        raise errors.ApplyScopeInvalidVariablesError()
    if rngs is not None and not _is_valid_rngs(rngs):
        raise errors.InvalidRngError(
            'rngs should be a dictionary mapping strings to `jax.PRNGKey`.')
    new_variables = _unfreeze_variables(variables, mutable)
    return Scope(new_variables, rngs=rngs, mutable=mutable)