예제 #1
0
    def testSparseMatmul(self):
        X = jnp.arange(16).reshape(4, 4)
        Xsp = BCOO.fromdense(X)
        Y = jnp.ones(4)
        Ysp = BCOO.fromdense(Y)

        func = self.sparsify(operator.matmul)

        # dot_general
        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = func(Xsp, Y)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse, result_dense)

        # rdot_general
        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = func(Y, Xsp)
        result_dense = operator.matmul(Y, X)
        self.assertAllClose(result_sparse, result_dense)

        # spdot_general
        result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse.todense(), result_dense)
예제 #2
0
 def spvalue_to_array(spvalue):
     if spvalue.is_sparse():
         assert spvalue.indices_ref is not None
         return BCOO((spenv.data(spvalue), spenv.indices(spvalue)),
                     shape=spvalue.shape)
     else:
         return spenv.data(spvalue)
예제 #3
0
  def testSparseCondMismatchError(self):
    @self.sparsify
    def func(x, y):
      return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y))

    x = jnp.arange(5.0)
    y = jnp.arange(5.0)

    x_bcoo = BCOO.fromdense(x)
    y_bcoo = BCOO.fromdense(y)

    func(x, y)  # No error
    func(x_bcoo, y_bcoo)  # No error

    with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"):
      func(x_bcoo, y)
예제 #4
0
    def testSparsifyValue(self):
        X = jnp.arange(5)
        X_BCOO = BCOO.fromdense(X)

        args = (X, X_BCOO, X_BCOO)

        # Independent index
        spenv = SparsifyEnv()
        spvalues = arrays_to_spvalues(spenv, args)
        self.assertEqual(len(spvalues), len(args))
        self.assertLen(spenv._buffers, 5)
        self.assertEqual(
            spvalues,
            (SparsifyValue(X.shape, 0, None), SparsifyValue(
                X.shape, 1, 2), SparsifyValue(X.shape, 3, 4)))

        args_out = spvalues_to_arrays(spenv, spvalues)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])

        # Shared index
        spvalues = (SparsifyValue(X.shape, 0, None),
                    SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 2))
        spenv = SparsifyEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data])

        args_out = spvalues_to_arrays(spenv, spvalues)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])
예제 #5
0
    def testSparseWhileLoopDuplicateIndices(self):
        def cond_fun(params):
            i, A, B = params
            return i < 5

        def body_fun(params):
            i, A, B = params
            # TODO(jakevdp): track shared indices through while loop & use this
            #   version of the test, which requires shared indices in order for
            #   the nse of the result to remain the same.
            # return i + 1, A, A + B

            # This version is fine without shared indices, and tests that we're
            # flattening non-shared indices consistently.
            return i + 1, B, A

        def f(A):
            return lax.while_loop(cond_fun, body_fun, (0, A, A))

        A = jnp.arange(4).reshape((2, 2))
        out_dense = f(A)

        Asp = BCOO.fromdense(A)
        out_sparse = self.sparsify(f)(Asp)

        self.assertEqual(len(out_dense), 3)
        self.assertEqual(len(out_sparse), 3)
        self.assertArraysEqual(out_dense[0], out_dense[0])
        self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
        self.assertArraysEqual(out_dense[2], out_sparse[2].todense())
예제 #6
0
    def testArgSpec(self):
        X = jnp.arange(5)
        X_BCOO = BCOO.fromdense(X)

        args = (X, X_BCOO, X_BCOO)

        # Independent index
        spenv = SparseEnv()
        argspecs = arrays_to_argspecs(spenv, args)
        self.assertEqual(len(argspecs), len(args))
        self.assertEqual(spenv.size(), 5)
        self.assertEqual(argspecs, (ArgSpec(
            X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 4)))

        args_out = argspecs_to_arrays(spenv, argspecs)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])

        # Shared index
        argspecs = (ArgSpec(X.shape, 0,
                            None), ArgSpec(X.shape, 1,
                                           2), ArgSpec(X.shape, 3, 2))
        spenv = SparseEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data])

        args_out = argspecs_to_arrays(spenv, argspecs)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])
예제 #7
0
    def testSparseMul(self, shape, dtype, n_batch, n_dense, unique_indices):
        rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
        x = BCOO.fromdense(rng_sparse(shape, dtype),
                           n_batch=n_batch,
                           n_dense=n_dense)

        # Scalar multiplication
        scalar = 2
        y = self.sparsify(operator.mul)(x, scalar)
        self.assertArraysEqual(x.todense() * scalar, y.todense())

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape,
                         data_ref=1,
                         indices_ref=0,
                         unique_indices=unique_indices),
            spenv.sparse(y.shape,
                         data_ref=2,
                         indices_ref=0,
                         unique_indices=unique_indices)
        ]

        result = sparsify_raw(operator.mul)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())
