def testSparsifyValue(self): X = jnp.arange(5) X_BCOO = BCOO.fromdense(X) args = (X, X_BCOO, X_BCOO) # Independent index spenv = SparsifyEnv() spvalues = arrays_to_spvalues(spenv, args) self.assertEqual(len(spvalues), len(args)) self.assertLen(spenv._buffers, 5) self.assertEqual( spvalues, (SparsifyValue(X.shape, 0, None), SparsifyValue( X.shape, 1, 2), SparsifyValue(X.shape, 3, 4))) args_out = spvalues_to_arrays(spenv, spvalues) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) # Shared index spvalues = (SparsifyValue(X.shape, 0, None), SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 2)) spenv = SparsifyEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data]) args_out = spvalues_to_arrays(spenv, spvalues) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2])
def testSparseMul(self, shape, dtype, n_batch, n_dense, unique_indices): rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) x = BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) # Scalar multiplication scalar = 2 y = self.sparsify(operator.mul)(x, scalar) self.assertArraysEqual(x.todense() * scalar, y.todense()) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0, unique_indices=unique_indices), spenv.sparse(y.shape, data_ref=2, indices_ref=0, unique_indices=unique_indices) ] result = sparsify_raw(operator.mul)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense())
def testSparseMul(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Scalar multiplication out = self.sparsify(operator.mul)(x, 2.5) self.assertArraysEqual(out.todense(), x.todense() * 2.5) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0), spenv.sparse(y.shape, data_ref=2, indices_ref=0) ] result = sparsify_raw(operator.mul)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense())
def testSparseSubtract(self): x = BCOO.fromdense(3 * jnp.arange(5)) y = BCOO.fromdense(jnp.arange(5)) # Distinct indices out = self.sparsify(operator.sub)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 2 * jnp.arange(5)) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0), spenv.sparse(y.shape, data_ref=2, indices_ref=0) ] result = sparsify_raw(operator.sub)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() - y.todense())