Esempio n. 1
0
def _scatter_update(x,
                    idx,
                    y,
                    scatter_op,
                    indices_are_sorted,
                    unique_indices,
                    normalize_indices=True):
    """Helper for indexed updates.

  Computes the value of x that would result from computing::
    x[idx] op= y
  except in a pure functional way, with no in-place updating.

  Args:
    x: ndarray to be updated.
    idx: None, an integer, a slice, an ellipsis, an ndarray with integer dtype,
      or a tuple of those indicating the locations of `x` into which to scatter-
      update the values in `y`.
    y: values to be scattered.
    scatter_op: callable, one of lax.scatter, lax.scatter_add, lax.scatter_min,
      or lax_scatter_max.
    indices_are_sorted: whether `idx` is known to be sorted
    unique_indices: whether `idx` is known to be free of duplicates

  Returns:
    An ndarray representing an updated `x` after performing the scatter-update.
  """

    x = jnp.asarray(x)
    y = jnp.asarray(y)
    # XLA gathers and scatters are very similar in structure; the scatter logic
    # is more or less a transpose of the gather equivalent.
    treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx)
    return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                         indices_are_sorted, unique_indices, normalize_indices)
Esempio n. 2
0
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
                           mode=None, fill_value=None):
  # mirrors lax_numpy._rewriting_take.
  treedef, static_idx, dynamic_idx = lax_numpy._split_index_for_jit(idx, arr.shape)
  result = sparsify(
      lambda arr, idx: lax_numpy._gather(arr, treedef, static_idx, idx, indices_are_sorted,
                                         unique_indices, mode, fill_value))(arr, dynamic_idx)
  # Account for a corner case in the rewriting_take implementation.
  if not isinstance(result, BCOO) and np.size(result) == 0:
    result = BCOO.fromdense(result)
  return result