def _mul_sparse(spenv, *argspecs): X, Y = argspecs if X.is_sparse() and Y.is_sparse(): if X.shape != Y.shape: raise NotImplementedError( "Multiplication between sparse matrices of different shapes.") if X.indices_ref == Y.indices_ref: out_data = lax.mul(X.data(spenv), Y.data(spenv)) out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref) elif X.indices(spenv).ndim != Y.indices(spenv).ndim or X.data( spenv).ndim != Y.data(spenv).ndim: raise NotImplementedError( "Multiplication between sparse matrices with different batch/dense dimensions." ) else: raise NotImplementedError( "Multiplication between sparse matrices with different sparsity patterns." ) else: if Y.is_sparse(): X, Y = Y, X out_data = bcoo_multiply_dense(X.data(spenv), X.indices(spenv), Y.data(spenv), shape=X.shape) out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref) return (out_argspec, )
def _mul_sparse(spenv, *argspecs): X, Y = argspecs if X.is_sparse() and Y.is_sparse(): if X.indices_ref == Y.indices_ref: # TODO(jakevdp): this is inaccurate if there are duplicate indices out_data = lax.mul(X.data(spenv), Y.data(spenv)) out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref) else: data, indices, shape = bcoo_multiply_sparse(X.data(spenv), X.indices(spenv), Y.data(spenv), Y.indices(spenv), lhs_shape=X.shape, rhs_shape=Y.shape) out_argspec = ArgSpec(shape, spenv.push(data), spenv.push(indices)) else: if Y.is_sparse(): X, Y = Y, X out_data = bcoo_multiply_dense(X.data(spenv), X.indices(spenv), Y.data(spenv), shape=X.shape) out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref) return (out_argspec, )
def _mul_sparse(spenv, *spvalues): X, Y = spvalues if X.is_sparse() and Y.is_sparse(): if X.indices_ref == Y.indices_ref and X.unique_indices: if config.jax_enable_checks: assert X.indices_sorted == Y.indices_sorted assert X.unique_indices == Y.unique_indices out_data = lax.mul(spenv.data(X), spenv.data(Y)) out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref, indices_sorted=X.indices_sorted, unique_indices=True) else: X_promoted, Y_promoted = spvalues_to_arrays(spenv, spvalues) mat = bcoo_multiply_sparse(X_promoted, Y_promoted) out_spvalue = spenv.sparse(mat.shape, mat.data, mat.indices) else: if Y.is_sparse(): X, Y = Y, X X_promoted = spvalues_to_arrays(spenv, X) out_data = bcoo_multiply_dense(X_promoted, spenv.data(Y)) out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref, indices_sorted=X.indices_sorted, unique_indices=X.unique_indices) return (out_spvalue, )
def _mul_sparse(spenv, *spvalues): X, Y = spvalues if X.is_sparse() and Y.is_sparse(): if X.indices_ref == Y.indices_ref: # TODO(jakevdp): this is inaccurate if there are duplicate indices out_data = lax.mul(spenv.data(X), spenv.data(Y)) out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref) else: X_promoted, Y_promoted = spvalues_to_arrays(spenv, spvalues) mat = bcoo_multiply_sparse(X_promoted, Y_promoted) out_spvalue = spenv.sparse(mat.shape, mat.data, mat.indices) else: if Y.is_sparse(): X, Y = Y, X X_promoted = spvalues_to_arrays(spenv, X) out_data = bcoo_multiply_dense(X_promoted, spenv.data(Y)) out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref) return (out_spvalue,)