Exemple #1
0
    def test_eigh_general(self):
        m = 50
        dtype = np.float64
        largest = False
        A = random_spd_coo(m, dtype=dtype).todense()
        B = random_spd_coo(m, dtype=dtype, seed=1).todense()

        w, v = eigh_general(A, B, largest)
        self.assertAllClose(A @ v, B @ v * w, rtol=1e-6)
        self.assertAllClose(v.T @ B @ v, jnp.eye(m), rtol=1e-6, atol=1e-8)
Exemple #2
0
    def test_eigh_vjp(self):
        n = 20
        dtype = np.float64
        a = random_spd_coo(n=n, dtype=dtype)
        a = a.todense()

        def eigh(a):
            w, v = jax.numpy.linalg.eigh(a)
            v = standardize_eigenvector_signs(v)
            return w, v

        def eigh_fwd(a):
            w, v = eigh(a)
            return (w, v), (w, v)

        def eigh_rev(res, g):
            grad_w, grad_v = g
            w, v = res
            grad_a = cg.eigh_rev(grad_w, grad_v, w, v)
            grad_a = symmetrize(grad_a)
            return (grad_a, )

        eigh_fun = jax.custom_vjp(eigh)
        eigh_fun.defvjp(eigh_fwd, eigh_rev)
        jtu.check_grads(eigh_fun, (a, ), order=1, modes="rev", rtol=1e-3)
        w, v = eigh(a)
        self.assertAllClose(a @ v, v * w, rtol=1e-6)
Exemple #3
0
    def test_eigh_partial_vjp(self):
        dtype = np.float64
        n = 20
        k = 4
        largest = False
        a = random_spd_coo(n, dtype=dtype).todense()

        def eigh_partial_fwd(a, k: int, largest: bool):
            w, v = eigh_partial(a, k, largest)
            return (w, v), (w, v, a)

        def eigh_partial_rev(res, g):
            w, v, a = res
            grad_w, grad_v = g
            rng_key = jax.random.PRNGKey(0)
            x0 = jax.random.normal(rng_key, v.shape, dtype=v.dtype)
            grad_a, x0 = cg.eigh_partial_rev(grad_w, grad_v, w, v, a, x0)
            grad_a = symmetrize(grad_a)
            return (grad_a, None, None)

        eigh_partial_fun = jax.custom_vjp(eigh_partial)
        eigh_partial_fun.defvjp(eigh_partial_fwd, eigh_partial_rev)

        jtu.check_grads(
            partial(eigh_partial_fun, k=k, largest=largest),
            (a, ),
            1,
            modes=["rev"],
            rtol=1e-3,
        )
Exemple #4
0
    def test_lobpcg(self):
        m = 50
        k = 10
        dtype = np.float64
        A = random_spd_coo(m, dtype=dtype)
        rng = np.random.default_rng(0)
        X0 = rng.uniform(size=(m, k)).astype(dtype)

        B = None
        iK = None

        E_expected, X_expected = eigh_general(A.todense(), B, False)
        E_expected = E_expected[:k]
        X_expected = X_expected[:, :k]

        data, row, col, shape = coo_components(A.tocoo())
        A_coo = coo.matmul_fun(
            jnp.asarray(data),
            jnp.asarray(row),
            jnp.asarray(col),
            jnp.zeros(((shape[0], ))),
        )
        data, indices, indptr, _ = csr_components(A.tocsr())
        A_csr = csr.matmul_fun(jnp.asarray(data), jnp.asarray(indices),
                               jnp.asarray(indptr))
        for A_fun in (jnp.asarray(A.todense()), A_coo, A_csr):
            E_actual, X_actual = lobpcg(A=A_fun,
                                        B=B,
                                        X0=X0,
                                        iK=iK,
                                        largest=False,
                                        max_iters=200)
            X_actual = standardize_eigenvector_signs(X_actual)
            self.assertAllClose(E_expected, E_actual, rtol=1e-8, atol=1e-10)
            self.assertAllClose(X_expected, X_actual, rtol=1e-4, atol=1e-10)
Exemple #5
0
    def test_eigh_partial_coo_vjp(self):
        dtype = np.float64
        n = 20
        k = 4
        largest = False
        a = random_spd_coo(n, dtype=dtype)

        def eigh_partial_coo(data, row, col, size, k: int, largest: bool):
            data = coo.symmetrize(data, row, col, size)
            a = coo.to_dense(data, row, col, (size, size))
            w, v = eigh_partial(a, k, largest)
            v = standardize_eigenvector_signs(v)
            return w, v

        def eigh_partial_fwd(data, row, col, size, k: int, largest: bool):
            w, v = eigh_partial_coo(data, row, col, size, k, largest)
            return (w, v), (w, v, data, row, col)

        def eigh_partial_rev(res, g):
            w, v, data, row, col = res
            size = v.shape[0]
            grad_w, grad_v = g
            rng_key = jax.random.PRNGKey(0)
            x0 = jax.random.normal(rng_key, shape=v.shape, dtype=w.dtype)
            grad_data, x0 = cg.eigh_partial_rev(
                grad_w,
                grad_v,
                w,
                v,
                coo.matmul_fun(data, row, col, jnp.zeros((size, ))),
                x0,
                outer_impl=coo.masked_outer_fun(row, col),
            )
            grad_data = coo.symmetrize(grad_data, row, col, size)
            return (grad_data, None, None, None, None, None)

        eigh_partial_fn = jax.custom_vjp(eigh_partial_coo)
        eigh_partial_fn.defvjp(eigh_partial_fwd, eigh_partial_rev)

        data, row, col, _ = coo_components(a)
        self.assertTrue(coo.is_symmetric(row, col, data))
        self.assertTrue(coo.is_ordered(row, col))
        jtu.check_grads(
            partial(eigh_partial_fn,
                    k=k,
                    largest=largest,
                    row=row,
                    col=col,
                    size=n),
            (data, ),
            1,
            modes=["rev"],
            rtol=1e-3,
        )
