def test_lobpcg(self): m = 50 k = 10 dtype = np.float64 A = random_spd_coo(m, dtype=dtype) rng = np.random.default_rng(0) X0 = rng.uniform(size=(m, k)).astype(dtype) B = None iK = None E_expected, X_expected = eigh_general(A.todense(), B, False) E_expected = E_expected[:k] X_expected = X_expected[:, :k] data, row, col, shape = coo_components(A.tocoo()) A_coo = coo.matmul_fun( jnp.asarray(data), jnp.asarray(row), jnp.asarray(col), jnp.zeros(((shape[0], ))), ) data, indices, indptr, _ = csr_components(A.tocsr()) A_csr = csr.matmul_fun(jnp.asarray(data), jnp.asarray(indices), jnp.asarray(indptr)) for A_fun in (jnp.asarray(A.todense()), A_coo, A_csr): E_actual, X_actual = lobpcg(A=A_fun, B=B, X0=X0, iK=iK, largest=False, max_iters=200) X_actual = standardize_eigenvector_signs(X_actual) self.assertAllClose(E_expected, E_actual, rtol=1e-8, atol=1e-10) self.assertAllClose(X_expected, X_actual, rtol=1e-4, atol=1e-10)
def lobpcg_coo(data, row, col, X0, largest, k): size = X0.shape[0] data = coo.symmetrize(data, row, col, size) A = coo.matmul_fun(data, row, col, jnp.zeros((size, ))) w, v = lobpcg(A, X0, largest=largest, k=k) v = standardize_eigenvector_signs(v) return w, v
def eigh_partial_rev(res, g): w, v, data, row, col = res size = v.shape[0] grad_w, grad_v = g rng_key = jax.random.PRNGKey(0) x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype) grad_data, x0 = cg.eigh_partial_rev( grad_w, grad_v, w, v, coo.matmul_fun(data, row, col, jnp.zeros((size, ))), x0, outer_impl=coo.masked_outer_fun(row, col), ) grad_data = coo.symmetrize(grad_data, row, col, size) return (grad_data, None, None, None, None, None)
def lobpcg_rev(res, g): grad_w, grad_v = g w, v, data, row, col = res size = v.shape[0] A = coo.matmul_fun(data, row, col, jnp.zeros((size, ))) x0 = jax.random.normal(jax.random.PRNGKey(0), shape=v.shape, dtype=v.dtype) grad_data, x0 = eigh_partial_rev(grad_w, grad_v, w, v, A, x0, outer_impl=coo.masked_outer_fun( row, col)) grad_data = coo.symmetrize(grad_data, row, col, size) return grad_data, None, None, None, None, None
def chebyshev_subspace_iteration_coo( data: jnp.ndarray, row: jnp.ndarray, col: jnp.ndarray, sized: jnp.ndarray, v0: jnp.ndarray, tol: Optional[float] = None, max_iters: int = 1000, scale: float = 1.0, order: int = 8, ): a = coo.matmul_fun(data, row, col, sized) w, v, info = si.chebyshev_subspace_iteration(order, scale, a, v0, tol=tol, max_iters=max_iters) del info v = utils.standardize_signs(v) return w, v
def eigh_partial_rev_coo( grad_w: jnp.ndarray, grad_v: jnp.ndarray, w: jnp.ndarray, v: jnp.ndarray, l0: jnp.ndarray, data: jnp.ndarray, row: jnp.ndarray, col: jnp.ndarray, sized: jnp.ndarray, tol: float = 1e-5, ): """ Args: grad_w: [k] gradient w.r.t eigenvalues grad_v: [m, k] gradient w.r.t eigenvectors w: [k] eigenvalues. v: [m, k] eigenvectors. l0: initial solution to least squares problem. data, row, col, sized: coo formatted [m, m] matrix. seed: used to initialize conjugate gradient solution. Returns: grad_data: gradient of `data` input, same `shape` and `dtype`. l: solution to least squares problem. """ outer_impl = coo.masked_outer_fun(row, col) a = coo.matmul_fun(data, row, col, sized) grad_data, l0 = cg.eigh_partial_rev(grad_w, grad_v, w, v, l0, a, outer_impl, tol=tol) return grad_data, l0