コード例 #1
0
 def eigh_partial_rev(res, g):
     w, v, a = res
     grad_w, grad_v = g
     rng_key = jax.random.PRNGKey(0)
     x0 = jax.random.normal(rng_key, v.shape, dtype=v.dtype)
     grad_a, x0 = cg.eigh_partial_rev(grad_w, grad_v, w, v, a, x0)
     grad_a = symmetrize(grad_a)
     return (grad_a, None, None)
コード例 #2
0
 def lobpcg_rev(res, g):
     grad_w, grad_v = g
     w, v, a = res
     x0 = jax.random.normal(jax.random.PRNGKey(0),
                            shape=v.shape,
                            dtype=v.dtype)
     grad_a, x0 = eigh_partial_rev(grad_w, grad_v, w, v, a, x0)
     grad_a = symmetrize(grad_a)
     return grad_a, None, None, None
コード例 #3
0
 def eigh_partial_rev(res, g):
     w, v, data, indices, indptr = res
     grad_w, grad_v = g
     rng_key = jax.random.PRNGKey(0)
     x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype)
     grad_data, x0 = cg.eigh_partial_rev(
         grad_w,
         grad_v,
         w,
         v,
         csr.matmul_fun(data, indices, indptr),
         x0,
         outer_impl=csr.masked_outer_fun(indices, indptr),
     )
     grad_data = csr.symmetrize(grad_data, indices)
     return grad_data, None, None, None, None
コード例 #4
0
 def eigh_partial_rev(res, g):
     w, v, data, row, col = res
     size = v.shape[0]
     grad_w, grad_v = g
     rng_key = jax.random.PRNGKey(0)
     x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype)
     grad_data, x0 = cg.eigh_partial_rev(
         grad_w,
         grad_v,
         w,
         v,
         coo.matmul_fun(data, row, col, jnp.zeros((size, ))),
         x0,
         outer_impl=coo.masked_outer_fun(row, col),
     )
     grad_data = coo.symmetrize(grad_data, row, col, size)
     return (grad_data, None, None, None, None, None)
コード例 #5
0
 def lobpcg_rev(res, g):
     grad_w, grad_v = g
     w, v, data, indices, indptr = res
     A = csr.matmul_fun(data, indices, indptr)
     x0 = jax.random.normal(jax.random.PRNGKey(0),
                            shape=v.shape,
                            dtype=v.dtype)
     grad_data, x0 = eigh_partial_rev(
         grad_w,
         grad_v,
         w,
         v,
         A,
         x0,
         outer_impl=csr.masked_outer_fun(indices, indptr),
     )
     grad_data = csr.symmetrize(grad_data, indices)
     return grad_data, None, None, None, None, None
コード例 #6
0
 def lobpcg_rev(res, g):
     grad_w, grad_v = g
     w, v, data, row, col = res
     size = v.shape[0]
     A = coo.matmul_fun(data, row, col, jnp.zeros((size, )))
     x0 = jax.random.normal(jax.random.PRNGKey(0),
                            shape=v.shape,
                            dtype=v.dtype)
     grad_data, x0 = eigh_partial_rev(grad_w,
                                      grad_v,
                                      w,
                                      v,
                                      A,
                                      x0,
                                      outer_impl=coo.masked_outer_fun(
                                          row, col))
     grad_data = coo.symmetrize(grad_data, row, col, size)
     return grad_data, None, None, None, None, None
コード例 #7
0
def eigh_partial_rev_coo(
    grad_w: jnp.ndarray,
    grad_v: jnp.ndarray,
    w: jnp.ndarray,
    v: jnp.ndarray,
    l0: jnp.ndarray,
    data: jnp.ndarray,
    row: jnp.ndarray,
    col: jnp.ndarray,
    sized: jnp.ndarray,
    tol: float = 1e-5,
):
    """
    Args:
        grad_w: [k] gradient w.r.t eigenvalues
        grad_v: [m, k] gradient w.r.t eigenvectors
        w: [k] eigenvalues.
        v: [m, k] eigenvectors.
        l0: initial solution to least squares problem.
        data, row, col, sized: coo formatted [m, m] matrix.
        seed: used to initialize conjugate gradient solution.

    Returns:
        grad_data: gradient of `data` input, same `shape` and `dtype`.
        l: solution to least squares problem.
    """
    outer_impl = coo.masked_outer_fun(row, col)
    a = coo.matmul_fun(data, row, col, sized)
    grad_data, l0 = cg.eigh_partial_rev(grad_w,
                                        grad_v,
                                        w,
                                        v,
                                        l0,
                                        a,
                                        outer_impl,
                                        tol=tol)
    return grad_data, l0