예제 #8
0
 def testPytreeInput(self):
     f = self.sparsify(lambda x: x)
     args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4)))
     out = f(args)
     self.assertLen(out, 2)
     self.assertArraysEqual(args[0], out[0])
     self.assertBcooIdentical(args[1], out[1])
예제 #9
0
 def argspec_to_array(argspec):
   if argspec.is_sparse():
     assert argspec.indices_ref is not None
     return BCOO((argspec.data(spenv), argspec.indices(spenv)), shape=argspec.shape)
   elif argspec.is_unit():
     return core.unit
   else:
     return argspec.data(spenv)
예제 #10
0
    def testSparseMul(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Scalar multiplication
        out = self.sparsify(operator.mul)(x, 2.5)
        self.assertArraysEqual(out.todense(), x.todense() * 2.5)

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.mul)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())
예제 #11
0
    def testSparseSubtract(self):
        x = BCOO.fromdense(3 * jnp.arange(5))
        y = BCOO.fromdense(jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.sub)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 2 * jnp.arange(5))

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.sub)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() - y.todense())
예제 #12
0
    def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense):
        rng = jtu.rand_some_zero(self.rng())
        arr = rng(shape, 'int32')
        arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

        arr2 = arr.reshape(new_shape)
        arr2_sparse = arr_sparse.reshape(new_shape)

        self.assertArraysEqual(arr2, arr2_sparse.todense())
예제 #13
0
 def testWeakTypes(self):
     # Regression test for https://github.com/google/jax/issues/8267
     M = jnp.arange(12, dtype='int32').reshape(3, 4)
     Msp = BCOO.fromdense(M)
     self.assertArraysEqual(
         operator.mul(2, M),
         self.sparsify(operator.mul)(2, Msp).todense(),
         check_dtypes=True,
     )
예제 #14
0
 def testToDense(self):
   M = jnp.arange(4)
   Msp = BCOO.fromdense(M)
   @self.sparsify
   def func(M):
     return todense(M) + 1
   self.assertArraysEqual(func(M), M + 1)
   self.assertArraysEqual(func(Msp), M + 1)
   self.assertArraysEqual(jit(func)(M), M + 1)
   self.assertArraysEqual(jit(func)(Msp), M + 1)
예제 #15
0
  def testDropvar(self):
    def inner(x):
      return x * 2, x * 3

    def f(x):
      _, y = jit(inner)(x)
      return y * 4

    x_dense = jnp.arange(5)
    x_sparse = BCOO.fromdense(x_dense)
    self.assertArraysEqual(self.sparsify(f)(x_sparse).todense(), f(x_dense))
예제 #16
0
    def testSparseMul(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Scalar multiplication
        out = self.sparsify(operator.mul)(x, 2.5)
        self.assertArraysEqual(out.todense(), x.todense() * 2.5)

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape, data_ref=1, indices_ref=0),
            spenv.sparse(y.shape, data_ref=2, indices_ref=0)
        ]

        result = sparsify_raw(operator.mul)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())
예제 #17
0
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
                           mode=None, fill_value=None):
  # mirrors lax_numpy._rewriting_take.
  treedef, static_idx, dynamic_idx = lax_numpy._split_index_for_jit(idx, arr.shape)
  result = sparsify(
      lambda arr, idx: lax_numpy._gather(arr, treedef, static_idx, idx, indices_are_sorted,
                                         unique_indices, mode, fill_value))(arr, dynamic_idx)
  # Account for a corner case in the rewriting_take implementation.
  if not isinstance(result, BCOO) and np.size(result) == 0:
    result = BCOO.fromdense(result)
  return result
예제 #18
0
    def testSparseCondSimple(self):
        def func(x):
            return lax.cond(False, lambda x: x, lambda x: 2 * x, x)

        x = jnp.arange(5.0)
        result_dense = func(x)

        x_bcoo = BCOO.fromdense(x)
        result_sparse = self.sparsify(func)(x_bcoo)

        self.assertArraysAllClose(result_dense, result_sparse.todense())
예제 #19
0
    def testSparsifySparseXlaCall(self):
        # Test sparse lowering of XLA call
        def func(M):
            return 2 * M

        M = jnp.arange(6).reshape(2, 3)
        Msp = BCOO.fromdense(M)

        out_dense = func(M)
        out_sparse = self.sparsify(jit(func))(Msp)
        self.assertArraysEqual(out_dense, out_sparse.todense())
예제 #20
0
    def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense):
        rng = jtu.rand_default(self.rng())

        M_dense = rng(shape, np.float32)
        M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense)
        func = self.sparsify(partial(lax.squeeze, dimensions=dimensions))

        result_dense = func(M_dense)
        result_sparse = func(M_sparse).todense()

        self.assertAllClose(result_sparse, result_dense)
