Example #1
0
sp_indices_p = core.Primitive('sp_indices')

@sp_indices_p.def_impl
def _sp_indices_impl(mat):
  return mat.indices

@sp_indices_p.def_abstract_eval
def _sp_indices_abstract_eval(mat):
  return mat.indices_aval

def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices):
  return [indices]

# Note: cannot use lower_fun to define attribute access primitives
# because it leads to infinite recursion.
xla.register_translation(sp_indices_p, _sp_indices_translation_rule)

def _sp_indices_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices):
  return [data_and_indices[1]]

mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering)

sp_data_p = core.Primitive('sp_data')

@sp_data_p.def_impl
def _sp_data_impl(mat):
  return mat.data

@sp_data_p.def_abstract_eval
def _sp_data_abstract_eval(mat):
  return mat.data_aval
Example #2
0

def _sharded_call_impl(fun, *args, nparts, in_parts, out_parts_thunk,
                       local_in_parts, local_out_parts_thunk, local_nparts,
                       name):
    compiled_fun = _sharded_callable(fun, nparts, in_parts, out_parts_thunk,
                                     local_in_parts, local_out_parts_thunk,
                                     local_nparts, name,
                                     *map(xla.abstractify, args))
    return compiled_fun(*args)


sharded_call_p = core.CallPrimitive("sharded_call")
sharded_call = sharded_call_p.bind
sharded_call_p.def_impl(_sharded_call_impl)
xla.register_translation(sharded_call_p, _sharded_jit_translation_rule)
mlir.register_lowering(sharded_call_p, _sharded_jit_lowering)


class _UnconstrainedPartitionSingleton:
    def __str__(self):
        return "UNCONSTRAINED"


# Unconstrained sentinel value for PartitionSpec, representing a dimension for
# which the user wants XLA to assign the best partitioning.
# TODO(yashkatariya): May rename to AUTO.
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()


class PartitionSpec(tuple):
Example #3
0

batching.axis_primitive_batchers[remat_p] = remat_vmap


def checkpoint_name(x, name):
    return name_p.bind(x, name=name)


name_p.def_impl(lambda x, *, name: x)
name_p.def_abstract_eval(lambda x, *, name: x)


def name_jvp(primals, tangents, *, name):
    (x, ), (xdot, ) = primals, tangents
    return name_p.bind(x, name=name), xdot  # don't name the tangent value


ad.primitive_jvps[name_p] = name_jvp

xla.register_translation(name_p,
                         lambda ctx, avals_in, avals_out, x, *, name: [x])


def name_batcher(args, dims, *, name):
    (x, ), (d, ) = args, dims
    return name_p.bind(x, name=name), d


batching.primitive_batchers[name_p] = name_batcher
Example #4
0
File: coo.py Project: 0x0is1/jax
def _coo_todense_transpose(ct, data, row, col, *, shape):
    # Note: we assume that transpose has the same sparsity pattern.
    # Can we check this?
    assert ad.is_undefined_primal(data)
    if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
        raise ValueError("Cannot transpose with respect to sparse indices")
    assert ct.shape == shape
    assert row.aval.dtype == col.aval.dtype
    assert ct.dtype == data.aval.dtype
    return _coo_extract(row, col, ct), row, col


ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.register_translation(coo_todense_p, _coo_todense_translation_rule)
if (cusparse and cusparse.is_supported) or (hipsparse
                                            and hipsparse.is_supported):
    xla.register_translation(coo_todense_p,
                             _coo_todense_gpu_translation_rule,
                             platform='gpu')

#--------------------------------------------------------------------
# coo_fromdense

coo_fromdense_p = core.Primitive('coo_fromdense')
coo_fromdense_p.multiple_results = True


def coo_fromdense(mat, *, nse, index_dtype=jnp.int32):
    """Create COO-format sparse matrix from a dense matrix.
Example #5
0
        return xla.xla_destructure(
            ctx.builder,
            hip_prng.threefry2x32(
                ctx.builder,
                (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
                (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))))


threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.register_translation(
    threefry2x32_p,
    xla.lower_fun(partial(_threefry2x32_lowering, use_rolled_loops=False),
                  multiple_results=True,
                  new_style=True))
xla.register_translation(threefry2x32_p,
                         xla.lower_fun(partial(_threefry2x32_lowering,
                                               use_rolled_loops=True),
                                       multiple_results=True,
                                       new_style=True),
                         platform='cpu')
if cuda_prng:
    xla.register_translation(threefry2x32_p,
                             _threefry2x32_gpu_translation_rule,
                             platform='gpu')
if hip_prng:
    xla.register_translation(threefry2x32_p,
                             _threefry2x32_gpu_translation_rule,
Example #6
0
        _, row, col = bufs
        return _coo_extract(row, col, ct), row, col
    else:
        raise NotImplementedError(f"todense_transpose for {type(obj)}")


def _todense_batching_rule(batched_args, batch_dims, *, tree):
    return jax.vmap(partial(_todense_impl, tree=tree),
                    batch_dims)(*batched_args), 0


ad.primitive_jvps[todense_p] = _todense_jvp
ad.primitive_transposes[todense_p] = _todense_transpose
batching.primitive_batchers[todense_p] = _todense_batching_rule
xla.register_translation(
    todense_p,
    xla.lower_fun(_todense_impl, multiple_results=False, new_style=True))


def empty(shape,
          dtype=None,
          index_dtype='int32',
          sparse_format='bcoo',
          **kwds):
    """Create an empty sparse array.

  Args:
    shape: sequence of integers giving the array shape.
    dtype: (optional) dtype of the array.
    index_dtype: (optional) dtype of the index arrays.
    format: string specifying the matrix format (e.g. ['bcoo']).
Example #7
0
    padding = ((0, 0), ) + padding
    base_dilation = (1, ) + base_dilation
    window_dilation = (1, ) + window_dilation
    out = _select_and_gather_add(t, x, select_prim, window_dimensions,
                                 window_strides, padding, base_dilation,
                                 window_dilation)
    return (out, 0)


select_and_gather_add_p = lax.standard_primitive(
    _select_and_gather_add_shape_rule, lax._input_dtype,
    'select_and_gather_add')
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
ad.primitive_transposes[select_and_gather_add_p] = \
  _select_and_gather_add_transpose
batching.primitive_batchers[select_and_gather_add_p] = \
  _select_and_gather_add_batching_rule

mlir.register_lowering(
    select_and_gather_add_p,
    mlir.lower_fun(_select_and_gather_add_using_variadic_reducewindow,
                   multiple_results=False))

# TODO(b/183233858): use variadic reducewindow on GPU, when implemented.
xla.register_translation(select_and_gather_add_p,
                         _select_and_gather_add_translation,
                         platform='gpu')
mlir.register_lowering(select_and_gather_add_p,
                       mlir.xla_fallback_lowering(select_and_gather_add_p),
                       platform="gpu")
Example #8
0
    # Use JAX's convention for complex gradients
    # https://github.com/google/jax/issues/6223#issuecomment-807740707
    return lax.conj(out)


def fft_transpose_rule(t, operand, fft_type, fft_lengths):
    if fft_type == xla_client.FftType.RFFT:
        result = _rfft_transpose(t, fft_lengths)
    elif fft_type == xla_client.FftType.IRFFT:
        result = _irfft_transpose(t, fft_lengths)
    else:
        result = fft(t, fft_type, fft_lengths)
    return result,


def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return fft(x, fft_type, fft_lengths), 0


