示例#1
0
def _mul_sparse(spenv, *argspecs):
    X, Y = argspecs
    if X.is_sparse() and Y.is_sparse():
        if X.shape != Y.shape:
            raise NotImplementedError(
                "Multiplication between sparse matrices of different shapes.")
        if X.indices_ref == Y.indices_ref:
            out_data = lax.mul(X.data(spenv), Y.data(spenv))
            out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)
        elif X.indices(spenv).ndim != Y.indices(spenv).ndim or X.data(
                spenv).ndim != Y.data(spenv).ndim:
            raise NotImplementedError(
                "Multiplication between sparse matrices with different batch/dense dimensions."
            )
        else:
            raise NotImplementedError(
                "Multiplication between sparse matrices with different sparsity patterns."
            )
    else:
        if Y.is_sparse():
            X, Y = Y, X
        out_data = bcoo_multiply_dense(X.data(spenv),
                                       X.indices(spenv),
                                       Y.data(spenv),
                                       shape=X.shape)
        out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)

    return (out_argspec, )
示例#2
0
def _mul_sparse(spenv, *argspecs):
    X, Y = argspecs
    if X.is_sparse() and Y.is_sparse():
        if X.indices_ref == Y.indices_ref:
            # TODO(jakevdp): this is inaccurate if there are duplicate indices
            out_data = lax.mul(X.data(spenv), Y.data(spenv))
            out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)
        else:
            data, indices, shape = bcoo_multiply_sparse(X.data(spenv),
                                                        X.indices(spenv),
                                                        Y.data(spenv),
                                                        Y.indices(spenv),
                                                        lhs_shape=X.shape,
                                                        rhs_shape=Y.shape)
            out_argspec = ArgSpec(shape, spenv.push(data), spenv.push(indices))
    else:
        if Y.is_sparse():
            X, Y = Y, X
        out_data = bcoo_multiply_dense(X.data(spenv),
                                       X.indices(spenv),
                                       Y.data(spenv),
                                       shape=X.shape)
        out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)

    return (out_argspec, )
示例#3
0
def _mul_sparse(spenv, *spvalues):
    X, Y = spvalues
    if X.is_sparse() and Y.is_sparse():
        if X.indices_ref == Y.indices_ref and X.unique_indices:
            if config.jax_enable_checks:
                assert X.indices_sorted == Y.indices_sorted
                assert X.unique_indices == Y.unique_indices
            out_data = lax.mul(spenv.data(X), spenv.data(Y))
            out_spvalue = spenv.sparse(X.shape,
                                       out_data,
                                       indices_ref=X.indices_ref,
                                       indices_sorted=X.indices_sorted,
                                       unique_indices=True)
        else:
            X_promoted, Y_promoted = spvalues_to_arrays(spenv, spvalues)
            mat = bcoo_multiply_sparse(X_promoted, Y_promoted)
            out_spvalue = spenv.sparse(mat.shape, mat.data, mat.indices)
    else:
        if Y.is_sparse():
            X, Y = Y, X
        X_promoted = spvalues_to_arrays(spenv, X)
        out_data = bcoo_multiply_dense(X_promoted, spenv.data(Y))
        out_spvalue = spenv.sparse(X.shape,
                                   out_data,
                                   indices_ref=X.indices_ref,
                                   indices_sorted=X.indices_sorted,
                                   unique_indices=X.unique_indices)

    return (out_spvalue, )
示例#4
0
def _mul_sparse(spenv, *spvalues):
  X, Y = spvalues
  if X.is_sparse() and Y.is_sparse():
    if X.indices_ref == Y.indices_ref:
      # TODO(jakevdp): this is inaccurate if there are duplicate indices
      out_data = lax.mul(spenv.data(X), spenv.data(Y))
      out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref)
    else:
      X_promoted, Y_promoted = spvalues_to_arrays(spenv, spvalues)
      mat = bcoo_multiply_sparse(X_promoted, Y_promoted)
      out_spvalue = spenv.sparse(mat.shape, mat.data, mat.indices)
  else:
    if Y.is_sparse():
      X, Y = Y, X
    X_promoted = spvalues_to_arrays(spenv, X)
    out_data = bcoo_multiply_dense(X_promoted, spenv.data(Y))
    out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref)

  return (out_spvalue,)