def eigh_partial_rev(res, g): w, v, a = res grad_w, grad_v = g rng_key = jax.random.PRNGKey(0) x0 = jax.random.normal(rng_key, v.shape, dtype=v.dtype) grad_a, x0 = cg.eigh_partial_rev(grad_w, grad_v, w, v, a, x0) grad_a = symmetrize(grad_a) return (grad_a, None, None)
def lobpcg_rev(res, g): grad_w, grad_v = g w, v, a = res x0 = jax.random.normal(jax.random.PRNGKey(0), shape=v.shape, dtype=v.dtype) grad_a, x0 = eigh_partial_rev(grad_w, grad_v, w, v, a, x0) grad_a = symmetrize(grad_a) return grad_a, None, None, None
def eigh_partial_rev(res, g): w, v, data, indices, indptr = res grad_w, grad_v = g rng_key = jax.random.PRNGKey(0) x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype) grad_data, x0 = cg.eigh_partial_rev( grad_w, grad_v, w, v, csr.matmul_fun(data, indices, indptr), x0, outer_impl=csr.masked_outer_fun(indices, indptr), ) grad_data = csr.symmetrize(grad_data, indices) return grad_data, None, None, None, None
def eigh_partial_rev(res, g): w, v, data, row, col = res size = v.shape[0] grad_w, grad_v = g rng_key = jax.random.PRNGKey(0) x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype) grad_data, x0 = cg.eigh_partial_rev( grad_w, grad_v, w, v, coo.matmul_fun(data, row, col, jnp.zeros((size, ))), x0, outer_impl=coo.masked_outer_fun(row, col), ) grad_data = coo.symmetrize(grad_data, row, col, size) return (grad_data, None, None, None, None, None)
def lobpcg_rev(res, g): grad_w, grad_v = g w, v, data, indices, indptr = res A = csr.matmul_fun(data, indices, indptr) x0 = jax.random.normal(jax.random.PRNGKey(0), shape=v.shape, dtype=v.dtype) grad_data, x0 = eigh_partial_rev( grad_w, grad_v, w, v, A, x0, outer_impl=csr.masked_outer_fun(indices, indptr), ) grad_data = csr.symmetrize(grad_data, indices) return grad_data, None, None, None, None, None
def lobpcg_rev(res, g): grad_w, grad_v = g w, v, data, row, col = res size = v.shape[0] A = coo.matmul_fun(data, row, col, jnp.zeros((size, ))) x0 = jax.random.normal(jax.random.PRNGKey(0), shape=v.shape, dtype=v.dtype) grad_data, x0 = eigh_partial_rev(grad_w, grad_v, w, v, A, x0, outer_impl=coo.masked_outer_fun( row, col)) grad_data = coo.symmetrize(grad_data, row, col, size) return grad_data, None, None, None, None, None
def eigh_partial_rev_coo( grad_w: jnp.ndarray, grad_v: jnp.ndarray, w: jnp.ndarray, v: jnp.ndarray, l0: jnp.ndarray, data: jnp.ndarray, row: jnp.ndarray, col: jnp.ndarray, sized: jnp.ndarray, tol: float = 1e-5, ): """ Args: grad_w: [k] gradient w.r.t eigenvalues grad_v: [m, k] gradient w.r.t eigenvectors w: [k] eigenvalues. v: [m, k] eigenvectors. l0: initial solution to least squares problem. data, row, col, sized: coo formatted [m, m] matrix. seed: used to initialize conjugate gradient solution. Returns: grad_data: gradient of `data` input, same `shape` and `dtype`. l: solution to least squares problem. """ outer_impl = coo.masked_outer_fun(row, col) a = coo.matmul_fun(data, row, col, sized) grad_data, l0 = cg.eigh_partial_rev(grad_w, grad_v, w, v, l0, a, outer_impl, tol=tol) return grad_data, l0