Ejemplo n.º 1
0
    def _inverse(self, y):
        # To derive the inverse mapping note that:
        #   y[i] = exp(x[i]) / normalization
        # and
        #   y[end] = 1 / normalization.
        # Thus:
        # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization)
        #      = log(exp(x[i])/normalization) - log(y[end])
        #      = log(y[i]) - log(y[end])

        # Do this first to make sure CSE catches that it'll happen again in
        # _inverse_log_det_jacobian.
        x = tf.math.log(y)

        log_normalization = (-x[..., -1])[..., tf.newaxis]
        x = x[..., :-1] + log_normalization

        # Set shape hints.
        if tensorshape_util.rank(y.shape) is not None:
            last_dim = tf.compat.dimension_value(y.shape[-1])
            shape = y.shape[:-1].concatenate(
                None if last_dim is None else last_dim - 1)
            tensorshape_util.assert_is_compatible_with(x.shape, shape)
            x.set_shape(shape)

        return x
Ejemplo n.º 2
0
    def _forward(self, x):
        # Pad the last dim with a zeros vector. We need this because it lets us
        # infer the scale in the inverse function.
        y = distribution_util.pad(x, axis=-1, back=True)

        # Set shape hints.
        if tensorshape_util.rank(x.shape) is not None:
            last_dim = tf.compat.dimension_value(x.shape[-1])
            shape = x.shape[:-1].concatenate(
                None if last_dim is None else last_dim + 1)
            tensorshape_util.assert_is_compatible_with(y.shape, shape)
            y.set_shape(shape)

        return tf.nn.softmax(y)
