def test_bcoo_dot_general_ad(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): rng = jtu.rand_small(self.rng()) rng_sparse = rand_sparse(self.rng()) X = rng_sparse(lhs_shape, dtype) data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense) Y = rng(rhs_shape, dtype) def f_dense(Y): return lax.dot_general(X, Y, dimension_numbers=dimension_numbers) def f_sparse(Y): return sparse.bcoo_dot_general(data, indices, Y, lhs_shape=X.shape, dimension_numbers=dimension_numbers) jf_dense = jax.jacfwd(f_dense)(Y) jr_dense = jax.jacrev(f_dense)(Y) jf_sparse = jax.jacfwd(f_sparse)(Y) jr_sparse = jax.jacrev(f_sparse)(Y) tol = {} if jtu.device_under_test() == "tpu": tol = {np.float32: 5E-3} self.assertAllClose(jf_dense, jf_sparse, rtol=tol) self.assertAllClose(jr_dense, jr_sparse, rtol=tol) self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
def test_bcoo_extract(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) data, indices = sparse.bcoo_fromdense(M) data2 = sparse.bcoo_extract(indices, M) self.assertArraysEqual(data, data2) data3 = jit(sparse.bcoo_extract)(indices, M) self.assertArraysEqual(data, data3)
def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes): rng = rand_sparse(self.rng()) M = rng(shape, dtype) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) data_out, indices_out, shape_out = sparse.bcoo_reduce_sum(data, indices, shape=shape, axes=axes) result_dense = M.sum(axes) result_sparse = sparse.bcoo_todense(data_out, indices_out, shape=shape_out) tol = {np.float32: 1E-6, np.float64: 1E-14} self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
def test_bcoo_dedupe(self, shape, dtype, n_batch, n_dense): rng = self.rng() rng_sparse = rand_sparse(self.rng()) M = rng_sparse(shape, dtype) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) for i, s in enumerate(shape[n_batch:len(shape) - n_dense]): indices = indices.at[..., i, :].set(rng.randint(0, s, size=indices.shape[-1])) data2, indices2 = _dedupe_bcoo(data, indices) M1 = sparse.bcoo_todense(data, indices, shape=shape) M2 = sparse.bcoo_todense(data2, indices2, shape=shape) self.assertAllClose(M1, M2)
def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) M1 = sparse.bcoo_todense(data, indices[:1], shape=M.shape) M2 = sparse.bcoo_todense(data, jnp.stack(shape[0] * [indices[0]]), shape=M.shape) self.assertAllClose(M1, M2) M3 = sparse.bcoo_todense(data[:1], indices, shape=M.shape) M4 = sparse.bcoo_todense(jnp.stack(shape[0] * [data[0]]), indices, shape=M.shape) self.assertAllClose(M3, M4)
def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) extract = partial(sparse.bcoo_extract, indices) j1 = jax.jacfwd(extract)(M) j2 = jax.jacrev(extract)(M) hess = jax.hessian(extract)(M) self.assertArraysAllClose(j1, j2) self.assertEqual(j1.shape, data.shape + M.shape) self.assertEqual(hess.shape, data.shape + 2 * M.shape)
def test_bcoo_todense_ad(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) todense = partial(sparse.bcoo_todense, indices=indices, shape=shape) j1 = jax.jacfwd(todense)(data) j2 = jax.jacrev(todense)(data) hess = jax.hessian(todense)(data) self.assertArraysAllClose(j1, j2) self.assertEqual(j1.shape, M.shape + data.shape) self.assertEqual(hess.shape, M.shape + 2 * data.shape)
def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) n_sparse = M.ndim - n_batch - n_dense nse = int(_bcoo_nse(M, n_batch=n_batch, n_dense=n_dense)) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) # TODO: test fromdense JIT assert data.dtype == dtype assert data.shape == shape[:n_batch] + (nse,) + shape[n_batch + n_sparse:] assert indices.dtype == jnp.int32 # TODO: test passing this arg assert indices.shape == shape[:n_batch] + (n_sparse, nse) todense = partial(sparse.bcoo_todense, shape=shape) self.assertArraysEqual(M, todense(data, indices)) self.assertArraysEqual(M, jit(todense)(data, indices))
def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): rng = jtu.rand_small(self.rng()) rng_sparse = rand_sparse(self.rng()) X = rng_sparse(lhs_shape, dtype) data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense) Y = rng(rhs_shape, dtype) def f_dense(X, Y): return lax.dot_general(X, Y, dimension_numbers=dimension_numbers) def f_sparse(data, indices, Y): return sparse.bcoo_dot_general(data, indices, Y, lhs_shape=X.shape, dimension_numbers=dimension_numbers) for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]): X = sparse.bcoo_todense(data, indices, shape=X.shape) self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y))
def args_maker(): lhs = rng(lhs_shape, dtype) rhs = rng_sparse(rhs_shape, dtype) data, indices = sparse.bcoo_fromdense(rhs, n_batch=n_batch, n_dense=n_dense) return data, indices, lhs, rhs
def fromdense(M): return sparse.bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)[0]