Beispiel #1
 def testRollStatic(self):
     with self.assertRaisesRegexp(Exception, 'None'):
         distribution_util.rotate_transpose(None, 1)
     for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
         for shift in np.arange(-5, 5):
             y = distribution_util.rotate_transpose(x, shift)
             self.assertAllEqual(self._np_rotate_transpose(x, shift),
             self.assertAllEqual(np.roll(x.shape, shift),
Beispiel #2
 def testRollStatic(self):
   if tf.executing_eagerly():
     error_message = r'Attempt to convert a value \(None\)'
     error_message = 'None values not supported.'
   with self.assertRaisesRegexp(ValueError, error_message):
     distribution_util.rotate_transpose(None, 1)
   for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
     for shift in np.arange(-5, 5):
       y = distribution_util.rotate_transpose(x, shift)
           self._np_rotate_transpose(x, shift), self.evaluate(y))
           np.roll(x.shape, shift), tensorshape_util.as_list(y.shape))
Beispiel #3
    def undo_make_batch_of_event_sample_matrices(
        """Reshapes/transposes `Distribution` `Tensor` from B_+E_+S_ to S+B+E.

      - `B_ = B if B or not expand_batch_dim else [1]`,
      - `E_ = E if E else [1]`,
      - `S_ = [tf.reduce_prod(S)]`.

    This function "reverses" `make_batch_of_event_sample_matrices`.

      x: `Tensor` of shape `B_+E_+S_`.
      sample_shape: `Tensor` (1D, `int32`).
      expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
        such that `batch_ndims>=1`.
      name: Python `str`. The name to give this op.

      x: `Tensor`. Input transposed/reshaped to `S+B+E`.
        with self._name_scope(name, values=[x, sample_shape]):
            x = tf.convert_to_tensor(x, name="x")
            # x.shape: _B+_E+[prod(S)]
            sample_shape = tf.convert_to_tensor(sample_shape,
            x = distribution_util.rotate_transpose(x, shift=1)
            # x.shape: [prod(S)]+_B+_E
            if self._is_all_constant_helper(self.batch_ndims,
                if self._batch_ndims_is_0 or self._event_ndims_is_0:
                    squeeze_dims = []
                    if self._event_ndims_is_0:
                        squeeze_dims += [-1]
                    if self._batch_ndims_is_0 and expand_batch_dim:
                        squeeze_dims += [1]
                    if squeeze_dims:
                        x = tf.squeeze(x, axis=squeeze_dims)
                        # x.shape: [prod(S)]+B+E
                _, batch_shape, event_shape = self.get_shape(x)
                s = (x.shape.as_list()
                     if x.shape.is_fully_defined() else tf.shape(x))
                batch_shape = s[1:1 + self.batch_ndims]
                # Since sample_dims=1 and is left-most, we add 1 to the number of
                # batch_ndims to get the event start dim.
                event_start = tf.where(
                    tf.logical_and(expand_batch_dim, self._batch_ndims_is_0),
                    2, 1 + self.batch_ndims)
                event_shape = s[event_start:event_start + self.event_ndims]
            new_shape = tf.concat([sample_shape, batch_shape, event_shape], 0)
            x = tf.reshape(x, shape=new_shape)
            # x.shape: S+B+E
            return x
 def testRollDynamic(self):
   for x_value in (np.ones(1, dtype=np.float32),
                   np.ones([2, 1], dtype=np.float32),
                   np.ones([3, 2, 1], dtype=np.float32)):
     for shift_value in np.arange(-5, 5).astype(np.int32):
       x = tf1.placeholder_with_default(x_value, shape=None)
       shift = tf1.placeholder_with_default(shift_value, shape=None)
           self._np_rotate_transpose(x_value, shift_value),
           self.evaluate(distribution_util.rotate_transpose(x, shift)))
Beispiel #5
def _sub_diag(nonmatrix):
    """Get the first sub-diagonal of a shape [N, N, ...] 'non matrix'."""
    with tf.name_scope('sub_matrix'):
        # TODO(b/143702351) Once array_ops.matrix_diag_part_v3 is ready and exposed,
        # replace the call to matrix_diag_part_v2 below with tf.linalg.matrix_diag.
        # We can also stop special casing for matrix_dim < 2 at that point.
        # Until then, OpError raised for 1x1 matricies without static shape.
        # In fact, non-static shape breaks matrix_diag_part_v2, so we must raise
        # this message now.
        # See http://b/138403336 for the TF issue tracker.
        if not tensorshape_util.is_fully_defined(nonmatrix.shape[:2]):
            raise ValueError(
                '`inverse_temperatures did not have statically defined shape, '
                'which breaks tracking of is_swap_{proposed,accepted}.  '
                'Please provide an inverse_temperatures with statically known shape.'

        # The sub-matrix of a 1x1 matrix is not defined (throws exception), so in
        # this special case return an empty matrix.
        # TODO(b/143702351) Remove this special case handling once
        # matrix_diag_part_v3 is ready.
        matrix_dim = ps.size0(nonmatrix)
        if matrix_dim is not None and matrix_dim < 2:
            # Shape is [..., 0], so returned tensor is empty, thus contains no
            # values...and therefore the fact that we use 'ones' doesn't matter.
            shape = ps.pad(ps.shape(nonmatrix)[2:],
                           paddings=[[0, 1]],
            matrix_sub_diag = tf.cast(tf.ones(shape), nonmatrix.dtype)

            # Get first sub-diagonal.  `padding_value` is not used (since matrix is
            # square), but is required for the API since this is raw gen_array_ops.
            matrix_sub_diag = tf.raw_ops.MatrixDiagPartV2(
                input=distribution_util.rotate_transpose(nonmatrix, shift=-2),
                k=ps.convert_to_shape_tensor(-1, dtype=tf.int32),
                padding_value=tf.cast(0.0, dtype=nonmatrix.dtype))

        return distribution_util.rotate_transpose(matrix_sub_diag, shift=1)
def _observation_particles_cov_linop(
    """LinearOperatorLowRankUpdate holding observation noise covariance.

  All arguments can be derived from `observation_particles_dist`. We pass them
  as arguments to have a simpler graph, and encourage calling `.sample` once.

    predicted_observation_particles: Ensemble of state particles fed through the
      observation function.  `observation_particles_dist.mean()`
    ensemble_mean_observations: Ensemble mean (mean across `axis=0`) of
    observation_cov: `LinearOperator` defining the observation noise covariance.

    LinearOperatorLowRankUpdate with covariance the sum of `observation_cov`
      and the ensemble covariance of `predicted_observation_particles`.
    # In our usual docstring notation, let B be a batch shape, X be the ensemble
    # of states, and G(X) the deterministic observation transformation of X. Then,
    # predicted_observations_particles = G(X)  (an ensemble)
    #                  shape = [n_ensemble] + B + [n_observations]
    # ensemble_mean_observations =
    #    tf.reduce_mean(predicted_observations, axis=0)  # Ensemble mean

    # Create matrix U with shape B + [n_observations, n_ensemble] so that, with
    # Cov the ensemble covariance, Cov(G(X)) = UUᵀ.
    centered_observations = (predicted_observation_particles -
    n_ensemble = tf.cast(
        tf.shape(centered_observations)[0], centered_observations.dtype)
    u = distribution_util.rotate_transpose(
        centered_observations / tf.sqrt(n_ensemble), -1)

    # cov_operator ~ Γ + Cov(G(X))
    return tf.linalg.LinearOperatorLowRankUpdate(
        base_operator=observation_cov,  # = Γ
        u=u,  # UUᵀ = Cov(G(X))
Beispiel #7
    def make_batch_of_event_sample_matrices(
        """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_.

      - `B_ = B if B or not expand_batch_dim else [1]`,
      - `E_ = E if E else [1]`,
      - `S_ = [tf.reduce_prod(S)]`.

      x: `Tensor`.
      expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
        such that `batch_ndims >= 1`.
      name: Python `str`. The name to give this op.

      x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`.
      sample_shape: `Tensor` (1D, `int32`).
        with self._name_scope(name, values=[x]):
            x = tf.convert_to_tensor(x, name="x")
            # x.shape: S+B+E
            sample_shape, batch_shape, event_shape = self.get_shape(x)
            event_shape = distribution_util.pick_vector(
                self._event_ndims_is_0, [1], event_shape)
            if expand_batch_dim:
                batch_shape = distribution_util.pick_vector(
                    self._batch_ndims_is_0, [1], batch_shape)
            new_shape = tf.concat([[-1], batch_shape, event_shape], 0)
            x = tf.reshape(x, shape=new_shape)
            # x.shape: [prod(S)]+B_+E_
            x = distribution_util.rotate_transpose(x, shift=-1)
            # x.shape: B_+E_+[prod(S)]
            return x, sample_shape
Beispiel #8
def auto_correlation(x,
    """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

    TypeError:  If `x` is not a supported type.
    # Implementation details:
    # Extend length N / 2 1-D array x to length N by zero padding onto the end.
    # Then, set
    #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
    # It is not hard to see that
    #   F[x]_k Conj(F[x]_k) = F[R]_k, where
    #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
    # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

    # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
    # based version of estimating RXX.
    # Note that this is a special case of the Wiener-Khinchin Theorem.
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name='x')

        # Rotate dimensions of x in order to put axis at the rightmost dim.
        # FFT op requires this.
        rank = ps.rank(x)
        if axis < 0:
            axis = rank + axis
        shift = rank - 1 - axis
        # Suppose x.shape[axis] = T, so there are T 'time' steps.
        #   ==> x_rotated.shape = B + [T],
        # where B is x_rotated's batch shape.
        x_rotated = distribution_util.rotate_transpose(x, shift)

        if center:
            x_rotated = x_rotated - tf.reduce_mean(
                x_rotated, axis=-1, keepdims=True)

        # x_len = N / 2 from above explanation.  The length of x along axis.
        # Get a value for x_len that works in all cases.
        x_len = ps.shape(x_rotated)[-1]

        # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
        # the moment is necessary so that all FFT implementations work.
        # Zero pad to the next power of 2 greater than 2 * x_len, which equals
        # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
        x_len_float64 = ps.cast(x_len, np.float64)
        target_length = ps.pow(np.float64(2.),
                               ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.)))
        pad_length = ps.cast(target_length - x_len_float64, np.int32)

        # We should have:
        # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
        #                     = B + [T + pad_length]
        x_rotated_pad = distribution_util.pad(x_rotated,

        dtype = x.dtype
        if not dtype_util.is_complex(dtype):
            if not dtype_util.is_floating(dtype):
                raise TypeError(
                    'Argument x must have either float or complex dtype'
                    ' found: {}'.format(dtype))
            x_rotated_pad = tf.complex(

        # Autocorrelation is IFFT of power-spectral density (up to some scaling).
        fft_x_rotated_pad = tf.signal.fft(x_rotated_pad)
        spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad)
        # shifted_product is R[m] from above detailed explanation.
        # It is the inner product sum_n X[n] * Conj(X[n - m]).
        shifted_product = tf.signal.ifft(spectral_density)

        # Cast back to real-valued if x was real to begin with.
        shifted_product = tf.cast(shifted_product, dtype)

        # Figure out if we can deduce the final static shape, and set max_lags.
        # Use x_rotated as a reference, because it has the time dimension in the far
        # right, and was created before we performed all sorts of crazy shape
        # manipulations.
        know_static_shape = True
        if not tensorshape_util.is_fully_defined(x_rotated.shape):
            know_static_shape = False
        if max_lags is None:
            max_lags = x_len - 1
            max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
            max_lags_ = tf.get_static_value(max_lags)
            if max_lags_ is None or not know_static_shape:
                know_static_shape = False
                max_lags = tf.minimum(x_len - 1, max_lags)
                max_lags = min(x_len - 1, max_lags_)

        # Chop off the padding.
        # We allow users to provide a huge max_lags, but cut it off here.
        # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
        shifted_product_chopped = shifted_product[..., :max_lags + 1]

        # If possible, set shape.
        if know_static_shape:
            chopped_shape = tensorshape_util.as_list(x_rotated.shape)
            chopped_shape[-1] = min(x_len, max_lags + 1)
            tensorshape_util.set_shape(shifted_product_chopped, chopped_shape)

        # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
        # other terms were zeros arising only due to zero padding.
        # `denominator = (N / 2 - m)` (defined below) is the proper term to
        # divide by to make this an unbiased estimate of the expectation
        # E[X[n] Conj(X[n - m])].
        x_len = ps.cast(x_len, dtype_util.real_dtype(dtype))
        max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype))
        denominator = x_len - ps.range(0., max_lags + 1.)
        denominator = ps.cast(denominator, dtype)
        shifted_product_rotated = shifted_product_chopped / denominator

        if normalize:
            shifted_product_rotated /= shifted_product_rotated[..., :1]

        # Transpose dimensions back to those of x.
        return distribution_util.rotate_transpose(shifted_product_rotated,
Beispiel #9
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
        with tf.name_scope(
                mcmc_util.make_name(, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            if self._state_includes_replicas:
                it_n_replica = inverse_temperatures.shape[0]
                state_n_replica = init_state[0].shape[0]
                if ((it_n_replica is not None)
                        and (state_n_replica is not None)
                        and (it_n_replica != state_n_replica)):
                    raise ValueError(
                        'Number of replicas implied by initial state ({}) must equal '
                        'number of replicas implied by inverse_temperatures ({}), but '
                        'did not'.format(it_n_replica, state_n_replica))

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = ps.size0(inverse_temperatures)
            replica_shape = ps.convert_to_shape_tensor([num_replica])

            if self._state_includes_replicas:
                replica_states = init_state
                replica_states = [
                    tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                        ps.concat([replica_shape, ps.shape(x)], axis=0),
                        name='replica_states') for x in init_state

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
            except TypeError as e:
                if 'argument' not in str(e):
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a second '
                    '(`seed`) argument. `TransitionKernel` instances now receive seeds '
                    'via `one_step`.')

            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = ps.shape(
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = bu.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = bu.left_justified_broadcast_to(tf.range(num_replica),
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),

            return ReplicaExchangeMCKernelResults(
                step_count=tf.zeros(shape=(), dtype=tf.int32),
Beispiel #10
def percentile(x,
    """Compute the `q`-th percentile(s) of `x`.

  Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the
  way from the minimum to the maximum in a sorted copy of `x`.

  The values and distances of the two nearest neighbors as well as the
  `interpolation` parameter will determine the percentile if the normalized
  ranking does not match the location of `q` exactly.

  This function is the same as the median if `q = 50`, the same as the minimum
  if `q = 0` and the same as the maximum if `q = 100`.

  Multiple percentiles can be computed at once by using `1-D` vector `q`.
  Dimension zero of the returned `Tensor` will index the different percentiles.

  Compare to `numpy.percentile`.

    x:  Numeric `N-D` `Tensor` with `N > 0`.  If `axis` is not `None`,
      `x` must have statically known number of dimensions.
    q:  Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s).
    axis:  Optional `0-D` or `1-D` integer `Tensor` with constant values. The
      axis that index independent samples over which to return the desired
      percentile.  If `None` (the default), treat every dimension as a sample
      dimension, returning a scalar.
    interpolation : {'nearest', 'linear', 'lower', 'higher', 'midpoint'}.
      Default value: 'nearest'.  This specifies the interpolation method to
      use when the desired quantile lies between two data points `i < j`:
        * linear: i + (j - i) * fraction, where fraction is the fractional part
          of the index surrounded by i and j.
        * lower: `i`.
        * higher: `j`.
        * nearest: `i` or `j`, whichever is nearest.
        * midpoint: (i + j) / 2.
      `linear` and `midpoint` interpolation do not work with integer dtypes.
    keep_dims:  Python `bool`. If `True`, the last dimension is kept with size 1
      If `False`, the last dimension is removed from the output shape.
    validate_args:  Whether to add runtime checks of argument validity. If
      False, and arguments are incorrect, correct behavior is not guaranteed.
    preserve_gradients:  Python `bool`.  If `True`, ensure that gradient w.r.t
      the percentile `q` is preserved in the case of linear interpolation.
      If `False`, the gradient will be (incorrectly) zero when `q` corresponds
      to a point in `x`.
    name:  A Python string name to give this `Op`.  Default is 'percentile'

    A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or,
      if `axis` is `None`, a `rank(q)` `Tensor`.  The first `rank(q)` dimensions
      index quantiles for different values of `q`.

    ValueError:  If argument 'interpolation' is not an allowed type.
    ValueError:  If interpolation type not compatible with `dtype`.

  #### Examples

  # Get 30th percentile with default ('nearest') interpolation.
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=30.)
  ==> 2.0

  # Get 30th percentile with 'linear' interpolation.
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=30., interpolation='linear')
  ==> 1.9

  # Get 30th and 70th percentiles with 'lower' interpolation
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=[30., 70.], interpolation='lower')
  ==> [1., 3.]

  # Get 100th percentile (maximum).  By default, this is computed over every dim
  x = [[1., 2.]
       [3., 4.]]
  tfp.stats.percentile(x, q=100.)
  ==> 4.

  # Treat the leading dim as indexing samples, and find the 100th quantile (max)
  # over all such samples.
  x = [[1., 2.]
       [3., 4.]]
  tfp.stats.percentile(x, q=100., axis=[0])
  ==> [3., 4.]

    name = name or 'percentile'
    allowed_interpolations = {
        'linear', 'lower', 'higher', 'nearest', 'midpoint'

    if interpolation is None:
        interpolation = 'nearest'
        if interpolation not in allowed_interpolations:
            raise ValueError(
                'Argument `interpolation` must be in %s.  Found %s' %
                (allowed_interpolations, interpolation))

    with tf1.name_scope(name, values=[x, q]):
        x = tf.convert_to_tensor(value=x, name='x')

        if interpolation in {'linear', 'midpoint'} and x.dtype.is_integer:
            raise TypeError(
                '{} interpolation not allowed with dtype {}'.format(
                    interpolation, x.dtype))

        # Double is needed here and below, else we get the wrong index if the array
        # is huge along axis.
        q = tf.cast(q, tf.float64)
        _get_static_ndims(q, expect_ndims_no_more_than=1)

        if validate_args:
            q = distribution_util.with_dependencies([
                tf1.assert_rank_in(q, [0, 1]),
                tf1.assert_greater_equal(q, tf.cast(0., tf.float64)),
                tf1.assert_less_equal(q, tf.cast(100., tf.float64))
            ], q)

        # Move `axis` dims of `x` to the rightmost, call it `y`.
        if axis is None:
            y = tf.reshape(x, [-1])
            x_ndims = _get_static_ndims(x,
            axis = _make_static_axis_non_negative_list(axis, x_ndims)
            y = _move_dims_to_flat_end(x, axis, x_ndims, right_end=True)

        frac_at_q_or_above = 1. - q / 100.

        # Sort everything, not just the top 'k' entries, which allows multiple calls
        # to sort only once (under the hood) and use CSE.
        sorted_y = _sort_tensor(y)

        d = tf.cast(tf.shape(input=y)[-1], tf.float64)

        def _get_indices(interp_type):
            """Get values of y at the indices implied by interp_type."""
            # Note `lower` <--> ceiling.  Confusing, huh?  Due to the fact that
            # _sort_tensor sorts highest to lowest, tf.ceil corresponds to the higher
            # index, but the lower value of y!
            if interp_type == 'lower':
                indices = tf.math.ceil((d - 1) * frac_at_q_or_above)
            elif interp_type == 'higher':
                indices = tf.floor((d - 1) * frac_at_q_or_above)
            elif interp_type == 'nearest':
                indices = tf.round((d - 1) * frac_at_q_or_above)
            # d - 1 will be distinct from d in int32, but not necessarily double.
            # So clip to avoid out of bounds errors.
            return tf.clip_by_value(tf.cast(indices, tf.int32), 0,
                                    tf.shape(input=y)[-1] - 1)

        if interpolation in ['nearest', 'lower', 'higher']:
            gathered_y = tf.gather(sorted_y,
        elif interpolation == 'midpoint':
            gathered_y = 0.5 * (
                tf.gather(sorted_y, _get_indices('lower'), axis=-1) +
                tf.gather(sorted_y, _get_indices('higher'), axis=-1))
        elif interpolation == 'linear':
            # Copy-paste of docstring on interpolation:
            # linear: i + (j - i) * fraction, where fraction is the fractional part
            # of the index surrounded by i and j.
            larger_y_idx = _get_indices('lower')
            exact_idx = (d - 1) * frac_at_q_or_above
            if preserve_gradients:
                # If q corresponds to a point in x, we will initially have
                # larger_y_idx == smaller_y_idx.
                # This results in the gradient w.r.t. fraction being zero (recall `q`
                # enters only through `fraction`...and see that things cancel).
                # The fix is to ensure that smaller_y_idx and larger_y_idx are always
                # separated by exactly 1.
                smaller_y_idx = tf.maximum(larger_y_idx - 1, 0)
                larger_y_idx = tf.minimum(smaller_y_idx + 1,
                                          tf.shape(input=y)[-1] - 1)
                fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx
                smaller_y_idx = _get_indices('higher')
                fraction = tf.math.ceil(
                    (d - 1) * frac_at_q_or_above) - exact_idx

            fraction = tf.cast(fraction, y.dtype)
            gathered_y = (
                tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction) +
                tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction)

        # Propagate NaNs
        if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64):
            # Apparently tf.is_nan doesn't like other dtypes
            nan_batch_members = tf.reduce_any(input_tensor=tf.math.is_nan(x),
            right_rank_matched_shape = tf.pad(
                paddings=[[0, tf.rank(input=q)]],
            nan_batch_members = tf.reshape(nan_batch_members,
            nan = np.array(np.nan, gathered_y.dtype.as_numpy_dtype)
            gathered_y = tf.where(nan_batch_members, nan, gathered_y)

        # Expand dimensions if requested
        if keep_dims:
            if axis is None:
                ones_vec = tf.ones(shape=[
                    _get_best_effort_ndims(x) + _get_best_effort_ndims(q)
                gathered_y *= tf.ones(ones_vec, dtype=x.dtype)
                gathered_y = _insert_back_keep_dims(gathered_y, axis)

        # If q is a scalar, then result has the right shape.
        # If q is a vector, then result has trailing dim of shape q.shape, which
        # needs to be rotated to dim 0.
        return distribution_util.rotate_transpose(gathered_y, tf.rank(q))
Beispiel #11
def find_bins(x,
    """Bin values into discrete intervals.

  Given `edges = [c0, ..., cK]`, defining intervals
  `I0 = [c0, c1)`, `I1 = [c1, c2)`, ..., `I_{K-1} = [c_{K-1}, cK]`,
  This function returns `bins`, such that:
  `edges[bins[i]] <= x[i] < edges[bins[i] + 1]`.

    x:  Numeric `N-D` `Tensor` with `N > 0`.
    edges:  `Tensor` of same `dtype` as `x`.  The first dimension indexes edges
      of intervals.  Must either be `1-D` or have
      `x.shape[1:] == edges.shape[1:]`.  If `rank(edges) > 1`, `edges[k]`
      designates a shape `edges.shape[1:]` `Tensor` of bin edges for the
      corresponding dimensions of `x`.
    extend_lower_interval:  Python `bool`.  If `True`, extend the lowest
      interval `I0` to `(-inf, c1]`.
    extend_upper_interval:  Python `bool`.  If `True`, extend the upper
      interval `I_{K-1}` to `[c_{K-1}, +inf)`.
    dtype: The output type (`int32` or `int64`). `Default value:` `x.dtype`.
      This effects the output values when `x` is below/above the intervals,
      which will be `-1/K+1` for `int` types and `NaN` for `float`s.
      At indices where `x` is `NaN`, the output values will be `0` for `int`
      types and `NaN` for floats.
    name:  A Python string name to prepend to created ops. Default: 'find_bins'

    bins: `Tensor` with same `shape` as `x` and `dtype`.
      Has whole number values.  `bins[i] = k` means the `x[i]` falls into the
      `kth` bin, ie, `edges[bins[i]] <= x[i] < edges[bins[i] + 1]`.

    ValueError:  If `edges.shape[0]` is determined to be less than 2.

  #### Examples

  Cut a `1-D` array

  x = [0., 5., 6., 10., 20.]
  edges = [0., 5., 10.]
  tfp.stats.find_bins(x, edges)
  ==> [0., 0., 1., 1., np.nan]

  Cut `x` into its deciles

  x = tf.random_uniform(shape=(100, 200))
  decile_edges = tfp.stats.quantiles(x, num_quantiles=10)
  bins = tfp.stats.find_bins(x, edges=decile_edges)
  ==> (100, 200)
  tf.reduce_mean(bins == 0.)
  ==> approximately 0.1
  tf.reduce_mean(bins == 1.)
  ==> approximately 0.1

    # TFP users may be surprised to see the "action" in the leftmost dim of
    # edges, rather than the rightmost (event) dim.  Why?
    # 1. Most likely you created edges by getting quantiles over samples, and
    #    quantile/percentile return these edges in the leftmost (sample) dim.
    # 2. Say you have event_shape = [5], then we expect the bin will be different
    #    for all 5 events, so the index of the bin should not be in the event dim.
    with tf1.name_scope(name, default_name='find_bins', values=[x, edges]):
        in_type = dtype_util.common_dtype([x, edges], dtype_hint=tf.float32)
        edges = tf.convert_to_tensor(value=edges, name='edges', dtype=in_type)
        x = tf.convert_to_tensor(value=x, name='x', dtype=in_type)

        if (tf.compat.dimension_value(edges.shape[0]) is not None
                and tf.compat.dimension_value(edges.shape[0]) < 2):
            raise ValueError(
                'First dimension of `edges` must have length > 1 to index 1 or '
                'more bin. Found: {}'.format(edges.shape))

        flattening_x = edges.shape.ndims == 1 and x.shape.ndims > 1

        if flattening_x:
            x_orig_shape = tf.shape(input=x)
            x = tf.reshape(x, [-1])

        if dtype is None:
            dtype = in_type
        dtype = tf.as_dtype(dtype)

        # Move first dims into the rightmost.
        x_permed = distribution_util.rotate_transpose(x, shift=-1)
        edges_permed = distribution_util.rotate_transpose(edges, shift=-1)

        # If...
        #   x_permed = [0, 1, 6., 10]
        #   edges = [0, 5, 10.]
        #   ==> almost_output = [0, 1, 2, 2]
        searchsorted_type = dtype if dtype in [tf.int32, tf.int64] else None
        almost_output_permed = tf.searchsorted(sorted_sequence=edges_permed,
        # Move the rightmost dims back to the leftmost.
        almost_output = tf.cast(
            distribution_util.rotate_transpose(almost_output_permed, shift=1),

        # In above example, we want [0, 0, 1, 1], so correct this here.
        bins = tf.clip_by_value(almost_output - 1, tf.cast(0, dtype),
                                tf.cast(tf.shape(input=edges)[0] - 2, dtype))

        if not extend_lower_interval:
            low_fill = np.nan if dtype.is_floating else -1
            bins = tf.where(x < tf.expand_dims(edges[0], 0),
                            tf.cast(low_fill, dtype), bins)

        if not extend_upper_interval:
            up_fill = np.nan if dtype.is_floating else tf.shape(
                input=edges)[0] - 1
            bins = tf.where(x > tf.expand_dims(edges[-1], 0),
                            tf.cast(up_fill, dtype), bins)

        if flattening_x:
            bins = tf.reshape(bins, x_orig_shape)

        return bins
Beispiel #12
def percentile(x,
    """Compute the `q`-th percentile(s) of `x`.

  Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the
  way from the minimum to the maximum in a sorted copy of `x`.

  The values and distances of the two nearest neighbors as well as the
  `interpolation` parameter will determine the percentile if the normalized
  ranking does not match the location of `q` exactly.

  This function is the same as the median if `q = 50`, the same as the minimum
  if `q = 0` and the same as the maximum if `q = 100`.

  Multiple percentiles can be computed at once by using `1-D` vector `q`.
  Dimension zero of the returned `Tensor` will index the different percentiles.

  # Get 30th percentile with default ('nearest') interpolation.
  x = [1., 2., 3., 4.]
  percentile(x, q=30.)
  ==> 2.0

  # Get 30th and 70th percentiles with 'lower' interpolation
  x = [1., 2., 3., 4.]
  percentile(x, q=[30., 70.], interpolation='lower')
  ==> [1., 3.]

  # Get 100th percentile (maximum).  By default, this is computed over every dim
  x = [[1., 2.]
       [3., 4.]]
  percentile(x, q=100.)
  ==> 4.

  # Treat the leading dim as indexing samples, and find the 100th quantile (max)
  # over all such samples.
  x = [[1., 2.]
       [3., 4.]]
  percentile(x, q=100., axis=[0])
  ==> [3., 4.]

  Compare to `numpy.percentile`.

    x:  Floating point `N-D` `Tensor` with `N > 0`.  If `axis` is not `None`,
      `x` must have statically known number of dimensions.
    q:  Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s).
    axis:  Optional `0-D` or `1-D` integer `Tensor` with constant values. The
      axis that hold independent samples over which to return the desired
      percentile.  If `None` (the default), treat every dimension as a sample
      dimension, returning a scalar.
    interpolation : {'lower', 'higher', 'nearest'}.  Default: 'nearest' This
      optional parameter specifies the interpolation method to
      use when the desired quantile lies between two data points `i < j`:
        * lower: `i`.
        * higher: `j`.
        * nearest: `i` or `j`, whichever is nearest.
    keep_dims:  Python `bool`. If `True`, the last dimension is kept with size 1
      If `False`, the last dimension is removed from the output shape.
    validate_args:  Whether to add runtime checks of argument validity. If
      False, and arguments are incorrect, correct behavior is not guaranteed.
    name:  A Python string name to give this `Op`.  Default is 'percentile'

    A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or,
      if `axis` is `None`, a `rank(q)` `Tensor`.  The first `rank(q)` dimensions
      index quantiles for different values of `q`.

    ValueError:  If argument 'interpolation' is not an allowed type.
    name = name or 'percentile'
    allowed_interpolations = {'lower', 'higher', 'nearest'}

    if interpolation is None:
        interpolation = 'nearest'
        if interpolation not in allowed_interpolations:
            raise ValueError(
                'Argument `interpolation` must be in %s.  Found %s' %
                (allowed_interpolations, interpolation))

    with tf.name_scope(name, values=[x, q]):
        x = tf.convert_to_tensor(x, name='x')
        # Double is needed here and below, else we get the wrong index if the array
        # is huge along axis.
        q = tf.to_double(q, name='q')
        _get_static_ndims(q, expect_ndims_no_more_than=1)

        if validate_args:
            q = control_flow_ops.with_dependencies([
                tf.assert_rank_in(q, [0, 1]),
                tf.assert_greater_equal(q, tf.to_double(0.)),
                tf.assert_less_equal(q, tf.to_double(100.))
            ], q)

        if axis is None:
            y = tf.reshape(x, [-1])
            axis = tf.convert_to_tensor(axis, name='axis')
            axis_ndims = _get_static_ndims(axis,
            axis_const = tensor_util.constant_value(axis)
            if axis_const is None:
                raise ValueError(
                    'Expected argument `axis` to be statically available.  Found: %s'
                    % axis)
            axis = axis_const
            if axis_ndims == 0:
                axis = [axis]
            axis = [int(a) for a in axis]
            x_ndims = _get_static_ndims(x,
            axis = _make_static_axis_non_negative(axis, x_ndims)
            # Move dims in axis to the end, since _sort_tensor, which calls top_k,
            # only sorts the last dim.
            y = _move_dims_to_flat_end(x, axis, x_ndims)

        frac_at_q_or_above = 1. - q / 100.
        d = tf.to_double(tf.shape(y)[-1])

        if interpolation == 'lower':
            indices = tf.ceil((d - 1) * frac_at_q_or_above)
        elif interpolation == 'higher':
            indices = tf.floor((d - 1) * frac_at_q_or_above)
        elif interpolation == 'nearest':
            indices = tf.round((d - 1) * frac_at_q_or_above)

        # If d is gigantic, then we would have d == d - 1, even in double... So
        # let's use max/min to avoid out of bounds errors.
        d = tf.shape(y)[-1]
        # d - 1 will be distinct from d in int32.
        indices = tf.clip_by_value(tf.to_int32(indices), 0, d - 1)

        # Sort everything, not just the top 'k' entries, which allows multiple calls
        # to sort only once (under the hood) and use CSE.
        sorted_y = _sort_tensor(y)

        # Gather the indices along the sorted (last) dimension.
        # If q is a vector, the last dim of gathered_y indexes different q[i].
        gathered_y = tf.gather(sorted_y, indices, axis=-1)

        if keep_dims:
            if axis is None:
                ones_vec = tf.ones(shape=[
                    _get_best_effort_ndims(x) + _get_best_effort_ndims(q)
                gathered_y *= tf.ones(ones_vec, dtype=x.dtype)
                gathered_y = _insert_back_keep_dims(gathered_y, axis)

        # If q is a scalar, then result has the right shape.
        # If q is a vector, then result has trailing dim of shape q.shape, which
        # needs to be rotated to dim 0.
        return util.rotate_transpose(gathered_y, tf.rank(q))
Beispiel #13
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
        with tf.name_scope(
                mcmc_util.make_name(, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            if self._state_includes_replicas:
                it_n_replica = inverse_temperatures.shape[0]
                state_n_replica = init_state[0].shape[0]
                if ((it_n_replica is not None)
                        and (state_n_replica is not None)
                        and (it_n_replica != state_n_replica)):
                    raise ValueError(
                        'Number of replicas implied by initial state ({}) must equal '
                        'number of replicas implied by inverse_temperatures ({}), but '
                        'did not'.format(it_n_replica, state_n_replica))

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = ps.size0(inverse_temperatures)
            replica_shape = tf.convert_to_tensor([num_replica])

            if self._state_includes_replicas:
                replica_states = init_state
                replica_states = [
                    tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                        ps.concat([replica_shape, ps.shape(x)], axis=0),
                        name='replica_states') for x in init_state

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                self.target_log_prob_fn, inverse_temperatures)
            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we fall back to the previous behavior and warn.
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
            except TypeError as e:
                if 'argument' not in str(e):
                    'The second (`seed`) argument to `ReplicaExchangeMC`s '
                    '`make_kernel_fn` is deprecated. `TransitionKernel` instances now '
                    'receive seeds via `bootstrap_results` and `one_step`. This '
                    'fallback may become an error 2020-09-20.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel, self._seed_stream())

            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = ps.shape(
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = mcmc_util.left_justified_broadcast_to(
                tf.range(num_replica), replica_and_batch_shape)
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),

            post_swap_replica_results = _make_post_swap_replica_results(
                lambda x: x,

            return ReplicaExchangeMCKernelResults(
                step_count=tf.zeros(shape=(), dtype=tf.int32),
Beispiel #14
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
        with tf.name_scope(
                mcmc_util.make_name(, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = prefer_static.size0(inverse_temperatures)
            replica_shape = tf.convert_to_tensor([num_replica])

            replica_states = [
                tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                        [replica_shape, prefer_static.shape(x)], axis=0),
                    name='replica_states') for x in init_state

            inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = prefer_static.shape(
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = mcmc_util.left_justified_broadcast_to(
                tf.range(num_replica), replica_and_batch_shape)
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),

            post_swap_replica_results = _make_post_swap_replica_results(
                lambda x: x,

            return ReplicaExchangeMCKernelResults(
Beispiel #15
def auto_correlation(x,
  """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

    TypeError:  If `x` is not a supported type.
  # Implementation details:
  # Extend length N / 2 1-D array x to length N by zero padding onto the end.
  # Then, set
  #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
  # It is not hard to see that
  #   F[x]_k Conj(F[x]_k) = F[R]_k, where
  #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
  # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

  # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
  # based version of estimating RXX.
  # Note that this is a special case of the Wiener-Khinchin Theorem.
  with tf.name_scope(name, values=[x]):
    x = tf.convert_to_tensor(x, name='x')

    # Rotate dimensions of x in order to put axis at the rightmost dim.
    # FFT op requires this.
    rank = util.prefer_static_rank(x)
    if axis < 0:
      axis = rank + axis
    shift = rank - 1 - axis
    # Suppose x.shape[axis] = T, so there are T 'time' steps.
    #   ==> x_rotated.shape = B + [T],
    # where B is x_rotated's batch shape.
    x_rotated = util.rotate_transpose(x, shift)

    if center:
      x_rotated -= tf.reduce_mean(x_rotated, axis=-1, keepdims=True)

    # x_len = N / 2 from above explanation.  The length of x along axis.
    # Get a value for x_len that works in all cases.
    x_len = util.prefer_static_shape(x_rotated)[-1]

    # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
    # the moment is necessary so that all FFT implementations work.
    # Zero pad to the next power of 2 greater than 2 * x_len, which equals
    # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
    x_len_float64 = tf.cast(x_len, np.float64)
    target_length = tf.pow(
        np.float64(2.), tf.ceil(tf.log(x_len_float64 * 2) / np.log(2.)))
    pad_length = tf.cast(target_length - x_len_float64, np.int32)

    # We should have:
    # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
    #                     = B + [T + pad_length]
    x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length)

    dtype = x.dtype
    if not dtype.is_complex:
      if not dtype.is_floating:
        raise TypeError('Argument x must have either float or complex dtype'
                        ' found: {}'.format(dtype))
      x_rotated_pad = tf.complex(x_rotated_pad,

    # Autocorrelation is IFFT of power-spectral density (up to some scaling).
    fft_x_rotated_pad = tf.fft(x_rotated_pad)
    spectral_density = fft_x_rotated_pad * tf.conj(fft_x_rotated_pad)
    # shifted_product is R[m] from above detailed explanation.
    # It is the inner product sum_n X[n] * Conj(X[n - m]).
    shifted_product = tf.ifft(spectral_density)

    # Cast back to real-valued if x was real to begin with.
    shifted_product = tf.cast(shifted_product, dtype)

    # Figure out if we can deduce the final static shape, and set max_lags.
    # Use x_rotated as a reference, because it has the time dimension in the far
    # right, and was created before we performed all sorts of crazy shape
    # manipulations.
    know_static_shape = True
    if not x_rotated.shape.is_fully_defined():
      know_static_shape = False
    if max_lags is None:
      max_lags = x_len - 1
      max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
      max_lags_ = tf.contrib.util.constant_value(max_lags)
      if max_lags_ is None or not know_static_shape:
        know_static_shape = False
        max_lags = tf.minimum(x_len - 1, max_lags)
        max_lags = min(x_len - 1, max_lags_)

    # Chop off the padding.
    # We allow users to provide a huge max_lags, but cut it off here.
    # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
    shifted_product_chopped = shifted_product[..., :max_lags + 1]

    # If possible, set shape.
    if know_static_shape:
      chopped_shape = x_rotated.shape.as_list()
      chopped_shape[-1] = min(x_len, max_lags + 1)

    # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
    # other terms were zeros arising only due to zero padding.
    # `denominator = (N / 2 - m)` (defined below) is the proper term to
    # divide by to make this an unbiased estimate of the expectation
    # E[X[n] Conj(X[n - m])].
    x_len = tf.cast(x_len, dtype.real_dtype)
    max_lags = tf.cast(max_lags, dtype.real_dtype)
    denominator = x_len - tf.range(0., max_lags + 1.)
    denominator = tf.cast(denominator, dtype)
    shifted_product_rotated = shifted_product_chopped / denominator

    if normalize:
      shifted_product_rotated /= shifted_product_rotated[..., :1]

    # Transpose dimensions back to those of x.
    return util.rotate_transpose(shifted_product_rotated, -shift)
    def _get_exchanged_states(self, old_states, exchange_proposed,
                              exchange_proposed_n, sampled_replica_states,
        """Get list of TensorArrays holding exchanged states, and zeros."""
        with tf1.name_scope('get_exchanged_states'):

            target_log_probs = []
            for replica in range(self.num_replica):
                replica_log_prob = _get_field(sampled_replica_results[replica],
                inverse_temp = self.inverse_temperatures[replica]
                target_log_probs.append(replica_log_prob / inverse_temp)
            target_log_probs = tf.stack(target_log_probs, axis=0)

            dtype = target_log_probs.dtype
            num_state_parts = len(sampled_replica_states[0])
            # exchanged_states[k][i] is Tensor of (new) state part k, for replica i.
            # The `k` will be known statically, and `i` is a Tensor.
            # We will insert values into indices `i` for every replica with a proposed
            # exchange.
            exchanged_states = [
                    # State part k has same shape, regardless of replica.  So use 0.
                for k in range(num_state_parts)

            # Two TensorArrays, for KernelResults only.
            if self._exchange_between_adjacent_only:
                # Since exchanges are between adjacent only, we track exchanges by the
                # index of the edge between replicas.  E.g., if we have replicas
                # [0, 1, 2, 3], then edge index 0 is for exchanges between replicas 0
                # and 1.
                is_exchange_proposed_for_kr = tf.TensorArray(
                    dtype=tf.bool,  # Initialized to False
                    size=self.num_replica - 1,
                is_exchange_accepted_for_kr = tf.TensorArray(
                    dtype=tf.bool,  # Initialized to False
                    size=self.num_replica - 1,
                is_exchange_proposed_for_kr = tf.convert_to_tensor(np.nan)
                is_exchange_accepted_for_kr = tf.convert_to_tensor(np.nan)

            # Draw random variables here, to avoid sampling in the loop (and losing
            # reproducibility).  This may mean we sample too many, but we will always
            # have enough.
            sample_shape = tf.concat(
                ([self.num_replica // 2
                  ], tf.shape(input=target_log_probs)[1:]),
            log_uniforms = tf.math.log(

            def _swap(is_exchange_accepted, x, y):
                """Swap batches of x, y where accepted."""
                with tf1.name_scope('swap_where_exchange_accepted'):
                    new_x = mcmc_util.choose(is_exchange_accepted, y, x)
                    new_y = mcmc_util.choose(is_exchange_accepted, x, y)
                return new_x, new_y

            def cond(i, unused_exchanged_states, unused_is_exchanged_for_kr,
                return i < exchange_proposed_n

            def body(i, exchanged_states, is_exchange_proposed_for_kr,
                """Body of while loop for exchanging states."""
                # Propose exchange between replicas indexed by m and n.
                m, n = tf.unstack(exchange_proposed[i])

                # Construct log_accept_ratio:  -temp_diff * target_log_prob_diff.
                # Note target_log_prob_diff = -EnergyDiff (common definition is in terms
                # of energy).
                temp_diff = self.inverse_temperatures[
                    m] - self.inverse_temperatures[n]
                # Difference of target log probs may be +- Inf or NaN.  We want the
                # product of this with the temperature difference to have "alt value" of
                # -Inf.
                log_accept_ratio = mcmc_util.safe_sum([
                    -temp_diff * target_log_probs[m],
                    temp_diff * target_log_probs[n]

                is_exchange_accepted = log_uniforms[i] < log_accept_ratio

                if self._exchange_between_adjacent_only:
                    exchange_edge = tf.minimum(m, n)
                    is_exchange_proposed_for_kr = is_exchange_proposed_for_kr.write(
                        exchange_edge, True)
                    is_exchange_accepted_for_kr = is_exchange_accepted_for_kr.write(
                        exchange_edge, is_exchange_accepted)

                for k in range(num_state_parts):
                    new_m, new_n = _swap(is_exchange_accepted,
                    exchanged_states[k] = exchanged_states[k].write(m, new_m)
                    exchanged_states[k] = exchanged_states[k].write(n, new_n)

                return (i + 1, exchanged_states, is_exchange_proposed_for_kr,

            # At this point, exchanged_states[k] is a length num_replicas TensorArray.
            (exchanged_states, is_exchange_proposed_for_kr,
             is_exchange_accepted_for_kr) = tf.while_loop(
                 ])[1:]  # Remove `i`
            if self._exchange_between_adjacent_only:
                # Stack to give shape [self.num_replica]
                is_exchange_proposed_for_kr = is_exchange_proposed_for_kr.stack(
                is_exchange_proposed_for_kr.set_shape([self.num_replica - 1])

                # Stack on axis=-1 to give shape batch_shape + [self.num_replica]
                # ...TensorArray.stack stacks on axis=0, and doesn't take an axis kwarg,
                # so must rotate_transpose.
                is_exchange_accepted_for_kr = distribution_util.rotate_transpose(
                    is_exchange_accepted_for_kr.stack(), shift=-1)
                    target_log_probs[-1].shape.concatenate(self.num_replica -

            return (exchanged_states, is_exchange_proposed_for_kr,