예제 #1
0
    def reduce_fn(operands, inits, axis=None, keepdims=False):
        """Applies `reducer` to the given operands along the given axes.

    Args:
      operands: tuple of tensors, all having the same shape.
      inits: tuple of scalar tensors, with dtypes aligned to those of operands.
      axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of
        `int`. `None` is taken to mean "reduce all axes".
      keepdims: When `True`, we do not squeeze away the reduced dims, instead
        returning values with singleton dims in those axes.

    Returns:
      reduced: A tuple of the reduced operands.
    """
        # Static shape consistency checks.
        args_shape = operands[0].shape
        for arg in operands[1:]:
            args_shape = tensorshape_util.merge_with(args_shape, arg.shape)
        ndims = tensorshape_util.rank(args_shape)
        if ndims is None:
            raise ValueError(
                'Rank of at least one of `operands` must be known statically.')
        # Ensure the 'axis' arg is a tuple of non-negative ints.
        axis = np.arange(ndims) if axis is None else np.array(axis)
        if axis.ndim > 1:
            raise ValueError(
                '`axis` must be `None`, an `int`, or a sequence of '
                '`int`, but got {}'.format(axis))
        axis = np.reshape(axis, [-1])
        axis = np.where(axis < 0, axis + ndims, axis)
        axis = tuple(int(ax) for ax in axis)

        axis_nhot = ps.reduce_sum(ps.one_hot(axis,
                                             depth=ndims,
                                             on_value=True,
                                             off_value=False,
                                             dtype=tf.bool),
                                  axis=0)
        in_shape = args_shape
        if not tensorshape_util.is_fully_defined(in_shape):
            in_shape = tf.shape(operands[0])
        unsqueezed_shape = ps.where(axis_nhot, 1, in_shape)

        result = _variadic_reduce_custom_grad(operands, inits, axis, reducer,
                                              unsqueezed_shape)

        if keepdims:
            result = tf.nest.map_structure(
                lambda t: tf.reshape(t, unsqueezed_shape), result)
        return result
예제 #2
0
 def _calculate_batch_shape(self):
     """Computes fully defined batch shape for the new distribution."""
     all_batch_shapes = [
         d.batch_shape.as_list() if tensorshape_util.is_fully_defined(
             d.batch_shape) else d.batch_shape_tensor()
         for d in self.distributions
     ]
     original_shape = ps.stack(all_batch_shapes, axis=0)
     index_mask = ps.cast(ps.one_hot(self._axis,
                                     ps.shape(original_shape)[1]),
                          dtype=tf.bool)
     new_concat_dim = ps.cast(ps.reduce_sum(original_shape,
                                            axis=0)[self._axis],
                              dtype=tf.int32)
     return ps.where(index_mask, new_concat_dim,
                     ps.reduce_max(original_shape, axis=0))