Exemple #6
0
    def test_lobpcg_coo_vjp(self):
        m = 50
        k = 10
        largest = False
        dtype = np.float64

        def lobpcg_coo(data, row, col, X0, largest, k):
            size = X0.shape[0]
            data = coo.symmetrize(data, row, col, size)
            A = coo.matmul_fun(data, row, col, jnp.zeros((size, )))
            w, v = lobpcg(A, X0, largest=largest, k=k)
            v = standardize_eigenvector_signs(v)
            return w, v

        def lobpcg_fwd(data, row, col, X0, largest, k):
            w, v = lobpcg_coo(data, row, col, X0, largest, k)
            return (w, v), (w, v, data, row, col)

        def lobpcg_rev(res, g):
            grad_w, grad_v = g
            w, v, data, row, col = res
            size = v.shape[0]
            A = coo.matmul_fun(data, row, col, jnp.zeros((size, )))
            x0 = jax.random.normal(jax.random.PRNGKey(0),
                                   shape=v.shape,
                                   dtype=v.dtype)
            grad_data, x0 = eigh_partial_rev(grad_w,
                                             grad_v,
                                             w,
                                             v,
                                             A,
                                             x0,
                                             outer_impl=coo.masked_outer_fun(
                                                 row, col))
            grad_data = coo.symmetrize(grad_data, row, col, size)
            return grad_data, None, None, None, None, None

        lobpcg_fun = jax.custom_vjp(lobpcg_coo)
        lobpcg_fun.defvjp(lobpcg_fwd, lobpcg_rev)

        rng = np.random.default_rng(0)
        A = random_spd_coo(m, sparsity=0.1, dtype=dtype)
        data, row, col, _ = coo_components(A)

        X0 = rng.uniform(size=(m, k)).astype(dtype)
        jtu.check_grads(
            partial(lobpcg_fun, row=row, col=col, X0=X0, largest=largest, k=k),
            (data, ),
            order=1,
            modes=["rev"],
            rtol=2e-3,
        )
Exemple #7
0
    def test_lobpcg_vjp(self):
        m = 50
        k = 10
        largest = False
        dtype = np.float64

        def lobpcg_simple(A, X0, largest, k):
            A = symmetrize(A)
            w, v = lobpcg(A, X0, largest=largest, k=k)
            v = standardize_eigenvector_signs(v)
            return w, v

        def lobpcg_fwd(A, X0, largest, k):
            w, v = lobpcg_simple(A, X0, largest, k)
            return (w, v), (w, v, A)

        def lobpcg_rev(res, g):
            grad_w, grad_v = g
            w, v, a = res
            x0 = jax.random.normal(jax.random.PRNGKey(0),
                                   shape=v.shape,
                                   dtype=v.dtype)
            grad_a, x0 = eigh_partial_rev(grad_w, grad_v, w, v, a, x0)
            grad_a = symmetrize(grad_a)
            return grad_a, None, None, None

        lobpcg_fun = jax.custom_vjp(lobpcg_simple)
        lobpcg_fun.defvjp(lobpcg_fwd, lobpcg_rev)

        A = random_spd_coo(m, dtype=dtype).todense()
        rng = np.random.default_rng(0)
        X0 = rng.uniform(size=(m, k)).astype(dtype)
        jtu.check_grads(
            partial(lobpcg_fun, X0=X0, largest=largest, k=k),
            (A, ),
            order=1,
            modes=["rev"],
            rtol=1e-3,
        )
Exemple #8
0
    def test_eigh_coo_vjp(self):
        n = 20
        dtype = np.float64
        a = random_spd_coo(n=n, dtype=dtype)

        def eigh_coo(data, row, col, size):
            data = coo.symmetrize(data, row, col, size)
            a = coo.to_dense(data, row, col, (size, size))
            w, v = jnp.linalg.eigh(a)
            v = standardize_eigenvector_signs(v)
            return w, v

        def eigh_coo_fwd(data, row, col, size):
            w, v = eigh_coo(data, row, col, size)
            return (w, v), (w, v, row, col)

        def eigh_coo_rev(res, g):
            grad_w, grad_v = g
            w, v, row, col = res
            size = v.shape[0]
            grad_data = cg.eigh_rev(grad_w, grad_v, w, v,
                                    coo.masked_matmul_fun(row, col))
            grad_data = coo.symmetrize(grad_data, row, col, size)
            return (grad_data, None, None, None)

        eigh = jax.custom_vjp(eigh_coo)
        eigh.defvjp(eigh_coo_fwd, eigh_coo_rev)

        data, row, col, shape = coo_components(a)
        self.assertTrue(coo.is_symmetric(row, col, data, shape))
        jtu.check_grads(
            partial(eigh, row=row, col=col, size=n),
            (data, ),
            order=1,
            modes="rev",
            rtol=1e-3,
        )