def test_unsupported_op(self): p = core.Primitive('unsupported_op') p.def_abstract_eval(lambda x: x) p.def_impl(lambda x: x) def thunk(): mask(p.bind, ['n'], 'n')([np.arange(3)], {'n': 2}) message = "Masking rule for unsupported_op not implemented yet." self.assertRaisesWithLiteralMatch(NotImplementedError, message, thunk)
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None, weak_type_rule=None, named_shape_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule named_shape_rule = named_shape_rule or standard_named_shape_rule prim = core.Primitive(name) prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule)) xla.register_translation( prim, translation_rule or partial(_standard_translate, name)) return prim
def test_shapecheck_unsupported_op(self): p = jc.Primitive('unsupported_op') p.def_impl(lambda x: x) def thunk(): @shapecheck(['n'], 'n') def identity(x): return p.bind(x) self.assertRaisesRegex( NotImplementedError, "Shape rule for unsupported_op not implemented yet.", thunk)
def setup_spec(spec, grad=True): xla_client.register_cpu_custom_call_target( spec["xla_name"], getattr(xla_ops, spec["name"])()) prim = core.Primitive("celerite2_" + spec["name"]) prim.multiple_results = True spec["base_primitive"] = prim prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(partial(_abstract_eval, spec)) xla.backend_specific_translations["cpu"][prim] = partial( _translation_rule, spec) if not grad: return prim xla_client.register_cpu_custom_call_target( spec["xla_name"] + b"_rev", getattr(xla_ops, spec["name"] + "_rev")()) jvp = core.Primitive("celerite2_" + spec["name"] + "_jvp") jvp.multiple_results = True rev = core.Primitive("celerite2_" + spec["name"] + "_rev") rev.multiple_results = True spec["jvp_primitive"] = jvp spec["rev_primitive"] = rev ad.primitive_jvps[prim] = partial(_jvp, spec) jvp.def_abstract_eval(partial(_jvp_abstract_eval, spec)) ad.primitive_transposes[jvp] = partial(_jvp_transpose, spec) rev.def_impl(partial(xla.apply_primitive, rev)) rev.def_abstract_eval(partial(_rev_abstract_eval, spec)) xla.backend_specific_translations["cpu"][rev] = partial( _rev_translation_rule, spec) return prim
def _build_op(name, spec): xla_client.register_cpu_custom_call_target( name, getattr(xla_ops, spec["name"])()) prim = core.Primitive(f"celerite2_{spec['name']}") prim.multiple_results = True prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(partial(_abstract_eval, spec)) xla.backend_specific_translations["cpu"][prim] = partial( _translation_rule, name, spec) if not spec["has_rev"]: return prim xla_client.register_cpu_custom_call_target( name + b"_rev", getattr(xla_ops, f"{spec['name']}_rev")()) jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp") jvp_prim.multiple_results = True rev_prim = core.Primitive(f"celerite2_{spec['name']}_rev") rev_prim.multiple_results = True # Setup a dummy JVP rule ad.primitive_jvps[prim] = partial(_jvp, prim, jvp_prim, spec) jvp_prim.def_abstract_eval(partial(_jvp_abstract_eval, spec)) ad.primitive_transposes[jvp_prim] = partial(_jvp_transpose, rev_prim, spec) # Handle reverse pass using custom op rev_prim.def_impl(partial(xla.apply_primitive, rev_prim)) rev_prim.def_abstract_eval(partial(_rev_abstract_eval, spec)) xla.backend_specific_translations["cpu"][rev_prim] = partial( _rev_translation_rule, name + b"_rev", spec) return prim
from jax import core from jax.interpreters import xla from jax.lib import cusparse from jax.lib import xla_bridge from jax.lib import xla_client import jax.numpy as jnp import numpy as np xb = xla_bridge xops = xla_client.ops #-------------------------------------------------------------------- # csr_todense csr_todense_p = core.Primitive('csr_todense') def csr_todense(data, indices, indptr, *, shape): """Convert CSR-format sparse matrix to a dense matrix. Args: data : array of shape ``(nnz,)``. indices : array of shape ``(nnz,)`` indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` shape : length-2 tuple representing the matrix shape Returns: mat : array with specified shape and dtype matching ``data`` """ return csr_todense_p.bind(data, indices, indptr, shape=shape)
local_nparts=local_nparts, name=flat_fun.__name__) return tree_unflatten(out_tree(), out) return wrapped def _sharding_constraint_impl(x, partitions): # TODO(skye): can we also prevent this from being called in other # non-sharded_jit contexts? (e.g. pmap, control flow) raise NotImplementedError( "with_sharding_constraint() should only be called inside sharded_jit()" ) sharding_constraint_p = core.Primitive("sharding_constraint") sharding_constraint_p.def_impl(_sharding_constraint_impl) sharding_constraint_p.def_abstract_eval(lambda x, partitions: x) ad.deflinear2( sharding_constraint_p, lambda ct, _, partitions: (with_sharding_constraint(ct, partitions), )) def _sharding_constraint_lowering(ctx, x_node, partitions): return [ mlir.wrap_with_sharding_op(x_node, xla.sharding_to_proto(partitions)) ] mlir.register_lowering(sharding_constraint_p, _sharding_constraint_lowering)
from oryx.core.interpreters import log_prob as lp from oryx.core.ppl import transformations seed = random.PRNGKey conditional = transformations.conditional graph_replace = transformations.graph_replace joint_log_prob = transformations.joint_log_prob joint_sample = transformations.joint_sample log_prob = transformations.log_prob intervene = transformations.intervene random_variable = transformations.random_variable # Define a random normal primitive so we can register it with the `log_prob` # transformation. random_normal_p = jax_core.Primitive('random_normal') def random_normal(key): return random_normal_p.bind(key) def random_normal_impl(rng): return random.normal(rng) def random_normal_abstract(_): return abstract_arrays.ShapedArray((), np.float32) def random_normal_log_prob(_, x):
for operand in contract_fake_ops: idx = tuple(i for i, fake_op in enumerate(fake_ops) if operand is fake_op) assert len(idx) == 1 contract_operands.append(operands[idx[0]]) return contract_operands, contractions lax_numpy._polymorphic_einsum_contract_path_handlers[ _DimPolynomial] = _einsum_contract_path # A JAX primitive with no array arguments but with a dimension parameter # that is a DimPoly. The value of the primitive is the value of the dimension. # This primitive is used only in the context of jax2tf, so it does not need # XLA translation rules. dim_as_value_p = core.Primitive("dim_as_value") def _dim_as_value_abstract(dim: DimSize) -> core.AbstractValue: return core.ShapedArray((), np.int32) dim_as_value_p.def_abstract_eval(_dim_as_value_abstract) def _dim_as_value(dim: DimSize): return dim_as_value_p.bind(dim=dim) class PolyShape(tuple): """Tuple of polymorphic dimension specifications.
from jax import lax from jax import linear_util as lu from jax.config import config from jax.experimental import maps from jax.experimental import pjit from jax.interpreters import mlir from jax._src import lib as jaxlib from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util from jax._src.lax import control_flow as lcf import numpy as np config.parse_flags_with_absl() effect_p = core.Primitive('effect') effect_p.multiple_results = True @effect_p.def_effectful_abstract_eval def _(*, effect): return [], {effect} mlir.lowerable_effects.add('foo') mlir.lowerable_effects.add('foo2') mlir.lowerable_effects.add('bar') mlir.lowerable_effects.add('while') mlir.lowerable_effects.add('while1') mlir.lowerable_effects.add('while2') core.ordered_effects.add('foo')
def _custom_ivjp(fun, ivjp, args): in_avals = [raise_to_shaped(get_aval(x)) for x in args] fun_jaxpr = custom_derivatives._initial_style_jaxpr(fun, in_avals) try: ivjp_jaxpr = custom_derivatives._initial_style_jaxpr( ivjp, in_avals + fun_jaxpr.out_avals * 2) except RecursionError: raise ValueError("Calls to {} from its custom ivjp aren't supported yet".format(fun.__name__)) return custom_ivjp_p.bind(*args, fun_jaxpr=fun_jaxpr, ivjp_jaxpr=ivjp_jaxpr) def _custom_ivjp_impl(*args, fun_jaxpr, **_): return core.jaxpr_as_fun(fun_jaxpr)(*args) custom_ivjp_p = core.Primitive('custom_ivjp') custom_ivjp_p.multiple_results = True custom_ivjp_p.def_impl(_custom_ivjp_impl) custom_ivjp_p.def_abstract_eval(lambda *_, fun_jaxpr, **__: fun_jaxpr.out_avals) def _custom_ivjp_jvp(primals, tangents, *, fun_jaxpr, ivjp_jaxpr): primals_out = custom_ivjp_p.bind(*primals, fun_jaxpr=fun_jaxpr, ivjp_jaxpr=ivjp_jaxpr) fun = core.jaxpr_as_fun(fun_jaxpr) # FIXME: This might compute the primals multiple times, but we only need to do # this trick while linearizing. It should be possible to do it through # a custom partial eval rule. _, tangents_out = ad.jvp(lu.wrap_init(fun)).call_wrapped(primals, tangents) return primals_out, tangents_out ad.primitive_jvps[custom_ivjp_p] = _custom_ivjp_jvp
_partition_knowns) from ..core import raise_to_shaped, get_aval, Literal, Jaxpr from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs from ..api_util import flatten_fun_nokwargs from ..tree_util import tree_flatten, tree_unflatten from ..util import safe_map, safe_zip, unzip2, split_list, cache from .. import source_info_util map = safe_map zip = safe_zip ################################################################################ # Reverse call primitive ################################################################################ invertible_call_p = core.Primitive('invertible_call') invertible_call_p.call_primitive = True invertible_call = partial(core.call_bind, invertible_call_p) invertible_call_p.def_custom_bind(invertible_call) invertible_call_p.def_impl(core.call_impl) invertible_call_p.multiple_results = True def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params): uks = [not t.pval.is_known() for t in out_tracers] out_tracers_known, out_tracers_unknown = _partition_knowns(out_tracers, uks) # Add dummy arguments representing the outputs to the jaxpr. Those should # remain unused if the expression is evaluated, but they make it well-formed. out_known_avals = [raise_to_shaped(t.pval.get_aval()) for t in out_tracers_known] out_consts = [trace.instantiate_const(t) for t in out_tracers_known] new_jaxpr = _append_invars(params['call_jaxpr'], tuple(out_known_avals))
c.GetShape(k2).dimensions(), c.GetShape(x1).dimensions(), c.GetShape(x2).dimensions()) rank = len(shape) def _broadcast(x): ndims = c.GetShape(x).rank() return xla_client.ops.BroadcastInDim(x, shape, tuple(range(rank - ndims, rank))) return cuda_prng.threefry2x32(xla_bridge.computation_builder_shim(c), (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2))) threefry2x32_p = core.Primitive("threefry2x32") threefry2x32_p.multiple_results = True threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p)) threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval) batching.defbroadcasting(threefry2x32_p) xla.translations[threefry2x32_p] = xla.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=False)) xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=True)) if cuda_prng: xla.backend_specific_translations['gpu'][threefry2x32_p] = \ _threefry2x32_gpu_translation_rule @jit def threefry_2x32(keypair, count):
@custom_transforms def cumsum(x): return np.cumsum(x, axis=-1) defjvp(cumsum, lambda g, ans, x: np.cumsum(g, axis=-1)) # XXX work around the issue: batching rule for 'reduce_window' not implemented # when using @custom_transforms decorator def _cumprod_impl(x): return np.cumprod(x, axis=-1) cumprod_p = core.Primitive('cumprod') cumprod_p.def_impl(_cumprod_impl) cumprod_p.def_abstract_eval( partial(partial_eval.abstract_eval_fun, _cumprod_impl)) xla.translations[cumprod_p] = partial(xla.lower_fun, _cumprod_impl) # XXX this implementation does not address the case x=0, hence the result in that case will be nan # Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely ad.defjvp2(cumprod_p, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans) batching.defvectorized(cumprod_p) def cumprod(x): return cumprod_p.bind(x) def promote_shapes(*args, shape=()):
new_invars, new_outvars, jaxpr.jaxpr.eqns) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) def _perm(primal_counts, tangent_counts, lst): n = sum(primal_counts) primals, tangents = lst[:n], lst[n:] primal_groups = split_list(primals, primal_counts[:-1]) tangent_groups = split_list(tangents, tangent_counts[:-1]) return _interleave(primal_groups, tangent_groups) def _interleave(xs, ys): assert len(xs) == len(ys) return [e for pair in zip(xs, ys) for l in pair for e in l] custom_lin_p: core.Primitive = core.Primitive('custom_lin') custom_lin_p.def_abstract_eval(lambda *_, out_avals, **__: out_avals) custom_lin_p.multiple_results = True def _raise_custom_vjp_error_on_jvp(*_, **__): raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp " "function.") custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp) def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals): res, _ = split_list(invals, [num_res]) cts_out = map(instantiate_zeros_aval, out_avals, cts_out) cts_in = bwd.call_wrapped(*res, *cts_out) return [None] * num_res + list(cts_in) primitive_transposes[custom_lin_p] = _custom_lin_transpose
cts = [ zeros_like_aval(a) if type(ct) is Zero else ct for ct, a in zip(cts, cts_avals) ] cts_out = linear_call_p.bind(*t_consts, *f_consts, *operands_res, *cts, callee=transpose, transpose=callee, num_callee_consts=len(t_consts), num_transpose_consts=len(f_consts), num_res=len(operands_res)) return [None ] * (num_callee_consts + num_transpose_consts + num_res) + cts_out def _linear_call_abstract_eval(*args, **kwargs): return map(core.raise_to_shaped, kwargs['callee'].out_avals) linear_call_p = core.Primitive('linear_call') linear_call_p.multiple_results = True linear_call_p.def_impl(_linear_call_impl) linear_call_p.def_abstract_eval(_linear_call_abstract_eval) ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule xla.initial_style_translations[linear_call_p] = xla.lower_fun_initial_style( _linear_call_impl)
return tuple(x) def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2): shape = lax.broadcast_shapes( c.GetShape(k1).dimensions(), c.GetShape(k2).dimensions(), c.GetShape(x1).dimensions(), c.GetShape(x2).dimensions()) rank = len(shape) def _broadcast(x): ndims = c.GetShape(x).rank() return c.BroadcastInDim(x, shape, tuple(range(rank - ndims, rank))) return cuda_prng.threefry2x32( c, (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2))) threefry2x32_p = core.Primitive("threefry2x32") threefry2x32_p.multiple_results = True threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p)) threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval) batching.defbroadcasting(threefry2x32_p) xla.translations[threefry2x32_p] = xla.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=False), instantiate=True) xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=True), instantiate=True) if cuda_prng: xla.backend_specific_translations['gpu'][threefry2x32_p] = \ _threefry2x32_gpu_translation_rule @jit def threefry_2x32(keypair, count): """Apply the Threefry 2x32 hash.
# buffers from different XLA backends are passed through the host. backend = xb.get_device_backend(device) moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device) return device_array.make_device_array(x.aval, device, moved_buf) def _device_put_impl(x, device: Optional[Device] = None): if device_array.type_is_device_array(x): return _copy_device_array_to_device(x, device) try: a = xla.abstractify(x) except TypeError as err: raise TypeError( f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err return aval_to_result_handler(device, a)(*device_put(x, device)) device_put_p = core.Primitive('device_put') device_put_p.def_impl(_device_put_impl) device_put_p.def_abstract_eval(lambda x, device=None: x) xla.translations[device_put_p] = lambda c, x, device=None: x ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent]) masking.defvectorized(device_put_p) batching.defvectorized(device_put_p) def _device_put_lowering(ctx, x, *, device): return [x] mlir.register_lowering(device_put_p, _device_put_lowering)
from oryx.core import trace_util __all__ = [ 'HarvestTrace', 'HarvestTracer', 'call_and_reap', 'harvest', 'nest', 'plant', 'reap', 'sow', ] Value = Any sow_p = jax_core.Primitive('sow') sow_p.multiple_results = True @sow_p.def_impl def _sow_impl(*args, **_): return args @sow_p.def_abstract_eval def _sow_abstract_eval(*avals, **_): return avals @functools.partial(ad.deflinear, sow_p) def _sow_transpose(cts_in, *_, **__):
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params): del params # other params ignored because we're just executing the primal fun return core.jaxpr_as_fun(fun_jaxpr)(*args) def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params): del args, params return fun_jaxpr.out_avals custom_jvp_call_jaxpr_p = core.Primitive('custom_jvp_call_jaxpr') custom_jvp_call_jaxpr_p.multiple_results = True custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl) custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval) CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p def _custom_jvp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], num_consts: int): _, args = split_list(primals, [num_consts]) consts_dot, args_dot = split_list(tangents, [num_consts]) if any(type(t) is not Zero for t in consts_dot):
def standard_pmap_primitive(name): prim = core.Primitive(name) prim.def_impl(partial(pxla.apply_parallel_primitive, prim)) prim.def_abstract_eval(lambda x, *args, **params: x) return prim
# https://github.com/google/jax/issues/1142 # courtesy of mattjj def mybar_impl(w): A, _ = pymbar.BAR(w[0], w[1]) return A def mybar_vjp(g, w): return g * tmbar.dG_dw(w) def mybar(x): return mybar_p.bind(x) mybar_p = core.Primitive('mybar') mybar_p.def_impl(mybar_impl) ad.defvjp(mybar_p, mybar_vjp) def BAR_leg(insertion_du_dls, deletion_du_dls, lambda_schedule): insertion_W = math_utils.trapz(insertion_du_dls, lambda_schedule) deletion_W = math_utils.trapz(deletion_du_dls, lambda_schedule) return mybar(jnp.stack([insertion_W, deletion_W])) def BAR_loss( complex_insertion_du_dls, # [C, N] complex_deletion_du_dls, # [C, N] solvent_insertion_du_dls, # [C, N]
else: val_out, arg_out = approx_min_k(operand, k, reduction_dimension, recall_target, reduction_input_size_override, aggregate_to_topk) if type(tangent) is ad_util.Zero: tangent_out = ad_util.Zero.from_value(val_out) else: arg_shape = arg_out.shape rank = len(arg_shape) if reduction_dimension < 0: reduction_dimension += rank iotas = [ lax.broadcasted_iota(arg_out.dtype, arg_shape, i) for i in range(rank) ] idx = tuple( arg_out if i == reduction_dimension else iotas[i] for i in range(rank)) tangent_out = tangent[idx] return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out)) approx_top_k_p = core.Primitive('approx_top_k') approx_top_k_p.multiple_results = True approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p)) approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval) xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation) xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation, platform='tpu') batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp
axis_name=axis_name, axis_env=axis_env, axis_index_groups=axis_index_groups, platform=platform) dtype = c.get_shape(val).numpy_dtype() if dtypes.issubdtype(dtype, np.complexfloating): return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val))) else: return psum(val) return xops.Tuple(c, list(map(_translate, args))) def _psum_transpose_rule(cts, axis_name, axis_index_groups): nonzero_out_cts, treedef = tree_util.tree_flatten(cts) nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, nonzero_in_cts) psum_p = core.Primitive('psum') psum_p.multiple_results = True psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) pxla.soft_pmap_rules[psum_p] = \ partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = _psum_translation_rule ad.deflinear(psum_p, _psum_transpose_rule) pxla.multi_host_supported_collectives.add(psum_p) batching.split_axis_rules[psum_p] = partial(_split_axis_comm_assoc, psum_p) batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p) batching.collective_rules[psum_p] = \ partial(_batched_reduction_collective, psum_p, lambda v, d: v.sum(d), lambda v, axis_size: axis_size * v)
fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) out_tree_dest = None out = prim.bind(flat_fun, *flat_args, num_args=len(flat_args), name=f.__name__, in_tree=in_tree, out_tree=lambda: out_tree_dest, **params) out_tree_dest = out_tree() return tree_util.tree_unflatten(out_tree_dest, out) return wrapped return bind tie_all_p = jax_core.Primitive('tie_all') tie_all_p.multiple_results = True tie_all_p.def_impl(lambda *args: args) tie_all_p.def_abstract_eval(lambda *args: safe_map( # pylint: disable=g-long-lambda abstract_arrays.raise_to_shaped, args)) xla.translations[tie_all_p] = lambda c, *args: xc.ops.Tuple(c, args) def _tie_all_batch_rule(batched_args, batch_dims): return batched_args, batch_dims def _tie_all_transpose(cts_in, *args, **params): del args, params return cts_in ad.deflinear(tie_all_p, _tie_all_transpose)
for data dependency, for implementing the "result" feature, and for the current token. * tapped_args_treedef_: the treedef of the tapped positional arguments. * tap_func_: the actual (Python) function to invoke with the tapped positional arguments (unflatted according to tapped_args_treedef_) and the parameters that were passed to the id_tap function. * transforms: a tuple of the transformations that have been applied. Each element of the tuple is itself a tuple with the first element the name of the transform. The remaining elements depend on the transform. For example, for `batch`, the parameters are the dimensions that have been batched, and for `mask` the logical shapes. These are unpacked by _ConsumerCallable before passing to the user function. * the remaining parameters are from the user's invocation of the id_tap API function and are passed to the tap function. """ id_tap_p = core.Primitive("id_tap") id_tap_p.multiple_results = True xla.outfeed_primitives.add(id_tap_p) def _add_transform(params: Dict, name: str, *transform_params) -> Dict: """Adds the `transform` to the params["transforms"]. Uses a tuple representation internally, will be unpacked before the callback by _ConsumerCallable. """ new_transform = (name, *transform_params) return dict( params, transforms=(params.get("transforms", ()) + (new_transform,)))
tangents = map(ad.instantiate_zeros, tangents) jvp_call, _ = ad.jvp_jaxpr(call, [True] * len(primals), True) jvp_in_tree = treedef_tuple((in_tree, in_tree)) jvp_out_tree = treedef_tuple((out_tree, out_tree)) outs = custom_vmap_p.bind(*primals, *tangents, call=jvp_call, rule=jvp_of_rule_rule, in_tree=jvp_in_tree, out_tree=jvp_out_tree) assert len(outs) % 2 == 0, len(outs) out_primals, out_tangents = util.split_list(outs, [len(outs) // 2]) return out_primals, out_tangents custom_vmap_p = core.Primitive('custom_vmap_call') custom_vmap_p.multiple_results = True custom_vmap_p.def_impl(custom_vmap_impl) custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval) batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp xla.register_initial_style_primitive(custom_vmap_p) mlir.register_lowering(custom_vmap_p, mlir.lower_fun(custom_vmap_impl, multiple_results=True)) # -- custom vmap applications def tree_split(mask, tree): lhs = tree_map(lambda l, x: x if l else None, mask, tree) rhs = tree_map(lambda l, x: None if l else x, mask, tree)
def sparse_array_constant_handler(c, val, canonicalize_dtypes): return (xb.constant(val.data, canonicalize_dtypes), xb.constant(val.indices, canonicalize_dtypes)) core.pytype_aval_mappings[SparseArray] = lambda x: x.aval core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x xla.device_put_handlers[SparseArray] = sparse_array_device_put_handler xla.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler xb.register_constant_handler(SparseArray, sparse_array_constant_handler) sp_indices_p = core.Primitive('sp_indices') @sp_indices_p.def_impl def _sp_indices_impl(mat): return mat.indices @sp_indices_p.def_abstract_eval def _sp_indices_abstract_eval(mat): return mat.indices_aval def _sp_indices_translation_rule(c, data, indices): return indices
return False def checkpoint_dots(prim, *_, **__) -> bool: # Matrix multiplies are expensive, so let's save them (and nothing else). return prim in {jax._src.lax.lax.dot_general_p, jax._src.lax.convolution.conv_general_dilated_p} def dot_with_no_batch_dims(prim, *_, **params) -> bool: # This is a useful heuristic for transformers. if prim is jax._src.lax.lax.dot_general_p: (_, _), (lhs_b, rhs_b) = params['dimension_numbers'] if not lhs_b and not rhs_b: return True return False name_p = core.Primitive('name') def save_any_names_but_these(*names_not_to_save): # Save named values, excluding the names given. names_not_to_save = frozenset(names_not_to_save) def policy(prim, *_, **params): if prim is name_p: return params['name'] not in names_not_to_save return False # only allow saving named values return policy def save_only_these_names(*names_which_can_be_saved): # Save named values, only among the names given. names_which_can_be_saved = set(names_which_can_be_saved) def policy(prim, *_, **params): if prim is name_p: