def test_bcoo_dedupe(self, shape, dtype, n_batch, n_dense): rng = self.rng() rng_sparse = rand_sparse(self.rng()) M = rng_sparse(shape, dtype) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) for i, s in enumerate(shape[n_batch:len(shape) - n_dense]): indices = indices.at[..., i, :].set(rng.randint(0, s, size=indices.shape[-1])) data2, indices2 = _dedupe_bcoo(data, indices) M1 = sparse.bcoo_todense(data, indices, shape=shape) M2 = sparse.bcoo_todense(data2, indices2, shape=shape) self.assertAllClose(M1, M2)
def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) M1 = sparse.bcoo_todense(data, indices[:1], shape=M.shape) M2 = sparse.bcoo_todense(data, jnp.stack(shape[0] * [indices[0]]), shape=M.shape) self.assertAllClose(M1, M2) M3 = sparse.bcoo_todense(data[:1], indices, shape=M.shape) M4 = sparse.bcoo_todense(jnp.stack(shape[0] * [data[0]]), indices, shape=M.shape) self.assertAllClose(M3, M4)
def _todense_sparse_rule(spenv, argspec, *, tree): del tree # TODO(jakvdp): we should assert that tree is PytreeDef(*) out = sparse.bcoo_todense(argspec.data(spenv), argspec.indices(spenv), shape=argspec.shape) out_argspec = sparse.transform.ArgSpec(argspec.shape, spenv.push(out), None) return (out_argspec, )
def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes): rng = rand_sparse(self.rng()) M = rng(shape, dtype) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) data_out, indices_out, shape_out = sparse.bcoo_reduce_sum(data, indices, shape=shape, axes=axes) result_dense = M.sum(axes) result_sparse = sparse.bcoo_todense(data_out, indices_out, shape=shape_out) tol = {np.float32: 1E-6, np.float64: 1E-14} self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): rng = jtu.rand_small(self.rng()) rng_sparse = rand_sparse(self.rng()) X = rng_sparse(lhs_shape, dtype) data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense) Y = rng(rhs_shape, dtype) def f_dense(X, Y): return lax.dot_general(X, Y, dimension_numbers=dimension_numbers) def f_sparse(data, indices, Y): return sparse.bcoo_dot_general(data, indices, Y, lhs_shape=X.shape, dimension_numbers=dimension_numbers) for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]): X = sparse.bcoo_todense(data, indices, shape=X.shape) self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y))