Esempio n. 1
0
  def testSparseWhileLoop(self):
    def cond_fun(params):
      i, A = params
      return i < 5

    def body_fun(params):
      i, A = params
      return i + 1, 2 * A

    def f(A):
      return lax.while_loop(cond_fun, body_fun, (0, A))

    A = jnp.arange(4)
    out_dense = f(A)

    Asp = BCOO.fromdense(A)
    out_sparse = sparsify(f)(Asp)

    self.assertEqual(len(out_dense), 2)
    self.assertEqual(len(out_sparse), 2)
    self.assertArraysEqual(out_dense[0], out_dense[0])
    self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
Esempio n. 2
0
    def testSparsifyValue(self):
        X = jnp.arange(5)
        X_BCOO = BCOO.fromdense(X)

        args = (X, X_BCOO, X_BCOO)

        # Independent index
        spenv = SparsifyEnv()
        spvalues = arrays_to_spvalues(spenv, args)
        self.assertEqual(len(spvalues), len(args))
        self.assertLen(spenv._buffers, 5)
        self.assertEqual(
            spvalues,
            (SparsifyValue(
                X.shape, 0, None, indices_sorted=False, unique_indices=False),
             SparsifyValue(
                 X.shape, 1, 2, indices_sorted=True, unique_indices=True),
             SparsifyValue(
                 X.shape, 3, 4, indices_sorted=True, unique_indices=True)))

        args_out = spvalues_to_arrays(spenv, spvalues)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])

        # Shared index
        spvalues = (SparsifyValue(X.shape, 0, None),
                    SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 2))
        spenv = SparsifyEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data])

        args_out = spvalues_to_arrays(spenv, spvalues)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])
Esempio n. 3
0
 def testUnitHandling(self):
     x = BCOO.fromdense(jnp.arange(5))
     f = jit(lambda x, y: x)
     result = self.sparsify(jit(f))(x, core.unit)
     self.assertBcooIdentical(result, x)
Esempio n. 4
0
 def testSparseConcatenate(self, shapes, func, n_batch):
     f = self.sparsify(getattr(jnp, func))
     rng = jtu.rand_some_zero(self.rng())
     arrs = [rng(shape, 'int32') for shape in shapes]
     sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
     self.assertArraysEqual(f(arrs), f(sparrs).todense())