Exemple #1
0
def _use_qr(u, params):
    """Uses QR decomposition."""
    a, b, c = params
    m, n = u.shape
    y = jnp.concatenate([jnp.sqrt(c) * u, jnp.eye(n)])
    q, _ = lax_linalg.qr(y, full_matrices=False)
    q1 = q[:m, :]
    q2 = (q[m:, :]).T.conj()
    e = b / c
    u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2))
    return u
Exemple #2
0
def qr(a, mode="reduced"):
    if mode in ("reduced", "r", "full"):
        full_matrices = False
    elif mode == "complete":
        full_matrices = True
    else:
        raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
    a = _promote_arg_dtypes(jnp.asarray(a))
    q, r = lax_linalg.qr(a, full_matrices)
    if mode == "r":
        return r
    return q, r
Exemple #3
0
def _qr(a, mode, pivoting):
    if pivoting:
        raise NotImplementedError(
            "The pivoting=True case of qr is not implemented.")
    if mode in ("full", "r"):
        full_matrices = True
    elif mode == "economic":
        full_matrices = False
    else:
        raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    q, r = lax_linalg.qr(a, full_matrices=full_matrices)
    if mode == "r":
        return (r, )
    return q, r
Exemple #4
0
def _qr(a, mode, pivoting):
    if pivoting:
        raise NotImplementedError(
            "The pivoting=True case of qr is not implemented.")
    if mode in ("full", "r"):
        full_matrices = True
    elif mode == "economic":
        full_matrices = False
    else:
        raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    q, r = lax_linalg.qr(a, full_matrices)
    if mode == "r":
        return r
    return q, r
Exemple #5
0
def qr(a, mode="reduced"):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    if mode == "raw":
        a, taus = lax_linalg.geqrf(a)
        return _T(a), taus
    if mode in ("reduced", "r", "full"):
        full_matrices = False
    elif mode == "complete":
        full_matrices = True
    else:
        raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
    q, r = lax_linalg.qr(a, full_matrices=full_matrices)
    if mode == "r":
        return r
    return q, r
Exemple #6
0
def _use_qr(u, m, n, params):
    """QDWH iteration using QR decomposition.

  Args:
  u: a matrix, with static (padded) shape M x N.
  m, n: the dynamic shape of the matrix, where m <= M and n <= N.
  params: the QDWH parameters.
  """
    a, b, c = params
    M, N = u.shape

    y = _dynamic_concat(jnp.sqrt(c) * u, jnp.eye(N, dtype=jnp.dtype(u)), m)
    q, _ = lax_linalg.qr(y, full_matrices=False)
    # q1 = q[:m, :]
    q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n))
    # q2 = (q[m:, :]).T.conj()
    q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0)
    q2 = _mask(q2, (n, n)).T.conj()
    e = b / c
    u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2))
    return u