예제 #21
0
    def testSparseMatmul(self):
        X = jnp.arange(16).reshape(4, 4)
        Xsp = BCOO.fromdense(X)
        Y = jnp.ones(4)
        Ysp = BCOO.fromdense(Y)

        # dot_general
        result_sparse = self.sparsify(operator.matmul)(Xsp, Y)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse, result_dense)

        # rdot_general
        result_sparse = self.sparsify(operator.matmul)(Y, Xsp)
        result_dense = operator.matmul(Y, X)
        self.assertAllClose(result_sparse, result_dense)

        # spdot_general
        result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse.todense(), result_dense)
예제 #22
0
    def testSparseSubtract(self):
        x = BCOO.fromdense(3 * jnp.arange(5))
        y = BCOO.fromdense(jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.sub)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 2 * jnp.arange(5))

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape, data_ref=1, indices_ref=0),
            spenv.sparse(y.shape, data_ref=2, indices_ref=0)
        ]

        result = sparsify_raw(operator.sub)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() - y.todense())
예제 #23
0
    def testSparsifyWithConsts(self):
        M_dense = jnp.arange(24).reshape(4, 6)
        M_sparse = BCOO.fromdense(M_dense)

        @self.sparsify
        def func(x):
            return jit(lambda x: jnp.sum(x, 1))(x)

        result_dense = func(M_dense)
        result_sparse = func(M_sparse)

        self.assertAllClose(result_sparse.todense(), result_dense)
예제 #24
0
def argspecs_to_arrays(
    spenv: SparseEnv,
    argspecs: Sequence[ArgSpec],
    ) -> Sequence[AnyArray]:
  args = []
  for argspec in argspecs:
    if argspec.is_sparse():
      assert argspec.indices_ref is not None
      args.append(BCOO((argspec.data(spenv), argspec.indices(spenv)), shape=argspec.shape))
    else:
      args.append(argspec.data(spenv))
    assert args[-1].shape == argspec.shape
  return tuple(args)
예제 #25
0
    def testSparseForiLoop(self):
        def func(M, x):
            body_fun = lambda i, val: (M @ val) / M.shape[1]
            return lax.fori_loop(0, 2, body_fun, x)

        x = jnp.arange(5.0)
        M = jnp.arange(25).reshape(5, 5)
        M_bcoo = BCOO.fromdense(M)

        result_dense = func(M, x)
        result_sparse = self.sparsify(func)(M_bcoo, x)

        self.assertArraysAllClose(result_dense, result_sparse)
예제 #26
0
    def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch,
                                        n_dense, dimensions):
        rng = jtu.rand_some_zero(self.rng())
        arr = rng(shape, 'int32')
        arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

        f = self.sparsify(
            lambda x: lax.reshape(x, new_shape, dimensions=dimensions))

        arr2 = f(arr)
        arr2_sparse = f(arr_sparse)

        self.assertArraysEqual(arr2, arr2_sparse.todense())
예제 #27
0
    def testSparsify(self):
        M_dense = jnp.arange(24).reshape(4, 6)
        M_sparse = BCOO.fromdense(M_dense)
        v = jnp.arange(M_dense.shape[0])

        @self.sparsify
        def func(x, v):
            return -jnp.sin(jnp.pi * x).T @ (v + 1)

        result_dense = func(M_dense, v)
        result_sparse = func(M_sparse, v)

        self.assertAllClose(result_sparse, result_dense)
예제 #28
0
    def testNotImplementedMessages(self):
        x = BCOO.fromdense(jnp.arange(5.0))
        # Test a densifying primitive
        with self.assertRaisesRegex(
                NotImplementedError,
                r"^sparse rule for cos is not implemented because it would result in dense output\."
        ):
            self.sparsify(lax.cos)(x)

        # Test a generic not implemented primitive.
        with self.assertRaisesRegex(
                NotImplementedError,
                r"^sparse rule for complex is not implemented\.$"):
            self.sparsify(lax.complex)(x, x)
예제 #29
0
    def testSparseSum(self):
        x = jnp.arange(20).reshape(4, 5)
        xsp = BCOO.fromdense(x)

        def f(x):
            return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1))

        result_dense = f(x)
        result_sparse = self.sparsify(f)(xsp)

        assert len(result_dense) == len(result_sparse)

        for res_dense, res_sparse in zip(result_dense, result_sparse):
            if isinstance(res_sparse, BCOO):
                res_sparse = res_sparse.todense()
            self.assertArraysAllClose(res_dense, res_sparse)
예제 #30
0
    def testSparsify(self):
        M_dense = jnp.arange(24).reshape(4, 6)
        M_sparse = BCOO.fromdense(M_dense)
        v = jnp.arange(M_dense.shape[0])

        @self.sparsify
        def func(x, v):
            return -jnp.sin(jnp.pi * x).T @ (v + 1)

        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = func(M_sparse, v)
        result_dense = func(M_dense, v)
        self.assertAllClose(result_sparse, result_dense)