Example #1
0
  def wrapper(variables: VariableDict,
              *args,
              rngs: Optional[RNGSequences] = None,
              **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:

    if not _is_valid_variables(variables):
      raise errors.ApplyScopeInvalidVariablesError()
    if rngs is not None and not _is_valid_rngs(rngs):
      raise errors.ApplyScopeInvalidRngsError()
    new_variables = _unfreeze_variables(variables, mutable)
    with Scope(new_variables, rngs=rngs, mutable=mutable).temporary() as root:
      y = fn(root, *args, **kwargs)
    if mutable is not False:
      mutated_variables = {k: v
                           for k, v in new_variables.items()
                           if in_filter(mutable, k)}
      return y, freeze(mutated_variables)
    else:
      return y
Example #2
0
def bind(variables: VariableDict,
         rngs: Optional[RNGSequences] = None,
         mutable: CollectionFilter = False):
    """Bind variables and rngs to a new ``Scope``.
  
  bind provides a ``Scope`` instance without transforming a function
  with ``apply``. This is particulary useful for debugging and
  interactive use cases like notebooks where a function would limit
  the ability split up code into different cells.

  a ``Scope`` instance is a stateful object. Note that idiomatic JAX is functional
  and therefore a ``Scope` does not mix well well with vanilla JAX APIs. Therefore,
  we recommend using ``apply`` when code should be reusable and compatible
  across the JAX software ecosystem.
  """
    if not _is_valid_variables(variables):
        raise errors.ApplyScopeInvalidVariablesError()
    if rngs is not None and not _is_valid_rngs(rngs):
        raise errors.ApplyScopeInvalidRngsError()
    new_variables = _unfreeze_variables(variables, mutable)
    return Scope(new_variables, rngs=rngs, mutable=mutable)