def register(cls): """Registers a class as a JAX pytree node.""" def flatten(linop): param_names = set(non_shape_params[cls.__name__]) components = { param_name: value for param_name, value in linop.parameters.items() if param_name in param_names } metadata = { param_name: value for param_name, value in linop.parameters.items() if param_name not in param_names } if components: keys, values = zip(*sorted(components.items())) else: keys, values = (), () return values, (keys, metadata) def unflatten(info, xs): keys, metadata = info parameters = dict(list(zip(keys, xs)), **metadata) return cls(**parameters) tree_util.register_pytree_node(cls, flatten, unflatten)
def differentiable(cls: Type[T]) -> Type[T]: keys = _get_keys(cls) def _tree_flatten(node: Module) -> Tuple[Tuple[Dict[str, Any]], Dict[str, Any]]: children = {} aux_data = {} for key in keys[_DIFFERENTIABLE]: children[key] = getattr(node, key) for key in keys[_NON_DIFFERENTIABLE]: aux_data[key] = getattr(node, key) logger.debug('=' * 50) logger.debug('flatten: %s', cls) logger.debug('aux_data: %s', aux_data) logger.debug('children: %s', children) return (children,), aux_data def _tree_unflatten(aux_data: Tuple[Dict[str, Any]], children: Dict[str, Any]) -> Module: logger.debug('=' * 50) logger.debug('unflatten: %s', cls) logger.debug('aux_data: %s', aux_data) logger.debug('children: %s', children) return cls(**aux_data, **children[0]) # type: ignore register_pytree_node(cls, _tree_flatten, _tree_unflatten) return cls
def register_pytree_namedtuple(cls): register_pytree_node( cls, lambda xs: (tuple(xs), None), # tell JAX how to unpack lambda _, xs: cls(*xs) # tell JAX how to pack back ) return cls
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) tree_util.register_pytree_node( cls, cls.flatten, # Pytype incorrectly thinks that cls.unflatten accepts three arguments. cls.unflatten # type: ignore )
def _register_dataclass_type(data_class): """Register dataclass so JAX knows how to handle it.""" flatten = lambda d: jax.tree_flatten(d.__dict__) unflatten = lambda s, xs: data_class(**s.unflatten(xs)) try: tree_util.register_pytree_node( nodetype=data_class, flatten_func=flatten, unflatten_func=unflatten) except ValueError: logging.info('%s is already registered as JAX PyTree node.', data_class)
def py_tree_registered_dataclass(cls, *args, **kwargs): """Creates a new dataclass type and registers it as a pytree node.""" dcls = dataclasses.dataclass(cls, *args, **kwargs) tree_util.register_pytree_node( dcls, lambda instance: ( # pylint: disable=g-long-lambda [getattr(instance, f.name) for f in dataclasses.fields(instance)], None), lambda _, instance_args: dcls(*instance_args)) return dcls
def register_graph_as_jax_pytree(cls: Type[T]) -> None: def tree_unflatten(hashed: Hashable, trees: Sequence[PyTree]) -> T: node_dicts, edge_dicts = trees if not isinstance(node_dicts, dict): raise TypeError if not isinstance(edge_dicts, dict): raise TypeError graph = cls() graph.add_nodes_from(node_dicts.items()) graph.add_edges_from([(source, target, data) for (source, target), data in edge_dicts.items() ]) return graph def tree_flatten(graph: T) -> Tuple[Sequence[PyTree], Hashable]: return ((dict(graph.nodes), dict(graph.edges)), None) register_pytree_node(cls, tree_flatten, tree_unflatten)
def __dir__(self): if isinstance(self._data, dict): return list(self._data.keys()) elif isinstance(self._data, FrozenDict): return list(self._data._dict.keys()) else: return [] def __repr__(self): return f'{self._data}' def __hash__(self): # Note: will only work when wrapping FrozenDict. return hash(self._data) def copy(self, **kwargs): return self._data.__class__(self._data.copy(**kwargs)) tree_util.register_pytree_node( DotGetter, lambda x: ((x._data, ), ()), # pylint: disable=protected-access lambda _, data: data[0]) # Note: restores as raw dict, intentionally. serialization.register_serialization_state( DotGetter, serialization._dict_state_dict, # pylint: disable=protected-access serialization._restore_dict) # pylint: disable=protected-access
zeros_like_p.def_abstract_eval(lambda x: x) class Zero: __slots__ = ['aval'] def __init__(self, aval): self.aval = aval def __repr__(self): return 'Zero({})'.format(self.aval) @staticmethod def from_value(val): return Zero(raise_to_shaped(get_aval(val))) register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval)) def _stop_gradient_impl(x): if not valid_jaxtype(x): raise TypeError("stop_gradient only works on valid JAX arrays, but " f"input argument is: {x}") return x stop_gradient_p: Primitive = Primitive('stop_gradient') stop_gradient_p.def_impl(_stop_gradient_impl) stop_gradient_p.def_abstract_eval(lambda x: x)
def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, transform_stack, primals_in, cotangents_in): return backward_pass(jaxpr.jaxpr, reduce_axes, transform_stack, jaxpr.consts, primals_in, cotangents_in) class UndefinedPrimal: __slots__ = ['aval'] def __init__(self, aval): self.aval = aval def __repr__(self): return 'UndefinedPrimal({})'.format(self.aval) def is_undefined_primal(x): return type(x) is UndefinedPrimal register_pytree_node(UndefinedPrimal, lambda z: ((), z.aval), lambda aval, _: UndefinedPrimal(aval)) def get_primitive_transpose(p): try: return primitive_transposes[p] except KeyError as err: raise NotImplementedError( "Transpose rule (for reverse-mode differentiation) for '{}' " "not implemented".format(p)) from err @lu.transformation_with_aux def nonzero_tangent_outputs(*args, **kwargs): results = (_, tangents_out) = yield args, kwargs yield results, [type(r) is not Zero for r in tangents_out]
from vmap import grad as fx_grad import torch.fx as fx from jax import grad, jit, partial, random, value_and_grad, lax from jax.flatten_util import ravel_pytree import jax.numpy as np from jax import random from jax.tree_util import register_pytree_node, tree_multimap # (q, p) -> (position (param value), momentum) IntegratorState = namedtuple("IntegratorState", ["q", "p", "potential_energy", "q_grad"]) # a tree-like JAX primitive that allows program transformations # to work on Python containers (https://jax.readthedocs.io/en/latest/pytrees.html) register_pytree_node(IntegratorState, lambda xs: (tuple(xs), None), lambda _, xs: IntegratorState(*xs)) def leapfrog(potential_fn, kinetic_fn): r""" Second order symplectic integrator that uses the leapfrog algorithm for position `q` and momentum `p`. :param potential_fn: Python callable that computes the potential energy given input parameters. The input parameters to `potential_fn` can be any python collection type. :param kinetic_fn: Python callable that returns the kinetic energy given inverse mass matrix and momentum. :return: a pair of (`init_fn`, `update_fn`). """ def init_fn(q, p):
def register_pytree_namedtuple(cls): register_pytree_node(cls, lambda xs: (tuple(xs), None), lambda _, xs: cls(*xs))
def scatter(self, indices, value, name=None): raise NotImplementedError('If you need this feature, please email ' '`[email protected]`.') def split(self, value, lengths, name=None): raise NotImplementedError('If you need this feature, please email ' '`[email protected]`.') def __repr__(self): return ('tf2{}.TensorArray(dtype={}, size={}, dynamic_size={}, ' 'element_shape={}, len(data)={})').format( 'jax' if JAX_MODE else 'numpy', self._dtype, self._size, self._dynamic_size, self._element_shape, len(self._data)) if JAX_MODE: from jax import tree_util # pylint: disable=g-import-not-at-top def flatten(val): vals = (val._data, ) # pylint: disable=protected-access aux = dict(dtype=val.dtype, element_shape=val.element_shape, dynamic_size=val.dynamic_size) return vals, aux def unflatten(aux, vals): return TensorArray(data=vals[0], **aux) tree_util.register_pytree_node(TensorArray, flatten, unflatten)
""" out = TFCDict(self) if isinstance(o, dict) or (type(o) is type(self)): for key in self._keys: out[key] -= o[key] elif isinstance(o, np.ndarray): o = o.flatten() for k in range(self._nKeys): out[self._keys[k]] -= o[self._slices[k]] return out # Register TFCDict as a JAX type register_pytree_node( TFCDict, lambda x: (list(x.values()), list(x.keys())), lambda keys, values: TFCDict(safe_zip(keys, values)), ) class TFCDictRobust(OrderedDict): """This class is like the :class:`TFCDict <tfc.utils.TFCUtils.TFCDict>` class, but it handles non-flat arrays.""" def __init__(self, *args): """Initialize TFCDictRobust using the OrderedDict method.""" # Store dictionary and keep a record of the keys. Keys will stay in same # order, so that adding and subtracting is repeatable. super().__init__(*args) self._keys = list(self.keys()) self._nKeys = len(self._keys)
raise NotImplementedError('If you need this feature, please email ' '`[email protected]`.') def grad(self, source, flow=None, name=None): raise NotImplementedError('If you need this feature, please email ' '`[email protected]`.') def scatter(self, indices, value, name=None): raise NotImplementedError('If you need this feature, please email ' '`[email protected]`.') def split(self, value, lengths, name=None): raise NotImplementedError('If you need this feature, please email ' '`[email protected]`.') if JAX_MODE: from jax import tree_util # pylint: disable=g-import-not-at-top def to_tree(val): vals = (val._data, ) # pylint: disable=protected-access aux = dict(dtype=val.dtype, element_shape=val.element_shape, dynamic_size=val.dynamic_size) return vals, aux def from_tree(aux, vals): return TensorArray(data=vals[0], **aux) tree_util.register_pytree_node(TensorArray, to_tree, from_tree)
f'at mapped index {", ".join(map(str, idx))}: ' # type: ignore f'{_format_msg(self.msgs[int(self.code[idx])], self.payload[idx])}' # type: ignore for idx, e in np.ndenumerate(self.err) if e) or None return None def throw(self): """Throw ValueError with error message if error happened.""" err = self.get() if err: raise ValueError(err) register_pytree_node( Error, lambda e: ((e.err, e.code, e.payload), tuple(sorted(e.msgs.items()))), lambda msgs, data: Error( data[0], data[1], # type: ignore dict(msgs), data[2])) # type: ignore init_error = Error(False, 0, {}) next_code = it.count(1).__next__ # globally unique ids, could be uuid4 def assert_func(error: Error, pred: Bool, msg: str, payload: Optional[Payload]) -> Error: code = next_code() payload = init_payload if payload is None else payload out_err = error.err | jnp.logical_not(pred) out_code = lax.select(error.err, error.code, code) out_payload = lax.select(error.err, error.payload, payload)
Return ------ g : mean.shape array The new gvars. """ cov = gvar.gvar.cov mean = np.asarray(mean) shape = mean.shape mean = mean.flat jac = np.array(jac) # TODO patches gvar issue #27 jac = jac.reshape(len(mean), len(indices)) g = np.zeros(len(mean), object) for i, jacrow in enumerate(jac): der = gvar.svec(len(indices)) der._assign(jac[i], indices) g[i] = gvar.GVar(mean[i], der, cov) return g.reshape(shape) def bufferdict_flatten(bd): return tuple(bd.values()), tuple(bd.keys()) def bufferdict_unflatten(keys, values): return gvar.BufferDict(zip(keys, values)) # register BufferDict as a pytree tree_util.register_pytree_node(gvar.BufferDict, bufferdict_flatten, bufferdict_unflatten)
nn_params = nn_module.init(rng_key, jnp.ones(input_shape)) # haiku init returns an immutable dict nn_params = haiku.data_structures.to_mutable_dict(nn_params) # we cast it to a mutable one to be able to set priors for parameters # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) return partial(nn_module.apply, nn_params, None) # register an "empty" parameter which only stores its shape # so that the optimizer can skip optimize this parameter, while # it still provides shape information for priors ParamShape = namedtuple("ParamShape", ["shape"]) register_pytree_node(ParamShape, lambda x: ((None,), x.shape), lambda shape, x: ParamShape(shape)) def _update_params(params, new_params, prior, prefix=''): """ A helper to recursively set prior to new_params. """ for name, item in params.items(): flatten_name = ".".join([prefix, name]) if prefix else name if isinstance(item, dict): assert not isinstance(prior, dict) or flatten_name not in prior new_item = new_params[name] _update_params(item, new_item, prior, prefix=flatten_name) elif (not isinstance(prior, dict)) or flatten_name in prior: d = prior[flatten_name] if isinstance(prior, dict) else prior if isinstance(params[name], ParamShape):
@_add_doc(optimizers.sm3) class SM3(_NumPyroOptim): def __init__(self, *args, **kwargs): super(SM3, self).__init__(optimizers.sm3, *args, **kwargs) # TODO: currently, jax.scipy.optimize.minimize only supports 1D input, # so we need to add the following mechanism to transform params to flat_params # and pass `unravel_fn` arround. # When arbitrary pytree is supported in JAX, we can just simply use # identity functions for `init_fn` and `get_params`. _MinimizeState = namedtuple("MinimizeState", ["flat_params", "unravel_fn"]) register_pytree_node( _MinimizeState, lambda state: ((state.flat_params, ), (state.unravel_fn, )), lambda data, xs: _MinimizeState(xs[0], data[0]), ) def _minimize_wrapper(): def init_fn(params): flat_params, unravel_fn = ravel_pytree(params) return _MinimizeState(flat_params, unravel_fn) def update_fn(i, grad_tree, opt_state): # we don't use update_fn in Minimize, so let it do nothing return opt_state def get_params(opt_state): flat_params, unravel_fn = opt_state
- **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics). - **accept_prob** - Acceptance probability of the proposal. Note that ``z`` does not correspond to the proposal if it is rejected. - **mean_accept_prob** - Mean acceptance probability until current iteration during warmup adaptation or sampling (for diagnostics). - **step_size** - Step size to be used by the integrator in the next iteration. This is adapted during warmup. - **inverse_mass_matrix** - The inverse mass matrix to be be used for the next iteration. This is adapted during warmup. - **rng** - random number generator seed used for the iteration. """ register_pytree_node( HMCState, lambda xs: (tuple(xs), None), lambda _, xs: HMCState(*xs) ) HMCState.update = HMCState._replace def _get_num_steps(step_size, trajectory_length): num_steps = np.clip(trajectory_length / step_size, a_min=1) # NB: casting to np.int64 does not take effect (returns np.int32 instead) # if jax_enable_x64 is False return num_steps.astype(np.int64) def _sample_momentum(unpack_fn, mass_matrix_sqrt, rng):
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) tree_util.register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten)
def register_pytree(cls): if not getattr(cls, '_registered', False): register_pytree_node(cls, lambda xs: (tuple(xs), None), lambda _, xs: cls(*xs)) cls._registered = True
def __init__(self, x, y, z): self.x = x self.y = y self.z = z def __eq__(self, other): return self.x == other.x and self.y == other.y and self.z == other.z def __hash__(self): return hash((self.x, self.y, self.z)) def __repr__(self): return "AnObject({},{},{})".format(self.x, self.y, self.z) tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z), lambda z, xy: AnObject(xy[0], xy[1], z)) @tree_util.register_pytree_node_class class Special: def __init__(self, x, y): self.x = x self.y = y def __repr__(self): return "Special(x={}, y={})".format(self.x, self.y) def tree_flatten(self): return ((self.x, self.y), None) @classmethod def tree_unflatten(cls, aux_data, children):
def make_wrapper_type(cls): """Creates a flattenable Distribution type.""" clsid = (cls.__module__, cls.__name__) if clsid not in _registry: class _WrapperType(cls): """Oryx distribution wrapper type.""" def __init__(self, *args, **kwargs): self._args = args self._kwargs = kwargs self._instance = object.__new__(cls) cls.__init__(self._instance, *self._args, **self._kwargs) def __getattr__(self, key): if key not in ('_args', '_kwargs', '_type_spec', '_instance'): return getattr(self._instance, key) return object.__getattribute__(self, key) @property def _type_spec(self): kwargs = dict(self._kwargs) param_specs = {} try: event_ndims = self._params_event_ndims() except NotImplementedError: event_ndims = {} for k in event_ndims: if k in kwargs and kwargs[k] is not None: elem = kwargs.pop(k) if type(elem) == object: # pylint: disable=unidiomatic-typecheck param_specs[k] = object elif tf.is_tensor(elem): param_specs[k] = (elem.shape, elem.dtype) else: param_specs[k] = type(elem) for k, v in list(kwargs.items()): if isinstance(v, tfd.Distribution): param_specs[k] = kwargs.pop(k) return _JaxDistributionTypeSpec(clsid, param_specs, kwargs) def __str__(self): return repr(self) def __repr__(self): return '{}()'.format(self.__class__.__name__) _WrapperType.__name__ = cls.__name__ + 'Wrapper' def to_tree(obj): type_spec = obj._type_spec # pylint: disable=protected-access components = type_spec._to_components(obj) # pylint: disable=protected-access keys, values = list(zip(*sorted(components.items()))) return values, (keys, type_spec) def from_tree(info, xs): keys, type_spec = info components = dict(list(zip(keys, xs))) return type_spec._from_components(components) # pylint: disable=protected-access tree_util.register_pytree_node(_WrapperType, to_tree, from_tree) _registry[clsid] = _WrapperType return _registry[clsid]
def make_wrapper_type(cls): """Creates new Bijector type that can be flattened/unflattened and is lazy.""" clsid = (cls.__module__, cls.__name__) def bijector_bind(bijector, x, **kwargs): return core.call_bind( bijector_p, direction=kwargs['direction'])(_bijector)(bijector, x, **kwargs) def _bijector(bij, x, **kwargs): direction = kwargs.pop('direction', 'forward') if direction == 'forward': return cls.forward(bij, x, **kwargs) elif direction == 'inverse': return cls.inverse(bij, x, **kwargs) else: raise ValueError( 'Bijector direction must be "forward" or "inverse".') if clsid not in _registry: class _WrapperType(cls): """Oryx bijector wrapper type.""" def __init__(self, *args, **kwargs): self.use_primitive = kwargs.pop('use_primitive', True) self._args = args self._kwargs = kwargs def forward(self, x, **kwargs): if self.use_primitive: return bijector_bind(self, x, direction='forward', **kwargs) return cls.forward(self, x, **kwargs) def inverse(self, x, **kwargs): if self.use_primitive: return bijector_bind(self, x, direction='inverse', **kwargs) return cls.inverse(self, x, **kwargs) def _get_instance(self): obj = object.__new__(cls) cls.__init__(obj, *self._args, **self._kwargs) return obj def __getattr__(self, key): if key not in ('_args', '_kwargs', 'parameters', '_type_spec'): return getattr(self._get_instance(), key) return object.__getattribute__(self, key) @property def parameters(self): return self._get_instance().parameters @property def _type_spec(self): kwargs = dict(self._kwargs) param_specs = {} event_ndims = {} for k in event_ndims: if k in kwargs and kwargs[k] is not None: elem = kwargs.pop(k) if type(elem) == object: # pylint: disable=unidiomatic-typecheck param_specs[k] = object elif tf.is_tensor(elem): param_specs[k] = (elem.shape, elem.dtype) else: param_specs[k] = type(elem) for k, v in list(kwargs.items()): if isinstance(v, tfb.Bijector): param_specs[k] = kwargs.pop(k) return _JaxBijectorTypeSpec(clsid, param_specs, kwargs) def __str__(self): return repr(self) def __repr__(self): return '{}()'.format(self.__class__.__name__) _WrapperType.__name__ = cls.__name__ + 'Wrapper' def to_tree(obj): type_spec = obj._type_spec # pylint: disable=protected-access components = type_spec._to_components(obj) # pylint: disable=protected-access keys, values = list(zip(*sorted(components.items()))) return values, (keys, type_spec) def from_tree(info, xs): keys, type_spec = info components = dict(list(zip(keys, xs))) return type_spec._from_components(components) # pylint: disable=protected-access tree_util.register_pytree_node(_WrapperType, to_tree, from_tree) _registry[clsid] = _WrapperType return _registry[clsid]
"""Dataclass for storing parameters of a Linear RNN.""" A: jnp.array # Input weights. pylint: disable=invalid-name W: jnp.array # Recurrent weights. pylint: disable=invalid-name b: jnp.array # Bias. def apply(self, x, h) -> jnp.array: """Linear RNN Update.""" return self.A @ x + self.W @ h + self.b def flatten(self): return (self.A, self.W, self.b) # Register the LinearRNN dataclass as a pytree, so that we can directly # pass it to other jax functions (optimizers, flatten, etc.) register_pytree_node(LinearRNN, lambda node: (node.flatten(), None), lambda _, children: LinearRNN(*children)) class RNNCell: """Base class for all RNN Cells. An RNNCell must implement the following methods: init(PRNGKey, input_shape) -> output_shape, rnn_params apply(params, inputs, state) -> next_state """ def __init__(self, num_units, h_init=zeros): """Initializes an RNNCell.""" self.num_units = num_units self.h_init = h_init # Compute RNN Jacobians.
# TODO(mattjj): don't just ignore custom jvp rules? del primitive, jvp # Unused. return fun.call_wrapped(*tracers) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees): del primitive, fwd, bwd, out_trees # Unused. return fun.call_wrapped(*tracers) class ZeroTerm(object): pass zero_term = ZeroTerm() register_pytree_node(ZeroTerm, lambda z: ((), None), lambda _, xs: zero_term) class ZeroSeries(object): pass zero_series = ZeroSeries() register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series) call_param_updaters = {} def _xla_call_param_updater(params, num_inputs): donated_invars = params['donated_invars']
self.size += 1 def as_array(self): return tree_util.tree_multimap(lambda *args: np.array(list(args)), *self.data) def flatten(self): return (self.data, ), (self.size, self.idx) @classmethod def unflatten(cls, data, xs): size, idx = data return HarvestList(xs[0], size=size, idx=idx) tree_util.register_pytree_node(HarvestList, HarvestList.flatten, HarvestList.unflatten) class HarvestTrace(jax_core.Trace): """A HarvestTrace manages HarvestTracer objects. Since HarvestTracers are just wrappers around known values, HarvestTrace just passes these values through primitives, except in the case of `sow` and `nest`, which are specially handled by the active HarvestContext. Default primitive logic lives in `process_primitive`, with special logic for `sow` in `handle_sow`. """ def pure(self, val): return HarvestTracer(self, val)
map = safe_map zip = safe_zip # The implementation here basically works by flattening pytrees. There are two # levels of pytrees to think about: the pytree of params, which we can think of # as defining an "outer pytree", and a pytree produced by applying init_fun to # each leaf of the params pytree, which we can think of as the "inner pytrees". # Since pytrees can be flattened, that structure is isomorphic to a list of # lists (with no further nesting). pack = tuple OptimizerState = namedtuple("OptimizerState", ["packed_state", "tree_def", "subtree_defs"]) register_pytree_node( OptimizerState, lambda xs: ((xs.packed_state, ), (xs.tree_def, xs.subtree_defs)), lambda data, xs: OptimizerState(xs[0], data[0], data[1])) def optimizer(opt_maker): """Decorator to make an optimizer defined for arrays generalize to containers. With this decorator, you can write init, update, and get_params functions that each operate only on single arrays, and convert them to corresponding functions that operate on pytrees of parameters. See the optimizers defined in optimizers.py for examples. Args: opt_maker: a function that returns an ``(init_fun, update_fun, get_params)`` triple of functions that might only work with ndarrays, as per
maxval=high)) def truncated_normal(self, lower, upper, size, scale=1.): rands = jr.truncated_normal(self.split_key(), lower=lower, upper=upper, shape=_size2shape(size)) return JaxArray(rands * scale) def bernoulli(self, p, size=None): return JaxArray( jr.bernoulli(self.split_key(), p=p, shape=_size2shape(size))) register_pytree_node( RandomState, lambda t: ((t.value, ), None), lambda aux_data, flat_contents: RandomState(*flat_contents)) DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32)) def seed(seed=None): global DEFAULT DEFAULT.seed(np.random.randint(0, 100000) if seed is None else seed) def rand(*dn): return JaxArray( jr.uniform(DEFAULT.split_key(), shape=dn, minval=0., maxval=1.))