Exemplo n.º 1
0
Arquivo: csr.py Projeto: jbampton/jax
def _csr_todense_transpose(ct, data, indices, indptr, *, shape):
    # Note: we assume that transpose has the same sparsity pattern.
    # Can we check this?
    assert ad.is_undefined_primal(data)
    if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
        raise ValueError("Cannot transpose with respect to sparse indices")
    assert ct.shape == shape
    assert indices.aval.dtype == indptr.aval.dtype
    assert ct.dtype == data.aval.dtype
    return _csr_extract(indices, indptr, ct), indices, indptr
Exemplo n.º 2
0
def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype):
  M, = primals
  Mdot, = tangents

  primals_out = csr_fromdense(M, nse=nse, index_dtype=index_dtype)
  data, indices, indptr = primals_out

  if type(Mdot) is ad.Zero:
    data_dot = ad.Zero.from_value(data)
  else:
    data_dot = _csr_extract(indices, indptr, Mdot)

  tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr))

  return primals_out, tangents_out