def testSparseCondSimple(self): def func(x): return lax.cond(False, lambda x: x, lambda x: 2 * x, x) x = jnp.arange(5.0) result_dense = func(x) x_bcoo = BCOO.fromdense(x) result_sparse = sparsify(func)(x_bcoo) self.assertArraysAllClose(result_dense, result_sparse.todense())
def testSparsifySparseXlaCall(self): # Test sparse lowering of XLA call def func(M): return 2 * M M = jnp.arange(6).reshape(2, 3) Msp = BCOO.fromdense(M) out_dense = func(M) out_sparse = sparsify(jit(func))(Msp) self.assertArraysEqual(out_dense, out_sparse.todense())
def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense): rng = jtu.rand_default(self.rng()) M_dense = rng(shape, np.float32) M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense) func = sparsify(partial(lax.squeeze, dimensions=dimensions)) result_dense = func(M_dense) result_sparse = func(M_sparse).todense() self.assertAllClose(result_sparse, result_dense)
def testSparseForiLoop(self): def func(M, x): body_fun = lambda i, val: (M @ val) / M.shape[1] return lax.fori_loop(0, 2, body_fun, x) x = jnp.arange(5.0) M = jnp.arange(25).reshape(5, 5) M_bcoo = BCOO.fromdense(M) result_dense = func(M, x) result_sparse = sparsify(func)(M_bcoo, x) self.assertArraysAllClose(result_dense, result_sparse)
def testSparseSum(self): x = jnp.arange(20).reshape(4, 5) xsp = BCOO.fromdense(x) def f(x): return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1)) result_dense = f(x) result_sparse = sparsify(f)(xsp) assert len(result_dense) == len(result_sparse) for res_dense, res_sparse in zip(result_dense, result_sparse): if isinstance(res_sparse, BCOO): res_sparse = res_sparse.todense() self.assertArraysAllClose(res_dense, res_sparse)
def testSparseMul(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Scalar multiplication out = sparsify(operator.mul)(x, 2.5) self.assertArraysEqual(out.todense(), x.todense() * 2.5) # Shared indices – requires lower level call argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.mul)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense())
def testSparseAdd(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Distinct indices out = sparsify(operator.add)(x, y) self.assertEqual(out.nnz, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 3 * jnp.arange(5)) # Shared indices – requires lower level call argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.add)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() + y.todense())
def testSparseWhileLoop(self): def cond_fun(params): i, A = params return i < 5 def body_fun(params): i, A = params return i + 1, 2 * A def f(A): return lax.while_loop(cond_fun, body_fun, (0, A)) A = jnp.arange(4) out_dense = f(A) Asp = BCOO.fromdense(A) out_sparse = sparsify(f)(Asp) self.assertEqual(len(out_dense), 2) self.assertEqual(len(out_sparse), 2) self.assertArraysEqual(out_dense[0], out_dense[0]) self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
def sparsify(cls, f): return sparsify(f, use_tracer=True)
def sparsify(cls, f): return sparsify(f, use_tracer=False)
def testUnitHandling(self): x = BCOO.fromdense(jnp.arange(5)) f = jit(lambda x, y: x) result = sparsify(jit(f))(x, core.unit) self.assertBcooIdentical(result, x)
def testSparsifyDenseXlaCall(self): # Test handling of dense xla_call within jaxpr interpreter. out = sparsify(jit(lambda x: x + 1))(0.0) self.assertEqual(out, 1.0)