Exemplo n.º 1
0
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
                                 full_matrices):
    shape = c.GetShape(operand)
    dims = shape.dimensions()
    m, n = dims[-2:]
    batch_dims = dims[:-2]
    cs = xb.computation_builder_shim(c)
    r, tau, info_geqrf = geqrf_impl(cs, operand)
    if m < n:
        q = xops.Slice(r, [0] * len(dims),
                       list(batch_dims) + [m, m], [1] * len(dims))
        q, info_orgqr = orgqr_impl(cs, q, tau)
    elif not full_matrices:
        q, info_orgqr = orgqr_impl(cs, r, tau)
        r = xops.Slice(r, [0] * len(dims),
                       list(batch_dims) + [n, n], [1] * len(dims))
    else:
        padding_config = [(0, 0, 0)] * len(dims)
        padding_config[-1] = (0, m - n, 0)
        q = xops.Pad(
            r, xops.Constant(c, onp.array(0, dtype=shape.element_type())),
            xla_client.make_padding_config(padding_config))
        q, info_orgqr = orgqr_impl(cs, q, tau)

    ok = xops.And(
        xops.Eq(info_geqrf, xops.ConstantLiteral(c, onp.array(0, onp.int32))),
        xops.Eq(info_orgqr, xops.ConstantLiteral(c, onp.array(0, onp.int32))))
    q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q,
                             _nan_like(c, q))
    r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r,
                             _nan_like(c, r))
    return xops.Tuple(c, [q, r])
Exemplo n.º 2
0
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
                                 full_matrices):
  shape = c.get_shape(operand)
  dims = shape.dimensions()
  m, n = dims[-2:]
  batch_dims = dims[:-2]
  r, tau, info_geqrf = geqrf_impl(c, operand)
  if m < n:
    q = xops.Slice(r, [0] * len(dims), list(batch_dims) + [m, m],
                   [1] * len(dims))
    q, info_orgqr = orgqr_impl(c, q, tau)
  elif not full_matrices:
    q, info_orgqr = orgqr_impl(c, r, tau)
    r = xops.Slice(r, [0] * len(dims), list(batch_dims) + [n, n],
                   [1] * len(dims))
  else:
    padding_config = [(0, 0, 0)] * len(dims)
    padding_config[-1] = (0, m - n, 0)
    q = xops.Pad(r, xops.Constant(c, np.array(0, dtype=shape.element_type())),
                 xla_client.make_padding_config(padding_config))
    q, info_orgqr = orgqr_impl(c, q, tau)
  if info_geqrf is not None:
    ok = xops.And(
      xops.Eq(info_geqrf, xops.ConstantLiteral(c, np.array(0, np.int32))),
      xops.Eq(info_orgqr, xops.ConstantLiteral(c, np.array(0, np.int32))))
    q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q,
                             _nan_like(c, q))
    r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r,
                             _nan_like(c, r))
  else:
    pass # rocsolver does not return info

  r = xla.lower_fun(jnp.triu, multiple_results=False)(c, r)
  return xops.Tuple(c, [q, r])