sp_indices_p = core.Primitive('sp_indices') @sp_indices_p.def_impl def _sp_indices_impl(mat): return mat.indices @sp_indices_p.def_abstract_eval def _sp_indices_abstract_eval(mat): return mat.indices_aval def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices): return [indices] # Note: cannot use lower_fun to define attribute access primitives # because it leads to infinite recursion. xla.register_translation(sp_indices_p, _sp_indices_translation_rule) def _sp_indices_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices): return [data_and_indices[1]] mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering) sp_data_p = core.Primitive('sp_data') @sp_data_p.def_impl def _sp_data_impl(mat): return mat.data @sp_data_p.def_abstract_eval def _sp_data_abstract_eval(mat): return mat.data_aval
def _sharded_call_impl(fun, *args, nparts, in_parts, out_parts_thunk, local_in_parts, local_out_parts_thunk, local_nparts, name): compiled_fun = _sharded_callable(fun, nparts, in_parts, out_parts_thunk, local_in_parts, local_out_parts_thunk, local_nparts, name, *map(xla.abstractify, args)) return compiled_fun(*args) sharded_call_p = core.CallPrimitive("sharded_call") sharded_call = sharded_call_p.bind sharded_call_p.def_impl(_sharded_call_impl) xla.register_translation(sharded_call_p, _sharded_jit_translation_rule) mlir.register_lowering(sharded_call_p, _sharded_jit_lowering) class _UnconstrainedPartitionSingleton: def __str__(self): return "UNCONSTRAINED" # Unconstrained sentinel value for PartitionSpec, representing a dimension for # which the user wants XLA to assign the best partitioning. # TODO(yashkatariya): May rename to AUTO. _UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton() class PartitionSpec(tuple):
batching.axis_primitive_batchers[remat_p] = remat_vmap def checkpoint_name(x, name): return name_p.bind(x, name=name) name_p.def_impl(lambda x, *, name: x) name_p.def_abstract_eval(lambda x, *, name: x) def name_jvp(primals, tangents, *, name): (x, ), (xdot, ) = primals, tangents return name_p.bind(x, name=name), xdot # don't name the tangent value ad.primitive_jvps[name_p] = name_jvp xla.register_translation(name_p, lambda ctx, avals_in, avals_out, x, *, name: [x]) def name_batcher(args, dims, *, name): (x, ), (d, ) = args, dims return name_p.bind(x, name=name), d batching.primitive_batchers[name_p] = name_batcher
def _coo_todense_transpose(ct, data, row, col, *, shape): # Note: we assume that transpose has the same sparsity pattern. # Can we check this? assert ad.is_undefined_primal(data) if ad.is_undefined_primal(row) or ad.is_undefined_primal(col): raise ValueError("Cannot transpose with respect to sparse indices") assert ct.shape == shape assert row.aval.dtype == col.aval.dtype assert ct.dtype == data.aval.dtype return _coo_extract(row, col, ct), row, col ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None) ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose xla.register_translation(coo_todense_p, _coo_todense_translation_rule) if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported): xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule, platform='gpu') #-------------------------------------------------------------------- # coo_fromdense coo_fromdense_p = core.Primitive('coo_fromdense') coo_fromdense_p.multiple_results = True def coo_fromdense(mat, *, nse, index_dtype=jnp.int32): """Create COO-format sparse matrix from a dense matrix.
return xla.xla_destructure( ctx.builder, hip_prng.threefry2x32( ctx.builder, (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))) threefry2x32_p = core.Primitive("threefry2x32") threefry2x32_p.multiple_results = True threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p)) threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval) batching.defbroadcasting(threefry2x32_p) xla.register_translation( threefry2x32_p, xla.lower_fun(partial(_threefry2x32_lowering, use_rolled_loops=False), multiple_results=True, new_style=True)) xla.register_translation(threefry2x32_p, xla.lower_fun(partial(_threefry2x32_lowering, use_rolled_loops=True), multiple_results=True, new_style=True), platform='cpu') if cuda_prng: xla.register_translation(threefry2x32_p, _threefry2x32_gpu_translation_rule, platform='gpu') if hip_prng: xla.register_translation(threefry2x32_p, _threefry2x32_gpu_translation_rule,
_, row, col = bufs return _coo_extract(row, col, ct), row, col else: raise NotImplementedError(f"todense_transpose for {type(obj)}") def _todense_batching_rule(batched_args, batch_dims, *, tree): return jax.vmap(partial(_todense_impl, tree=tree), batch_dims)(*batched_args), 0 ad.primitive_jvps[todense_p] = _todense_jvp ad.primitive_transposes[todense_p] = _todense_transpose batching.primitive_batchers[todense_p] = _todense_batching_rule xla.register_translation( todense_p, xla.lower_fun(_todense_impl, multiple_results=False, new_style=True)) def empty(shape, dtype=None, index_dtype='int32', sparse_format='bcoo', **kwds): """Create an empty sparse array. Args: shape: sequence of integers giving the array shape. dtype: (optional) dtype of the array. index_dtype: (optional) dtype of the index arrays. format: string specifying the matrix format (e.g. ['bcoo']).
padding = ((0, 0), ) + padding base_dilation = (1, ) + base_dilation window_dilation = (1, ) + window_dilation out = _select_and_gather_add(t, x, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation) return (out, 0) select_and_gather_add_p = lax.standard_primitive( _select_and_gather_add_shape_rule, lax._input_dtype, 'select_and_gather_add') ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp ad.primitive_transposes[select_and_gather_add_p] = \ _select_and_gather_add_transpose batching.primitive_batchers[select_and_gather_add_p] = \ _select_and_gather_add_batching_rule mlir.register_lowering( select_and_gather_add_p, mlir.lower_fun(_select_and_gather_add_using_variadic_reducewindow, multiple_results=False)) # TODO(b/183233858): use variadic reducewindow on GPU, when implemented. xla.register_translation(select_and_gather_add_p, _select_and_gather_add_translation, platform='gpu') mlir.register_lowering(select_and_gather_add_p, mlir.xla_fallback_lowering(select_and_gather_add_p), platform="gpu")
# Use JAX's convention for complex gradients # https://github.com/google/jax/issues/6223#issuecomment-807740707 return lax.conj(out) def fft_transpose_rule(t, operand, fft_type, fft_lengths): if fft_type == xla_client.FftType.RFFT: result = _rfft_transpose(t, fft_lengths) elif fft_type == xla_client.FftType.IRFFT: result = _irfft_transpose(t, fft_lengths) else: result = fft(t, fft_type, fft_lengths) return result, def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return fft(x, fft_type, fft_lengths), 0 fft_p = Primitive('fft') fft_p.def_impl(fft_impl) fft_p.def_abstract_eval(fft_abstract_eval) xla.register_translation(fft_p, _fft_translation_rule) ad.deflinear2(fft_p, fft_transpose_rule) batching.primitive_batchers[fft_p] = fft_batching_rule if pocketfft: xla.register_translation(fft_p, _fft_translation_rule_cpu, platform='cpu')
recall_target, reduction_input_size_override, aggregate_to_topk) if type(tangent) is ad_util.Zero: tangent_out = ad_util.Zero.from_value(val_out) else: arg_shape = arg_out.shape rank = len(arg_shape) if reduction_dimension < 0: reduction_dimension += rank iotas = [ lax.broadcasted_iota(arg_out.dtype, arg_shape, i) for i in range(rank) ] idx = tuple(arg_out if i == reduction_dimension else iotas[i] for i in range(rank)) tangent_out = tangent[idx] return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out)) approx_top_k_p = core.Primitive('approx_top_k') approx_top_k_p.multiple_results = True approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p)) approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval) xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation) xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation, platform='tpu') batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp
assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct for ct, ct_aval in zip(cts, call.out_avals) ] ct_out = tree_unflatten(out_tree, cts) ct_lin = rule(res_arg, ct_out) ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin) check_transpose_rule_trees(rule, lin_tree, ct_lin_tree) return [None] * len(tree_leaves(res_arg)) + ct_lin_flat def custom_transpose_abstract_eval(*in_avals, call, **_): return call.out_avals custom_transpose_p = core.Primitive('custom_transpose_call') custom_transpose_p.multiple_results = True custom_transpose_p.def_impl(custom_transpose_impl) custom_transpose_p.def_abstract_eval(custom_transpose_abstract_eval) ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule xla.register_translation(custom_transpose_p, xla.lower_fun(custom_transpose_impl, new_style=True, multiple_results=True), initial_style=True) mlir.register_lowering( custom_transpose_p, mlir.lower_fun(custom_transpose_impl, multiple_results=True))
consts=consts, window_dimensions=window_dimensions, window_strides=window_strides, padding=padding, base_dilation=base_dilation, window_dilation=window_dilation) return outs, (0, ) * num_operands reduce_window_p = core.Primitive('reduce_window') reduce_window_p.multiple_results = True reduce_window_p.def_impl(partial(xla.apply_primitive, reduce_window_p)) reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule) batching.primitive_batchers[ reduce_window_p] = _generic_reduce_window_batch_rule xla.register_translation(reduce_window_p, _reduce_window_translation_rule) def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operands, init_values = util.split_list(args, [len(args) // 2]) _, init_value_avals = util.split_list(avals_in, [len(operands)]) scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] rw = mhlo.ReduceWindowOp( map(mlir.aval_to_ir_type, avals_out), operands, init_values, mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), mlir.dense_int_elements(base_dilation), mlir.dense_int_elements(window_dilation), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
if result_shape.is_tuple(): res_op = xops.GetTupleElement(res_tf, idx) if res_aval.dtype != res_shape.numpy_dtype(): res_op = xops.ConvertElementType( res_op, new_element_type=xla.dtype_to_primitive_type( res_aval.dtype)) return res_op return [ post_process_result(i, res_aval, res_shape) for i, (res_aval, res_shape) in enumerate(zip(result_avals, result_shapes)) ] return code_gen, result_avals xla.register_translation(call_tf_p, _call_tf_translation_rule) TfVal = jax2tf_internal.TfVal def _jax2tf_call_tf(*args: TfVal, callable_flat_tf: Callable, **_) -> TfVal: with jax2tf_internal.inside_call_tf(): res_tf_flat = callable_flat_tf(*args) return res_tf_flat jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf