예제 #1
0
  def test_bcoo_dot_general_ad(self, lhs_shape, rhs_shape, dtype,
                               dimension_numbers, n_batch, n_dense):
    rng = jtu.rand_small(self.rng())
    rng_sparse = rand_sparse(self.rng())

    X = rng_sparse(lhs_shape, dtype)
    data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense)
    Y = rng(rhs_shape, dtype)

    def f_dense(Y):
      return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)

    def f_sparse(Y):
      return sparse.bcoo_dot_general(data, indices, Y, lhs_shape=X.shape,
                                        dimension_numbers=dimension_numbers)

    jf_dense = jax.jacfwd(f_dense)(Y)
    jr_dense = jax.jacrev(f_dense)(Y)
    jf_sparse = jax.jacfwd(f_sparse)(Y)
    jr_sparse = jax.jacrev(f_sparse)(Y)

    tol = {}
    if jtu.device_under_test() == "tpu":
      tol = {np.float32: 5E-3}

    self.assertAllClose(jf_dense, jf_sparse, rtol=tol)
    self.assertAllClose(jr_dense, jr_sparse, rtol=tol)
    self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
예제 #2
0
 def test_bcoo_extract(self, shape, dtype, n_batch, n_dense):
   rng = rand_sparse(self.rng())
   M = rng(shape, dtype)
   data, indices = sparse.bcoo_fromdense(M)
   data2 = sparse.bcoo_extract(indices, M)
   self.assertArraysEqual(data, data2)
   data3 = jit(sparse.bcoo_extract)(indices, M)
   self.assertArraysEqual(data, data3)
예제 #3
0
 def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
   rng = rand_sparse(self.rng())
   M = rng(shape, dtype)
   data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
   data_out, indices_out, shape_out = sparse.bcoo_reduce_sum(data, indices, shape=shape, axes=axes)
   result_dense = M.sum(axes)
   result_sparse = sparse.bcoo_todense(data_out, indices_out, shape=shape_out)
   tol = {np.float32: 1E-6, np.float64: 1E-14}
   self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
예제 #4
0
  def test_bcoo_dedupe(self, shape, dtype, n_batch, n_dense):
    rng = self.rng()
    rng_sparse = rand_sparse(self.rng())
    M = rng_sparse(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
    for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
      indices = indices.at[..., i, :].set(rng.randint(0, s, size=indices.shape[-1]))
    data2, indices2 = _dedupe_bcoo(data, indices)
    M1 = sparse.bcoo_todense(data, indices, shape=shape)
    M2 = sparse.bcoo_todense(data2, indices2, shape=shape)

    self.assertAllClose(M1, M2)
예제 #5
0
  def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    M1 = sparse.bcoo_todense(data, indices[:1], shape=M.shape)
    M2 = sparse.bcoo_todense(data, jnp.stack(shape[0] * [indices[0]]), shape=M.shape)
    self.assertAllClose(M1, M2)

    M3 = sparse.bcoo_todense(data[:1], indices, shape=M.shape)
    M4 = sparse.bcoo_todense(jnp.stack(shape[0] * [data[0]]), indices, shape=M.shape)
    self.assertAllClose(M3, M4)
예제 #6
0
  def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    extract = partial(sparse.bcoo_extract, indices)
    j1 = jax.jacfwd(extract)(M)
    j2 = jax.jacrev(extract)(M)
    hess = jax.hessian(extract)(M)
    self.assertArraysAllClose(j1, j2)
    self.assertEqual(j1.shape, data.shape + M.shape)
    self.assertEqual(hess.shape, data.shape + 2 * M.shape)
예제 #7
0
  def test_bcoo_todense_ad(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    todense = partial(sparse.bcoo_todense, indices=indices, shape=shape)
    j1 = jax.jacfwd(todense)(data)
    j2 = jax.jacrev(todense)(data)
    hess = jax.hessian(todense)(data)
    self.assertArraysAllClose(j1, j2)
    self.assertEqual(j1.shape, M.shape + data.shape)
    self.assertEqual(hess.shape, M.shape + 2 * data.shape)
예제 #8
0
  def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    n_sparse = M.ndim - n_batch - n_dense
    nse = int(_bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
    # TODO: test fromdense JIT

    assert data.dtype == dtype
    assert data.shape == shape[:n_batch] + (nse,) + shape[n_batch + n_sparse:]
    assert indices.dtype == jnp.int32  # TODO: test passing this arg
    assert indices.shape == shape[:n_batch] + (n_sparse, nse)

    todense = partial(sparse.bcoo_todense, shape=shape)
    self.assertArraysEqual(M, todense(data, indices))
    self.assertArraysEqual(M, jit(todense)(data, indices))
예제 #9
0
  def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype,
                                          dimension_numbers, n_batch, n_dense):
    rng = jtu.rand_small(self.rng())
    rng_sparse = rand_sparse(self.rng())

    X = rng_sparse(lhs_shape, dtype)
    data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense)
    Y = rng(rhs_shape, dtype)

    def f_dense(X, Y):
      return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)

    def f_sparse(data, indices, Y):
      return sparse.bcoo_dot_general(data, indices, Y, lhs_shape=X.shape,
                                        dimension_numbers=dimension_numbers)

    for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]):
      X = sparse.bcoo_todense(data, indices, shape=X.shape)
      self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y))
예제 #10
0
 def args_maker():
   lhs = rng(lhs_shape, dtype)
   rhs = rng_sparse(rhs_shape, dtype)
   data, indices = sparse.bcoo_fromdense(rhs, n_batch=n_batch, n_dense=n_dense)
   return data, indices, lhs, rhs
예제 #11
0
 def fromdense(M):
   return sparse.bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)[0]