Exemple #1
0
def lobpcg(
    A: ArrayOrFun,
    X0: jnp.ndarray,
    B: Optional[ArrayOrFun] = None,
    iK: Optional[ArrayOrFun] = None,
    largest: bool = False,
    k: Optional[int] = None,
    tol: Optional[float] = None,
    max_iters: int = 1000,
):
    """
    Find some of the eigenpairs for the generalized eigenvalue problem (A, B).

    Args:
        A: `[m, m]` hermitian matrix, or function representing pre-multiplication by an
            `[m, m]` hermitian matrix.
        X0: `[m, n]`, `k <= n < m`. Initial guess of eigenvectors.
        B: same type as A. If not given, identity is used.
        iK: Optional inverse preconditioner. If not given, identity is used.
        largest: if True, return the largest `k` eigenvalues, otherwise the smallest.
        k: number of eigenpairs to return. Uses `n` if not provided.
        tol: tolerance for convergence.
        max_iters: maximum number of iterations.

    Returns:
        w: [k] smallest/largest eigenvalues of generalized eigenvalue problem `(A, B)`.
        v: [n, k] eigenvectors associated with `w`. `v[:, i]` matches `w[i]`.
    """
    # Perform argument checks and fix default / computed arguments
    if B is not None:
        raise NotImplementedError("Implementations with non-None B have issues")
    if iK is not None:
        raise NotImplementedError("Inplementations with non-None iK have issues")
    ohm = jax.random.normal(jax.random.PRNGKey(0), shape=X0.shape, dtype=X0.dtype)
    A = as_array_fun(A)
    A_norm = utils.approx_matrix_norm2(A, ohm)
    if B is None:
        B = utils.identity
        B_norm = jnp.ones((), dtype=X0.dtype)
    else:
        B = as_array_fun(B)
        B_norm = utils.approx_matrix_norm2(B, ohm)
    if iK is None:
        iK = utils.identity
    else:
        iK = as_array_fun(iK)

    if tol is None:
        dtype = X0.dtype
        if dtype == jnp.float32:
            feps = 1.2e-7
        elif dtype == jnp.float64:
            feps = 2.23e-16
        else:
            raise KeyError(dtype)
        tol = feps ** 0.5

    k = k or X0.shape[1]
    return _lobpcg(A, X0, B, iK, largest, k, tol, max_iters, A_norm, B_norm)
Exemple #2
0
def rayleigh_ritz(S: jnp.ndarray,
                  A: ArrayOrFun,
                  B: Optional[ArrayOrFun] = None,
                  largest: bool = False):
    """

    Based on algorithm2 of [duersch2018](
        https://epubs.siam.org/doi/abs/10.1137/17M1129830)

    Args:
        S: [m, ns] float array, matrix basis for search space. Columns must be linearly
            independent and well-conditioned with respect to `B`.
        A: Callable simulating [m, m] float matrix multiplication.
        B: Callable simulating [m, m] float matrix multiplication.

    Returns:
        (eig_vals, C) satisfying the following:
            C.T @ S.T @ B(S) @ C = jnp.eye(ns)
            C.T @ S.T @ A(S) @ C = jnp.diag(eig_vals)

        eig_vals: [ns] eigenvalues. Sorted in descending order if largest, otherwise
            ascending.
        C: [ns, ns] float matrix satisfying:
    """
    A = as_array_fun(A)
    if B is None:
        BS = S
    else:
        BS = as_array_fun(B)(S)
    SBS = S.T @ BS
    d_right = jnp.diag(SBS)**-0.5  # d_right * X == X @ D
    d_left = jnp.expand_dims(d_right, 1)  # d_left * X == D @ X
    R_low = jnp.linalg.cholesky(d_left * SBS * d_right)  # upper triangular
    R_up = R_low.T

    # R_inv = jnp.linalg.inv(R_up)
    # RDSASDR = R_inv.T @ (d_left * (S.T @ A(S)) * d_right) @ R_inv

    DSASD = d_left * (S.T @ A(S)) * d_right
    RDSASD = jax.scipy.linalg.solve_triangular(R_low, DSASD, lower=True)
    RDSASDR = jax.scipy.linalg.solve_triangular(R_low, RDSASD.T, lower=True).T

    eig_vals, Z = eigh(RDSASDR, largest=largest)
    if B is not None:
        Z /= jnp.linalg.norm(Z, ord=2, axis=0)
    # C = d_left * R_inv @ Z
    C = d_left * (jax.scipy.linalg.solve_triangular(R_up, Z, lower=False))
    return eig_vals, C
Exemple #3
0
def matrix_poly_vector_prod_from_roots(
    roots: jnp.ndarray, A: ArrayOrFun, x: jnp.ndarray
):
    """
    Evaluate matrix-polynomial-vector product with the given roots.

    i.e. `(A - r_0) @ (A - r_1)... @ x`

    The full matrix `A` is never computed.

    Args:
        roots: 1D array of polynomial roots.
        A: `ArrayOrFun` for square matrix at which the polynomial is evaluated.
        x: rhs vector.

    Returns:
        output of the same size as `x`.
    """
    A = as_array_fun(A)
    assert len(roots.shape) == 1

    def body_fun(carry: jnp.ndarray, ri: float):
        el = A(carry) - ri * carry
        return el, el

    res, _ = jax.lax.scan(body_fun, x, roots)
    return res
