Пример #1
0
    def testSparseMul(self, shape, dtype, n_batch, n_dense, unique_indices):
        rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
        x = BCOO.fromdense(rng_sparse(shape, dtype),
                           n_batch=n_batch,
                           n_dense=n_dense)

        # Scalar multiplication
        scalar = 2
        y = self.sparsify(operator.mul)(x, scalar)
        self.assertArraysEqual(x.todense() * scalar, y.todense())

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape,
                         data_ref=1,
                         indices_ref=0,
                         unique_indices=unique_indices),
            spenv.sparse(y.shape,
                         data_ref=2,
                         indices_ref=0,
                         unique_indices=unique_indices)
        ]

        result = sparsify_raw(operator.mul)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())
Пример #2
0
    def testSparseMul(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Scalar multiplication
        out = self.sparsify(operator.mul)(x, 2.5)
        self.assertArraysEqual(out.todense(), x.todense() * 2.5)

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.mul)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())
Пример #3
0
    def testSparseSubtract(self):
        x = BCOO.fromdense(3 * jnp.arange(5))
        y = BCOO.fromdense(jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.sub)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 2 * jnp.arange(5))

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.sub)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() - y.todense())
Пример #4
0
    def testSparseMul(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Scalar multiplication
        out = self.sparsify(operator.mul)(x, 2.5)
        self.assertArraysEqual(out.todense(), x.todense() * 2.5)

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape, data_ref=1, indices_ref=0),
            spenv.sparse(y.shape, data_ref=2, indices_ref=0)
        ]

        result = sparsify_raw(operator.mul)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())
Пример #5
0
    def testSparseSubtract(self):
        x = BCOO.fromdense(3 * jnp.arange(5))
        y = BCOO.fromdense(jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.sub)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 2 * jnp.arange(5))

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape, data_ref=1, indices_ref=0),
            spenv.sparse(y.shape, data_ref=2, indices_ref=0)
        ]

        result = sparsify_raw(operator.sub)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() - y.todense())