Ejemplo n.º 1
0
def _mul_sparse(spenv, *argspecs):
    X, Y = argspecs
    if X.is_sparse() and Y.is_sparse():
        if X.shape != Y.shape:
            raise NotImplementedError(
                "Multiplication between sparse matrices of different shapes.")
        if X.indices_ref == Y.indices_ref:
            out_data = lax.mul(X.data(spenv), Y.data(spenv))
            out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)
        elif X.indices(spenv).ndim != Y.indices(spenv).ndim or X.data(
                spenv).ndim != Y.data(spenv).ndim:
            raise NotImplementedError(
                "Multiplication between sparse matrices with different batch/dense dimensions."
            )
        else:
            raise NotImplementedError(
                "Multiplication between sparse matrices with different sparsity patterns."
            )
    else:
        if Y.is_sparse():
            X, Y = Y, X
        Ydata = Y.data(spenv)
        if Ydata.ndim == 0:
            out_data = lax.mul(X.data(spenv), Ydata)
        elif Ydata.shape == X.shape:
            out_data = lax.mul(X.data(spenv),
                               sparse.bcoo_extract(X.indices(spenv), Ydata))
        else:
            raise NotImplementedError(
                "Multiplication between sparse and dense matrices of different shape."
            )
        out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)

    return (out_argspec, )
Ejemplo n.º 2
0
 def test_bcoo_extract(self, shape, dtype, n_batch, n_dense):
   rng = rand_sparse(self.rng())
   M = rng(shape, dtype)
   data, indices = sparse.bcoo_fromdense(M)
   data2 = sparse.bcoo_extract(indices, M)
   self.assertArraysEqual(data, data2)
   data3 = jit(sparse.bcoo_extract)(indices, M)
   self.assertArraysEqual(data, data3)
Ejemplo n.º 3
0
Archivo: api.py Proyecto: jbampton/jax
def _todense_transpose(ct, *bufs, tree):
  assert ad.is_undefined_primal(bufs[0])
  assert not any(ad.is_undefined_primal(buf) for buf in bufs[1:])

  standin = object()
  obj = tree_util.tree_unflatten(tree, [standin] * len(bufs))
  from jax.experimental.sparse import BCOO, bcoo_extract
  if obj is standin:
    return (ct,)
  elif isinstance(obj, BCOO):
    _, indices = bufs
    return bcoo_extract(indices, ct), indices
  elif isinstance(obj, COO):
    _, row, col = bufs
    return _coo_extract(row, col, ct), row, col
  else:
    raise NotImplementedError(f"todense_transpose for {type(obj)}")