def segment_sum(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None): # TODO(zhangqiaorjc): use non-None default. """Computes the sum within segments of an array. Similar to TensorFlow's segment_sum: https://www.tensorflow.org/api_docs/python/tf/math/segment_sum Args: data: an array with the values to be summed. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be summed. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are dropped and do not contribute to the sum. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as ``max(segment_ids) + 1``. Since `num_segments` determines the size of the output, a static value must be provided to use ``segment_sum`` in a ``jit``-compiled function. indices_are_sorted: whether ``segment_ids`` is known to be sorted. unique_indices: whether `segment_ids` is known to be free of duplicates. bucket_size: size of bucket to group indices into. ``segment_sum`` is performed on each bucket separately to improve numerical stability of addition. Default ``None`` means no bucketing. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ if num_segments is None: num_segments = jnp.max(segment_ids) + 1 num_segments = int(num_segments) out = jnp.zeros((num_segments, ) + data.shape[1:], dtype=data.dtype) num_buckets = 1 if bucket_size is None \ else util.ceil_of_ratio(segment_ids.size, bucket_size) if num_buckets == 1: return _scatter_update(out, segment_ids, data, lax.scatter_add, indices_are_sorted, unique_indices, normalize_indices=False) # Bucketize indices and perform segment_sum on each bucket to improve # numerical stability. outs = [] for sub_data, sub_segment_ids in zip( jnp.array_split(data, num_buckets), jnp.array_split(segment_ids, num_buckets)): outs.append( segment_sum(sub_data, sub_segment_ids, num_segments, indices_are_sorted, unique_indices)) return jnp.sum(jnp.stack(outs), axis=0)
def _segment_update(name: str, data: Array, segment_ids: Array, scatter_op: Callable, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: Optional[int] = None, reducer: Optional[Callable] = None, mode: Optional[lax.GatherScatterMode] = None) -> Array: jnp._check_arraylike(name, data, segment_ids) mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode data = jnp.asarray(data) segment_ids = jnp.asarray(segment_ids) dtype = data.dtype if num_segments is None: num_segments = jnp.max(segment_ids) + 1 num_segments = core.concrete_or_error( int, num_segments, "segment_sum() `num_segments` argument.") if num_segments is not None and num_segments < 0: raise ValueError("num_segments must be non-negative.") num_buckets = 1 if bucket_size is None \ else util.ceil_of_ratio(segment_ids.size, bucket_size) if num_buckets == 1: out = jnp.full((num_segments, ) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) return _scatter_update(out, segment_ids, data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False, mode=mode) # Bucketize indices and perform segment_update on each bucket to improve # numerical stability for operations like product and sum. assert reducer is not None out = jnp.full((num_buckets, num_segments) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) out = _scatter_update( out, np.index_exp[lax.div(jnp.arange(segment_ids.shape[0]), bucket_size), segment_ids[None, :]], data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False, mode=mode) return reducer(out, axis=0).astype(dtype)
def _segment_update(name: str, data: Array, segment_ids: Array, scatter_op: Callable, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: Optional[int] = None, reducer: Optional[Callable] = None) -> Array: jnp._check_arraylike(name, data, segment_ids) data = jnp.asarray(data) segment_ids = jnp.asarray(segment_ids) dtype = data.dtype if num_segments is None: num_segments = jnp.max(segment_ids) + 1 num_segments = core.concrete_or_error( int, num_segments, "segment_sum() `num_segments` argument.") if num_segments is not None and num_segments < 0: raise ValueError("num_segments must be non-negative.") out = jnp.full((num_segments, ) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) num_buckets = 1 if bucket_size is None \ else util.ceil_of_ratio(segment_ids.size, bucket_size) if num_buckets == 1: return _scatter_update(out, segment_ids, data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False) # Bucketize indices and perform segment_update on each bucket to improve # numerical stability for operations like product and sum. assert reducer is not None outs = [] for sub_data, sub_segment_ids in zip( jnp.array_split(data, num_buckets), jnp.array_split(segment_ids, num_buckets)): outs.append( _segment_update(name, sub_data, sub_segment_ids, scatter_op, num_segments, indices_are_sorted, unique_indices)) return reducer(jnp.stack(outs), axis=0).astype(dtype)