fft_p = Primitive('fft')
fft_p.def_impl(fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
xla.register_translation(fft_p, _fft_translation_rule)
ad.deflinear2(fft_p, fft_transpose_rule)
batching.primitive_batchers[fft_p] = fft_batching_rule
if pocketfft:
    xla.register_translation(fft_p, _fft_translation_rule_cpu, platform='cpu')
Example #9
0
                                        recall_target,
                                        reduction_input_size_override,
                                        aggregate_to_topk)
    if type(tangent) is ad_util.Zero:
        tangent_out = ad_util.Zero.from_value(val_out)
    else:
        arg_shape = arg_out.shape
        rank = len(arg_shape)
        if reduction_dimension < 0:
            reduction_dimension += rank
        iotas = [
            lax.broadcasted_iota(arg_out.dtype, arg_shape, i)
            for i in range(rank)
        ]
        idx = tuple(arg_out if i == reduction_dimension else iotas[i]
                    for i in range(rank))
        tangent_out = tangent[idx]
    return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out))


approx_top_k_p = core.Primitive('approx_top_k')
approx_top_k_p.multiple_results = True
approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p))
approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
xla.register_translation(approx_top_k_p,
                         _approx_top_k_tpu_translation,
                         platform='tpu')
batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule
ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp
Example #10
0
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct
        for ct, ct_aval in zip(cts, call.out_avals)
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = rule(res_arg, ct_out)
    ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin)
    check_transpose_rule_trees(rule, lin_tree, ct_lin_tree)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat


def custom_transpose_abstract_eval(*in_avals, call, **_):
    return call.out_avals


custom_transpose_p = core.Primitive('custom_transpose_call')
custom_transpose_p.multiple_results = True
custom_transpose_p.def_impl(custom_transpose_impl)
custom_transpose_p.def_abstract_eval(custom_transpose_abstract_eval)
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
xla.register_translation(custom_transpose_p,
                         xla.lower_fun(custom_transpose_impl,
                                       new_style=True,
                                       multiple_results=True),
                         initial_style=True)
mlir.register_lowering(
    custom_transpose_p,
    mlir.lower_fun(custom_transpose_impl, multiple_results=True))
                                consts=consts,
                                window_dimensions=window_dimensions,
                                window_strides=window_strides,
                                padding=padding,
                                base_dilation=base_dilation,
                                window_dilation=window_dilation)
    return outs, (0, ) * num_operands


reduce_window_p = core.Primitive('reduce_window')
reduce_window_p.multiple_results = True
reduce_window_p.def_impl(partial(xla.apply_primitive, reduce_window_p))
reduce_window_p.def_abstract_eval(_reduce_window_abstract_eval_rule)
batching.primitive_batchers[
    reduce_window_p] = _generic_reduce_window_batch_rule
xla.register_translation(reduce_window_p, _reduce_window_translation_rule)


def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr,
                                 consts, window_dimensions, window_strides,
                                 padding, base_dilation, window_dilation):
    operands, init_values = util.split_list(args, [len(args) // 2])
    _, init_value_avals = util.split_list(avals_in, [len(operands)])
    scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
    rw = mhlo.ReduceWindowOp(
        map(mlir.aval_to_ir_type, avals_out), operands, init_values,
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        mlir.dense_int_elements(base_dilation),
        mlir.dense_int_elements(window_dilation),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
Example #12
0
            if result_shape.is_tuple():
                res_op = xops.GetTupleElement(res_tf, idx)
            if res_aval.dtype != res_shape.numpy_dtype():
                res_op = xops.ConvertElementType(
                    res_op,
                    new_element_type=xla.dtype_to_primitive_type(
                        res_aval.dtype))
            return res_op

        return [
            post_process_result(i, res_aval, res_shape)
            for i, (res_aval,
                    res_shape) in enumerate(zip(result_avals, result_shapes))
        ]

    return code_gen, result_avals


xla.register_translation(call_tf_p, _call_tf_translation_rule)

TfVal = jax2tf_internal.TfVal


def _jax2tf_call_tf(*args: TfVal, callable_flat_tf: Callable, **_) -> TfVal:
    with jax2tf_internal.inside_call_tf():
        res_tf_flat = callable_flat_tf(*args)
    return res_tf_flat


jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf