Exemple #1
0
def _lu_tpu_translation_rule(c, operand):
    if hasattr(xops, "LU"):
        lu, pivot, perm = xops.LU(operand)
        return xops.Tuple(c, [lu, pivot, perm])
    else:
        return xla.lower_fun(_lu_python, multiple_results=True)(c, operand)
Exemple #2
0
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct
        for ct, ct_aval in zip(cts, call.out_avals)
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = rule(res_arg, ct_out)
    ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin)
    check_transpose_rule_trees(rule, lin_tree, ct_lin_tree)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat


def custom_transpose_abstract_eval(*in_avals, call, **_):
    return call.out_avals


custom_transpose_p = core.Primitive('custom_transpose_call')
custom_transpose_p.multiple_results = True
custom_transpose_p.def_impl(custom_transpose_impl)
custom_transpose_p.def_abstract_eval(custom_transpose_abstract_eval)
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
xla.register_translation(custom_transpose_p,
                         xla.lower_fun(custom_transpose_impl,
                                       new_style=True,
                                       multiple_results=True),
                         initial_style=True)
mlir.register_lowering(
    custom_transpose_p,
    mlir.lower_fun(custom_transpose_impl, multiple_results=True))
Exemple #3
0
    return xla_client.ops.Tuple(c, [zeros, zeros])
  def _broadcast(x):
    ndims = c.get_shape(x).rank()
    return xla_client.ops.BroadcastInDim(x, shape,
                                         tuple(range(rank - ndims, rank)))
  return cuda_prng.threefry2x32(
      c, (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))


threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.translations_with_avals[threefry2x32_p] = xla.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=False),
    multiple_results=True, with_avals=True)
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=True),
    multiple_results=True)
if cuda_prng:
  xla.backend_specific_translations['gpu'][threefry2x32_p] = \
      _threefry2x32_gpu_translation_rule


@partial(jit, inline=True)
def threefry_2x32(keypair, count):
  """Apply the Threefry 2x32 hash.

  Args:
    keypair: a pair of 32bit unsigned integers used for the key.
Exemple #4
0
    return xops.Tuple(c, [lu, pivot, perm])


def _lu_tpu_translation_rule(c, operand):
    if hasattr(xops, "LU"):
        lu, pivot, perm = xops.LU(operand)
        return xops.Tuple(c, [lu, pivot, perm])
    else:
        return xla.lower_fun(_lu_python, multiple_results=True)(c, operand)


lu_p = Primitive('lu')
lu_p.multiple_results = True
lu_p.def_impl(_lu_impl)
lu_p.def_abstract_eval(_lu_abstract_eval)
xla.translations[lu_p] = xla.lower_fun(_lu_python, multiple_results=True)
ad.primitive_jvps[lu_p] = _lu_jvp_rule
batching.primitive_batchers[lu_p] = _lu_batching_rule

xla.backend_specific_translations['cpu'][lu_p] = partial(
    _lu_cpu_gpu_translation_rule, lapack.getrf)

xla.backend_specific_translations['gpu'][lu_p] = partial(
    _lu_cpu_gpu_translation_rule, cusolver.getrf)

xla.backend_specific_translations['tpu'][lu_p] = _lu_tpu_translation_rule


# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
def _lu_pivots_body_fn(i, permutation_and_swaps):
    permutation, swaps = permutation_and_swaps
Exemple #5
0
identity_p = core.Primitive('identity')


@identity_p.def_impl
def _identity_impl(mat):
    return mat


@identity_p.def_abstract_eval
def _identity_abstract_eval(mat):
    return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)


xla.translations_with_avals[identity_p] = xla.lower_fun(_identity_impl,
                                                        multiple_results=False,
                                                        with_avals=True)


def split(x):
    return split_p.bind(x)


split_p = core.Primitive('split')
split_p.multiple_results = True


@split_p.def_impl
def _split_impl(mat):
    return mat, mat
Exemple #6
0
                              rule=jvp_of_rule_rule,
                              in_tree=jvp_in_tree)
    assert len(outs) % 2 == 0, len(outs)
    out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
    return out_primals, out_tangents


custom_vmap_p = core.Primitive('custom_vmap_call')
custom_vmap_p.multiple_results = True
custom_vmap_p.def_impl(custom_vmap_impl)
custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
xla.register_translation(custom_vmap_p,
                         xla.lower_fun(custom_vmap_impl,
                                       new_style=True,
                                       multiple_results=True),
                         initial_style=True)
mlir.register_lowering(custom_vmap_p,
                       mlir.lower_fun(custom_vmap_impl, multiple_results=True))

# -- custom vmap applications


def tree_split(mask, tree):
    lhs = tree_map(lambda l, x: x if l else None, mask, tree)
    rhs = tree_map(lambda l, x: None if l else x, mask, tree)
    return lhs, rhs


def tree_merge(mask, lhs_tree, rhs_tree):
Exemple #7
0
                               forward=forward,
                               length=length,
                               jaxpr=jaxpr,
                               num_consts=num_consts,
                               num_carry=num_carry,
                               linear=linear)


scan_p = core.Primitive("scan")
scan_p.multiple_results = True
scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(_scan_impl)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.primitive_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.initial_style_translations[scan_p] = xla.lower_fun(_scan_impl,
                                                       initial_style=True)
batching.primitive_batchers[scan_p] = _scan_batching_rule


def map(f, xs):
    """Map a function over leading array axes.

  Like Python's builtin map, except inputs and outputs are in the form of
  stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
  need to apply a function element by element for reduced memory usage or
  heterogeneous computation with other control flow primitives.

  When ``xs`` is an array type, the semantics of ``map`` are given by this
  Python implementation::

    def map(f, xs):
