Exemplo n.º 1
0
def _transpose_sparse(spenv, *argspecs, permutation):
    permutation = tuple(permutation)
    args = argspecs_to_arrays(spenv, argspecs)
    shape = args[0].shape
    data, indices = sparse.bcoo_transpose(args[0].data,
                                          args[0].indices,
                                          permutation=permutation,
                                          shape=shape)
    out_shape = tuple(shape[i] for i in permutation)

    n_batch = args[0].indices.ndim - 2
    n_sparse = args[0].indices.shape[-2]
    batch_dims_unchanged = (permutation[:n_batch] == tuple(range(n_batch)))
    dense_dims_unchanged = (permutation[n_batch + n_sparse:] == tuple(
        range(n_batch + n_sparse, len(shape))))
    sparse_dims_unchanged = (permutation[n_batch:n_batch + n_sparse] == tuple(
        range(n_batch, n_batch + n_sparse)))

    # Data is unchanged if batch & dense dims are not permuted
    if batch_dims_unchanged and dense_dims_unchanged:
        data_ref = argspecs[0].data_ref
    else:
        data_ref = spenv.push(data)

    # Indices unchanged if batch & sparse dims are not permuted
    if batch_dims_unchanged and sparse_dims_unchanged:
        indices_ref = argspecs[0].indices_ref
    else:
        indices_ref = spenv.push(indices)

    argspec = ArgSpec(out_shape, data_ref, indices_ref)
    return (argspec, )
Exemplo n.º 2
0
def _transpose_sparse(spenv, *spvalues, permutation):
  permutation = tuple(permutation)
  args = spvalues_to_arrays(spenv, spvalues)
  shape = args[0].shape
  mat = sparse.BCOO((args[0].data, args[0].indices), shape=shape)
  mat_transposed = sparse.bcoo_transpose(mat, permutation=permutation)
  out_shape = tuple(shape[i] for i in permutation)

  n_batch = args[0].indices.ndim - 2
  n_sparse = args[0].indices.shape[-1]
  batch_dims_unchanged = (permutation[:n_batch] == tuple(range(n_batch)))
  dense_dims_unchanged = (permutation[n_batch + n_sparse:] == tuple(range(n_batch + n_sparse, len(shape))))
  sparse_dims_unchanged = (permutation[n_batch:n_batch + n_sparse] == tuple(range(n_batch, n_batch + n_sparse)))

  # Data is unchanged if batch & dense dims are not permuted
  kwds = {}
  if batch_dims_unchanged and dense_dims_unchanged:
    kwds['data_ref'] = spvalues[0].data_ref
  else:
    kwds['data'] = mat_transposed.data

  # Indices unchanged if batch & sparse dims are not permuted
  if batch_dims_unchanged and sparse_dims_unchanged:
    kwds['indices_ref'] = spvalues[0].indices_ref
  else:
    kwds['indices'] = mat_transposed.indices

  spvalue = spenv.sparse(out_shape, **kwds)
  return (spvalue,)