Ejemplo n.º 1
0
    def test_coo_matvec_ad(self, shape, dtype, bshape):
        tol = {
            np.float32: 1E-6,
            np.float64: 1E-13,
            np.complex64: 1E-6,
            np.complex128: 1E-13
        }

        rng = rand_sparse(self.rng(), post=jnp.array)
        rng_b = jtu.rand_default(self.rng())

        M = rng(shape, dtype)
        data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
        x = rng_b(bshape, dtype)
        xdot = rng_b(bshape, dtype)

        # Forward-mode with respect to the vector
        f_dense = lambda x: M @ x
        f_sparse = lambda x: sparse_ops.coo_matvec(
            data, row, col, x, shape=M.shape)
        v_sparse, t_sparse = api.jvp(f_sparse, [x], [xdot])
        v_dense, t_dense = api.jvp(f_dense, [x], [xdot])
        self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
        self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

        # Reverse-mode with respect to the vector
        primals_dense, vjp_dense = api.vjp(f_dense, x)
        primals_sparse, vjp_sparse = api.vjp(f_sparse, x)
        out_dense, = vjp_dense(primals_dense)
        out_sparse, = vjp_sparse(primals_sparse)
        self.assertAllClose(primals_dense[0],
                            primals_sparse[0],
                            atol=tol,
                            rtol=tol)
        self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)

        # Forward-mode with respect to nonzero elements of the matrix
        f_sparse = lambda data: sparse_ops.coo_matvec(
            data, row, col, x, shape=M.shape)
        f_dense = lambda data: sparse_ops.coo_todense(
            data, row, col, shape=M.shape) @ x
        data = rng((len(data), ), data.dtype)
        data_dot = rng((len(data), ), data.dtype)
        v_sparse, t_sparse = api.jvp(f_sparse, [data], [data_dot])
        v_dense, t_dense = api.jvp(f_dense, [data], [data_dot])

        self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
        self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

        # Reverse-mode with respect to nonzero elements of the matrix
        primals_dense, vjp_dense = api.vjp(f_dense, data)
        primals_sparse, vjp_sparse = api.vjp(f_sparse, data)
        out_dense, = vjp_dense(primals_dense)
        out_sparse, = vjp_sparse(primals_sparse)
        self.assertAllClose(primals_dense[0],
                            primals_sparse[0],
                            atol=tol,
                            rtol=tol)
        self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
Ejemplo n.º 2
0
  def test_coo_todense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
    f = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape)

    # Forward-mode
    primals, tangents = api.jvp(f, [data], [jnp.ones_like(data)])
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))

    # Reverse-mode
    primals, vjp_fun = api.vjp(f, data)
    data_out, = vjp_fun(primals)
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(data_out, data)
Ejemplo n.º 3
0
  def test_coo_fromdense(self, shape, dtype):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    M_coo = sparse.coo_matrix(M)

    nnz = M_coo.nnz
    index_dtype = jnp.int32
    fromdense = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz, index_dtype=jnp.int32)

    data, row, col = fromdense(M)
    self.assertArraysEqual(data, M_coo.data.astype(dtype))
    self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
    self.assertArraysEqual(col, M_coo.col.astype(index_dtype))

    data, indices, indptr = jit(fromdense)(M)
    self.assertArraysEqual(data, M_coo.data.astype(dtype))
    self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
    self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
Ejemplo n.º 4
0
  def test_coo_fromdense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    nnz = (M != 0).sum()
    f = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz)

    # Forward-mode
    primals, tangents = api.jvp(f, [M], [jnp.ones_like(M)])
    self.assertArraysEqual(primals[0], f(M)[0])
    self.assertArraysEqual(primals[1], f(M)[1])
    self.assertArraysEqual(primals[2], f(M)[2])
    self.assertArraysEqual(tangents[0], jnp.ones(nnz, dtype=dtype))
    self.assertEqual(tangents[1].dtype, dtypes.float0)
    self.assertEqual(tangents[2].dtype, dtypes.float0)

    # Reverse-mode
    primals, vjp_fun = api.vjp(f, M)
    M_out, = vjp_fun(primals)
    self.assertArraysEqual(primals[0], f(M)[0])
    self.assertArraysEqual(primals[1], f(M)[1])
    self.assertArraysEqual(primals[2], f(M)[2])
    self.assertArraysEqual(M_out, M)