def test_eigh_general(self): m = 50 dtype = np.float64 largest = False A = random_spd_coo(m, dtype=dtype).todense() B = random_spd_coo(m, dtype=dtype, seed=1).todense() w, v = eigh_general(A, B, largest) self.assertAllClose(A @ v, B @ v * w, rtol=1e-6) self.assertAllClose(v.T @ B @ v, jnp.eye(m), rtol=1e-6, atol=1e-8)
def test_eigh_vjp(self): n = 20 dtype = np.float64 a = random_spd_coo(n=n, dtype=dtype) a = a.todense() def eigh(a): w, v = jax.numpy.linalg.eigh(a) v = standardize_eigenvector_signs(v) return w, v def eigh_fwd(a): w, v = eigh(a) return (w, v), (w, v) def eigh_rev(res, g): grad_w, grad_v = g w, v = res grad_a = cg.eigh_rev(grad_w, grad_v, w, v) grad_a = symmetrize(grad_a) return (grad_a, ) eigh_fun = jax.custom_vjp(eigh) eigh_fun.defvjp(eigh_fwd, eigh_rev) jtu.check_grads(eigh_fun, (a, ), order=1, modes="rev", rtol=1e-3) w, v = eigh(a) self.assertAllClose(a @ v, v * w, rtol=1e-6)
def test_eigh_partial_vjp(self): dtype = np.float64 n = 20 k = 4 largest = False a = random_spd_coo(n, dtype=dtype).todense() def eigh_partial_fwd(a, k: int, largest: bool): w, v = eigh_partial(a, k, largest) return (w, v), (w, v, a) def eigh_partial_rev(res, g): w, v, a = res grad_w, grad_v = g rng_key = jax.random.PRNGKey(0) x0 = jax.random.normal(rng_key, v.shape, dtype=v.dtype) grad_a, x0 = cg.eigh_partial_rev(grad_w, grad_v, w, v, a, x0) grad_a = symmetrize(grad_a) return (grad_a, None, None) eigh_partial_fun = jax.custom_vjp(eigh_partial) eigh_partial_fun.defvjp(eigh_partial_fwd, eigh_partial_rev) jtu.check_grads( partial(eigh_partial_fun, k=k, largest=largest), (a, ), 1, modes=["rev"], rtol=1e-3, )
def test_lobpcg(self): m = 50 k = 10 dtype = np.float64 A = random_spd_coo(m, dtype=dtype) rng = np.random.default_rng(0) X0 = rng.uniform(size=(m, k)).astype(dtype) B = None iK = None E_expected, X_expected = eigh_general(A.todense(), B, False) E_expected = E_expected[:k] X_expected = X_expected[:, :k] data, row, col, shape = coo_components(A.tocoo()) A_coo = coo.matmul_fun( jnp.asarray(data), jnp.asarray(row), jnp.asarray(col), jnp.zeros(((shape[0], ))), ) data, indices, indptr, _ = csr_components(A.tocsr()) A_csr = csr.matmul_fun(jnp.asarray(data), jnp.asarray(indices), jnp.asarray(indptr)) for A_fun in (jnp.asarray(A.todense()), A_coo, A_csr): E_actual, X_actual = lobpcg(A=A_fun, B=B, X0=X0, iK=iK, largest=False, max_iters=200) X_actual = standardize_eigenvector_signs(X_actual) self.assertAllClose(E_expected, E_actual, rtol=1e-8, atol=1e-10) self.assertAllClose(X_expected, X_actual, rtol=1e-4, atol=1e-10)
def test_eigh_partial_coo_vjp(self): dtype = np.float64 n = 20 k = 4 largest = False a = random_spd_coo(n, dtype=dtype) def eigh_partial_coo(data, row, col, size, k: int, largest: bool): data = coo.symmetrize(data, row, col, size) a = coo.to_dense(data, row, col, (size, size)) w, v = eigh_partial(a, k, largest) v = standardize_eigenvector_signs(v) return w, v def eigh_partial_fwd(data, row, col, size, k: int, largest: bool): w, v = eigh_partial_coo(data, row, col, size, k, largest) return (w, v), (w, v, data, row, col) def eigh_partial_rev(res, g): w, v, data, row, col = res size = v.shape[0] grad_w, grad_v = g rng_key = jax.random.PRNGKey(0) x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype) grad_data, x0 = cg.eigh_partial_rev( grad_w, grad_v, w, v, coo.matmul_fun(data, row, col, jnp.zeros((size, ))), x0, outer_impl=coo.masked_outer_fun(row, col), ) grad_data = coo.symmetrize(grad_data, row, col, size) return (grad_data, None, None, None, None, None) eigh_partial_fn = jax.custom_vjp(eigh_partial_coo) eigh_partial_fn.defvjp(eigh_partial_fwd, eigh_partial_rev) data, row, col, _ = coo_components(a) self.assertTrue(coo.is_symmetric(row, col, data)) self.assertTrue(coo.is_ordered(row, col)) jtu.check_grads( partial(eigh_partial_fn, k=k, largest=largest, row=row, col=col, size=n), (data, ), 1, modes=["rev"], rtol=1e-3, )
def test_lobpcg_coo_vjp(self): m = 50 k = 10 largest = False dtype = np.float64 def lobpcg_coo(data, row, col, X0, largest, k): size = X0.shape[0] data = coo.symmetrize(data, row, col, size) A = coo.matmul_fun(data, row, col, jnp.zeros((size, ))) w, v = lobpcg(A, X0, largest=largest, k=k) v = standardize_eigenvector_signs(v) return w, v def lobpcg_fwd(data, row, col, X0, largest, k): w, v = lobpcg_coo(data, row, col, X0, largest, k) return (w, v), (w, v, data, row, col) def lobpcg_rev(res, g): grad_w, grad_v = g w, v, data, row, col = res size = v.shape[0] A = coo.matmul_fun(data, row, col, jnp.zeros((size, ))) x0 = jax.random.normal(jax.random.PRNGKey(0), shape=v.shape, dtype=v.dtype) grad_data, x0 = eigh_partial_rev(grad_w, grad_v, w, v, A, x0, outer_impl=coo.masked_outer_fun( row, col)) grad_data = coo.symmetrize(grad_data, row, col, size) return grad_data, None, None, None, None, None lobpcg_fun = jax.custom_vjp(lobpcg_coo) lobpcg_fun.defvjp(lobpcg_fwd, lobpcg_rev) rng = np.random.default_rng(0) A = random_spd_coo(m, sparsity=0.1, dtype=dtype) data, row, col, _ = coo_components(A) X0 = rng.uniform(size=(m, k)).astype(dtype) jtu.check_grads( partial(lobpcg_fun, row=row, col=col, X0=X0, largest=largest, k=k), (data, ), order=1, modes=["rev"], rtol=2e-3, )
def test_lobpcg_vjp(self): m = 50 k = 10 largest = False dtype = np.float64 def lobpcg_simple(A, X0, largest, k): A = symmetrize(A) w, v = lobpcg(A, X0, largest=largest, k=k) v = standardize_eigenvector_signs(v) return w, v def lobpcg_fwd(A, X0, largest, k): w, v = lobpcg_simple(A, X0, largest, k) return (w, v), (w, v, A) def lobpcg_rev(res, g): grad_w, grad_v = g w, v, a = res x0 = jax.random.normal(jax.random.PRNGKey(0), shape=v.shape, dtype=v.dtype) grad_a, x0 = eigh_partial_rev(grad_w, grad_v, w, v, a, x0) grad_a = symmetrize(grad_a) return grad_a, None, None, None lobpcg_fun = jax.custom_vjp(lobpcg_simple) lobpcg_fun.defvjp(lobpcg_fwd, lobpcg_rev) A = random_spd_coo(m, dtype=dtype).todense() rng = np.random.default_rng(0) X0 = rng.uniform(size=(m, k)).astype(dtype) jtu.check_grads( partial(lobpcg_fun, X0=X0, largest=largest, k=k), (A, ), order=1, modes=["rev"], rtol=1e-3, )
def test_eigh_coo_vjp(self): n = 20 dtype = np.float64 a = random_spd_coo(n=n, dtype=dtype) def eigh_coo(data, row, col, size): data = coo.symmetrize(data, row, col, size) a = coo.to_dense(data, row, col, (size, size)) w, v = jnp.linalg.eigh(a) v = standardize_eigenvector_signs(v) return w, v def eigh_coo_fwd(data, row, col, size): w, v = eigh_coo(data, row, col, size) return (w, v), (w, v, row, col) def eigh_coo_rev(res, g): grad_w, grad_v = g w, v, row, col = res size = v.shape[0] grad_data = cg.eigh_rev(grad_w, grad_v, w, v, coo.masked_matmul_fun(row, col)) grad_data = coo.symmetrize(grad_data, row, col, size) return (grad_data, None, None, None) eigh = jax.custom_vjp(eigh_coo) eigh.defvjp(eigh_coo_fwd, eigh_coo_rev) data, row, col, shape = coo_components(a) self.assertTrue(coo.is_symmetric(row, col, data, shape)) jtu.check_grads( partial(eigh, row=row, col=col, size=n), (data, ), order=1, modes="rev", rtol=1e-3, )