Ejemplo n.º 1
0
  def test_bcoo_dedupe(self, shape, dtype, n_batch, n_dense):
    rng = self.rng()
    rng_sparse = rand_sparse(self.rng())
    M = rng_sparse(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
    for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
      indices = indices.at[..., i, :].set(rng.randint(0, s, size=indices.shape[-1]))
    data2, indices2 = _dedupe_bcoo(data, indices)
    M1 = sparse.bcoo_todense(data, indices, shape=shape)
    M2 = sparse.bcoo_todense(data2, indices2, shape=shape)

    self.assertAllClose(M1, M2)
Ejemplo n.º 2
0
  def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    M1 = sparse.bcoo_todense(data, indices[:1], shape=M.shape)
    M2 = sparse.bcoo_todense(data, jnp.stack(shape[0] * [indices[0]]), shape=M.shape)
    self.assertAllClose(M1, M2)

    M3 = sparse.bcoo_todense(data[:1], indices, shape=M.shape)
    M4 = sparse.bcoo_todense(jnp.stack(shape[0] * [data[0]]), indices, shape=M.shape)
    self.assertAllClose(M3, M4)
Ejemplo n.º 3
0
def _todense_sparse_rule(spenv, argspec, *, tree):
    del tree  # TODO(jakvdp): we should assert that tree is PytreeDef(*)
    out = sparse.bcoo_todense(argspec.data(spenv),
                              argspec.indices(spenv),
                              shape=argspec.shape)
    out_argspec = sparse.transform.ArgSpec(argspec.shape, spenv.push(out),
                                           None)
    return (out_argspec, )
Ejemplo n.º 4
0
 def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
   rng = rand_sparse(self.rng())
   M = rng(shape, dtype)
   data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
   data_out, indices_out, shape_out = sparse.bcoo_reduce_sum(data, indices, shape=shape, axes=axes)
   result_dense = M.sum(axes)
   result_sparse = sparse.bcoo_todense(data_out, indices_out, shape=shape_out)
   tol = {np.float32: 1E-6, np.float64: 1E-14}
   self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
Ejemplo n.º 5
0
  def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype,
                                          dimension_numbers, n_batch, n_dense):
    rng = jtu.rand_small(self.rng())
    rng_sparse = rand_sparse(self.rng())

    X = rng_sparse(lhs_shape, dtype)
    data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense)
    Y = rng(rhs_shape, dtype)

    def f_dense(X, Y):
      return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)

    def f_sparse(data, indices, Y):
      return sparse.bcoo_dot_general(data, indices, Y, lhs_shape=X.shape,
                                        dimension_numbers=dimension_numbers)

    for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]):
      X = sparse.bcoo_todense(data, indices, shape=X.shape)
      self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y))