def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]: args, in_tree = tree_flatten((args, kwargs)) def f_(*args): args, kwargs = tree_unflatten(in_tree, args) return f(*args, **kwargs) jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])( *args).jaxpr res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)] res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)} results = [] for x in res_lits: results.append((x.aval, 'from a literal')) for v in jaxpr.constvars: if v in res_vars: results.append((v.aval, 'from a constant')) assert len(jaxpr.invars) == len(args) for i, v in enumerate(jaxpr.invars): if v in res_vars: src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}' results.append((v.aval, src)) for eqn in jaxpr.eqns: src = source_info_util.summarize(eqn.source_info) for v in eqn.outvars: if v in res_vars: if eqn.primitive is name_p: results.append( (v.aval, f"named '{eqn.params['name']}' from {src}")) else: results.append((v.aval, f'from {src}')) assert len(results) == len(jaxpr.outvars) return results
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) fun, nz_arg_cts = nonzero_outputs(fun) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) # Preserve axis for primal arguments, skip tangents (represented as undefined primals). in_axes, out_axes = params['in_axes'], params['out_axes'] new_in_axes = (*[axis for axis, x in zip(in_axes, args) if not is_undefined_primal(x)], *[axis for axis, x in zip(out_axes, ct) if type(x) is not Zero]) # The interim strategy we use below (until avals-with-names) only works # when all outputs are mapped. assert all(out_axis is not None for out_axis in out_axes), out_axes # NOTE: This assumes that the output cotangents being zero is a deterministic # function of which input cotangents were zero. @as_hashable_function(closure=(in_axes, tuple(type(c) is Zero for c in ct))) def out_axes_thunk(): return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts()) if nz) new_params = dict(params, name=wrap_name(params['name'], 'transpose'), in_axes=new_in_axes, out_axes_thunk=out_axes_thunk) del new_params['out_axes'] update_params = call_transpose_param_updaters.get(primitive) if update_params: new_params = update_params(new_params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) out_flat = primitive.bind(fun, *all_args, **new_params) arg_cts = tree_unflatten(out_tree(), out_flat) # The freevars are being fanned out (not mapped). During transpose the # dual of fan-out is fan-in-sum. We apply it to the unmapped invars. assert len(in_axes) == len(arg_cts) def unmap_zero(zero, in_axis): return (zero if in_axis is None else Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval))) arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else arg_ct if in_axis is not None else arg_ct.sum(0) for arg_ct, in_axis in zip(arg_cts, in_axes)) return tuple(arg_cts)
def map_ildj(prim, incells, outcells, **params): """InverseAndILDJ rule for the map primitives.""" f, incells = incells[0], incells[1:] def slice_aval(aval): return abstract_arrays.ShapedArray(aval.shape[1:], aval.dtype, aval.weak_type) def add_slice(cell, old_cell): new_slices = [ NDSlice(ndslice.value, ndslice.ildj, Slice(0, old_cell.aval.shape[0]), *ndslice.slices) for ndslice in cell.slices ] return InverseAndILDJ(old_cell.aval, new_slices) def remove_slice(cell): new_slices = [ NDSlice(ndslice.value, ndslice.ildj, *ndslice.slices[1:]) for ndslice in cell.slices ] aval = slice_aval(cell.aval) return InverseAndILDJ(aval, new_slices) mapped_incells = safe_map(remove_slice, incells) mapped_outcells = safe_map(remove_slice, outcells) flat_vals, in_tree = tree_util.tree_flatten((mapped_incells, mapped_outcells)) f, aux = flat_propagate(f, in_tree) # Assume all invars as mapped new_mapped_invars = (True,) * len(flat_vals) new_params = dict(params, mapped_invars=new_mapped_invars) subenv_vals = prim.bind(f, *flat_vals, **new_params) subenv_tree = aux() subenv = tree_util.tree_unflatten(subenv_tree, subenv_vals) new_incells = [subenv.read(var) for var in subenv.jaxpr.invars] new_outcells = [subenv.read(var) for var in subenv.jaxpr.outvars] new_incells = [add_slice(v, old_v) for old_v, v in safe_zip(incells, new_incells)] new_outcells = [add_slice(v, old_v) for old_v, v in safe_zip(outcells, new_outcells)] return new_incells, new_outcells, subenv
def haiku_module(name, nn_module, *, input_shape=None, **kwargs): """ Declare a :mod:`~haiku` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. :param str name: name of the module to be registered. :param haiku.Module nn_module: a `haiku` Module which has .init and .apply methods :param tuple input_shape: shape of the input taken by the neural network. :param kwargs: optional keyword arguments to initialize flax neural network as an alternative to `input_shape` :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import haiku # noqa: F401 except ImportError as e: raise ImportError( "Looking like you want to use haiku to declare " "nn modules. This is an experimental feature. " "You need to install `haiku` to be able to use this feature. " "It can be installed with `pip install dm-haiku`.") from e module_key = name + '$params' nn_params = numpyro.param(module_key) if nn_params is None: args = (jnp.ones(input_shape), ) if input_shape is not None else () # feed in dummy data to init params rng_key = numpyro.prng_key() nn_params = nn_module.init(rng_key, *args, **kwargs) # haiku init returns an immutable dict nn_params = haiku.data_structures.to_mutable_dict(nn_params) # we cast it to a mutable one to be able to set priors for parameters # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) return partial(nn_module.apply, nn_params, None)
def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) nonzero_tangents, tangent_tree_def = tree_flatten(tangents) nz_tangents = [type(t) is not Zero for t in tangents] if 'name' in params and not config.jax_experimental_name_stack: params = dict(params, name=wrap_name(params['name'], 'jvp')) f_jvp = jvp_subtrace(f, self.main) f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp) if isinstance(call_primitive, core.MapPrimitive): in_axes = params['in_axes'] tangent_in_axes = [ ax for ax, nz in zip(in_axes, nz_tangents) if nz ] out_axes_thunk = params['out_axes_thunk'] # The new thunk depends deterministically on the old thunk and the wrapped function. # Any caching already has to include the wrapped function as part of the key, so we # only use the previous thunk for equality checks. # NOTE: This assumes that the output tangents being zero is a deterministic # function of which input tangents were zero. @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk)) def new_out_axes_thunk(): out_axes = out_axes_thunk() return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz)) params = dict(params, in_axes=(*in_axes, *tangent_in_axes), out_axes_thunk=new_out_axes_thunk) f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def) update_params = call_param_updaters.get(call_primitive) new_params = update_params(params, nz_tangents) if update_params else params f_jvp = _update_annotation(f_jvp, f.in_type, nz_tangents) result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params) primal_out, tangent_out = tree_unflatten(out_tree_def(), result) return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
def _scan( f: Callable[[_Carry, _Input], Tuple[_Carry, _Output]], init: _Carry, xs: Iterable[_Input], ) -> Tuple[_Carry, _Output]: """Implements an unrolled version of scan. Based on `jax.lax.scan` and has a similar API. TODO(schsam): We introduce this function because lax.scan currently has a higher peak memory usage than the unrolled version. We will aim to swap this out for lax.scan when issue #1273 and related have been resolved. """ carry = init ys = [] flat_xs, tree_def = tree_flatten(xs) for flat_x in zip(*flat_xs): x = tree_unflatten(tree_def, flat_x) carry, y = f(carry, x) ys += [y] return carry, tree_multimap(lambda *y: np.stack(y), *ys)
def start_tracing_body(self): """Called upon starting the tracing of the loop body.""" # TODO: This is the first part of partial_eval.trace_to_subjaxpr. Share. self.trace = self.scope.start_subtrace() # The entire state is carried. self.carried_state_names = sorted(self.scope._mutable_state.keys()) for key in self.carried_state_names: init_val = self.scope._mutable_state[key] flat_init_vals, init_tree = tree_util.tree_flatten(init_val) flat_init_avals = safe_map(_BodyTracer.abstractify, flat_init_vals) flat_init_pvals = safe_map(pe.PartialVal.unknown, flat_init_avals) flat_init_vars = safe_map(self.trace.new_arg, flat_init_pvals) self.carried_state_vars[key] = flat_init_vars # Set the scope._mutable_state to new tracing variables. self.scope._mutable_state[key] = init_tree.unflatten(flat_init_vars) self.scope._mutable_state_aval[key] = init_tree.unflatten(flat_init_avals) # Make a copy of the initial state by unflattening the flat_init_vals self.carried_state_initial[key] = init_tree.unflatten(flat_init_vals) index_var_aval = _BodyTracer.abstractify(0) index_var_pval = pe.PartialVal.unknown(index_var_aval) self._index_var = self.trace.new_arg(index_var_pval)
def transposed(*args): in_primals, out_cts = tree_unflatten(treedef, args) in_pvals = [ pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else pe.PartialVal.known(x) for x in in_primals ] primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False) dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars] in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args, out_cts) in_cts_ = iter(in_cts) in_cts = [ next(in_cts_) if ad.is_undefined_primal(x) else ad_util.Zero(x.aval) for x in in_primals ] assert next(in_cts_, None) is None in_cts, cell.treedef = tree_flatten(in_cts) return in_cts
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: if not self.jvp: msg = "No JVP defined for custom_jvp function {} using defjvp." raise AttributeError(msg.format(self.__name__)) args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple(_stop_gradient(x) if i in nondiff_argnums else x for i, x in enumerate(args)) diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args) static_args = [args[i] for i in self.nondiff_argnums] jvp = _add_args(lu.wrap_init(self.jvp), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args jvp = lu.wrap_init(self.jvp) args_flat, in_tree = tree_flatten(dyn_args) flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree) flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree) out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat) _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2) return tree_unflatten(out_tree, out_flat)
def linearize(traceable, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) if not has_aux: jvpfun = jvp(traceable) else: jvpfun, aux = jvp(traceable, has_aux=True) in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace()) for p in primals)) _, in_tree = tree_flatten(((primals, primals), {})) jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals) assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals) _, out_primals_consts = unzip2(out_primals_pvals) jaxpr.invars = jaxpr.invars[len(primals):] jaxpr.outvars = jaxpr.outvars[len(out_primals_pvals):] if not has_aux: return out_primals_consts, out_tangents_pvals, jaxpr, consts else: return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
def psum(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Inputs of boolean dtype are converted to integers before the reduction. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce sum along the axis ``axis_name``. For example, with 4 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [6 6 6 6] >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [ 0. 0.16666667 0.33333334 0.5 ] """ _validate_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) leaves = [lax.convert_element_type(l, onp.int32) if dtypes.dtype(l) == onp.bool_ else l for l in leaves] out_flat = psum_p.bind(*leaves, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat)
def update(updates, state, params=None): inner_state = state.inner_state flat_updates = tree_flatten(updates)[0] isfinite = jnp.all( jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64), 1 + state.notfinite_count) def do_update(_): return inner.update(updates, inner_state, params) def reject_update(_): return (tree_map(jnp.zeros_like, updates), inner_state) updates, new_inner_state = lax.cond( jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors), do_update, reject_update, operand=None) return updates, ApplyIfFiniteState( notfinite_count=notfinite_count, last_finite=isfinite, total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite, inner_state=new_inner_state)
def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None): """ Vectorizing map that maps a function `fn` over `batch_ndims` leading axes of `xs`. This uses jax.vmap over smaller chunks of the batch dimensions to keep memory usage constant. :param callable fn: The function to map over. :param xs: JAX pytree (e.g. an array, a list/tuple/dict of arrays,...) :param int batch_ndims: The number of leading dimensions of `xs` to apply `fn` element-wise over them. :param int chunk_size: Size of each chunk of `xs`. Defaults to the size of batch dimensions. :returns: output of `fn(xs)`. """ flatten_xs = tree_flatten(xs)[0] batch_shape = np.shape(flatten_xs[0])[:batch_ndims] for x in flatten_xs[1:]: assert np.shape(x)[:batch_ndims] == batch_shape # we'll do map(vmap(fn), xs) and make xs.shape = (num_chunks, chunk_size, ...) num_chunks = batch_size = int(np.prod(batch_shape)) prepend_shape = (-1,) if batch_size > 1 else () xs = tree_map(lambda x: jnp.reshape(x, prepend_shape + jnp.shape(x)[batch_ndims:]), xs) # XXX: probably for the default behavior with chunk_size=None, # it is better to catch OOM error and reduce chunk_size by half until OOM disappears. chunk_size = batch_size if chunk_size is None else min(batch_size, chunk_size) if chunk_size > 1: pad = chunk_size - (batch_size % chunk_size) xs = tree_map(lambda x: jnp.pad(x, ((0, pad),) + ((0, 0),) * (np.ndim(x) - 1)), xs) num_chunks = batch_size // chunk_size + int(pad > 0) prepend_shape = (-1,) if num_chunks > 1 else () xs = tree_map(lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]), xs) fn = vmap(fn) ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs) map_ndims = int(num_chunks > 1) + int(chunk_size > 1) ys = tree_map(lambda y: jnp.reshape(y, (-1,) + jnp.shape(y)[map_ndims:])[:batch_size], ys) return tree_map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]), ys)
def _call_tf_impl(*args_jax_flat, args_treedef, func_tf, out_avals, **_): # On GPU we use dlpack to avoid copies of data to the host. def _arg_jax_to_tf(arg_jax): if (isinstance(arg_jax, xla.DeviceArray) and arg_jax.device_buffer.client.platform in _DLPACK_PLATFORMS and arg_jax.dtype in dlpack.SUPPORTED_DTYPES): arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False) return tf.experimental.dlpack.from_dlpack(arg_dlpack) # The following avoids copies to the host on CPU, always for DeviceArray # and even for ndarray if they are sufficiently aligned. # TODO(necula): on TPU this copies to the host! return tf.constant(np.asarray(arg_jax)) args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat)) with jax2tf_internal.inside_call_tf(): # Call in TF eager mode res_tf = func_tf(*args_treedef.unflatten(args_tf_flat)) res_tf_flat, _ = tree_util.tree_flatten(res_tf) # TODO(necula): check the result for tree and aval def _res_tf_to_jax(res_tf: TfVal, out_aval: core.AbstractValue): res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype( res_tf, jax_dtype=out_aval.dtype) if isinstance(res_tf, tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES: res_tf_platform = tf.DeviceSpec.from_string( res_tf.backing_device).device_type res_jax_platform = res_tf_platform.lower() if res_jax_platform in _DLPACK_PLATFORMS: res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf) return jax.dlpack.from_dlpack( res_dlpack, backend=xla_bridge.get_backend(res_jax_platform)) return jnp.asarray(np.asarray(res_tf)) return list(map(_res_tf_to_jax, res_tf_flat, out_avals))
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation if not self.fwd or not self.bwd: msg = "No VJP defined for custom_vjp function {} using defvjp." raise AttributeError(msg.format(self.__name__)) args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: for i in self.nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums = set(self.nondiff_argnums) dyn_argnums = [ i for i in range(len(args)) if i not in nondiff_argnums ] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args, require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args, require_static_args_hashable=False) bwd = _add_args(lu.wrap_init(self.bwd), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd, in_tree) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees) out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees) fst, aux = lu.merge_linear_aux(out_tree, out_trees) out_tree = aux if fst else aux[0] return tree_unflatten(out_tree, out_flat)
def process_higher_order_primitive(self, trace, call_primitive, f, tracers, params, is_map): del is_map name = jax_util.wrap_name(params.pop('name', f.__name__), 'reap') context = trace_util.get_dynamic_context(trace) vals = [t.val for t in tracers] plants = context.plants if 'in_axes' in params: # TODO(b/199459308): figure out if invars are mapped or unmapped params = dict(params, in_axes=(0, ) * len(tree_util.tree_leaves(plants)) + params['in_axes']) if 'donated_invars' in params: params = dict(params) params['donated_invars'] = ( (False, ) * len(tree_util.tree_leaves(plants)) + params['donated_invars']) elif call_primitive is nest_p: plants = plants.get(params['scope'], {}) all_vals, all_tree = tree_util.tree_flatten((plants, vals)) f = plant_eval(f, trace, self.settings, all_tree) out_vals = call_primitive.bind(f, *all_vals, name=name, **params) return jax_util.safe_map(trace.pure, out_vals)
def _sparsify_jaxpr(spenv, jaxpr, *argspecs): # TODO(jakevdp): currently this approach discards all information about # shared data & indices when generating the sparsified jaxpr. The # current approach produces valid sparsified while loops, but they # don't work in corner cases (see associated TODO in sparsify_test.py) out_tree = None @lu.wrap_init def wrapped(*args_flat): nonlocal out_tree args = tree_unflatten(in_tree, args_flat) argspecs = arrays_to_argspecs(spenv, args) result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv) out = argspecs_to_arrays(spenv, result) out_flat, out_tree = tree_flatten(out) return out_flat args = argspecs_to_arrays(spenv, argspecs) args_flat, in_tree = tree_flatten(args) avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat] sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat) sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts) return sp_jaxpr, out_tree
def jax_shape_for_update(update, shape_like): r"""Reshapes grads from array to tree like structure if neccesary for update Args: grads: a 1d jax/numpy array shape_like: this as in instance having the same type and shape of the desired conversion. Returns: A possibly non-flat structure of jax arrays containing a copy of data compatible with the given shape if jax_available and a copy of update otherwise """ shf, tree = tree_flatten(shape_like) updatelist = [] k = 0 for s in shf: size = s.size updatelist.append(jnp.asarray(update[k:k + size]).reshape(s.shape)) k += size return tree_unflatten(tree, updatelist)
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params): assert not jaxpr.constvars cell = lambda: None @lu.wrap_init def transposed(*args): in_primals, out_cts = tree_unflatten(treedef, args) in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else pe.PartialVal.known(x) for x in in_primals] primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False) dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars] in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args, out_cts) in_cts, cell.treedef = tree_flatten(in_cts_) return in_cts args, treedef = tree_flatten((in_primals, out_cts)) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals) transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_) in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params) return tree_unflatten(cell.treedef, in_cts) # type: ignore
def trainable_params_by_layer(self) -> Dict[str, int]: params, _ = tree_flatten(self) layer_names = get_keys(self.dict()) sizes: Dict[str, int] = {} if len(layer_names) != len(set(layer_names)): counts = Counter(layer_names) incrementer: Counter = Counter() for weights, name in zip(params, layer_names): if counts[name] > 0: counts[name] -= 1 new_name = f"{name}_{incrementer[name]}" incrementer[name] += 1 sizes[new_name] = weights.size else: for weights, name in zip(params, layer_names): sizes[name] = weights.size return sizes
def l2_norm_sq(params_tree, parameterization='ntk', W_std=1., b_std=0.05): """Parameterization dependent reweighted L2 norm of `params_tree` Following Lemma 3 of https://arxiv.org/abs/1806.03335, we reweight weight and bias parameters according to their initialisation variances, which are `parameterization` dependent. Args: params_tree (pytree): Pytree of parameters to take L2 norm of parameterization (str): 'ntk' or 'standard' W_std (float): Weight standard deviation b_std (float): Bias standard deviation Returns: l2_norm (float): Weighted L2 norm of `params_tree` """ leaves, _ = tree_flatten(params_tree) # In NTK parameterisation all parameters have variance 1: easy reweighting if parameterization == 'ntk': return sum(np.vdot(leaf, leaf) for leaf in leaves) # In standard parameterization, need to multiply by `N/W_std^2` for weights # and `1/b_std^2` for biases, where N is width. # NB this only works for stax.Dense MLPs right now. # This is extremely ugly and precarious to the arbitrary choice between the weight # matrix and its transpose. elif parameterization == 'standard': reg_list = [] for leaf in leaves: assert leaf.ndim == 1 or leaf.ndim == 2 if leaf.ndim == 1: reg_list.append(1 / b_std**2) elif leaf.ndim == 2: reg_list.append(leaf.shape[0] / W_std**2) return sum(reg_coef * np.vdot(leaf, leaf) for reg_coef, leaf in zip(reg_list, leaves))
def _calc_replica_ids(global_mesh: pxla.Mesh, mesh_axes: MeshAxes): pspec = _canonicalize_mesh_axes(mesh_axes) mesh_values = list(global_mesh.shape.values()) flattened_pspec, _ = tree_flatten(tuple(pspec)) # Get the location (coordinates) of each device in the device mesh. device_location = np.array( np.unravel_index([d.id for d in global_mesh.devices.flat], mesh_values)) # Find all the axes that were replicated. # If mesh_axes = (('x', 'y'), None, 'z') and ('x', 'y', 'z') were the mesh's # axis, then replicated axes will be None since all axes are being used to # shard the input. replicated_axis = np.isin(list(global_mesh.shape.keys()), flattened_pspec, invert=True) # If all elements in replicated_axis are False then the input is fully sharded # so replica ids should be all 0s. if not any(replicated_axis): return [0] * global_mesh.devices.size else: # Drop all the sharded axes and find the location of coordinates in a linear # array. return np.ravel_multi_index(device_location[replicated_axis], np.array(mesh_values)[replicated_axis])
def custom_transpose_transpose_rule(cts, *args, call, transpose, out_types, res_tree, lin_tree, out_tree): call_in_tree = treedef_tuple((res_tree, lin_tree)) # TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect # to which we are transposing (via `ad.is_undefined_primal`). # Consider passing this information to the custom transpose rule? res_arg, lin_arg = tree_unflatten(call_in_tree, args) del lin_arg assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct for ct in cts ] ct_out = tree_unflatten(out_tree, cts) ct_lin = transpose(res_arg, ct_out) check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin)) ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None), is_leaf=lambda x: x is None) return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
def numpy_unflatten(self, data, shape_like): r"""Attempts a deserialization of the given numpy data. This is typically used to deserialize parameters and gradients. Args: data: a 1d numpy array. shape_like: this as in instance having the same type and shape of the desired conversion. Returns: A possibly non-flat structure of jax arrays containing a copy of data compatible with the given shape. """ shf, tree = tree_flatten(shape_like) datalist = [] k = 0 for s in shf: size = s.size datalist.append( jax.numpy.asarray(data[k:k + size]).reshape(s.shape)) k += size return tree_unflatten(tree, datalist)
def result(*args, **kwargs): expected = function(*args, **kwargs) arguments, _ = j_tree_util.tree_flatten((args, kwargs)) (result,) = j_core.eval_jaxpr(jaxpr, constants, *arguments) assert _are_equal(result, expected) (result,) = numpy_function(*args, **kwargs) assert _are_equal(result, expected) try: (result,) = numba_function(*args, **kwargs) except: if not catch_numba: raise traceback.print_exc() else: assert _are_equal(result, expected) return expected
def default_call_interpreter_rule(primitive: jax_core.CallPrimitive, rules: Rules, state: Value, invals: Sequence[Value], call_jaxpr: jax_core.Jaxpr, **params: Any) -> Tuple[Value, Value]: """Handles simple call primitives like `jax_core.call_p`. When evaluating call primitives, the input `state` needs to be an additional input to the call primitive and it also needs to return an additional output `state`. After flattening the state along with the regular inputs, this handler recursively calls `eval_jaxpr_with_state` on the primitive's `call_jaxpr`. The output state from the recursive call is returned from the call primitive. Args: primitive: A `jax_core.CallPrimitive` such as `jax_core.call_p`. rules: A `dict` that maps JAX primitives to functions that take in `(state, *args)` and return `(output, new_state)`. state: The interpreter `state` value at the time of calling evaluating the call primitive. invals: The input values to the call primitive. call_jaxpr: The `jax_core.Jaxpr` that corresponds to the body of the call primitive. **params: The parameters of the call primitive. Returns: A tuple of the output of the call primitive and its output state. """ # Recursively use the effect handler for the call primitive's JAXpr. fun = lu.wrap_init( functools.partial(eval_jaxpr_with_state, call_jaxpr, rules, [])) state_invals, state_invals_tree = tree_util.tree_flatten((state, *invals)) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, state_invals_tree) ans_state = primitive.bind(flat_fun, *state_invals, **params) return tree_util.tree_unflatten(out_tree(), ans_state)
def update_fn(updates, state, params=None): flat_mask, treedef = tree_flatten(mask(updates) if callable(mask) else mask) # Flatten then filter out updates/params not in the mask: flat_updates = treedef.flatten_up_to(updates) masked_updates = [g for g, m in zip(flat_updates, flat_mask) if m] if params is not None: flat_params = treedef.flatten_up_to(params) masked_params = [p for p, m in zip(flat_params, flat_mask) if m] else: masked_params = None # Compute new updates new_masked_updates, new_inner_state = inner.update( masked_updates, state.inner_state, masked_params) # Incorporate new_masked_updates into flat_updates, then unflatten new_masked_updates = iter(new_masked_updates) for i, m in enumerate(flat_mask): if m: flat_updates[i] = next(new_masked_updates) new_updates = treedef.unflatten(flat_updates) return new_updates, MaskedState(inner_state=new_inner_state)
def sow(value, *, tag, name, mode='strict', key=None): """Marks a value with a name and a tag. Args: value: A JAX value to be tagged and named. tag (str): a string representing the tag of the sown value. name (str): a string representing the name to sow the value with. mode (str): The mode by which to sow the value. There are three options: 1. strict - if another value is sown with the same name and tag in the same context, harvest will throw an error. 2. clobber - if another is value is sown with the same name and tag, it will replace this value 3. append - sown values of the same name and tag are appended to a growing list. Append mode assumes some ordering on the values being sown defined by data-dependence. key: an optional JAX value that will be tied into the sown value. Returns: The original `value` that was passed in. """ if key is not None: value = prim.tie_in(key, value) flat_args, in_tree = tree_util.tree_flatten(value) out_flat = sow_p.bind(*flat_args, name=name, tag=tag, mode=mode, tree=in_tree) return tree_util.tree_unflatten(in_tree, out_flat)
def l2_norm(tree): """Compute the l2 norm of a pytree of arrays. Useful for weight decay.""" leaves, _ = tree_flatten(tree) return np.sqrt(sum(np.vdot(x, x) for x in leaves))
def tree_init(x0_tree): x0_flat, tree = tree_flatten(x0_tree) initial_states = [init(x0) for x0 in x0_flat] states_flat, subtrees = unzip2(map(tree_flatten, initial_states)) packed_state = pack(map(pack, states_flat)) return OptimizerState(packed_state, tree, subtrees)