def _coo_todense_transpose(ct, data, row, col, *, shape): # Note: we assume that transpose has the same sparsity pattern. # Can we check this? assert ad.is_undefined_primal(data) if ad.is_undefined_primal(row) or ad.is_undefined_primal(col): raise ValueError("Cannot transpose with respect to sparse indices") assert ct.shape == shape assert row.aval.dtype == col.aval.dtype assert ct.dtype == data.aval.dtype return _coo_extract(row, col, ct), row, col
def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype): M, = primals Mdot, = tangents primals_out = coo_fromdense(M, nse=nse, index_dtype=index_dtype) data, row, col = primals_out if type(Mdot) is ad.Zero: data_dot = ad.Zero.from_value(data) else: data_dot = _coo_extract(row, col, Mdot) tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col)) return primals_out, tangents_out
def _todense_transpose(ct, *bufs, tree): assert ad.is_undefined_primal(bufs[0]) assert not any(ad.is_undefined_primal(buf) for buf in bufs[1:]) standin = object() obj = tree_util.tree_unflatten(tree, [standin] * len(bufs)) from jax.experimental.sparse import BCOO, bcoo_extract if obj is standin: return (ct,) elif isinstance(obj, BCOO): _, indices = bufs return bcoo_extract(indices, ct), indices elif isinstance(obj, COO): _, row, col = bufs return _coo_extract(row, col, ct), row, col else: raise NotImplementedError(f"todense_transpose for {type(obj)}")