Exemple #1
0
 def eigh_partial_coo(data, indices, indptr, k: int, largest: bool):
     size = indptr.size - 1
     data = csr.symmetrize(data, indices)
     a = csr.to_dense(data, indices, indptr, (size, size))
     w, v = eigh_partial(a, k, largest)
     v = standardize_eigenvector_signs(v)
     return w, v
Exemple #2
0
 def eigh_csr_rev(res, g):
     grad_w, grad_v = g
     w, v, indices, indptr = res
     grad_data = cg.eigh_rev(grad_w, grad_v, w, v,
                             csr.masked_matmul_fun(indices, indptr))
     grad_data = csr.symmetrize(grad_data, indices)
     return (grad_data, None, None)
Exemple #3
0
 def eigh_csr(data, indices, indptr):
     size = indptr.size - 1
     data = csr.symmetrize(data, indices)
     a = csr.to_dense(data, indices, indptr, (size, size))
     w, v = jnp.linalg.eigh(a)
     v = standardize_eigenvector_signs(v)
     return w, v
Exemple #4
0
    def test_symmetrize(self):
        n = 50
        shape = (n, n)
        rng = np.random.default_rng(0)
        csr_mat = random_csr(rng, shape, sparsity=0.1)
        csr_mat = ((csr_mat + csr_mat.T) / 2).tocsr()
        csr_mat.sum_duplicates()
        actual = csr.symmetrize(csr_mat.data, csr_mat.indices)

        self.assertAllClose(actual, csr_mat.data)
Exemple #5
0
 def eigh_partial_rev(res, g):
     w, v, data, indices, indptr = res
     grad_w, grad_v = g
     rng_key = jax.random.PRNGKey(0)
     x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype)
     grad_data, x0 = cg.eigh_partial_rev(
         grad_w,
         grad_v,
         w,
         v,
         csr.matmul_fun(data, indices, indptr),
         x0,
         outer_impl=csr.masked_outer_fun(indices, indptr),
     )
     grad_data = csr.symmetrize(grad_data, indices)
     return grad_data, None, None, None, None
Exemple #6
0
 def lobpcg_rev(res, g):
     grad_w, grad_v = g
     w, v, data, indices, indptr = res
     A = csr.matmul_fun(data, indices, indptr)
     x0 = jax.random.normal(jax.random.PRNGKey(0),
                            shape=v.shape,
                            dtype=v.dtype)
     grad_data, x0 = eigh_partial_rev(
         grad_w,
         grad_v,
         w,
         v,
         A,
         x0,
         outer_impl=csr.masked_outer_fun(indices, indptr),
     )
     grad_data = csr.symmetrize(grad_data, indices)
     return grad_data, None, None, None, None, None
Exemple #7
0
 def lobpcg_csr(data, indices, indptr, X0, largest, k):
     data = csr.symmetrize(data, indices)
     A = csr.matmul_fun(data, indices, indptr)
     w, v = lobpcg(A, X0, largest=largest, k=k)
     v = standardize_eigenvector_signs(v)
     return w, v