Ejemplo n.º 3
0
def covariance(x,
               y=None,
               sample_axis=0,
               event_axis=-1,
               keepdims=False,
               name=None):
    """Sample covariance between observations indexed by `event_axis`.

  Given `N` samples of scalar random variables `X` and `Y`, covariance may be
  estimated as

  ```none
  Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)}
  Xbar := N^{-1} sum_{n=1}^N X_n
  Ybar := N^{-1} sum_{n=1}^N Y_n
  ```

  For vector-variate random variables `X = (X1, ..., Xd)`, `Y = (Y1, ..., Yd)`,
  one is often interested in the covariance matrix, `C_{ij} := Cov[Xi, Yj]`.

  ```python
  x = tf.random.normal(shape=(100, 2, 3))
  y = tf.random.normal(shape=(100, 2, 3))

  # cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j].
  cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None)

  # cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n]
  cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1)
  ```

  Notice we divide by `N`, which does not create `NaN` when `N = 1`, but is
  slightly biased.

  Args:
    x:  A numeric `Tensor` holding samples.
    y:  Optional `Tensor` with same `dtype` and `shape` as `x`.
      Default value: `None` (`y` is effectively set to `x`).
    sample_axis: Scalar or vector `Tensor` designating axis holding samples, or
      `None` (meaning all axis hold samples).
      Default value: `0` (leftmost dimension).
    event_axis:  Scalar or vector `Tensor`, or `None` (scalar events).
      Axis indexing random events, whose covariance we are interested in.
      If a vector, entries must form a contiguous block of dims. `sample_axis`
      and `event_axis` should not intersect.
      Default value: `-1` (rightmost axis holds events).
    keepdims:  Boolean.  Whether to keep the sample axis as singletons.
    name: Python `str` name prefixed to Ops created by this function.
          Default value: `None` (i.e., `'covariance'`).

  Returns:
    cov: A `Tensor` of same `dtype` as the `x`, and rank equal to
      `rank(x) - len(sample_axis) + 2 * len(event_axis)`.

  Raises:
    AssertionError:  If `x` and `y` are found to have different shape.
    ValueError:  If `sample_axis` and `event_axis` are found to overlap.
    ValueError:  If `event_axis` is found to not be contiguous.
  """

    with tf.name_scope(name or 'covariance'):
        x = tf.convert_to_tensor(x, name='x')
        # Covariance *only* uses the centered versions of x (and y).
        x = x - tf.reduce_mean(x, axis=sample_axis, keepdims=True)

        if y is None:
            y = x
        else:
            y = tf.convert_to_tensor(y, name='y', dtype=x.dtype)
            # If x and y have different shape, sample_axis and event_axis will likely
            # be wrong for one of them!
            tensorshape_util.assert_is_compatible_with(x.shape, y.shape)
            y = y - tf.reduce_mean(y, axis=sample_axis, keepdims=True)

        if event_axis is None:
            return tf.reduce_mean(x * tf.math.conj(y),
                                  axis=sample_axis,
                                  keepdims=keepdims)

        if sample_axis is None:
            raise ValueError(
                'sample_axis was None, which means all axis hold events, and this '
                'overlaps with event_axis ({})'.format(event_axis))

        event_axis = _make_positive_axis(event_axis, ps.rank(x))
        sample_axis = _make_positive_axis(sample_axis, ps.rank(x))

        # If we get lucky and axis is statically defined, we can do some checks.
        if _is_list_like(event_axis) and _is_list_like(sample_axis):
            event_axis = tuple(map(int, event_axis))
            sample_axis = tuple(map(int, sample_axis))
            if set(event_axis).intersection(sample_axis):
                raise ValueError(
                    'sample_axis ({}) and event_axis ({}) overlapped'.format(
                        sample_axis, event_axis))
            if (np.diff(np.array(sorted(event_axis))) > 1).any():
                raise ValueError(
                    'event_axis must be contiguous. Found: {}'.format(
                        event_axis))
            batch_axis = list(
                sorted(
                    set(range(tensorshape_util.rank(
                        x.shape))).difference(sample_axis + event_axis)))
        else:
            batch_axis = ps.setdiff1d(ps.range(0, ps.rank(x)),
                                      ps.concat((sample_axis, event_axis), 0))

        event_axis = ps.cast(event_axis, dtype=tf.int32)
        sample_axis = ps.cast(sample_axis, dtype=tf.int32)
        batch_axis = ps.cast(batch_axis, dtype=tf.int32)

        # Permute x/y until shape = B + E + S
        perm_for_xy = ps.concat((batch_axis, event_axis, sample_axis), 0)
        x_permed = tf.transpose(a=x, perm=perm_for_xy)
        y_permed = tf.transpose(a=y, perm=perm_for_xy)

        batch_ndims = ps.size(batch_axis)
        batch_shape = ps.shape(x_permed)[:batch_ndims]
        event_ndims = ps.size(event_axis)
        event_shape = ps.shape(x_permed)[batch_ndims:batch_ndims + event_ndims]
        sample_shape = ps.shape(x_permed)[batch_ndims + event_ndims:]
        sample_ndims = ps.size(sample_shape)
        n_samples = ps.reduce_prod(sample_shape)
        n_events = ps.reduce_prod(event_shape)

        # Flatten sample_axis into one long dim.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        # Do the same for event_axis.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))

        # After matmul, cov.shape = batch_shape + [n_events, n_events]
        cov = tf.matmul(x_permed_flat, y_permed_flat,
                        adjoint_b=True) / ps.cast(n_samples, x.dtype)

        # Insert some singletons to make
        # cov.shape = batch_shape + event_shape**2 + [1,...,1]
        # This is just like x_permed.shape, except the sample_axis is all 1's, and
        # the [n_events] became event_shape**2.
        cov = tf.reshape(
            cov,
            ps.concat(
                (
                    batch_shape,
                    # event_shape**2 used here because it is the same length as
                    # event_shape, and has the same number of elements as one
                    # batch of covariance.
                    event_shape**2,
                    ps.ones([sample_ndims], tf.int32)),
                0))
        # Permuting by the argsort inverts the permutation, making
        # cov.shape have ones in the position where there were samples, and
        # [n_events * n_events] in the event position.
        cov = tf.transpose(a=cov, perm=ps.invert_permutation(perm_for_xy))

        # Now expand event_shape**2 into event_shape + event_shape.
        # We here use (for the first time) the fact that we require event_axis to be
        # contiguous.
        e_start = event_axis[0]
        e_len = 1 + event_axis[-1] - event_axis[0]
        cov = tf.reshape(
            cov,
            ps.concat((ps.shape(cov)[:e_start], event_shape, event_shape,
                       ps.shape(cov)[e_start + e_len:]), 0))

        # tf.squeeze requires python ints for axis, not Tensor.  This is enough to
        # require our axis args to be constants.
        if not keepdims:
            squeeze_axis = ps.where(sample_axis < e_start, sample_axis,
                                    sample_axis + e_len)
            cov = _squeeze(cov, axis=squeeze_axis)

        return cov
