def eigh_partial_coo(data, indices, indptr, k: int, largest: bool): size = indptr.size - 1 data = csr.symmetrize(data, indices) a = csr.to_dense(data, indices, indptr, (size, size)) w, v = eigh_partial(a, k, largest) v = standardize_eigenvector_signs(v) return w, v
def eigh_csr_rev(res, g): grad_w, grad_v = g w, v, indices, indptr = res grad_data = cg.eigh_rev(grad_w, grad_v, w, v, csr.masked_matmul_fun(indices, indptr)) grad_data = csr.symmetrize(grad_data, indices) return (grad_data, None, None)
def eigh_csr(data, indices, indptr): size = indptr.size - 1 data = csr.symmetrize(data, indices) a = csr.to_dense(data, indices, indptr, (size, size)) w, v = jnp.linalg.eigh(a) v = standardize_eigenvector_signs(v) return w, v
def test_symmetrize(self): n = 50 shape = (n, n) rng = np.random.default_rng(0) csr_mat = random_csr(rng, shape, sparsity=0.1) csr_mat = ((csr_mat + csr_mat.T) / 2).tocsr() csr_mat.sum_duplicates() actual = csr.symmetrize(csr_mat.data, csr_mat.indices) self.assertAllClose(actual, csr_mat.data)
def eigh_partial_rev(res, g): w, v, data, indices, indptr = res grad_w, grad_v = g rng_key = jax.random.PRNGKey(0) x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype) grad_data, x0 = cg.eigh_partial_rev( grad_w, grad_v, w, v, csr.matmul_fun(data, indices, indptr), x0, outer_impl=csr.masked_outer_fun(indices, indptr), ) grad_data = csr.symmetrize(grad_data, indices) return grad_data, None, None, None, None
def lobpcg_rev(res, g): grad_w, grad_v = g w, v, data, indices, indptr = res A = csr.matmul_fun(data, indices, indptr) x0 = jax.random.normal(jax.random.PRNGKey(0), shape=v.shape, dtype=v.dtype) grad_data, x0 = eigh_partial_rev( grad_w, grad_v, w, v, A, x0, outer_impl=csr.masked_outer_fun(indices, indptr), ) grad_data = csr.symmetrize(grad_data, indices) return grad_data, None, None, None, None, None
def lobpcg_csr(data, indices, indptr, X0, largest, k): data = csr.symmetrize(data, indices) A = csr.matmul_fun(data, indices, indptr) w, v = lobpcg(A, X0, largest=largest, k=k) v = standardize_eigenvector_signs(v) return w, v