def test_coo_matvec_ad(self, shape, dtype, bshape): tol = { np.float32: 1E-6, np.float64: 1E-13, np.complex64: 1E-6, np.complex128: 1E-13 } rng = rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) M = rng(shape, dtype) data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum()) x = rng_b(bshape, dtype) xdot = rng_b(bshape, dtype) # Forward-mode with respect to the vector f_dense = lambda x: M @ x f_sparse = lambda x: sparse_ops.coo_matvec( data, row, col, x, shape=M.shape) v_sparse, t_sparse = api.jvp(f_sparse, [x], [xdot]) v_dense, t_dense = api.jvp(f_dense, [x], [xdot]) self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol) self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol) # Reverse-mode with respect to the vector primals_dense, vjp_dense = api.vjp(f_dense, x) primals_sparse, vjp_sparse = api.vjp(f_sparse, x) out_dense, = vjp_dense(primals_dense) out_sparse, = vjp_sparse(primals_sparse) self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol) self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol) # Forward-mode with respect to nonzero elements of the matrix f_sparse = lambda data: sparse_ops.coo_matvec( data, row, col, x, shape=M.shape) f_dense = lambda data: sparse_ops.coo_todense( data, row, col, shape=M.shape) @ x data = rng((len(data), ), data.dtype) data_dot = rng((len(data), ), data.dtype) v_sparse, t_sparse = api.jvp(f_sparse, [data], [data_dot]) v_dense, t_dense = api.jvp(f_dense, [data], [data_dot]) self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol) self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol) # Reverse-mode with respect to nonzero elements of the matrix primals_dense, vjp_dense = api.vjp(f_dense, data) primals_sparse, vjp_sparse = api.vjp(f_sparse, data) out_dense, = vjp_dense(primals_dense) out_sparse, = vjp_sparse(primals_sparse) self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol) self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
def test_coo_todense_ad(self, shape, dtype): rng = rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum()) f = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape) # Forward-mode primals, tangents = api.jvp(f, [data], [jnp.ones_like(data)]) self.assertArraysEqual(primals, f(data)) self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1)) # Reverse-mode primals, vjp_fun = api.vjp(f, data) data_out, = vjp_fun(primals) self.assertArraysEqual(primals, f(data)) self.assertArraysEqual(data_out, data)
def test_coo_fromdense(self, shape, dtype): rng = rand_sparse(self.rng()) M = rng(shape, dtype) M_coo = sparse.coo_matrix(M) nnz = M_coo.nnz index_dtype = jnp.int32 fromdense = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz, index_dtype=jnp.int32) data, row, col = fromdense(M) self.assertArraysEqual(data, M_coo.data.astype(dtype)) self.assertArraysEqual(row, M_coo.row.astype(index_dtype)) self.assertArraysEqual(col, M_coo.col.astype(index_dtype)) data, indices, indptr = jit(fromdense)(M) self.assertArraysEqual(data, M_coo.data.astype(dtype)) self.assertArraysEqual(row, M_coo.row.astype(index_dtype)) self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
def test_coo_fromdense_ad(self, shape, dtype): rng = rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) nnz = (M != 0).sum() f = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz) # Forward-mode primals, tangents = api.jvp(f, [M], [jnp.ones_like(M)]) self.assertArraysEqual(primals[0], f(M)[0]) self.assertArraysEqual(primals[1], f(M)[1]) self.assertArraysEqual(primals[2], f(M)[2]) self.assertArraysEqual(tangents[0], jnp.ones(nnz, dtype=dtype)) self.assertEqual(tangents[1].dtype, dtypes.float0) self.assertEqual(tangents[2].dtype, dtypes.float0) # Reverse-mode primals, vjp_fun = api.vjp(f, M) M_out, = vjp_fun(primals) self.assertArraysEqual(primals[0], f(M)[0]) self.assertArraysEqual(primals[1], f(M)[1]) self.assertArraysEqual(primals[2], f(M)[2]) self.assertArraysEqual(M_out, M)