示例#1
0
 def __getitem__(self, key):
     if isinstance(key, slice):
         return _make_1d_grid_from_slice(key, op_name=self.op_name)
     output = (_make_1d_grid_from_slice(k, op_name=self.op_name)
               for k in key)
     output = meshgrid(*output, indexing='ij', sparse=self.sparse)
     return output if self.sparse else stack(output, 0)
示例#2
0
文件: scatter.py 项目: tataudat/jax
def segment_sum(data,
                segment_ids,
                num_segments=None,
                indices_are_sorted=False,
                unique_indices=False,
                bucket_size=None):  # TODO(zhangqiaorjc): use non-None default.
    """Computes the sum within segments of an array.

  Similar to TensorFlow's segment_sum:
  https://www.tensorflow.org/api_docs/python/tf/math/segment_sum

  Args:
    data: an array with the values to be summed.
    segment_ids: an array with integer dtype that indicates the segments of
      `data` (along its leading axis) to be summed. Values can be repeated and
      need not be sorted. Values outside of the range [0, num_segments) are
      dropped and do not contribute to the sum.
    num_segments: optional, an int with nonnegative value indicating the number
      of segments. The default is set to be the minimum number of segments that
      would support all indices in ``segment_ids``, calculated as
      ``max(segment_ids) + 1``.
      Since `num_segments` determines the size of the output, a static value
      must be provided to use ``segment_sum`` in a ``jit``-compiled function.
    indices_are_sorted: whether ``segment_ids`` is known to be sorted.
    unique_indices: whether `segment_ids` is known to be free of duplicates.
    bucket_size: size of bucket to group indices into. ``segment_sum`` is
      performed on each bucket separately to improve numerical stability of
      addition. Default ``None`` means no bucketing.

  Returns:
    An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
    segment sums.
  """
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = int(num_segments)

    out = jnp.zeros((num_segments, ) + data.shape[1:], dtype=data.dtype)

    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        return _scatter_update(out,
                               segment_ids,
                               data,
                               lax.scatter_add,
                               indices_are_sorted,
                               unique_indices,
                               normalize_indices=False)

    # Bucketize indices and perform segment_sum on each bucket to improve
    # numerical stability.
    outs = []
    for sub_data, sub_segment_ids in zip(
            jnp.array_split(data, num_buckets),
            jnp.array_split(segment_ids, num_buckets)):
        outs.append(
            segment_sum(sub_data, sub_segment_ids, num_segments,
                        indices_are_sorted, unique_indices))
    return jnp.sum(jnp.stack(outs), axis=0)
示例#3
0
 def __getitem__(self, key):
   if isinstance(key, slice):
     return _make_1d_grid_from_slice(key, op_name=self.op_name)
   output = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key)
   with jax.numpy_dtype_promotion('standard'):
     output = _promote_dtypes(*output)
   output = meshgrid(*output, indexing='ij', sparse=self.sparse)
   return output if self.sparse else stack(output, 0)
示例#4
0
 def __getitem__(self, key):
     single_slice = isinstance(key, slice)
     if single_slice:
         key = (key, )
     output = []
     for k in key:
         output.append(_make_1d_grid_from_slice(k, op_name=self.op_name))
     if single_slice:
         return output[0]
     output = meshgrid(*output, indexing='ij', sparse=self.sparse)
     return output if self.sparse else stack(output, 0)
示例#5
0
def _segment_update(name: str,
                    data: Array,
                    segment_ids: Array,
                    scatter_op: Callable,
                    num_segments: Optional[int] = None,
                    indices_are_sorted: bool = False,
                    unique_indices: bool = False,
                    bucket_size: Optional[int] = None,
                    reducer: Optional[Callable] = None,
                    mode: Optional[lax.GatherScatterMode] = None) -> Array:
    jnp._check_arraylike(name, data, segment_ids)
    mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
    data = jnp.asarray(data)
    segment_ids = jnp.asarray(segment_ids)
    dtype = data.dtype
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")
    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")

    out = jnp.full((num_segments, ) + data.shape[1:],
                   _get_identity(scatter_op, dtype),
                   dtype=dtype)

    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        return _scatter_update(out,
                               segment_ids,
                               data,
                               scatter_op,
                               indices_are_sorted,
                               unique_indices,
                               normalize_indices=False,
                               mode=mode)

    # Bucketize indices and perform segment_update on each bucket to improve
    # numerical stability for operations like product and sum.
    assert reducer is not None
    outs = []
    for sub_data, sub_segment_ids in zip(
            jnp.array_split(data, num_buckets),
            jnp.array_split(segment_ids, num_buckets)):
        outs.append(
            _segment_update(name, sub_data, sub_segment_ids, scatter_op,
                            num_segments, indices_are_sorted, unique_indices))
    return reducer(jnp.stack(outs), axis=0).astype(dtype)