def wrapper(variables: VariableDict, *args, rngs: Optional[RNGSequences] = None, **kwargs) -> Union[Any, Tuple[Any, VariableDict]]: if not _is_valid_variables(variables): raise errors.ApplyScopeInvalidVariablesError() if rngs is not None and not _is_valid_rngs(rngs): raise errors.ApplyScopeInvalidRngsError() new_variables = _unfreeze_variables(variables, mutable) with Scope(new_variables, rngs=rngs, mutable=mutable).temporary() as root: y = fn(root, *args, **kwargs) if mutable is not False: mutated_variables = {k: v for k, v in new_variables.items() if in_filter(mutable, k)} return y, freeze(mutated_variables) else: return y
def bind(variables: VariableDict, rngs: Optional[RNGSequences] = None, mutable: CollectionFilter = False): """Bind variables and rngs to a new ``Scope``. bind provides a ``Scope`` instance without transforming a function with ``apply``. This is particulary useful for debugging and interactive use cases like notebooks where a function would limit the ability split up code into different cells. a ``Scope`` instance is a stateful object. Note that idiomatic JAX is functional and therefore a ``Scope` does not mix well well with vanilla JAX APIs. Therefore, we recommend using ``apply`` when code should be reusable and compatible across the JAX software ecosystem. """ if not _is_valid_variables(variables): raise errors.ApplyScopeInvalidVariablesError() if rngs is not None and not _is_valid_rngs(rngs): raise errors.ApplyScopeInvalidRngsError() new_variables = _unfreeze_variables(variables, mutable) return Scope(new_variables, rngs=rngs, mutable=mutable)