Exemple #1
0
def eig_cpu_translation_rule(c, operand, *, compute_left_eigenvectors,
                             compute_right_eigenvectors):
    shape = c.get_shape(operand)
    batch_dims = shape.dimensions()[:-2]

    w, vl, vr, info = _cpu_geev(c,
                                operand,
                                jobvl=compute_left_eigenvectors,
                                jobvr=compute_right_eigenvectors)

    ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
    w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, )), w,
                             _nan_like(c, w))
    output = [w]

    if compute_left_eigenvectors:
        vl = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vl,
                                  _nan_like(c, vl))
        output.append(vl)

    if compute_right_eigenvectors:
        vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr,
                                  _nan_like(c, vr))
        output.append(vr)

    return xops.Tuple(c, output)
Exemple #2
0
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
                                 full_matrices):
    shape = c.get_shape(operand)
    dims = shape.dimensions()
    m, n = dims[-2:]
    batch_dims = dims[:-2]
    r, tau, info_geqrf = geqrf_impl(c, operand)
    if m < n:
        q = xops.Slice(r, [0] * len(dims),
                       list(batch_dims) + [m, m], [1] * len(dims))
        q, info_orgqr = orgqr_impl(c, q, tau)
    elif not full_matrices:
        q, info_orgqr = orgqr_impl(c, r, tau)
        r = xops.Slice(r, [0] * len(dims),
                       list(batch_dims) + [n, n], [1] * len(dims))
    else:
        padding_config = [(0, 0, 0)] * len(dims)
        padding_config[-1] = (0, m - n, 0)
        q = xops.Pad(r,
                     xops.Constant(c, np.array(0, dtype=shape.element_type())),
                     xla_client.make_padding_config(padding_config))
        q, info_orgqr = orgqr_impl(c, q, tau)

    ok = xops.And(
        xops.Eq(info_geqrf, xops.ConstantLiteral(c, np.array(0, np.int32))),
        xops.Eq(info_orgqr, xops.ConstantLiteral(c, np.array(0, np.int32))))
    q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q,
                             _nan_like(c, q))
    r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r,
                             _nan_like(c, r))
    r = xla.lower_fun(jnp.triu, multiple_results=False)(c, r)
    return xops.Tuple(c, [q, r])
Exemple #3
0
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):
  shape = c.get_shape(operand).dimensions()
  m, n = shape[-2:]
  batch_dims = shape[:-2]

  if m == 0 or n == 0:
    return xla.lower_fun(_empty_svd, multiple_results=True)(
      c, operand, full_matrices=full_matrices, compute_uv=compute_uv)

  s, u, vt, info = gesvd_impl(c, operand,
                              full_matrices=full_matrices,
                              compute_uv=compute_uv)
  ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
  s = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), s,
                           _nan_like(c, s))

  result = [s]

  if compute_uv:
    u = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), u,
                             _nan_like(c, u))
    vt = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vt,
                              _nan_like(c, vt))
    result += [u, vt]

  return xops.Tuple(c, result)
Exemple #4
0
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
    shape = c.get_shape(operand)
    batch_dims = shape.dimensions()[:-2]
    v, w, info = syevd_impl(c, operand, lower=lower)
    ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
    v = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), v,
                             _nan_like(c, v))
    w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, )), w,
                             _nan_like(c, w))
    return xops.Tuple(c, [v, w])
Exemple #5
0
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
    shape = c.get_shape(operand)
    batch_dims = shape.dimensions()[:-2]
    result, info = potrf_impl(c, operand, lower=True)
    ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
    return _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)),
                                result, _nan_like(c, result))
Exemple #6
0
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):

  shape = c.get_shape(operand)
  batch_dims = shape.dimensions()[:-2]
  s, u, vt, info = gesvd_impl(c, operand,
                              full_matrices=full_matrices,
                              compute_uv=compute_uv)
  ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
  s = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), s,
                           _nan_like(c, s))

  result = [s]

  if compute_uv:
    u = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), u,
                             _nan_like(c, u))
    vt = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vt,
                              _nan_like(c, vt))
    result += [u, vt]

  return xops.Tuple(c, result)
Exemple #7
0
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand):
    shape = c.get_shape(operand)
    batch_dims = shape.dimensions()[:-2]
    m = shape.dimensions()[-2]
    lu, pivot, info = getrf_impl(c, operand)
    # Subtract 1 from the pivot to get 0-based indices.
    pivot = xops.Sub(pivot, xops.ConstantLiteral(c, np.array(1, np.int32)))
    ok = xops.Ge(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
    lu = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), lu,
                              _nan_like(c, lu))
    perm = xla.lower_fun(lambda x: lu_pivots_to_permutation(x, m),
                         multiple_results=False)(c, pivot)
    return xops.Tuple(c, [lu, pivot, perm])