예제 #3
0
def cumulative_variance(x, sample_axis=0, name=None):
    """Cumulative estimates of variance.

  Given `N` samples of a scalar-valued random variable `X`, we can compute
  cumulative variance estimates

    result[i] = variance(x[0:i+1])

  in O(N log(N)) time and O(1) TF kernel invocations.  This
  implementation also arranges to do so in a numerically accurate
  manner, i.e., without incurring a subtraction of floating-point
  numbers of size quadratic in the data `x`.  The underlying algorithm
  is from [1].

  Args:
    x:  A numeric `Tensor` holding samples.
    sample_axis: Scalar `Tensor` designating the axis holding samples.
      Other axes are treated in batch.  Default value: `0` (leftmost
      dimension).
    name: Python `str` name prefixed to Ops created by this function.
          Default value: `None` (i.e., `'cumulative_variance'`).

  Returns:
    cum_var: A `Tensor` of same shape and dtype as `x` giving
      cumulative variance estimates.  The zeroth element is the
      variance of a size-1 set of samples, so 0.

  #### References
  [1]: Philippe Pebay. Formulas for Robust, One-Pass Parallel Computation of
       Covariances and Arbitrary-Order Statistical Moments. _Technical Report
       SAND2008-6212_, 2008.
       https://prod-ng.sandia.gov/techlib-noauth/access-control.cgi/2008/086212.pdf

  """
    with tf.name_scope(name or 'cumulative_variance'):
        # At each index, we are interested in
        # - The count of items up to that index (inclusive and exclusive);
        # - The sum of items up to that index (exclusive);
        # - From which we compute the mean of items up to that index (exclusive);
        # - The residual of items up to that index (inclusive), which is
        #   the variance scaled by the count of items.
        #
        # The contribution from item i to the residual is that item's
        # squared discrepancy from the mean of all preceding items (i.e.,
        # the exclusive mean at the present item), adjusted by i-1/i.
        x = tf.convert_to_tensor(x)
        size = ps.shape(x)[sample_axis]
        counts_shp = ps.one_hot(sample_axis,
                                depth=ps.rank(x),
                                on_value=size,
                                off_value=1)
        excl_counts = tf.reshape(tf.range(size, dtype=x.dtype),
                                 shape=counts_shp)
        incl_counts = excl_counts + 1
        excl_sums = tf.cumsum(x, axis=sample_axis, exclusive=True)
        discrepancies = (excl_sums / excl_counts - x)**2
        discrepancies = tf.where(excl_counts == 0, x**2, discrepancies)
        adjustments = excl_counts / incl_counts
        # The zeroth item's residual contribution is 0, because it has no
        # other items to vary from.  The preceding expressions, however,
        # compute 0/0 at index 0, so we mask it out here.
        adjusted = tf.where(~tf.equal(excl_counts, 0),
                            adjustments * discrepancies, 0)
        incl_residual = tf.cumsum(adjusted, axis=sample_axis)
        return incl_residual / incl_counts
예제 #4
0
    def reduce_fn(operands, inits, axis=None, keepdims=False):
        """Applies `reducer` to the given operands along the given axes.

    Args:
      operands: tuple of tensors, all having the same shape.
      inits: tuple of scalar tensors, with dtypes aligned to those of operands.
      axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of
        `int`. `None` is taken to mean "reduce all axes".
      keepdims: When `True`, we do not squeeze away the reduced dims, instead
        returning values with singleton dims in those axes.

    Returns:
      reduced: A tuple of the reduced operands.
    """
        # Static shape consistency checks.
        args_shape = operands[0].shape
        for arg in operands[1:]:
            args_shape = tensorshape_util.merge_with(args_shape, arg.shape)
        ndims = tensorshape_util.rank(args_shape)
        if ndims is None:
            raise ValueError(
                'Rank of at least one of `operands` must be known statically.')
        # Ensure the 'axis' arg is a tuple of non-negative ints.
        axis = np.arange(ndims) if axis is None else np.array(axis)
        if axis.ndim > 1:
            raise ValueError(
                '`axis` must be `None`, an `int`, or a sequence of '
                '`int`, but got {}'.format(axis))
        axis = np.reshape(axis, [-1])
        axis = np.where(axis < 0, axis + ndims, axis)
        axis = tuple(int(ax) for ax in axis)

        if JAX_MODE:
            from jax import lax  # pylint: disable=g-import-not-at-top
            result = lax.reduce(operands,
                                init_values=inits,
                                dimensions=axis,
                                computation=reducer)
        elif (tf.executing_eagerly()
              or not control_flow_util.GraphOrParentsInXlaContext(
                  tf1.get_default_graph())):
            result = _variadic_reduce(operands,
                                      init=inits,
                                      axis=axis,
                                      reducer=reducer)
        else:
            result = _xla_reduce(operands, inits, axis)

        if keepdims:
            axis_nhot = ps.reduce_sum(ps.one_hot(axis,
                                                 depth=ndims,
                                                 on_value=True,
                                                 off_value=False,
                                                 dtype=tf.bool),
                                      axis=0)
            in_shape = args_shape
            if not tensorshape_util.is_fully_defined(in_shape):
                in_shape = tf.shape(operands[0])
            final_shape = ps.where(axis_nhot, 1, in_shape)
            result = tf.nest.map_structure(
                lambda t: tf.reshape(t, final_shape), result)
        return result