def _lu_solve_core(lu, permutation, b, trans): m = lu.shape[0] x = jnp.reshape(b, (m, -1)) if trans == 0: x = x[permutation, :] x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = triangular_solve(lu, x, left_side=True, lower=False) elif trans == 1 or trans == 2: conj = trans == 2 x = triangular_solve(lu, x, left_side=True, lower=False, transpose_a=True, conjugate_a=conj) x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True, transpose_a=True, conjugate_a=conj) x = x[jnp.argsort(permutation), :] else: raise ValueError( "'trans' value must be 0, 1, or 2, got {}".format(trans)) return lax.reshape(x, b.shape)
def eigh(H, *, precision="float32", termination_size=256, n=None): """ Computes the eigendecomposition of the symmetric/Hermitian matrix H. Args: H: The `n x n` Hermitian input. precision: :class:`~jax.lax.Precision` object specifying the matmul precision. symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used. termination_size: Recursion ends once the blocks reach this linear size. Returns: vals: The `n` eigenvalues of `H`, sorted from lowest to highest. vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up to numerical error. """ M, N = H.shape if M != N: raise TypeError(f"Input H of shape {H.shape} must be square.") if N <= termination_size: return jnp_linalg.eigh(H) # TODO(phawkins): consider rounding N up to a larger size to maximize reuse # between matrices. n = N if n is None else n with jax.default_matmul_precision(precision): eig_vals, eig_vecs = _eigh_work(H, n, termination_size=termination_size) eig_vals = _mask(eig_vals, (n,), jnp.nan) sort_idxs = jnp.argsort(eig_vals) eig_vals = eig_vals[sort_idxs] eig_vecs = eig_vecs[:, sort_idxs] return eig_vals, eig_vecs
def eigh(H, *, precision="float32", termination_size=256, n=None, sort_eigenvalues=True): """ Computes the eigendecomposition of the symmetric/Hermitian matrix H. Args: H: The `n x n` Hermitian input, padded to `N x N`. precision: :class:`~jax.lax.Precision` object specifying the matmul precision. termination_size: Recursion ends once the blocks reach this linear size. n: the true (dynamic) size of the matrix. sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to highest. Returns: vals: The `n` eigenvalues of `H`. vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up to numerical error. """ M, N = H.shape if M != N: raise TypeError(f"Input H of shape {H.shape} must be square.") if N <= termination_size: if n is not None: H = _mask(H, (n, n), jnp.eye(N, dtype=H.dtype)) return lax_linalg.eigh_jacobi(H, sort_eigenvalues=sort_eigenvalues) # TODO(phawkins): consider rounding N up to a larger size to maximize reuse # between matrices. n = N if n is None else n with jax.default_matmul_precision(precision): eig_vals, eig_vecs = _eigh_work(H, n, termination_size=termination_size) eig_vals = _mask(jnp.real(eig_vals), (n, ), jnp.nan) if sort_eigenvalues: sort_idxs = jnp.argsort(eig_vals) eig_vals = eig_vals[sort_idxs] eig_vecs = eig_vecs[:, sort_idxs] return eig_vals, eig_vecs
def _projector_subspace(P, H, n, rank, maxiter=2): """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into an `n x rank` isometry `V_minus` such that `P = V_minus @ V_minus.conj().T` and an `n x (n - rank)` isometry `V_minus` such that -(I - P) = V_plus @ V_plus.conj().T`. The subspaces are computed using the naiive QR eigendecomposition algorithm, which converges very quickly due to the sharp separation between the relevant eigenvalues of the projector. Args: P: A rank-`rank` Hermitian projector into the space of `H`'s first `rank` eigenpairs. `P` is padded to NxN. H: The aforementioned Hermitian matrix, which is used to track convergence. n: the true (dynamic) shape of `P`. rank: Rank of `P`. maxiter: Maximum number of iterations. Returns: V_minus, V_plus: Isometries into the eigenspaces described in the docstring. """ # Choose an initial guess: the `rank` largest-norm columns of P. N, _ = P.shape column_norms = jnp_linalg.norm(P, axis=1) # `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN. column_norms = _mask(column_norms, (n,), jnp.nan) sort_idxs = jnp.argsort(column_norms) X = P[:, sort_idxs] # X = X[:, :rank] X = _mask(X, (n, rank)) H_norm = jnp_linalg.norm(H) thresh = 10 * jnp.finfo(X.dtype).eps * H_norm # First iteration skips the matmul. def body_f_after_matmul(X): Q, _ = jnp_linalg.qr(X, mode="complete") # V1 = Q[:, :rank] # V2 = Q[:, rank:] V1 = _mask(Q, (n, rank)) V2 = _slice(Q, (0, rank), (n, n - rank), (N, N)) # TODO: might be able to get away with lower precision here error_matrix = jnp.dot(V2.conj().T, H) error_matrix = jnp.dot(error_matrix, V1) error = jnp_linalg.norm(error_matrix) / H_norm return V1, V2, error def cond_f(args): _, _, j, error = args still_counting = j < maxiter unconverged = error > thresh return jnp.logical_and(still_counting, unconverged)[0] def body_f(args): V1, _, j, _ = args X = jnp.dot(P, V1) V1, V2, error = body_f_after_matmul(X) return V1, V2, j + 1, error V1, V2, error = body_f_after_matmul(X) one = jnp.ones(1, dtype=jnp.int32) V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error)) return V1, V2