コード例 #1
0
ファイル: sparsify_test.py プロジェクト: ahoenselaar/jax
    def testArgSpec(self):
        X = jnp.arange(5)
        X_BCOO = BCOO.fromdense(X)

        args = (X, X_BCOO, X_BCOO)

        # Independent index
        spenv = SparseEnv()
        argspecs = arrays_to_argspecs(spenv, args)
        self.assertEqual(len(argspecs), len(args))
        self.assertEqual(spenv.size(), 5)
        self.assertEqual(argspecs, (ArgSpec(
            X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 4)))

        args_out = argspecs_to_arrays(spenv, argspecs)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])

        # Shared index
        argspecs = (ArgSpec(X.shape, 0,
                            None), ArgSpec(X.shape, 1,
                                           2), ArgSpec(X.shape, 3, 2))
        spenv = SparseEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data])

        args_out = argspecs_to_arrays(spenv, argspecs)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])
コード例 #2
0
ファイル: sparsify_test.py プロジェクト: ahoenselaar/jax
    def testSparseMul(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Scalar multiplication
        out = self.sparsify(operator.mul)(x, 2.5)
        self.assertArraysEqual(out.todense(), x.todense() * 2.5)

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.mul)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())
コード例 #3
0
ファイル: sparsify_test.py プロジェクト: ahoenselaar/jax
    def testSparseSubtract(self):
        x = BCOO.fromdense(3 * jnp.arange(5))
        y = BCOO.fromdense(jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.sub)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 2 * jnp.arange(5))

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.sub)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() - y.todense())