Exemple #8
0
        return _coo_extract(row, col, ct), row, col
    else:
        raise NotImplementedError(f"todense_transpose for {type(obj)}")


def _todense_batching_rule(batched_args, batch_dims, *, tree):
    return jax.vmap(partial(_todense_impl, tree=tree),
                    batch_dims)(*batched_args), 0


ad.primitive_jvps[todense_p] = _todense_jvp
ad.primitive_transposes[todense_p] = _todense_transpose
batching.primitive_batchers[todense_p] = _todense_batching_rule
xla.register_translation(
    todense_p,
    xla.lower_fun(_todense_impl, multiple_results=False, new_style=True))


def empty(shape,
          dtype=None,
          index_dtype='int32',
          sparse_format='bcoo',
          **kwds):
    """Create an empty sparse array.

  Args:
    shape: sequence of integers giving the array shape.
    dtype: (optional) dtype of the array.
    index_dtype: (optional) dtype of the index arrays.
    format: string specifying the matrix format (e.g. ['bcoo']).
    **kwds: additional keywords passed to the format-specific _empty constructor.
Exemple #9
0
@csr_todense_p.def_impl
def _csr_todense_impl(data, indices, indptr, *, shape):
    return _coo_todense_impl(data, *_csr_to_coo(indices, indptr), shape=shape)


@csr_todense_p.def_abstract_eval
def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
    assert data.ndim == indices.ndim == indptr.ndim == 1
    assert indices.dtype == indptr.dtype
    assert data.shape == indices.shape
    assert indptr.shape[0] == shape[0] + 1
    return core.ShapedArray(shape, data.dtype)


_csr_todense_translation_rule = xla.lower_fun(_csr_todense_impl,
                                              multiple_results=False,
                                              new_style=True)


def _csr_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
                                      indptr, *, shape):
    dtype = avals_in[0].dtype
    if not (np.issubdtype(dtype, np.floating)
            or np.issubdtype(dtype, np.complexfloating)):
        warnings.warn(
            f"csr_todense cusparse lowering not available for dtype={dtype}. "
            "Falling back to default implementation.",
            CuSparseEfficiencyWarning)
        return _csr_todense_translation_rule(ctx,
                                             avals_in,
                                             avals_out,
    window_dimensions = (1, ) + window_dimensions
    window_strides = (1, ) + window_strides
    padding = ((0, 0), ) + padding
    base_dilation = (1, ) + base_dilation
    window_dilation = (1, ) + window_dilation
    out = _select_and_gather_add(t, x, select_prim, window_dimensions,
                                 window_strides, padding, base_dilation,
                                 window_dilation)
    return (out, 0)


select_and_gather_add_p = lax.standard_primitive(
    _select_and_gather_add_shape_rule, lax._input_dtype,
    'select_and_gather_add',
    xla.lower_fun(_select_and_gather_add_using_variadic_reducewindow,
                  new_style=True,
                  multiple_results=False))
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
ad.primitive_transposes[select_and_gather_add_p] = \
  _select_and_gather_add_transpose
batching.primitive_batchers[select_and_gather_add_p] = \
  _select_and_gather_add_batching_rule
# TODO(b/183233858): use variadic reducewindow on GPU, when implemented.
xla.register_translation(select_and_gather_add_p,
                         _select_and_gather_add_translation,
                         platform='gpu')

mlir.register_lowering(
    select_and_gather_add_p,
    mlir.lower_fun(_select_and_gather_add_using_variadic_reducewindow,
                   multiple_results=False))
Exemple #11
0
            (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
    else:
        return hip_prng.threefry2x32_lowering(
            (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
            (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))


threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.register_translation(
    threefry2x32_p,
    xla.lower_fun(partial(_threefry2x32_lowering, use_rolled_loops=False),
                  multiple_results=True,
                  new_style=True))
xla.register_translation(threefry2x32_p,
                         xla.lower_fun(partial(_threefry2x32_lowering,
                                               use_rolled_loops=True),
                                       multiple_results=True,
                                       new_style=True),
                         platform='cpu')
mlir.register_lowering(
    threefry2x32_p,
    mlir.lower_fun(partial(_threefry2x32_lowering, use_rolled_loops=False),
                   multiple_results=True))
mlir.register_lowering(threefry2x32_p,
                       mlir.lower_fun(partial(_threefry2x32_lowering,
                                              use_rolled_loops=True),
                                      multiple_results=True),
Exemple #12
0
def _lu_pivots_to_permutation_translation_rule(c, pivots, *, permutation_size):
  lowered_fun = xla.lower_fun(
      lambda x: _generic_lu_pivots_to_permutation(x, permutation_size),
      multiple_results=False)
  return lowered_fun(c, pivots)
Exemple #13
0
def identity(x):
  return identity_p.bind(x)

identity_p = core.Primitive('identity')

@identity_p.def_impl
def _identity_impl(mat):
  return mat

@identity_p.def_abstract_eval
def _identity_abstract_eval(mat):
  return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)

xla.register_translation(
    identity_p, xla.lower_fun(_identity_impl, multiple_results=False,
                              new_style=True))


mlir.register_lowering(
    identity_p, mlir.lower_fun(_identity_impl, multiple_results=False))

def split(x):
  return split_p.bind(x)

split_p = core.Primitive('split')
split_p.multiple_results = True

@split_p.def_impl
def _split_impl(mat):
  return mat, mat
Exemple #14
0
    shape = c.GetShape(operand)
    batch_dims = shape.dimensions()[:-2]
    lu, pivot, info = getrf_impl(c, operand)
    # Subtract 1 from the pivot to get 0-based indices.
    pivot = c.Sub(pivot, c.ConstantS32Scalar(1))
    ok = c.Eq(info, c.ConstantS32Scalar(0))
    lu = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), lu,
                              _nan_like(c, lu))
    return c.Tuple(lu, pivot)


lu_p = Primitive('lu')
lu_p.multiple_results = True
lu_p.def_impl(_lu_impl)
lu_p.def_abstract_eval(_lu_abstract_eval)
xla.translations[lu_p] = xla.lower_fun(_lu_python, instantiate=True)
ad.primitive_jvps[lu_p] = _lu_jvp_rule
batching.primitive_batchers[lu_p] = _lu_batching_rule

xla.backend_specific_translations['cpu'][lu_p] = partial(
    _lu_cpu_gpu_translation_rule, lapack.getrf)

xla.backend_specific_translations['gpu'][lu_p] = partial(
    _lu_cpu_gpu_translation_rule, cusolver.getrf)


def lu_pivots_to_permutation(swaps, m):
    """Converts the pivots (row swaps) returned by LU to a permutation.

  We build a permutation rather than applying `swaps` directly to the rows
  of a matrix because lax loops aren't differentiable.
