Exemple #1
0
    def testSparseCondSimple(self):
        def func(x):
            return lax.cond(False, lambda x: x, lambda x: 2 * x, x)

        x = jnp.arange(5.0)
        result_dense = func(x)

        x_bcoo = BCOO.fromdense(x)
        result_sparse = sparsify(func)(x_bcoo)

        self.assertArraysAllClose(result_dense, result_sparse.todense())
Exemple #2
0
    def testSparsifySparseXlaCall(self):
        # Test sparse lowering of XLA call
        def func(M):
            return 2 * M

        M = jnp.arange(6).reshape(2, 3)
        Msp = BCOO.fromdense(M)

        out_dense = func(M)
        out_sparse = sparsify(jit(func))(Msp)
        self.assertArraysEqual(out_dense, out_sparse.todense())
Exemple #3
0
    def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense):
        rng = jtu.rand_default(self.rng())

        M_dense = rng(shape, np.float32)
        M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense)
        func = sparsify(partial(lax.squeeze, dimensions=dimensions))

        result_dense = func(M_dense)
        result_sparse = func(M_sparse).todense()

        self.assertAllClose(result_sparse, result_dense)
Exemple #4
0
  def testSparseForiLoop(self):
    def func(M, x):
      body_fun = lambda i, val: (M @ val) / M.shape[1]
      return lax.fori_loop(0, 2, body_fun, x)

    x = jnp.arange(5.0)
    M = jnp.arange(25).reshape(5, 5)
    M_bcoo = BCOO.fromdense(M)

    result_dense = func(M, x)
    result_sparse = sparsify(func)(M_bcoo, x)

    self.assertArraysAllClose(result_dense, result_sparse)
Exemple #5
0
    def testSparseSum(self):
        x = jnp.arange(20).reshape(4, 5)
        xsp = BCOO.fromdense(x)

        def f(x):
            return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1))

        result_dense = f(x)
        result_sparse = sparsify(f)(xsp)

        assert len(result_dense) == len(result_sparse)

        for res_dense, res_sparse in zip(result_dense, result_sparse):
            if isinstance(res_sparse, BCOO):
                res_sparse = res_sparse.todense()
            self.assertArraysAllClose(res_dense, res_sparse)
Exemple #6
0
    def testSparseMul(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Scalar multiplication
        out = sparsify(operator.mul)(x, 2.5)
        self.assertArraysEqual(out.todense(), x.todense() * 2.5)

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.mul)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())
Exemple #7
0
    def testSparseAdd(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Distinct indices
        out = sparsify(operator.add)(x, y)
        self.assertEqual(out.nnz, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 3 * jnp.arange(5))

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.add)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() + y.todense())
Exemple #8
0
  def testSparseWhileLoop(self):
    def cond_fun(params):
      i, A = params
      return i < 5

    def body_fun(params):
      i, A = params
      return i + 1, 2 * A

    def f(A):
      return lax.while_loop(cond_fun, body_fun, (0, A))

    A = jnp.arange(4)
    out_dense = f(A)

    Asp = BCOO.fromdense(A)
    out_sparse = sparsify(f)(Asp)

    self.assertEqual(len(out_dense), 2)
    self.assertEqual(len(out_sparse), 2)
    self.assertArraysEqual(out_dense[0], out_dense[0])
    self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
Exemple #9
0
 def sparsify(cls, f):
     return sparsify(f, use_tracer=True)
Exemple #10
0
 def sparsify(cls, f):
     return sparsify(f, use_tracer=False)
Exemple #11
0
 def testUnitHandling(self):
   x = BCOO.fromdense(jnp.arange(5))
   f = jit(lambda x, y: x)
   result = sparsify(jit(f))(x, core.unit)
   self.assertBcooIdentical(result, x)
Exemple #12
0
 def testSparsifyDenseXlaCall(self):
   # Test handling of dense xla_call within jaxpr interpreter.
   out = sparsify(jit(lambda x: x + 1))(0.0)
   self.assertEqual(out, 1.0)