Beispiel #1
0
def _lu_solve_core(lu, permutation, b, trans):
    m = lu.shape[0]
    x = jnp.reshape(b, (m, -1))
    if trans == 0:
        x = x[permutation, :]
        x = triangular_solve(lu,
                             x,
                             left_side=True,
                             lower=True,
                             unit_diagonal=True)
        x = triangular_solve(lu, x, left_side=True, lower=False)
    elif trans == 1 or trans == 2:
        conj = trans == 2
        x = triangular_solve(lu,
                             x,
                             left_side=True,
                             lower=False,
                             transpose_a=True,
                             conjugate_a=conj)
        x = triangular_solve(lu,
                             x,
                             left_side=True,
                             lower=True,
                             unit_diagonal=True,
                             transpose_a=True,
                             conjugate_a=conj)
        x = x[jnp.argsort(permutation), :]
    else:
        raise ValueError(
            "'trans' value must be 0, 1, or 2, got {}".format(trans))
    return lax.reshape(x, b.shape)
Beispiel #2
0
def eigh(H, *, precision="float32", termination_size=256, n=None):
  """ Computes the eigendecomposition of the symmetric/Hermitian matrix H.

  Args:
    H: The `n x n` Hermitian input.
    precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
    symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used.
    termination_size: Recursion ends once the blocks reach this linear size.
  Returns:
    vals: The `n` eigenvalues of `H`, sorted from lowest to highest.
    vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
      of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
      to numerical error.
  """
  M, N = H.shape
  if M != N:
    raise TypeError(f"Input H of shape {H.shape} must be square.")

  if N <= termination_size:
    return jnp_linalg.eigh(H)

  # TODO(phawkins): consider rounding N up to a larger size to maximize reuse
  # between matrices.

  n = N if n is None else n
  with jax.default_matmul_precision(precision):
    eig_vals, eig_vecs = _eigh_work(H, n, termination_size=termination_size)
  eig_vals = _mask(eig_vals, (n,), jnp.nan)
  sort_idxs = jnp.argsort(eig_vals)
  eig_vals = eig_vals[sort_idxs]
  eig_vecs = eig_vecs[:, sort_idxs]
  return eig_vals, eig_vecs
Beispiel #3
0
def eigh(H,
         *,
         precision="float32",
         termination_size=256,
         n=None,
         sort_eigenvalues=True):
    """ Computes the eigendecomposition of the symmetric/Hermitian matrix H.

  Args:
    H: The `n x n` Hermitian input, padded to `N x N`.
    precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
    termination_size: Recursion ends once the blocks reach this linear size.
    n: the true (dynamic) size of the matrix.
    sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to
      highest.
  Returns:
    vals: The `n` eigenvalues of `H`.
    vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
      of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
      to numerical error.
  """
    M, N = H.shape
    if M != N:
        raise TypeError(f"Input H of shape {H.shape} must be square.")

    if N <= termination_size:
        if n is not None:
            H = _mask(H, (n, n), jnp.eye(N, dtype=H.dtype))
        return lax_linalg.eigh_jacobi(H, sort_eigenvalues=sort_eigenvalues)

    # TODO(phawkins): consider rounding N up to a larger size to maximize reuse
    # between matrices.

    n = N if n is None else n
    with jax.default_matmul_precision(precision):
        eig_vals, eig_vecs = _eigh_work(H,
                                        n,
                                        termination_size=termination_size)
    eig_vals = _mask(jnp.real(eig_vals), (n, ), jnp.nan)
    if sort_eigenvalues:
        sort_idxs = jnp.argsort(eig_vals)
        eig_vals = eig_vals[sort_idxs]
        eig_vecs = eig_vecs[:, sort_idxs]
    return eig_vals, eig_vecs
Beispiel #4
0
def _projector_subspace(P, H, n, rank, maxiter=2):
  """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into
  an `n x rank` isometry `V_minus` such that `P = V_minus @ V_minus.conj().T`
  and an `n x (n - rank)` isometry `V_minus` such that
  -(I - P) = V_plus @ V_plus.conj().T`.

  The subspaces are computed using the naiive QR eigendecomposition
  algorithm, which converges very quickly due to the sharp separation
  between the relevant eigenvalues of the projector.

  Args:
    P: A rank-`rank` Hermitian projector into the space of `H`'s
       first `rank` eigenpairs. `P` is padded to NxN.
    H: The aforementioned Hermitian matrix, which is used to track
       convergence.
    n: the true (dynamic) shape of `P`.
    rank: Rank of `P`.
    maxiter: Maximum number of iterations.
  Returns:
    V_minus, V_plus: Isometries into the eigenspaces described in the docstring.
  """
  # Choose an initial guess: the `rank` largest-norm columns of P.
  N, _ = P.shape
  column_norms = jnp_linalg.norm(P, axis=1)
  # `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN.
  column_norms = _mask(column_norms, (n,), jnp.nan)
  sort_idxs = jnp.argsort(column_norms)
  X = P[:, sort_idxs]
  # X = X[:, :rank]
  X = _mask(X, (n, rank))

  H_norm = jnp_linalg.norm(H)
  thresh = 10 * jnp.finfo(X.dtype).eps * H_norm

  # First iteration skips the matmul.
  def body_f_after_matmul(X):
    Q, _ = jnp_linalg.qr(X, mode="complete")
    # V1 = Q[:, :rank]
    # V2 = Q[:, rank:]
    V1 = _mask(Q, (n, rank))
    V2 = _slice(Q, (0, rank), (n, n - rank), (N, N))

    # TODO: might be able to get away with lower precision here
    error_matrix = jnp.dot(V2.conj().T, H)
    error_matrix = jnp.dot(error_matrix, V1)
    error = jnp_linalg.norm(error_matrix) / H_norm
    return V1, V2, error

  def cond_f(args):
    _, _, j, error = args
    still_counting = j < maxiter
    unconverged = error > thresh
    return jnp.logical_and(still_counting, unconverged)[0]

  def body_f(args):
    V1, _, j, _ = args
    X = jnp.dot(P, V1)
    V1, V2, error = body_f_after_matmul(X)
    return V1, V2, j + 1, error

  V1, V2, error = body_f_after_matmul(X)
  one = jnp.ones(1, dtype=jnp.int32)
  V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
  return V1, V2