Exemple #4
0
def matrix_poly_vector_prod_from_coeffs(
    coeffs: jnp.ndarray, A: ArrayOrFun, x: jnp.ndarray
):
    """
    Evaluate matrix-polynomial-vector product with the given coefficients.

    i.e. (coeffs[0] * I + coeffs[1] * A + ... + coeffs[n] A ** n) @ x

    The full matrix `A` is never computed.

    Args:
        roots: 1D array of polynomial roots.
        A: `ArrayOrFun` for square matrix at which the polynomial is evaluated.
        x: rhs vector.

    Returns:
        output of the same size as `x`.
    """
    A = as_array_fun(A)
    assert len(coeffs.shape) == 1

    def body_fun(carry: jnp.ndarray, coeff: float):
        el = A(carry) + coeff * x
        return el, el

    res, _ = jax.lax.scan(body_fun, coeffs[-1] * x, coeffs[-2::-1])
    return res
Exemple #5
0
def approx_matrix_norm2(A: Optional[ArrayOrFun], ohm: jnp.ndarray):
    """
    Approximation of matrix 2-norm of `A`.

        |A ohm|_fro / |ohm|_fro <= |A|_2

    This function returns the lower bound (left hand side).

    Args:
        A: matrix or callable that simulates matrix multiplication.
        ohm: block-vector used in formula. Should be Gaussian.

    Returns:
        Scalar, lower bound on 2-norm of A.
    """
    A = as_array_fun(A)
    return jnp.linalg.norm(A(ohm), "fro") / jnp.linalg.norm(ohm, "fro")
Exemple #6
0
def eigh_general(A, B, largest: bool):
    if B is None:
        w, v = jnp.linalg.eigh(A)
        B = lambda x: x
    else:
        w, v = jnp.linalg.eig(jnp.linalg.solve(B, A))
        w = w.real
        v = v.real
        i = jnp.argsort(w)
        w = w[i]
        v = v[:, i]
        B = as_array_fun(B)
    if largest:
        w = w[-1::-1]
        v = v[:, -1::-1]

    norm2 = jax.vmap(lambda vi: (vi.conj() @ B(vi)).real, in_axes=1)(v)
    norm = jnp.sqrt(norm2)
    v = v / norm
    v = standardize_eigenvector_signs(v)
    return w, v
Exemple #7
0
def compute_residual(E: jnp.ndarray, X: jnp.ndarray, A: ArrayOrFun,
                     B: Optional[ArrayOrFun]):
    BX = X if B is None else as_array_fun(B)(X)
    return as_array_fun(A)(X) - BX * E
Exemple #8
0
def eigh_partial_rev(grad_w, grad_v, w, v, a, x0, outer_impl=jnp.outer):
    """
    Args:
        grad_w: [k] gradient w.r.t eigenvalues
        grad_v: [m, k] gradient w.r.t eigenvectors
        w: [k] eigenvalues
        v: [m, k] eigenvectors
        a: matmul function
        x0: [m, k] initial solution to (A - w[i]I)x[i] = Proj(grad_v[:, i])

    Returns:
        grad_a: [m, m]
        x0: [m, k]
    """
    # based on
    # https://github.com/fancompute/legume/blob/99dd012feee28156292787330dac5e4f0c41d4c8/legume/primitives.py#L170-L210
    a = as_array_fun(a)
    grad_As = []

    grad_As.append(
        jax.vmap(lambda grad_wi, vi: grad_wi * outer_impl(vi.conj(), vi),
                 (0, 1))(grad_w, v).sum(0))
    if grad_v is not None:
        # Add eigenvector part only if non-zero backward signal is present.
        # This can avoid NaN results for degenerate cases if the function
        # depends on the eigenvalues only.

        def f_inner(grad_vi, wi, vi, x0i):
            def if_any(operand):
                grad_vi, wi, vi, x0i = operand

                # Amat = (a - wi * jnp.eye(m, dtype=a.dtype)).T
                Amat = lambda x: (a(x.conj())).conj() - wi * x

                # Projection operator on space orthogonal to v
                P = projector(vi)

                # Find a solution lambda_0 using conjugate gradient
                (l0, _) = jax.scipy.sparse.linalg.cg(Amat,
                                                     P(grad_vi),
                                                     x0=P(x0i),
                                                     atol=0)
                # (l0, _) = jax.scipy.sparse.linalg.gmres(Amat, P(grad_vi), x0=P(x0i))
                # l0 = jax.numpy.linalg.lstsq(Amat, P(grad_vi))[0]
                # Project to correct for round-off errors
                # print(Amat(l0) - P(grad_vi))
                l0 = P(l0)
                return -outer_impl(l0, vi), l0

            def if_none(operand):
                x0i = operand[-1]
                return jnp.zeros_like(grad_As[0]), x0i

            operand = (grad_vi, wi, vi, x0i)
            # return if_any(operand) if jnp.any(grad_vi) else if_none(operand)
            return jax.lax.cond(jnp.any(grad_vi), if_any, if_none, operand)

        # x0s = []
        # for k in range(grad_v.shape[1]):
        #     out = f_inner(grad_v[:, k], w[k], v[:, k], x0[:, k])
        #     grad_As.append(out[0])
        #     x0s.append(out[1])
        # x0 = jnp.stack(x0s, axis=0)
        grad_a, x0 = jax.vmap(f_inner, in_axes=(1, 0, 1, 1),
                              out_axes=(0, 1))(grad_v, w, v, x0)
        grad_As.append(grad_a.sum(0))
    return sum(grad_As), x0