예제 #1
0
    def test_lobpcg(self):
        m = 50
        k = 10
        dtype = np.float64
        A = random_spd_coo(m, dtype=dtype)
        rng = np.random.default_rng(0)
        X0 = rng.uniform(size=(m, k)).astype(dtype)

        B = None
        iK = None

        E_expected, X_expected = eigh_general(A.todense(), B, False)
        E_expected = E_expected[:k]
        X_expected = X_expected[:, :k]

        data, row, col, shape = coo_components(A.tocoo())
        A_coo = coo.matmul_fun(
            jnp.asarray(data),
            jnp.asarray(row),
            jnp.asarray(col),
            jnp.zeros(((shape[0], ))),
        )
        data, indices, indptr, _ = csr_components(A.tocsr())
        A_csr = csr.matmul_fun(jnp.asarray(data), jnp.asarray(indices),
                               jnp.asarray(indptr))
        for A_fun in (jnp.asarray(A.todense()), A_coo, A_csr):
            E_actual, X_actual = lobpcg(A=A_fun,
                                        B=B,
                                        X0=X0,
                                        iK=iK,
                                        largest=False,
                                        max_iters=200)
            X_actual = standardize_eigenvector_signs(X_actual)
            self.assertAllClose(E_expected, E_actual, rtol=1e-8, atol=1e-10)
            self.assertAllClose(X_expected, X_actual, rtol=1e-4, atol=1e-10)
예제 #2
0
 def lobpcg_coo(data, row, col, X0, largest, k):
     size = X0.shape[0]
     data = coo.symmetrize(data, row, col, size)
     A = coo.matmul_fun(data, row, col, jnp.zeros((size, )))
     w, v = lobpcg(A, X0, largest=largest, k=k)
     v = standardize_eigenvector_signs(v)
     return w, v
예제 #3
0
 def eigh_partial_rev(res, g):
     w, v, data, row, col = res
     size = v.shape[0]
     grad_w, grad_v = g
     rng_key = jax.random.PRNGKey(0)
     x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype)
     grad_data, x0 = cg.eigh_partial_rev(
         grad_w,
         grad_v,
         w,
         v,
         coo.matmul_fun(data, row, col, jnp.zeros((size, ))),
         x0,
         outer_impl=coo.masked_outer_fun(row, col),
     )
     grad_data = coo.symmetrize(grad_data, row, col, size)
     return (grad_data, None, None, None, None, None)
예제 #4
0
 def lobpcg_rev(res, g):
     grad_w, grad_v = g
     w, v, data, row, col = res
     size = v.shape[0]
     A = coo.matmul_fun(data, row, col, jnp.zeros((size, )))
     x0 = jax.random.normal(jax.random.PRNGKey(0),
                            shape=v.shape,
                            dtype=v.dtype)
     grad_data, x0 = eigh_partial_rev(grad_w,
                                      grad_v,
                                      w,
                                      v,
                                      A,
                                      x0,
                                      outer_impl=coo.masked_outer_fun(
                                          row, col))
     grad_data = coo.symmetrize(grad_data, row, col, size)
     return grad_data, None, None, None, None, None
예제 #5
0
def chebyshev_subspace_iteration_coo(
    data: jnp.ndarray,
    row: jnp.ndarray,
    col: jnp.ndarray,
    sized: jnp.ndarray,
    v0: jnp.ndarray,
    tol: Optional[float] = None,
    max_iters: int = 1000,
    scale: float = 1.0,
    order: int = 8,
):
    a = coo.matmul_fun(data, row, col, sized)
    w, v, info = si.chebyshev_subspace_iteration(order,
                                                 scale,
                                                 a,
                                                 v0,
                                                 tol=tol,
                                                 max_iters=max_iters)
    del info
    v = utils.standardize_signs(v)
    return w, v
예제 #6
0
def eigh_partial_rev_coo(
    grad_w: jnp.ndarray,
    grad_v: jnp.ndarray,
    w: jnp.ndarray,
    v: jnp.ndarray,
    l0: jnp.ndarray,
    data: jnp.ndarray,
    row: jnp.ndarray,
    col: jnp.ndarray,
    sized: jnp.ndarray,
    tol: float = 1e-5,
):
    """
    Args:
        grad_w: [k] gradient w.r.t eigenvalues
        grad_v: [m, k] gradient w.r.t eigenvectors
        w: [k] eigenvalues.
        v: [m, k] eigenvectors.
        l0: initial solution to least squares problem.
        data, row, col, sized: coo formatted [m, m] matrix.
        seed: used to initialize conjugate gradient solution.

    Returns:
        grad_data: gradient of `data` input, same `shape` and `dtype`.
        l: solution to least squares problem.
    """
    outer_impl = coo.masked_outer_fun(row, col)
    a = coo.matmul_fun(data, row, col, sized)
    grad_data, l0 = cg.eigh_partial_rev(grad_w,
                                        grad_v,
                                        w,
                                        v,
                                        l0,
                                        a,
                                        outer_impl,
                                        tol=tol)
    return grad_data, l0