def testArgSpec(self): X = jnp.arange(5) X_BCOO = BCOO.fromdense(X) args = (X, X_BCOO, X_BCOO) # Independent index spenv = SparseEnv() argspecs = arrays_to_argspecs(spenv, args) self.assertEqual(len(argspecs), len(args)) self.assertEqual(spenv.size(), 5) self.assertEqual(argspecs, (ArgSpec( X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 4))) args_out = argspecs_to_arrays(spenv, argspecs) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) # Shared index argspecs = (ArgSpec(X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 2)) spenv = SparseEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data]) args_out = argspecs_to_arrays(spenv, argspecs) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2])
def testSparseMul(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Scalar multiplication out = self.sparsify(operator.mul)(x, 2.5) self.assertArraysEqual(out.todense(), x.todense() * 2.5) # Shared indices – requires lower level call argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.mul)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense())
def testSparseSubtract(self): x = BCOO.fromdense(3 * jnp.arange(5)) y = BCOO.fromdense(jnp.arange(5)) # Distinct indices out = self.sparsify(operator.sub)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 2 * jnp.arange(5)) # Shared indices – requires lower level call argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.sub)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() - y.todense())