def testSparseMatmul(self): X = jnp.arange(16).reshape(4, 4) Xsp = BCOO.fromdense(X) Y = jnp.ones(4) Ysp = BCOO.fromdense(Y) func = self.sparsify(operator.matmul) # dot_general with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = func(Xsp, Y) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse, result_dense) # rdot_general with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = func(Y, Xsp) result_dense = operator.matmul(Y, X) self.assertAllClose(result_sparse, result_dense) # spdot_general result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse.todense(), result_dense)
def spvalue_to_array(spvalue): if spvalue.is_sparse(): assert spvalue.indices_ref is not None return BCOO((spenv.data(spvalue), spenv.indices(spvalue)), shape=spvalue.shape) else: return spenv.data(spvalue)
def testSparseCondMismatchError(self): @self.sparsify def func(x, y): return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y)) x = jnp.arange(5.0) y = jnp.arange(5.0) x_bcoo = BCOO.fromdense(x) y_bcoo = BCOO.fromdense(y) func(x, y) # No error func(x_bcoo, y_bcoo) # No error with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"): func(x_bcoo, y)
def testSparsifyValue(self): X = jnp.arange(5) X_BCOO = BCOO.fromdense(X) args = (X, X_BCOO, X_BCOO) # Independent index spenv = SparsifyEnv() spvalues = arrays_to_spvalues(spenv, args) self.assertEqual(len(spvalues), len(args)) self.assertLen(spenv._buffers, 5) self.assertEqual( spvalues, (SparsifyValue(X.shape, 0, None), SparsifyValue( X.shape, 1, 2), SparsifyValue(X.shape, 3, 4))) args_out = spvalues_to_arrays(spenv, spvalues) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) # Shared index spvalues = (SparsifyValue(X.shape, 0, None), SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 2)) spenv = SparsifyEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data]) args_out = spvalues_to_arrays(spenv, spvalues) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2])
def testSparseWhileLoopDuplicateIndices(self): def cond_fun(params): i, A, B = params return i < 5 def body_fun(params): i, A, B = params # TODO(jakevdp): track shared indices through while loop & use this # version of the test, which requires shared indices in order for # the nse of the result to remain the same. # return i + 1, A, A + B # This version is fine without shared indices, and tests that we're # flattening non-shared indices consistently. return i + 1, B, A def f(A): return lax.while_loop(cond_fun, body_fun, (0, A, A)) A = jnp.arange(4).reshape((2, 2)) out_dense = f(A) Asp = BCOO.fromdense(A) out_sparse = self.sparsify(f)(Asp) self.assertEqual(len(out_dense), 3) self.assertEqual(len(out_sparse), 3) self.assertArraysEqual(out_dense[0], out_dense[0]) self.assertArraysEqual(out_dense[1], out_sparse[1].todense()) self.assertArraysEqual(out_dense[2], out_sparse[2].todense())
def testArgSpec(self): X = jnp.arange(5) X_BCOO = BCOO.fromdense(X) args = (X, X_BCOO, X_BCOO) # Independent index spenv = SparseEnv() argspecs = arrays_to_argspecs(spenv, args) self.assertEqual(len(argspecs), len(args)) self.assertEqual(spenv.size(), 5) self.assertEqual(argspecs, (ArgSpec( X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 4))) args_out = argspecs_to_arrays(spenv, argspecs) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) # Shared index argspecs = (ArgSpec(X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 2)) spenv = SparseEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data]) args_out = argspecs_to_arrays(spenv, argspecs) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2])
def testSparseMul(self, shape, dtype, n_batch, n_dense, unique_indices): rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) x = BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) # Scalar multiplication scalar = 2 y = self.sparsify(operator.mul)(x, scalar) self.assertArraysEqual(x.todense() * scalar, y.todense()) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0, unique_indices=unique_indices), spenv.sparse(y.shape, data_ref=2, indices_ref=0, unique_indices=unique_indices) ] result = sparsify_raw(operator.mul)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense())
def testPytreeInput(self): f = self.sparsify(lambda x: x) args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4))) out = f(args) self.assertLen(out, 2) self.assertArraysEqual(args[0], out[0]) self.assertBcooIdentical(args[1], out[1])
def argspec_to_array(argspec): if argspec.is_sparse(): assert argspec.indices_ref is not None return BCOO((argspec.data(spenv), argspec.indices(spenv)), shape=argspec.shape) elif argspec.is_unit(): return core.unit else: return argspec.data(spenv)
def testSparseMul(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Scalar multiplication out = self.sparsify(operator.mul)(x, 2.5) self.assertArraysEqual(out.todense(), x.todense() * 2.5) # Shared indices – requires lower level call argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.mul)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense())
def testSparseSubtract(self): x = BCOO.fromdense(3 * jnp.arange(5)) y = BCOO.fromdense(jnp.arange(5)) # Distinct indices out = self.sparsify(operator.sub)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 2 * jnp.arange(5)) # Shared indices – requires lower level call argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.sub)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() - y.todense())
def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense) arr2 = arr.reshape(new_shape) arr2_sparse = arr_sparse.reshape(new_shape) self.assertArraysEqual(arr2, arr2_sparse.todense())
def testWeakTypes(self): # Regression test for https://github.com/google/jax/issues/8267 M = jnp.arange(12, dtype='int32').reshape(3, 4) Msp = BCOO.fromdense(M) self.assertArraysEqual( operator.mul(2, M), self.sparsify(operator.mul)(2, Msp).todense(), check_dtypes=True, )
def testToDense(self): M = jnp.arange(4) Msp = BCOO.fromdense(M) @self.sparsify def func(M): return todense(M) + 1 self.assertArraysEqual(func(M), M + 1) self.assertArraysEqual(func(Msp), M + 1) self.assertArraysEqual(jit(func)(M), M + 1) self.assertArraysEqual(jit(func)(Msp), M + 1)
def testDropvar(self): def inner(x): return x * 2, x * 3 def f(x): _, y = jit(inner)(x) return y * 4 x_dense = jnp.arange(5) x_sparse = BCOO.fromdense(x_dense) self.assertArraysEqual(self.sparsify(f)(x_sparse).todense(), f(x_dense))
def testSparseMul(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Scalar multiplication out = self.sparsify(operator.mul)(x, 2.5) self.assertArraysEqual(out.todense(), x.todense() * 2.5) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0), spenv.sparse(y.shape, data_ref=2, indices_ref=0) ] result = sparsify_raw(operator.mul)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense())
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, mode=None, fill_value=None): # mirrors lax_numpy._rewriting_take. treedef, static_idx, dynamic_idx = lax_numpy._split_index_for_jit(idx, arr.shape) result = sparsify( lambda arr, idx: lax_numpy._gather(arr, treedef, static_idx, idx, indices_are_sorted, unique_indices, mode, fill_value))(arr, dynamic_idx) # Account for a corner case in the rewriting_take implementation. if not isinstance(result, BCOO) and np.size(result) == 0: result = BCOO.fromdense(result) return result
def testSparseCondSimple(self): def func(x): return lax.cond(False, lambda x: x, lambda x: 2 * x, x) x = jnp.arange(5.0) result_dense = func(x) x_bcoo = BCOO.fromdense(x) result_sparse = self.sparsify(func)(x_bcoo) self.assertArraysAllClose(result_dense, result_sparse.todense())
def testSparsifySparseXlaCall(self): # Test sparse lowering of XLA call def func(M): return 2 * M M = jnp.arange(6).reshape(2, 3) Msp = BCOO.fromdense(M) out_dense = func(M) out_sparse = self.sparsify(jit(func))(Msp) self.assertArraysEqual(out_dense, out_sparse.todense())
def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense): rng = jtu.rand_default(self.rng()) M_dense = rng(shape, np.float32) M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense) func = self.sparsify(partial(lax.squeeze, dimensions=dimensions)) result_dense = func(M_dense) result_sparse = func(M_sparse).todense() self.assertAllClose(result_sparse, result_dense)
def testSparseMatmul(self): X = jnp.arange(16).reshape(4, 4) Xsp = BCOO.fromdense(X) Y = jnp.ones(4) Ysp = BCOO.fromdense(Y) # dot_general result_sparse = self.sparsify(operator.matmul)(Xsp, Y) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse, result_dense) # rdot_general result_sparse = self.sparsify(operator.matmul)(Y, Xsp) result_dense = operator.matmul(Y, X) self.assertAllClose(result_sparse, result_dense) # spdot_general result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse.todense(), result_dense)
def testSparseSubtract(self): x = BCOO.fromdense(3 * jnp.arange(5)) y = BCOO.fromdense(jnp.arange(5)) # Distinct indices out = self.sparsify(operator.sub)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 2 * jnp.arange(5)) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0), spenv.sparse(y.shape, data_ref=2, indices_ref=0) ] result = sparsify_raw(operator.sub)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() - y.todense())
def testSparsifyWithConsts(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) @self.sparsify def func(x): return jit(lambda x: jnp.sum(x, 1))(x) result_dense = func(M_dense) result_sparse = func(M_sparse) self.assertAllClose(result_sparse.todense(), result_dense)
def argspecs_to_arrays( spenv: SparseEnv, argspecs: Sequence[ArgSpec], ) -> Sequence[AnyArray]: args = [] for argspec in argspecs: if argspec.is_sparse(): assert argspec.indices_ref is not None args.append(BCOO((argspec.data(spenv), argspec.indices(spenv)), shape=argspec.shape)) else: args.append(argspec.data(spenv)) assert args[-1].shape == argspec.shape return tuple(args)
def testSparseForiLoop(self): def func(M, x): body_fun = lambda i, val: (M @ val) / M.shape[1] return lax.fori_loop(0, 2, body_fun, x) x = jnp.arange(5.0) M = jnp.arange(25).reshape(5, 5) M_bcoo = BCOO.fromdense(M) result_dense = func(M, x) result_sparse = self.sparsify(func)(M_bcoo, x) self.assertArraysAllClose(result_dense, result_sparse)
def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch, n_dense, dimensions): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense) f = self.sparsify( lambda x: lax.reshape(x, new_shape, dimensions=dimensions)) arr2 = f(arr) arr2_sparse = f(arr_sparse) self.assertArraysEqual(arr2, arr2_sparse.todense())
def testSparsify(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) v = jnp.arange(M_dense.shape[0]) @self.sparsify def func(x, v): return -jnp.sin(jnp.pi * x).T @ (v + 1) result_dense = func(M_dense, v) result_sparse = func(M_sparse, v) self.assertAllClose(result_sparse, result_dense)
def testNotImplementedMessages(self): x = BCOO.fromdense(jnp.arange(5.0)) # Test a densifying primitive with self.assertRaisesRegex( NotImplementedError, r"^sparse rule for cos is not implemented because it would result in dense output\." ): self.sparsify(lax.cos)(x) # Test a generic not implemented primitive. with self.assertRaisesRegex( NotImplementedError, r"^sparse rule for complex is not implemented\.$"): self.sparsify(lax.complex)(x, x)
def testSparseSum(self): x = jnp.arange(20).reshape(4, 5) xsp = BCOO.fromdense(x) def f(x): return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1)) result_dense = f(x) result_sparse = self.sparsify(f)(xsp) assert len(result_dense) == len(result_sparse) for res_dense, res_sparse in zip(result_dense, result_sparse): if isinstance(res_sparse, BCOO): res_sparse = res_sparse.todense() self.assertArraysAllClose(res_dense, res_sparse)
def testSparsify(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) v = jnp.arange(M_dense.shape[0]) @self.sparsify def func(x, v): return -jnp.sin(jnp.pi * x).T @ (v + 1) with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = func(M_sparse, v) result_dense = func(M_dense, v) self.assertAllClose(result_sparse, result_dense)