Exemplo n.º 1
0
    def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype,
                                            dimension_numbers, n_batch,
                                            n_dense):
        rng = jtu.rand_small(self.rng())
        rng_sparse = rand_sparse(self.rng())

        X = rng_sparse(lhs_shape, dtype)
        data, indices = sparse_ops.bcoo_fromdense(X,
                                                  n_batch=n_batch,
                                                  n_dense=n_dense)
        Y = rng(rhs_shape, dtype)

        def f_dense(X, Y):
            return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)

        def f_sparse(data, indices, Y):
            return sparse_ops.bcoo_dot_general(
                data,
                indices,
                Y,
                lhs_shape=X.shape,
                dimension_numbers=dimension_numbers)

        for data, indices in itertools.product([data, data[:1]],
                                               [indices, indices[:1]]):
            X = sparse_ops.bcoo_todense(data, indices, shape=X.shape)
            self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y))
Exemplo n.º 2
0
    def test_bcoo_dot_general_ad(self, lhs_shape, rhs_shape, dtype,
                                 dimension_numbers, n_batch, n_dense):
        rng = jtu.rand_small(self.rng())
        rng_sparse = rand_sparse(self.rng())

        X = rng_sparse(lhs_shape, dtype)
        data, indices = sparse_ops.bcoo_fromdense(X,
                                                  n_batch=n_batch,
                                                  n_dense=n_dense)
        Y = rng(rhs_shape, dtype)

        def f_dense(Y):
            return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)

        def f_sparse(Y):
            return sparse_ops.bcoo_dot_general(
                data,
                indices,
                Y,
                lhs_shape=X.shape,
                dimension_numbers=dimension_numbers)

        jf_dense = jax.jacfwd(f_dense)(Y)
        jr_dense = jax.jacrev(f_dense)(Y)
        jf_sparse = jax.jacfwd(f_sparse)(Y)
        jr_sparse = jax.jacrev(f_sparse)(Y)

        tol = {}
        if jtu.device_under_test() == "tpu":
            tol = {np.float32: 5E-3}

        self.assertAllClose(jf_dense, jf_sparse, rtol=tol)
        self.assertAllClose(jr_dense, jr_sparse, rtol=tol)
        self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
Exemplo n.º 3
0
 def args_maker():
     lhs = rng(lhs_shape, dtype)
     rhs = rng_sparse(rhs_shape, dtype)
     data, indices = sparse_ops.bcoo_fromdense(rhs,
                                               n_batch=n_batch,
                                               n_dense=n_dense)
     return data, indices, lhs, rhs
Exemplo n.º 4
0
 def test_bcoo_extract(self, shape, dtype, n_batch, n_dense):
     rng = rand_sparse(self.rng())
     M = rng(shape, dtype)
     data, indices = sparse_ops.bcoo_fromdense(M)
     data2 = sparse_ops.bcoo_extract(indices, M)
     self.assertArraysEqual(data, data2)
     data3 = jit(sparse_ops.bcoo_extract)(indices, M)
     self.assertArraysEqual(data, data3)
Exemplo n.º 5
0
    def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        data, indices = sparse_ops.bcoo_fromdense(M,
                                                  n_batch=n_batch,
                                                  n_dense=n_dense)

        extract = partial(sparse_ops.bcoo_extract, indices)
        j1 = jax.jacfwd(extract)(M)
        j2 = jax.jacrev(extract)(M)
        hess = jax.hessian(extract)(M)
        self.assertArraysAllClose(j1, j2)
        self.assertEqual(j1.shape, data.shape + M.shape)
        self.assertEqual(hess.shape, data.shape + 2 * M.shape)
Exemplo n.º 6
0
 def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
     rng = rand_sparse(self.rng())
     M = rng(shape, dtype)
     data, indices = sparse_ops.bcoo_fromdense(M,
                                               n_batch=n_batch,
                                               n_dense=n_dense)
     data_out, indices_out, shape_out = sparse_ops.bcoo_reduce_sum(
         data, indices, shape=shape, axes=axes)
     result_dense = M.sum(axes)
     result_sparse = sparse_ops.bcoo_todense(data_out,
                                             indices_out,
                                             shape=shape_out)
     tol = {np.float32: 1E-6, np.float64: 1E-14}
     self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
Exemplo n.º 7
0
    def test_bcoo_dedupe(self, shape, dtype, n_batch, n_dense):
        rng = self.rng()
        rng_sparse = rand_sparse(self.rng())
        M = rng_sparse(shape, dtype)
        data, indices = sparse_ops.bcoo_fromdense(M,
                                                  n_batch=n_batch,
                                                  n_dense=n_dense)
        for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
            indices = indices.at[..., i, :].set(
                rng.randint(0, s, size=indices.shape[-1]))
        data2, indices2 = sparse_ops._dedupe_bcoo(data, indices)
        M1 = sparse_ops.bcoo_todense(data, indices, shape=shape)
        M2 = sparse_ops.bcoo_todense(data2, indices2, shape=shape)

        self.assertAllClose(M1, M2)
Exemplo n.º 8
0
    def test_bcoo_todense_ad(self, shape, dtype, n_batch, n_dense):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        data, indices = sparse_ops.bcoo_fromdense(M,
                                                  n_batch=n_batch,
                                                  n_dense=n_dense)

        todense = partial(sparse_ops.bcoo_todense,
                          indices=indices,
                          shape=shape)
        j1 = jax.jacfwd(todense)(data)
        j2 = jax.jacrev(todense)(data)
        hess = jax.hessian(todense)(data)
        self.assertArraysAllClose(j1, j2)
        self.assertEqual(j1.shape, M.shape + data.shape)
        self.assertEqual(hess.shape, M.shape + 2 * data.shape)
Exemplo n.º 9
0
    def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        data, indices = sparse_ops.bcoo_fromdense(M,
                                                  n_batch=n_batch,
                                                  n_dense=n_dense)

        M1 = sparse_ops.bcoo_todense(data, indices[:1], shape=M.shape)
        M2 = sparse_ops.bcoo_todense(data,
                                     jnp.stack(shape[0] * [indices[0]]),
                                     shape=M.shape)
        self.assertAllClose(M1, M2)

        M3 = sparse_ops.bcoo_todense(data[:1], indices, shape=M.shape)
        M4 = sparse_ops.bcoo_todense(jnp.stack(shape[0] * [data[0]]),
                                     indices,
                                     shape=M.shape)
        self.assertAllClose(M3, M4)
Exemplo n.º 10
0
    def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        n_sparse = M.ndim - n_batch - n_dense
        nse = int(sparse_ops._bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))
        data, indices = sparse_ops.bcoo_fromdense(M,
                                                  n_batch=n_batch,
                                                  n_dense=n_dense)
        # TODO: test fromdense JIT

        assert data.dtype == dtype
        assert data.shape == shape[:n_batch] + (nse, ) + shape[n_batch +
                                                               n_sparse:]
        assert indices.dtype == jnp.int32  # TODO: test passing this arg
        assert indices.shape == shape[:n_batch] + (n_sparse, nse)

        todense = partial(sparse_ops.bcoo_todense, shape=shape)
        self.assertArraysEqual(M, todense(data, indices))
        self.assertArraysEqual(M, jit(todense)(data, indices))
Exemplo n.º 11
0
 def fromdense(M):
     return sparse_ops.bcoo_fromdense(M,
                                      nse=nse,
                                      n_batch=n_batch,
                                      n_dense=n_dense)[0]