def _csr_todense_transpose(ct, data, indices, indptr, *, shape): # Note: we assume that transpose has the same sparsity pattern. # Can we check this? assert ad.is_undefined_primal(data) if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): raise ValueError("Cannot transpose with respect to sparse indices") assert ct.shape == shape assert indices.aval.dtype == indptr.aval.dtype assert ct.dtype == data.aval.dtype return _csr_extract(indices, indptr, ct), indices, indptr
def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype): M, = primals Mdot, = tangents primals_out = csr_fromdense(M, nse=nse, index_dtype=index_dtype) data, indices, indptr = primals_out if type(Mdot) is ad.Zero: data_dot = ad.Zero.from_value(data) else: data_dot = _csr_extract(indices, indptr, Mdot) tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr)) return primals_out, tangents_out