示例#1
0
文件: lax_linalg.py 项目: yotarok/jax
def _triangular_solve_cpu_translation_rule(c, a, b, left_side, lower,
                                           transpose_a, conjugate_a,
                                           unit_diagonal):
    shape = c.GetShape(a)
    dtype = shape.element_type().type

    if conjugate_a and not transpose_a:
        a = xops.Conj(a)
        conjugate_a = False
    if len(shape.dimensions()) == 2 and onp.dtype(dtype) in _cpu_lapack_types:
        return lapack.jax_trsm(xb.computation_builder_shim(c),
                               xb.constant(c, onp.array(1, dtype=dtype)), a, b,
                               left_side, lower, transpose_a, conjugate_a,
                               unit_diagonal)
    else:
        # Fall back to the HLO implementation for unsupported types or batching.
        # TODO: Consider swapping XLA for LAPACK in batched case
        if not transpose_a:
            transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
        else:
            transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT
                         if conjugate_a else
                         xops.TriangularSolveOptions_Transpose.TRANSPOSE)
        return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
                                    transpose)
示例#2
0
文件: lax_linalg.py 项目: yotarok/jax
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
                                 full_matrices):
    shape = c.GetShape(operand)
    dims = shape.dimensions()
    m, n = dims[-2:]
    batch_dims = dims[:-2]
    cs = xb.computation_builder_shim(c)
    r, tau, info_geqrf = geqrf_impl(cs, operand)
    if m < n:
        q = xops.Slice(r, [0] * len(dims),
                       list(batch_dims) + [m, m], [1] * len(dims))
        q, info_orgqr = orgqr_impl(cs, q, tau)
    elif not full_matrices:
        q, info_orgqr = orgqr_impl(cs, r, tau)
        r = xops.Slice(r, [0] * len(dims),
                       list(batch_dims) + [n, n], [1] * len(dims))
    else:
        padding_config = [(0, 0, 0)] * len(dims)
        padding_config[-1] = (0, m - n, 0)
        q = xops.Pad(
            r, xops.Constant(c, onp.array(0, dtype=shape.element_type())),
            xla_client.make_padding_config(padding_config))
        q, info_orgqr = orgqr_impl(cs, q, tau)

    ok = xops.And(
        xops.Eq(info_geqrf, xops.ConstantLiteral(c, onp.array(0, onp.int32))),
        xops.Eq(info_orgqr, xops.ConstantLiteral(c, onp.array(0, onp.int32))))
    q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q,
                             _nan_like(c, q))
    r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r,
                             _nan_like(c, r))
    return xops.Tuple(c, [q, r])
示例#3
0
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
  shape = c.GetShape(operand)
  batch_dims = shape.dimensions()[:-2]
  dtype = shape.element_type().type
  result, info = potrf_impl(xb.computation_builder_shim(c), operand, lower=True)
  ok = xops.Eq(info, xops.ConstantLiteral(c, onp.array(0, onp.int32)))
  return _broadcasting_select(c,
                              xops.Reshape(ok, batch_dims + (1, 1)), result,
                              _nan_like(c, result))
示例#4
0
文件: lax_linalg.py 项目: yotarok/jax
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand):
    shape = c.GetShape(operand)
    batch_dims = shape.dimensions()[:-2]
    lu, pivot, info = getrf_impl(xb.computation_builder_shim(c), operand)
    # Subtract 1 from the pivot to get 0-based indices.
    pivot = xops.Sub(pivot, xops.ConstantLiteral(c, onp.array(1, onp.int32)))
    ok = xops.Ge(info, xops.ConstantLiteral(c, onp.array(0, onp.int32)))
    lu = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), lu,
                              _nan_like(c, lu))
    return xops.Tuple(c, [lu, pivot])
示例#5
0
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
  shape = c.GetShape(operand)
  batch_dims = shape.dimensions()[:-2]
  v, w, info = syevd_impl(xb.computation_builder_shim(c), operand, lower=lower)
  ok = xops.Eq(info, xops.ConstantLiteral(c, onp.array(0, onp.int32)))
  v = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), v,
                           _nan_like(c, v))
  w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), w,
                           _nan_like(c, w))
  return xops.Tuple(c, [v, w])
示例#6
0
文件: lax_linalg.py 项目: yotarok/jax
def eig_cpu_translation_rule(c, operand):
    shape = c.GetShape(operand)
    batch_dims = shape.dimensions()[:-2]
    w, vl, vr, info = _cpu_geev(xb.computation_builder_shim(c), operand)
    ok = xops.Eq(info, xops.ConstantLiteral(c, onp.array(0, onp.int32)))
    w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, )), w,
                             _nan_like(c, w))
    vl = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vl,
                              _nan_like(c, vl))
    vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr,
                              _nan_like(c, vr))
    return xops.Tuple(c, [w, vl, vr])
示例#7
0
def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
  shape = lax.broadcast_shapes(
      c.GetShape(k1).dimensions(), c.GetShape(k2).dimensions(),
      c.GetShape(x1).dimensions(), c.GetShape(x2).dimensions())
  rank = len(shape)
  def _broadcast(x):
    ndims = c.GetShape(x).rank()
    return xla_client.ops.BroadcastInDim(x, shape,
                                         tuple(range(rank - ndims, rank)))
  return cuda_prng.threefry2x32(
      xla_bridge.computation_builder_shim(c),
      (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))
示例#8
0
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):

  shape = c.GetShape(operand)
  batch_dims = shape.dimensions()[:-2]
  s, u, vt, info = gesvd_impl(xb.computation_builder_shim(c), operand,
                              full_matrices=full_matrices,
                              compute_uv=compute_uv)
  ok = xops.Eq(info, xops.ConstantLiteral(c, onp.array(0, onp.int32)))
  s = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), s,
                           _nan_like(c, s))
  u = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), u,
                           _nan_like(c, u))
  vt = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vt,
                            _nan_like(c, vt))
  return xops.Tuple(c, [s, u, vt])
示例#9
0
def _triangular_solve_gpu_translation_rule(
    c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
  shape = c.GetShape(a)
  dtype = shape.element_type().type
  dims = shape.dimensions()
  m, n = dims[-2:]
  batch = prod(dims[:-2])
  if conjugate_a and not transpose_a:
    a = xops.Conj(a)
    conjugate_a = False
  if batch > 1 and m <= 32 and n <= 32:
    return cusolver.trsm(
      xb.computation_builder_shim(c), a, b, left_side, lower, transpose_a,
      conjugate_a, unit_diagonal)
  else:
    # Use the XLA implementation for unbatched triangular_solve.
    if not transpose_a:
      transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
    else:
      transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a
                   else xops.TriangularSolveOptions_Transpose.TRANSPOSE)
    return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
                                transpose)