def handle_call_primitive(self, call_primitive, f, tracers, params, is_map): """Handler for call_primitives, like jit or layer_call. When an UnzipTracer hits a call primitive, there is either a variable inside of the call primitive, in which case the input function needs to be unzipped into two, or there are no variables in the function, so the call_primitive is recorded in the trace as-is. We use `unzip_eval_wrapper`, which returns whether or not an unzip was successful or not. If it was successful, we record two new Jaxprs into the trace (one for init, one for apply). Otherwise, we just record the Jaxpr corresponding to the function call. Args: call_primitive: a call primitive like xla_call f: a jax.linear_util wrapped function to be called tracers: inputs to the function params: parameters of the primitives is_map: whether or not the primitive is a map primitive (e.g. xla_pmap) Returns: A list of output tracers """ name = params.get('name', f.__name__) settings = trace_util.get_dynamic_context(self).settings tracers = safe_map(self.instantiate_const_abstracted, tracers) if call_primitive in current_custom_rules(): return current_custom_rules()[call_primitive](self, f, *tracers, **params) if call_primitive in pe.call_partial_eval_rules: raise NotImplementedError in_pvals = [t.pval for t in tracers] if is_map: unknown = pe.PartialVal.unknown in_pvals = [ pval if pval.is_known() or in_axis is None else unknown( mapped_aval(params['axis_size'], in_axis, pval[0])) for pval, in_axis in zip(in_pvals, params['in_axes']) ] pvs, in_consts = jax_util.unzip2(t.pval for t in tracers) keys = tuple(t.is_key() for t in tracers) new_settings = UnzipSettings(settings.tag, call_primitive in block_registry) fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings) out_flat = call_primitive.bind(fun, *in_consts, **params) success, results = aux() if not success: out_pvs, out_keys, jaxpr, env = results out_pv_consts, consts = jax_util.split_list( out_flat, [len(out_pvs)]) out_tracers = self._bound_output_tracers(call_primitive, params, jaxpr, consts, env, tracers, out_pvs, out_pv_consts, out_keys, name, is_map) return out_tracers init_name = jax_util.wrap_name(name, 'init') apply_name = jax_util.wrap_name(name, 'apply') init_pvs, num_init_consts, apply_pvs = results[0] init_jaxpr, apply_jaxpr = results[1] init_env, apply_env = results[2] variable_names, variable_tree, apply_keys = results[3] key_tracers = [t for t in tracers if t.is_key()] abstract_tracers = [t for t in tracers if not t.is_key()] all_init_consts, all_apply_consts = jax_util.split_list( out_flat, [len(init_pvs) + num_init_consts]) init_pv_consts, init_consts = jax_util.split_list( all_init_consts, [len(init_pvs)]) apply_pv_consts, apply_consts = jax_util.split_list( all_apply_consts, [len(apply_pvs)]) variable_tracers = self._bound_output_tracers(call_primitive, params, init_jaxpr, init_consts, init_env, key_tracers, init_pvs, init_pv_consts, [True] * len(init_pvs), init_name, is_map) unflat_variables = tree_util.tree_unflatten(variable_tree, variable_tracers) if call_primitive is harvest.nest_p: variable_dict = harvest.sow(dict( safe_zip(variable_names, unflat_variables)), tag=settings.tag, name=params['scope'], mode='strict') unflat_variables = tuple(variable_dict[name] for name in variable_names) else: unflat_variables = [ harvest.sow( # pylint: disable=g-complex-comprehension unflat_variable, tag=settings.tag, name=name, mode='strict') for unflat_variable, name in safe_zip( unflat_variables, variable_names) ] variable_tracers = tree_util.tree_leaves(unflat_variables) out_tracers = self._bound_output_tracers( call_primitive, params, apply_jaxpr, apply_consts, apply_env, variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts, apply_keys, apply_name, is_map) return out_tracers
def params(self): return tree_util.tree_unflatten(self.params_tree, self.params_flat)
def unflatten_tree(tree, xs): """Inverse operation of `flatten_tree`.""" return tree_util.tree_unflatten(tree_util.tree_structure(tree), xs)
def structured(self): if self._structured is None: self._structured = tree_util.tree_unflatten(self.treedef, self.leaves) return self._structured
def block(arrays): leaves, tree = tree_flatten(arrays, is_leaf=lambda a: isinstance(a, JaxArray)) leaves = [(l.value if isinstance(l, JaxArray) else l) for l in leaves] arrays = tree_unflatten(tree, leaves) return JaxArray(jnp.block(arrays))
def flat_propagate(tree, *flat_invals): invals, outvals = tree_util.tree_unflatten(tree, flat_invals) subenv = yield ((invals, outvals), {}) subenv_vals, subenv_tree = tree_util.tree_flatten(subenv) yield subenv_vals, subenv_tree
def testRoundtripWithFlattenUpTo(self, inputs): _, tree = tree_util.tree_flatten(inputs) xs = tree.flatten_up_to(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs)
def doit(): f = lu.wrap_init(fun) args_flat, in_tree = tree_util.tree_flatten((args, {})) flat_fun, out_tree = flatten_fun(f, in_tree) out_flat = _interpret_fun(flat_fun, args_flat) return tree_util.tree_unflatten(out_tree(), out_flat)
def scan(f, init, xs): """Scan a function over leading array axes while carrying along state. The type signature in brief is .. code-block:: haskell scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis. When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given by this Python implementation:: def scan(f, init, xs): carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys) Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop carry and the second is a slice of ``xs`` along its leading axis, and that ``f`` returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output. init: an initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. Returns: A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. """ num_carry = len(tree_flatten(init)[0]) in_flat, in_tree = tree_flatten((init, xs)) init_flat, xs_flat = in_flat[:num_carry], in_flat[num_carry:] try: length, = {x.shape[0] for x in xs_flat} except AttributeError: msg = "scan got value with no leading axis to scan over: {}." raise ValueError( msg.format([x for x in xs_flat if not hasattr(x, 'shape')])) except ValueError: msg = "scan got values with different leading axis sizes: {}." raise ValueError(msg.format([x.shape[0] for x in xs_flat])) carry_avals = tuple(_map(_abstractify, init_flat)) x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat] x_dtypes = [x.dtype for x in xs_flat] x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes)) jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals) carry_avals_out, y_avals = split_list(jaxpr.out_avals, [num_carry]) if tuple(carry_avals_out) != carry_avals: msg = "scan carry output type must match carry input type, got {} and {}." raise TypeError(msg.format(tuple(carry_avals_out), carry_avals)) out = scan_p.bind(*itertools.chain(consts, in_flat), forward=True, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, linear=(False, ) * (len(consts) + len(in_flat))) return tree_unflatten(out_tree, out)
def checked_fun(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) (err, code, out_flat), msgs = checkify_flat(f, errors, *args_flat) out = tree_unflatten(out_tree(), out_flat) return Error(err, code, msgs), out
def unravel_pytree(arr): return tree_unflatten(treedef, unravel_list(arr))
def _wrapped(*args): args_flat, in_tree = tree_flatten(args, is_leaf=_is_bcoo) wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) out = sparsify_fun(wrapped_fun, args_flat) return tree_unflatten(out_tree(), out)
def jet(fun, primals, series): r"""Taylor-mode higher-order automatic differentiation. Args: fun: Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. primals: The primal values at which the Taylor approximation of ``fun`` should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters of ``fun``. series: Higher order Taylor-series-coefficients. Together, `primals` and `series` make up a truncated Taylor polynomial. Should be either a tuple or a list of tuples or lists, and its length dictates the degree of the truncated Taylor polynomial. Returns: A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``, and together, ``primals_out`` and ``series_out`` are a truncated Taylor polynomial of :math:`f(h(\cdot))`. The ``primals_out`` value has the same Python tree structure as ``primals``, and the ``series_out`` value the same Python tree structure as ``series``. For example: >>> import jax >>> import jax.numpy as np Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`, and the first few Taylor coefficients :math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`. Let :math:`f(y) = \sin(y)`. >>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5 >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args) :func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)` according to Faà di Bruno's formula: >>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),)) >>> print(f0, f(h0)) 0.12467473 0.12467473 >>> print(f1, df(h0) * h1) 0.7441479 0.74414825 >>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2) 2.9064622 2.9064634 """ try: order, = set(map(len, series)) except ValueError: msg = "jet terms have inconsistent lengths for different arguments" raise ValueError(msg) from None # TODO(mattjj): consider supporting pytree inputs for i, (x, terms) in enumerate(zip(primals, series)): treedef = tree_structure(x) if not treedef_is_leaf(treedef): raise ValueError(f"primal value at position {i} is not an array") for j, t in enumerate(terms): treedef = tree_structure(t) if not treedef_is_leaf(treedef): raise ValueError(f"term {j} for argument {i} is not an array") @lu.transformation_with_aux def flatten_fun_output(*args): ans = yield args, {} yield tree_flatten(ans) f, out_tree = flatten_fun_output(lu.wrap_init(fun)) out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
def sow_unzip(in_tracers, out_tracers, name=None, tree=None, tag=None, **_): del tag if tree: in_tracers = tree_util.tree_unflatten(tree, in_tracers) out_tracers = tree_util.tree_unflatten(tree, out_tracers) return name, in_tracers, out_tracers
def _optimization_barrier(arg): flat_args, treedef = tree_flatten(arg) return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
def tree_get_params(opt_state): states_flat, tree, subtrees = opt_state states = map(tree_unflatten, subtrees, states_flat) params = map(get_params, states) return tree_unflatten(tree, params)
def wrapped(*args, **params): spenv = SparseEnv() argspecs = arrays_to_argspecs(spenv, args) argspecs_out, out_tree = f_raw(spenv, *argspecs, **params) out = argspecs_to_arrays(spenv, argspecs_out) return tree_unflatten(out_tree, out)
def _cau_jaxpr(self, *args, **kwargs): flat_args = tree_util.tree_leaves(args) out_flat = eval_jaxpr_with_kwargs(self._jaxpr.jaxpr, self._jaxpr.literals, *flat_args, **kwargs) return tree_util.tree_unflatten(self._out_tree, out_flat)
def tree_unflatten(cls, meta, data): if not tree_util.all_leaves(data): data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data)) return FlatCache(None, leaves=data, treedef=meta)
def todo(x): primals, series = tree_unflatten(treedef, x) trace = JetTrace(main, core.cur_sublevel()) return map(partial(JetTracer, trace), primals, series)
def testRoundtripIsLeaf(self, tree): xs, treedef = tree_util.tree_flatten( tree, is_leaf=lambda t: isinstance(t, tuple)) recon_tree = tree_util.tree_unflatten(treedef, xs) self.assertEqual(recon_tree, tree)
def traceable(in_tree_def, *primals_and_series): primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series) primals_out, series_out = yield (primals_in, series_in), {} out_flat, out_tree_def = tree_flatten((primals_out, series_out)) yield out_flat, out_tree_def
def func(flat_args): unflat_args = tree_util.tree_unflatten(in_tree, flat_args) return fn(unflat_args)
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Collects and injects values into/from the scan body.""" context = trace_util.get_dynamic_context(trace) settings = context.settings values = [t.val for t in tracers] consts, init, xs = jax_util.split_list(values, [num_consts, num_carry]) active_sows = _find_sows(jaxpr, settings.tag) active_modes = [params['mode'] for params in active_sows] if any(mode == 'strict' for mode in active_modes): raise ValueError('Cannot use strict mode in a scan.') active_names = [params['name'] for params in active_sows] sow_modes = {name: mode for name, mode in zip(active_names, active_modes)} carry_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'clobber' } xs_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'append' } def jaxpr_fun(carry, x): body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals, *(consts + carry + x)) carry, y = jax_util.split_list(body_out, [num_carry]) return carry, y harvest_body = harvest(jaxpr_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, mode=settings.mode) def body(carry, x): x_plants, x_vals = x (carry, y), reaps = harvest_body({ **carry_plants, **x_plants }, carry, x_vals) return carry, (y, reaps) xs_flat = tree_util.tree_leaves((xs_plants, xs)) x_avals = [] for x in xs_flat: x_aval = jax_core.get_aval(x) if x_aval is jax_core.abstract_unit: x_avals.append(x_aval) else: x_shape, x_dtype = masking.padded_shape_as_value( x.shape[1:]), x.dtype x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype)) x_avals = tuple(x_avals) init_avals = tuple( abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init) in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs))) body_jaxpr, new_consts, out_tree = ( lax_control_flow._initial_style_jaxpr( # pylint: disable=protected-access body, in_tree, init_avals + x_avals)) new_values = list(new_consts) + in_flat num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts) remaining_linear = linear[num_consts:] new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] + (False, ) * num_xs_plants + remaining_linear[len(init):]) assert len(new_linear) == len(new_values) outs = lax.scan_p.bind(*new_values, length=length, reverse=reverse, jaxpr=body_jaxpr, num_consts=len(new_consts), num_carry=num_carry, linear=new_linear, unroll=unroll) outs = safe_map(trace.pure, outs) carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs) out_reaps = {} for k, val in reaps.items(): mode = sow_modes.get(k, 'strict') if mode == 'append': val = tree_util.tree_map(np.concatenate, val) elif mode == 'clobber': val = tree_util.tree_map(lambda x: x[-1], val) out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict') (carry, ys) = prim.tie_in(out_reaps, (carry, ys)) return carry + ys
def lexsort(keys, axis=-1): leaves, tree = tree_flatten(keys, is_leaf=lambda x: isinstance(x, JaxArray)) leaves = [_remove_jaxarray(l) for l in leaves] keys = tree_unflatten(tree, leaves) return JaxArray(jnp.lexsort(keys, axis))
def par_from_array(arr): value_flat = jnp.split(arr, section_sizes) value_flat = [x.reshape(s) for x, s in zip(value_flat, section_shapes)] params = tree_unflatten(value_tree, value_flat) return params
def testRoundtrip(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs)
def tree_get_params(opt_state): packed_state, tree, subtrees = opt_state states = map(tree_unflatten, subtrees, packed_state) params = map(get_params, states) return tree_unflatten(tree, params)
def _psum_transpose_rule(cts, axis_name, axis_index_groups): nonzero_out_cts, treedef = tree_util.tree_flatten(cts) nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, nonzero_in_cts)
def linear_call(fun: Callable, fun_transpose: Callable, residual_args, linear_args): """Call a linear function, with a custom implementation for its transpose. The type signatures of ``fun`` and ``fun_transpose`` are: .. code-block:: haskell fun :: r -> a -o b fun_transpose :: r -> b -o a where the ``-o`` arrow indicates a linear function, ``r`` is the residual input type and ``a`` is the linear input type. The functions ``fun`` and ``fun_transpose`` are coupled as transposes of one another. Specifically, the transpose of a ``linear_call`` primitive is another ``linear_call`` to ``fun_transpose``, with ``fun`` as its custom transposition. For example: >>> def f(r, x): ... return x / r >>> def t(r, t): ... return t / r >>> def div_add(x, denom): ... return x + linear_call(f, t, denom, x) >>> def transpose(f, x_example): ... def transposed(y): ... x, = jax.linear_transpose(f, x_example)(y) ... return x ... return transposed >>> div_add(9., 3.) DeviceArray(12., dtype=float32, weak_type=True) >>> transpose(partial(div_add, denom=3.), 1.)(18.) # custom DeviceArray(24., dtype=float32, weak_type=True) >>> transpose(lambda x: x + x / 3., 1.)(18.) # reference DeviceArray(24., dtype=float32, weak_type=True) The above definition of ``f`` illustrates the purpose of a residual argument: division is linear in one of its inputs (the dividend ``x``) but not the other (the divisor ``r``). As another example: >>> def custom_id(x): ... def f(_, x): return x ... def t(_, t): return 7. ... return linear_call(f, t, (), x) >>> custom_id(1.) 1.0 >>> transpose(custom_id, 1.)(1.) 7.0 >>> transpose(transpose(custom_id, 1.), 1.)(1.) 1.0 >>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.) 7.0 Args: fun: a Python callable specifying a linear function. It should take two arguments: one of "residual" inputs (type ``r``), i.e. inputs in which the function is not necessarly linear, and one of "linear" inputs (type ``a``). It should return output whose components are linear in the linear input (type ``b``). fun_transpose: a Python callable specifying a structurally linear function that is the transpose of ``fun`` with respect to its linear inputs. Its first argument is the same residual inputs (``r``) as ``fun``. Its second argument is of type ``b``. Finally, its output is of type ``a`` and each of its component are linear in its second argument (the ``b`` inputs). residual_args: Argument in which ``fun`` and ``fun_transpose`` are not necessarily linear. Not involved in transposition. linear_args: Argument in which ``fun`` and ``fun_transpose`` are linear and with respect to which the two are transposes. Returns: The call result, i.e. ``fun(residual_args, linear_args)``. """ operands_res, res_tree = tree_flatten(residual_args) operands_lin, lin_tree = tree_flatten(linear_args) f_in_tree = treedef_tuple((res_tree, lin_tree)) f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree) res_avals = map(abstractify, operands_res) lin_avals = map(abstractify, operands_lin) f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals)) f_jaxpr = _close_jaxpr(f_jaxpr) out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals) t_in_tree = treedef_tuple((res_tree, out_tree())) t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree) t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals)) t_jaxpr = _close_jaxpr(t_jaxpr) if t_out_tree() != lin_tree: raise TypeError( 'transpose output pytree structure must match that of linear inputs, ' f'got output structure {t_out_tree()} ' f'and input structure {lin_tree}.') out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin, callee=f_jaxpr, transpose=t_jaxpr, num_callee_consts=len(f_consts), num_transpose_consts=len(t_consts), num_res=len(operands_res)) return tree_unflatten(out_tree(), out)