Ejemplo n.º 4
0
def iterative_mergesort(y, permutation, name=None):
    """Non-recusive mergesort that counts exchanges.

  Args:
    y: a `Tensor` of shape `[n]` containing values to be sorted.
    permutation: `Tensor` of shape `[n]` with original ordering.
    name: Optional Python `str` name for ops created by this method.
      Default value: `None` (i.e., 'iterative_mergesort').

  Returns:
    exchanges: `int32` scalar that counts the number of exchanges required to
      produce a sorted permutation
    permutation: and a `tf.int32` Tensor that contains the ordering of y values
      that are sorted.
  """

    with tf.name_scope(name or 'iterative_mergesort'):
        y = tf.convert_to_tensor(y, name='y')
        permutation = tf.convert_to_tensor(permutation,
                                           name='permutation',
                                           dtype=tf.int32)
        shape = permutation.shape
        tensorshape_util.assert_is_compatible_with(y.shape, shape)
        n = ps.size(y)

        def outer_body(k, exchanges, permutation):
            # The outer body progressively merges lists as k grows by powers of 2,
            # tracking the total swaps required in exchanges as the new permutation is
            # built in place.
            y_ordered = tf.gather(y, permutation)

            def middle_body(left, exchanges, permutation):
                # the middle body advances through the sublists of size k, advancing
                # the left edge until the end of the input is reached.
                right = left + k
                end = tf.minimum(right + k, n)

                # See explanation here
                # https://www.geeksforgeeks.org/counting-inversions/.

                def inner_body(i, j, x, np, p):
                    # The [left, right) and [right, end) lists are merged sorted, with
                    # i and j tracking the advance through each range. x records the
                    # number of order (bubble-sort equivalent) swaps that are happening
                    # with each insertion, and np represents the size of the output
                    # permutation that's been filled in using the p tensor.
                    y_less = y_ordered[i] <= y_ordered[j]
                    element = tf.where(y_less, [permutation[i]],
                                       [permutation[j]])
                    new_p = tf.concat([p[0:np], element, p[np + 1:n]], axis=0)
                    tensorshape_util.set_shape(new_p, p.shape)
                    return (tf.where(y_less, i + 1,
                                     i), tf.where(y_less, j, j + 1),
                            tf.where(y_less, x, x + right - i), np + 1, new_p)

                i_j_x_np_p = (left, right, exchanges, 0,
                              tf.zeros([n], dtype=tf.int32))
                (i, j, exchanges, np,
                 p) = tf.while_loop(cond=lambda i, j, x, np, p: tf.math.
                                    logical_and(i < right, j < end),
                                    body=inner_body,
                                    loop_vars=i_j_x_np_p)
                permutation = tf.concat([
                    permutation[0:left], p[0:np], permutation[i:right],
                    permutation[j:end], permutation[end:n]
                ],
                                        axis=0)
                tensorshape_util.set_shape(permutation, shape)
                return left + 2 * k, exchanges, permutation

            _, exchanges, permutation = tf.while_loop(
                cond=lambda left, exchanges, permutation: left < n - k,
                body=middle_body,
                loop_vars=(0, exchanges, permutation))
            k *= 2
            return k, exchanges, permutation

        _, exchanges, permutation = tf.while_loop(
            cond=lambda k, exchanges, permutation: k < n,
            body=outer_body,
            loop_vars=(1, 0, permutation))
        return exchanges, permutation
Ejemplo n.º 5
0
def kendalls_tau(y_true, y_pred, name=None):
    """Computes Kendall's Tau for two ordered lists.

  Kendall's Tau measures the correlation between ordinal rankings. This
  implementation is similar to the one used in scipy.stats.kendalltau.
  The provided values may be of any type that is sortable, with the
  argsort indices indicating the true or proposed ordinal sequence.

  Args:
    y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking.
    y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the
      same N items.
    name: Optional Python `str` name for ops created by this method.
      Default value: `None` (i.e., 'kendalls_tau').

  Returns:
    kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores
      ordering of ties, as a `float32` scalar Tensor.
  """
    with tf.name_scope(name or 'kendalls_tau'):
        in_type = dtype_util.common_dtype([y_true, y_pred],
                                          dtype_hint=tf.float32)
        y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type)
        y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type)
        tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape)
        assertions = [
            assert_util.assert_rank(y_true, 1),
            assert_util.assert_greater(
                ps.size(y_true), 1, 'Ordering requires at least 2 elements.')
        ]
        with tf.control_dependencies(assertions):
            lexa = lexicographical_indirect_sort(y_true, y_pred)

        # See A Computer Method for Calculating Kendall's Tau with Ungrouped Data
        # by William Night, Journal of the American Statistical Association,
        # Jun., 1966, Vol. 61, No. 314, Part 1 (Jun., 1966), pp. 436-439
        # for notation https://www.jstor.org/stable/2282833

        def jointly_tied_pairs_body(first, t, i):
            not_equal = tf.math.logical_or(
                tf.not_equal(y_true[lexa[first]], y_true[lexa[i]]),
                tf.not_equal(y_pred[lexa[first]], y_pred[lexa[i]]))
            return (tf.where(not_equal, i, first),
                    tf.where(not_equal,
                             t + ((i - first) * (i - first - 1)) // 2,
                             t), i + 1)

        n = ps.size0(y_true)
        first, t, _ = tf.while_loop(cond=lambda first, t, i: i < n,
                                    body=jointly_tied_pairs_body,
                                    loop_vars=(0, 0, 1))
        t += ((n - first) * (n - first - 1)) // 2

        def ties_y_true_body(first, v, i):
            not_equal = tf.not_equal(y_true[lexa[first]], y_true[lexa[i]])
            return (tf.where(not_equal, i, first),
                    tf.where(not_equal,
                             v + ((i - first) * (i - first - 1)) // 2,
                             v), i + 1)

        first, v, _ = tf.while_loop(cond=lambda first, v, i: i < n,
                                    body=ties_y_true_body,
                                    loop_vars=(0, 0, 1))
        v += ((n - first) * (n - first - 1)) // 2

        # count exchanges
        exchanges, newperm = iterative_mergesort(y_pred, lexa)

        def ties_in_y_pred_body(first, u, i):
            not_equal = tf.not_equal(y_pred[newperm[first]],
                                     y_pred[newperm[i]])
            return (tf.where(not_equal, i, first),
                    tf.where(not_equal,
                             u + ((i - first) * (i - first - 1)) // 2,
                             u), i + 1)

        first, u, _ = tf.while_loop(cond=lambda first, u, i: i < n,
                                    body=ties_in_y_pred_body,
                                    loop_vars=(0, 0, 1))
        u += ((n - first) * (n - first - 1)) // 2
        n0 = (n * (n - 1)) // 2
        assertions = [
            assert_util.assert_less(v, tf.cast(n0, tf.int32),
                                    'All ranks are ties for y_true.'),
            assert_util.assert_less(u, tf.cast(n0, tf.int32),
                                    'All ranks are ties for y_pred.')
        ]
        with tf.control_dependencies(assertions):
            return (tf.cast(n0 - (u + v - t), tf.float32) -
                    2.0 * tf.cast(exchanges, tf.float32)) / tf.math.sqrt(
                        tf.cast(n0 - v, tf.float32) *
                        tf.cast(n0 - u, tf.float32))
Ejemplo n.º 6
0
def batch_interp_regular_nd_grid(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis,
                                 fill_value='constant_extension',
                                 name=None):
  """Multi-linear interpolation on a regular (constant spacing) grid.

  Given [a batch of] reference values, this function computes a multi-linear
  interpolant and evaluates it on [a batch of] of new `x` values.

  The interpolant is built from reference values indexed by `nd` dimensions
  of `y_ref`, starting at `axis`.

  For example, take the case of a `2-D` scalar valued function and no leading
  batch dimensions.  In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]`
  is the reference value corresponding to grid point

  ```
  [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1),
   x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)]
  ```

  In the general case, dimensions to the left of `axis` in `y_ref` are broadcast
  with leading dimensions in `x`, `x_ref_min`, `x_ref_max`.

  Args:
    x: Numeric `Tensor` The x-coordinates of the interpolated output values for
      each batch.  Shape `[..., D, nd]`, designating [a batch of] `D`
      coordinates in `nd` space.  `D` must be `>= 1` and is not a batch dim.
    x_ref_min:  `Tensor` of same `dtype` as `x`.  The minimum values of the
      (implicitly defined) reference `x_ref`.  Shape `[..., nd]`.
    x_ref_max:  `Tensor` of same `dtype` as `x`.  The maximum values of the
      (implicitly defined) reference `x_ref`.  Shape `[..., nd]`.
    y_ref:  `Tensor` of same `dtype` as `x`.  The reference output values. Shape
      `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference
      values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued
      function (for `M >= 0`).
    axis:  Scalar integer `Tensor`.  Dimensions `[axis, axis + nd)` of `y_ref`
      index the interpolation table.  E.g. `3-D` interpolation of a scalar
      valued function requires `axis=-3` and a `3-D` matrix valued function
      requires `axis=-5`.
    fill_value:  Determines what values output should take for `x` values that
      are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or
      'constant_extension' ==> Extend as constant function.
      Default value: `'constant_extension'`
    name:  A name to prepend to created ops.
      Default value: `'batch_interp_regular_nd_grid'`.

  Returns:
    y_interp:  Interpolation between members of `y_ref`, at points `x`.
      `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].`

  Raises:
    ValueError:  If `rank(x) < 2` is determined statically.
    ValueError:  If `axis` is not a scalar is determined statically.
    ValueError:  If `axis + nd > rank(y_ref)` is determined statically.

  #### Examples

  Interpolate a function of one variable.

  ```python
  y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20))

  tfp.math.batch_interp_regular_nd_grid(
      # x.shape = [3, 1], x_ref_min/max.shape = [1].  Trailing `1` for `1-D`.
      x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref,
      axis=0)
  ==> approx [exp(6.0), exp(0.5), exp(3.3)]
  ```

  Interpolate a scalar function of two variables.

  ```python
  x_ref_min = [0., 0.]
  x_ref_max = [2 * np.pi, 2 * np.pi]

  # Build y_ref.
  x0s, x1s = tf.meshgrid(
      tf.linspace(x_ref_min[0], x_ref_max[0], num=100),
      tf.linspace(x_ref_min[1], x_ref_max[1], num=100),
      indexing='ij')

  def func(x0, x1):
    return tf.sin(x0) * tf.cos(x1)

  y_ref = func(x0s, x1s)

  x = np.pi * tf.random.uniform(shape=(10, 2))

  tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
  ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])
  ```

  """
  with tf.name_scope(name or 'interp_regular_nd_grid'):
    dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                    dtype_hint=tf.float32)

    # Arg checking.
    if isinstance(fill_value, str):
      if fill_value != 'constant_extension':
        raise ValueError(
            'A fill value ({}) was not an allowed string ({})'.format(
                fill_value, 'constant_extension'))
    else:
      fill_value = tf.convert_to_tensor(
          fill_value, name='fill_value', dtype=dtype)
      _assert_ndims_statically(fill_value, expect_ndims=0)

    # x.shape = [..., nd].
    x = tf.convert_to_tensor(x, name='x', dtype=dtype)
    _assert_ndims_statically(x, expect_ndims_at_least=2)

    # y_ref.shape = [..., C1,...,Cnd, B1,...,BM]
    y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)

    # x_ref_min.shape = [nd]
    x_ref_min = tf.convert_to_tensor(
        x_ref_min, name='x_ref_min', dtype=dtype)
    x_ref_max = tf.convert_to_tensor(
        x_ref_max, name='x_ref_max', dtype=dtype)
    _assert_ndims_statically(
        x_ref_min, expect_ndims_at_least=1, expect_static=True)
    _assert_ndims_statically(
        x_ref_max, expect_ndims_at_least=1, expect_static=True)

    # nd is the number of dimensions indexing the interpolation table, it's the
    # 'nd' in the function name.
    nd = tf.compat.dimension_value(x_ref_min.shape[-1])
    if nd is None:
      raise ValueError('`x_ref_min.shape[-1]` must be known statically.')
    tensorshape_util.assert_is_compatible_with(
        x_ref_max.shape[-1:], x_ref_min.shape[-1:])

    # Convert axis and check it statically.
    axis = tf.convert_to_tensor(axis, dtype=tf.int32, name='axis')
    axis = ps.non_negative_axis(axis, tf.rank(y_ref))
    tensorshape_util.assert_has_rank(axis.shape, 0)
    axis_ = tf.get_static_value(axis)
    y_ref_rank_ = tf.get_static_value(tf.rank(y_ref))
    if axis_ is not None and y_ref_rank_ is not None:
      if axis_ + nd > y_ref_rank_:
        raise ValueError(
            'Since dims `[axis, axis + nd)` index the interpolation table, we '
            'must have `axis + nd <= rank(y_ref)`.  Found: '
            '`axis`: {},  rank(y_ref): {}, and inferred `nd` from trailing '
            'dimensions of `x_ref_min` to be {}.'.format(
                axis_, y_ref_rank_, nd))

    x_batch_shape = ps.shape(x)[:-2]
    x_ref_min_batch_shape = ps.shape(x_ref_min)[:-1]
    x_ref_max_batch_shape = ps.shape(x_ref_max)[:-1]
    y_ref_batch_shape = ps.shape(y_ref)[:axis]

    # Do a brute-force broadcast of batch dims (add zeros).
    batch_shape = y_ref_batch_shape
    for tensor in [x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape]:
      batch_shape = ps.broadcast_shape(batch_shape, tensor)

    def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons):
      """Return Tensor of zeros with some singletons on the rightmost dims."""
      ones = ps.ones(shape=[n_singletons], dtype=tf.int32)
      return ps.concat([batch_shape, ones], axis=0)

    x = _broadcast_with(
        x, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=2))
    x_ref_min = _broadcast_with(
        x_ref_min,
        _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1))
    x_ref_max = _broadcast_with(
        x_ref_max,
        _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1))
    y_ref = _broadcast_with(
        y_ref,
        _batch_shape_of_zeros_with_rightmost_singletons(
            n_singletons=tf.rank(y_ref) - axis))

    return _batch_interp_with_gather_nd(
        x=x,
        x_ref_min=x_ref_min,
        x_ref_max=x_ref_max,
        y_ref=y_ref,
        nd=nd,
        fill_value=fill_value,
        batch_dims=ps.rank(x) - 2)
def trapz(
    y,
    x=None,
    dx=None,
    axis=-1,
    name=None,
):
    """Integrate y(x) on the specified axis using the trapezoidal rule.

  Computes ∫ y(x) dx ≈ Σ [0.5 (y_k + y_{k+1}) * (x_{k+1} - x_k)]

  Args:
    y: Float `Tensor` of values to integrate.
    x: Optional, Float `Tensor` of points corresponding to the y values. The
      shape of x should match that of y. If x is None, the sample points are
      assumed to be evenly spaced dx apart.
    dx: Scalar float `Tensor`. The spacing between sample points when x is None.
      If neither x nor dx is provided then the default is dx = 1.
    axis: Scalar integer `Tensor`. The axis along which to integrate.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None`, uses name='trapz'.

  Returns:
    Float `Tensor` integral approximated by trapezoidal rule.
      Has the shape of y but with the dimension associated with axis removed.
  """
    with tf.name_scope(name or 'trapz'):
        if not (x is None or dx is None):
            raise ValueError(
                'Not permitted to specify both x and dx input args.')
        dtype = dtype_util.common_dtype([y, x, dx], dtype_hint=tf.float32)
        axis = ps.convert_to_shape_tensor(axis, dtype=tf.int32, name='axis')
        axis_rank = tensorshape_util.rank(axis.shape)
        if axis_rank is None:
            raise ValueError('Require axis to have a static shape.')
        if axis_rank:
            raise ValueError(
                'Only permitted to specify one axis, got axis={}'.format(axis))
        y = tf.convert_to_tensor(y, dtype=dtype, name='y')
        y_shape = ps.convert_to_shape_tensor(ps.shape(y), dtype=tf.int32)
        length = y_shape[axis]
        if x is None:
            if dx is None:
                dx = 1.
            dx = tf.convert_to_tensor(dx, dtype=dtype, name='dx')
            if ps.shape(dx):
                raise ValueError(
                    'Expected dx to be a scalar, got dx={}'.format(dx))
            elem_sum = tf.reduce_sum(y, axis=axis)
            elem_sum -= 0.5 * tf.reduce_sum(tf.gather(y, [0, length - 1],
                                                      axis=axis),
                                            axis=axis)  # half weight endpoints
            return elem_sum * dx
        else:
            x = tf.convert_to_tensor(x, dtype=dtype, name='x')
            tensorshape_util.assert_is_compatible_with(x.shape, y.shape)
            dx = (tf.gather(x, ps.range(1, length), axis=axis) -
                  tf.gather(x, ps.range(0, length - 1), axis=axis))
            return 0.5 * tf.reduce_sum(
                (tf.gather(y, ps.range(1, length), axis=axis) +
                 tf.gather(y, ps.range(0, length - 1), axis=axis)) * dx,
                axis=axis)