def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose): assert not ad.is_undefined_primal(indices) assert not ad.is_undefined_primal(indptr) if ad.is_undefined_primal(B): return data, indices, indptr, csr_matmat(data, indices, indptr, ct, shape=shape, transpose=not transpose) else: B = jnp.asarray(B) row, col = _csr_to_coo(indices, indptr) return (ct[row] * B[col]).sum(1), indices, indptr, B
def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose): assert not ad.is_undefined_primal(indices) assert not ad.is_undefined_primal(indptr) if ad.is_undefined_primal(v): return data, indices, indptr, csr_matvec(data, indices, indptr, ct, shape=shape, transpose=not transpose) else: v = jnp.asarray(v) # The following lines do this, but more efficiently. # return _csr_extract(indices, indptr, jnp.outer(ct, v)), indices, indptr, v row, col = _csr_to_coo(indices, indptr) return ct[row] * v[col], indices, indptr, v
def _csr_matmat_impl(data, indices, indptr, B, *, shape, transpose): return _coo_matmat(data, *_csr_to_coo(indices, indptr), B, spinfo=COOInfo(shape=shape), transpose=transpose)
def _csr_todense_impl(data, indices, indptr, *, shape): return _coo_todense(data, *_csr_to_coo(indices, indptr), spinfo=COOInfo(shape=shape))
def _csr_matvec_impl(data, indices, indptr, v, *, shape, transpose): return _coo_matvec_impl(data, *_csr_to_coo(indices, indptr), v, shape=shape, transpose=transpose)