コード例 #1
0
ファイル: sparse_ops_test.py プロジェクト: zhangqiaorjc/jax
    def test_coo_matvec_ad(self, shape, dtype, bshape):
        tol = {
            np.float32: 1E-6,
            np.float64: 1E-13,
            np.complex64: 1E-6,
            np.complex128: 1E-13
        }

        rng = rand_sparse(self.rng(), post=jnp.array)
        rng_b = jtu.rand_default(self.rng())

        M = rng(shape, dtype)
        data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
        x = rng_b(bshape, dtype)
        xdot = rng_b(bshape, dtype)

        # Forward-mode with respect to the vector
        f_dense = lambda x: M @ x
        f_sparse = lambda x: sparse_ops.coo_matvec(
            data, row, col, x, shape=M.shape)
        v_sparse, t_sparse = api.jvp(f_sparse, [x], [xdot])
        v_dense, t_dense = api.jvp(f_dense, [x], [xdot])
        self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
        self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

        # Reverse-mode with respect to the vector
        primals_dense, vjp_dense = api.vjp(f_dense, x)
        primals_sparse, vjp_sparse = api.vjp(f_sparse, x)
        out_dense, = vjp_dense(primals_dense)
        out_sparse, = vjp_sparse(primals_sparse)
        self.assertAllClose(primals_dense[0],
                            primals_sparse[0],
                            atol=tol,
                            rtol=tol)
        self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)

        # Forward-mode with respect to nonzero elements of the matrix
        f_sparse = lambda data: sparse_ops.coo_matvec(
            data, row, col, x, shape=M.shape)
        f_dense = lambda data: sparse_ops.coo_todense(
            data, row, col, shape=M.shape) @ x
        data = rng((len(data), ), data.dtype)
        data_dot = rng((len(data), ), data.dtype)
        v_sparse, t_sparse = api.jvp(f_sparse, [data], [data_dot])
        v_dense, t_dense = api.jvp(f_dense, [data], [data_dot])

        self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
        self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

        # Reverse-mode with respect to nonzero elements of the matrix
        primals_dense, vjp_dense = api.vjp(f_dense, data)
        primals_sparse, vjp_sparse = api.vjp(f_sparse, data)
        out_dense, = vjp_dense(primals_dense)
        out_sparse, = vjp_sparse(primals_sparse)
        self.assertAllClose(primals_dense[0],
                            primals_sparse[0],
                            atol=tol,
                            rtol=tol)
        self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
コード例 #2
0
  def test_coo_todense(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
    M = rng(shape, dtype)

    args = (M.data, M.row, M.col)
    todense = lambda *args: sparse_ops.coo_todense(*args, shape=M.shape)

    self.assertArraysEqual(M.toarray(), todense(*args))
    self.assertArraysEqual(M.toarray(), jit(todense)(*args))
コード例 #3
0
  def test_coo_todense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
    f = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape)

    # Forward-mode
    primals, tangents = api.jvp(f, [data], [jnp.ones_like(data)])
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))

    # Reverse-mode
    primals, vjp_fun = api.vjp(f, data)
    data_out, = vjp_fun(primals)
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(data_out, data)