Exemple #15
0
    rank = len(shape)

    def _broadcast(x):
        ndims = c.GetShape(x).rank()
        return c.BroadcastInDim(x, shape, tuple(range(rank - ndims, rank)))

    return cuda_prng.threefry2x32(c, (_broadcast(k1), _broadcast(k2)),
                                  (_broadcast(x1), _broadcast(x2)))


threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.translations[threefry2x32_p] = xla.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=False))
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=True))
if cuda_prng:
    xla.backend_specific_translations['gpu'][threefry2x32_p] = \
        _threefry2x32_gpu_translation_rule


@jit
def threefry_2x32(keypair, count):
    """Apply the Threefry 2x32 hash.

  Args:
    keypair: a pair of 32bit unsigned integers used for the key.
    count: an array of dtype uint32 used for the counts.
Exemple #16
0
  shape = c.GetShape(operand)
  batch_dims = shape.dimensions()[:-2]
  lu, pivot, info = getrf_impl(c, operand)
  # Subtract 1 from the pivot to get 0-based indices.
  pivot = c.Sub(pivot, c.ConstantS32Scalar(1))
  ok = c.Ge(info, c.ConstantS32Scalar(0))
  lu = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), lu,
                            _nan_like(c, lu))
  return c.Tuple(lu, pivot)


lu_p = Primitive('lu')
lu_p.multiple_results = True
lu_p.def_impl(_lu_impl)
lu_p.def_abstract_eval(_lu_abstract_eval)
xla.translations[lu_p] = xla.lower_fun(_lu_python)
ad.primitive_jvps[lu_p] = _lu_jvp_rule
batching.primitive_batchers[lu_p] = _lu_batching_rule

xla.backend_specific_translations['cpu'][lu_p] = partial(
  _lu_cpu_gpu_translation_rule, lapack.getrf)

xla.backend_specific_translations['gpu'][lu_p] = partial(
  _lu_cpu_gpu_translation_rule, cusolver.getrf)


# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
def _lu_pivots_body_fn(i, permutation_and_swaps):
  permutation, swaps = permutation_and_swaps
  batch_dims = swaps.shape[:-1]
  j = swaps[..., i]
Exemple #17
0
 def _xla(c, *xla_args, **params):
     translation = xla.lower_fun(self.impl, multiple_results=True)
     return translation(c, *xla_args, **params)
Exemple #18
0
    def _broadcast(x):
        ndims = c.GetShape(x).rank()
        return c.BroadcastInDim(x, shape, tuple(range(rank - ndims, rank)))

    return cuda_prng.threefry2x32(c, (_broadcast(k1), _broadcast(k2)),
                                  (_broadcast(x1), _broadcast(x2)))


threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.translations[threefry2x32_p] = xla.lower_fun(partial(
    _threefry2x32_lowering, use_rolled_loops=False),
                                                 instantiate=True)
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=True), instantiate=True)
if cuda_prng:
    xla.backend_specific_translations['gpu'][threefry2x32_p] = \
        _threefry2x32_gpu_translation_rule


@jit
def threefry_2x32(keypair, count):
    """Apply the Threefry 2x32 hash.

  Args:
    keypair: a pair of 32bit unsigned integers used for the key.
    count: an array of dtype uint32 used for the counts.
Exemple #19
0

@csr_todense_p.def_abstract_eval
def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
    assert data.ndim == indices.ndim == indptr.ndim == 1
    assert indices.dtype == indptr.dtype
    assert data.shape == indices.shape
    assert indptr.shape[0] == shape[0] + 1
    return core.ShapedArray(shape, data.dtype)


def _csr_todense_gpu_translation_rule(c, data, indices, indptr, *, shape):
    return cusparse.csr_todense(c, data, indices, indptr, shape=shape)


xla.translations[csr_todense_p] = xla.lower_fun(_csr_todense_impl,
                                                multiple_results=False)
if cusparse and cusparse.is_supported:
    xla.backend_specific_translations['gpu'][
        csr_todense_p] = _csr_todense_gpu_translation_rule

#--------------------------------------------------------------------
# csr_fromdense

csr_fromdense_p = core.Primitive('csr_fromdense')
csr_fromdense_p.multiple_results = True


def csr_fromdense(mat, *, nnz, index_dtype=np.int32):
    """Create CSR-format sparse matrix from a dense matrix.

  Args: