コード例 #1
0
def _coo_todense_transpose(ct, data, row, col, *, shape):
    # Note: we assume that transpose has the same sparsity pattern.
    # Can we check this?
    assert ad.is_undefined_primal(data)
    if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
        raise ValueError("Cannot transpose with respect to sparse indices")
    assert ct.shape == shape
    assert row.aval.dtype == col.aval.dtype
    assert ct.dtype == data.aval.dtype
    return _coo_extract(row, col, ct), row, col
コード例 #2
0
def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
    M, = primals
    Mdot, = tangents

    primals_out = coo_fromdense(M, nse=nse, index_dtype=index_dtype)
    data, row, col = primals_out

    if type(Mdot) is ad.Zero:
        data_dot = ad.Zero.from_value(data)
    else:
        data_dot = _coo_extract(row, col, Mdot)

    tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col))

    return primals_out, tangents_out
コード例 #3
0
ファイル: api.py プロジェクト: jbampton/jax
def _todense_transpose(ct, *bufs, tree):
  assert ad.is_undefined_primal(bufs[0])
  assert not any(ad.is_undefined_primal(buf) for buf in bufs[1:])

  standin = object()
  obj = tree_util.tree_unflatten(tree, [standin] * len(bufs))
  from jax.experimental.sparse import BCOO, bcoo_extract
  if obj is standin:
    return (ct,)
  elif isinstance(obj, BCOO):
    _, indices = bufs
    return bcoo_extract(indices, ct), indices
  elif isinstance(obj, COO):
    _, row, col = bufs
    return _coo_extract(row, col, ct), row, col
  else:
    raise NotImplementedError(f"todense_transpose for {type(obj)}")