Пример #1
0
 def testRollStatic(self):
     with self.test_session():
         with self.assertRaisesRegexp(ValueError,
                                      "None values not supported."):
             distribution_util.rotate_transpose(None, 1)
         for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
             for shift in np.arange(-5, 5):
                 y = distribution_util.rotate_transpose(x, shift)
                 self.assertAllEqual(self._np_rotate_transpose(x, shift),
                                     y.eval())
                 self.assertAllEqual(np.roll(x.shape, shift),
                                     y.get_shape().as_list())
Пример #2
0
    def _vectorize_then_blockify(self, matrix):
        """Shape batch matrix to batch vector, then blockify trailing dimensions."""
        # Suppose
        #   matrix.shape = [m0, m1, m2, m3],
        # and matrix is a matrix because the final two dimensions are matrix dims.
        #   self.block_depth = 2,
        #   self.block_shape = [b0, b1]  (note b0 * b1 = m2).
        # We will reshape matrix to
        #   [m3, m0, m1, b0, b1].

        # Vectorize: Reshape to batch vector.
        #   [m0, m1, m2, m3] --> [m3, m0, m1, m2]
        # This is called "vectorize" because we have taken the final two matrix dims
        # and turned this into a size m3 batch of vectors.
        vec = distribution_util.rotate_transpose(matrix, shift=1)

        # Blockify: Blockfy trailing dimensions.
        #   [m3, m0, m1, m2] --> [m3, m0, m1, b0, b1]
        if (vec.shape.is_fully_defined()
                and self.block_shape.is_fully_defined()):
            # vec_leading_shape = [m3, m0, m1],
            # the parts of vec that will not be blockified.
            vec_leading_shape = vec.shape[:-1]
            final_shape = vec_leading_shape.concatenate(self.block_shape)
        else:
            vec_leading_shape = array_ops.shape(vec)[:-1]
            final_shape = array_ops.concat(
                (vec_leading_shape, self.block_shape_tensor()), 0)
        return array_ops.reshape(vec, final_shape)
Пример #3
0
  def _unblockify_then_matricize(self, vec):
    """Flatten the block dimensions then reshape to a batch matrix."""
    # Suppose
    #   vec.shape = [v0, v1, v2, v3],
    #   self.block_depth = 2.
    # Then
    #   leading shape = [v0, v1]
    #   block shape = [v2, v3].
    # We will reshape vec to
    #   [v1, v2*v3, v0].

    # Un-blockify: Flatten block dimensions.  Reshape
    #   [v0, v1, v2, v3] --> [v0, v1, v2*v3].
    if vec.get_shape().is_fully_defined():
      # vec_shape = [v0, v1, v2, v3]
      vec_shape = vec.get_shape().as_list()
      # vec_leading_shape = [v0, v1]
      vec_leading_shape = vec_shape[:-self.block_depth]
      # vec_block_shape = [v2, v3]
      vec_block_shape = vec_shape[-self.block_depth:]
      # flat_shape = [v0, v1, v2*v3]
      flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
    else:
      vec_shape = array_ops.shape(vec)
      vec_leading_shape = vec_shape[:-self.block_depth]
      vec_block_shape = vec_shape[-self.block_depth:]
      flat_shape = array_ops.concat(
          (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0)
    vec_flat = array_ops.reshape(vec, flat_shape)

    # Matricize:  Reshape to batch matrix.
    #   [v0, v1, v2*v3] --> [v1, v2*v3, v0],
    # representing a shape [v1] batch of [v2*v3, v0] matrices.
    matrix = distribution_util.rotate_transpose(vec_flat, shift=-1)
    return matrix
Пример #4
0
    def undo_make_batch_of_event_sample_matrices(
            self,
            x,
            sample_shape,
            expand_batch_dim=True,
            name="undo_make_batch_of_event_sample_matrices"):
        """Reshapes/transposes `Distribution` `Tensor` from B_+E_+S_ to S+B+E.

    Where:
      - `B_ = B if B or not expand_batch_dim else [1]`,
      - `E_ = E if E else [1]`,
      - `S_ = [tf.reduce_prod(S)]`.

    This function "reverses" `make_batch_of_event_sample_matrices`.

    Args:
      x: `Tensor` of shape `B_+E_+S_`.
      sample_shape: `Tensor` (1D, `int32`).
      expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
        such that `batch_ndims>=1`.
      name: Python `str`. The name to give this op.

    Returns:
      x: `Tensor`. Input transposed/reshaped to `S+B+E`.
    """
        with self._name_scope(name, values=[x, sample_shape]):
            x = ops.convert_to_tensor(x, name="x")
            # x.shape: _B+_E+[prod(S)]
            sample_shape = ops.convert_to_tensor(sample_shape,
                                                 name="sample_shape")
            x = distribution_util.rotate_transpose(x, shift=1)
            # x.shape: [prod(S)]+_B+_E
            if self._is_all_constant_helper(self.batch_ndims,
                                            self.event_ndims):
                if self._batch_ndims_is_0 or self._event_ndims_is_0:
                    squeeze_dims = []
                    if self._event_ndims_is_0:
                        squeeze_dims += [-1]
                    if self._batch_ndims_is_0 and expand_batch_dim:
                        squeeze_dims += [1]
                    if squeeze_dims:
                        x = array_ops.squeeze(x, squeeze_dims=squeeze_dims)
                        # x.shape: [prod(S)]+B+E
                _, batch_shape, event_shape = self.get_shape(x)
            else:
                s = (x.get_shape().as_list() if
                     x.get_shape().is_fully_defined() else array_ops.shape(x))
                batch_shape = s[1:1 + self.batch_ndims]
                # Since sample_dims=1 and is left-most, we add 1 to the number of
                # batch_ndims to get the event start dim.
                event_start = array_ops.where(
                    math_ops.logical_and(expand_batch_dim,
                                         self._batch_ndims_is_0), 2,
                    1 + self.batch_ndims)
                event_shape = s[event_start:event_start + self.event_ndims]
            new_shape = array_ops.concat(
                [sample_shape, batch_shape, event_shape], 0)
            x = array_ops.reshape(x, shape=new_shape)
            # x.shape: S+B+E
            return x
  def _unblockify_then_matricize(self, vec):
    """Flatten the block dimensions then reshape to a batch matrix."""
    # Suppose
    #   vec.shape = [v0, v1, v2, v3],
    #   self.block_depth = 2.
    # Then
    #   leading shape = [v0, v1]
    #   block shape = [v2, v3].
    # We will reshape vec to
    #   [v1, v2*v3, v0].

    # Un-blockify: Flatten block dimensions.  Reshape
    #   [v0, v1, v2, v3] --> [v0, v1, v2*v3].
    if vec.get_shape().is_fully_defined():
      # vec_shape = [v0, v1, v2, v3]
      vec_shape = vec.get_shape().as_list()
      # vec_leading_shape = [v0, v1]
      vec_leading_shape = vec_shape[:-self.block_depth]
      # vec_block_shape = [v2, v3]
      vec_block_shape = vec_shape[-self.block_depth:]
      # flat_shape = [v0, v1, v2*v3]
      flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
    else:
      vec_shape = array_ops.shape(vec)
      vec_leading_shape = vec_shape[:-self.block_depth]
      vec_block_shape = vec_shape[-self.block_depth:]
      flat_shape = array_ops.concat(
          (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0)
    vec_flat = array_ops.reshape(vec, flat_shape)

    # Matricize:  Reshape to batch matrix.
    #   [v0, v1, v2*v3] --> [v1, v2*v3, v0],
    # representing a shape [v1] batch of [v2*v3, v0] matrices.
    matrix = distribution_util.rotate_transpose(vec_flat, shift=-1)
    return matrix
  def _vectorize_then_blockify(self, matrix):
    """Shape batch matrix to batch vector, then blockify trailing dimensions."""
    # Suppose
    #   matrix.shape = [m0, m1, m2, m3],
    # and matrix is a matrix because the final two dimensions are matrix dims.
    #   self.block_depth = 2,
    #   self.block_shape = [b0, b1]  (note b0 * b1 = m2).
    # We will reshape matrix to
    #   [m3, m0, m1, b0, b1].

    # Vectorize: Reshape to batch vector.
    #   [m0, m1, m2, m3] --> [m3, m0, m1, m2]
    # This is called "vectorize" because we have taken the final two matrix dims
    # and turned this into a size m3 batch of vectors.
    vec = distribution_util.rotate_transpose(matrix, shift=1)

    # Blockify: Blockfy trailing dimensions.
    #   [m3, m0, m1, m2] --> [m3, m0, m1, b0, b1]
    if (vec.get_shape().is_fully_defined() and
        self.block_shape.is_fully_defined()):
      # vec_leading_shape = [m3, m0, m1],
      # the parts of vec that will not be blockified.
      vec_leading_shape = vec.get_shape()[:-1]
      final_shape = vec_leading_shape.concatenate(self.block_shape)
    else:
      vec_leading_shape = array_ops.shape(vec)[:-1]
      final_shape = array_ops.concat(
          (vec_leading_shape, self.block_shape_tensor()), 0)
    return array_ops.reshape(vec, final_shape)
Пример #7
0
  def make_batch_of_event_sample_matrices(
      self, x, expand_batch_dim=True,
      name="make_batch_of_event_sample_matrices"):
    """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_.

    Where:
      - `B_ = B if B or not expand_batch_dim else [1]`,
      - `E_ = E if E else [1]`,
      - `S_ = [tf.reduce_prod(S)]`.

    Args:
      x: `Tensor`.
      expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
        such that `batch_ndims >= 1`.
      name: Python `str`. The name to give this op.

    Returns:
      x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`.
      sample_shape: `Tensor` (1D, `int32`).
    """
    with self._name_scope(name, values=[x]):
      x = tf.convert_to_tensor(x, name="x")
      # x.shape: S+B+E
      sample_shape, batch_shape, event_shape = self.get_shape(x)
      event_shape = distribution_util.pick_vector(
          self._event_ndims_is_0, [1], event_shape)
      if expand_batch_dim:
        batch_shape = distribution_util.pick_vector(
            self._batch_ndims_is_0, [1], batch_shape)
      new_shape = tf.concat([[-1], batch_shape, event_shape], 0)
      x = tf.reshape(x, shape=new_shape)
      # x.shape: [prod(S)]+B_+E_
      x = distribution_util.rotate_transpose(x, shift=-1)
      # x.shape: B_+E_+[prod(S)]
      return x, sample_shape
Пример #8
0
  def make_batch_of_event_sample_matrices(
      self, x, expand_batch_dim=True,
      name="make_batch_of_event_sample_matrices"):
    """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_.

    Where:
      - `B_ = B if B or not expand_batch_dim else [1]`,
      - `E_ = E if E else [1]`,
      - `S_ = [tf.reduce_prod(S)]`.

    Args:
      x: `Tensor`.
      expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
        such that `batch_ndims >= 1`.
      name: Python `str`. The name to give this op.

    Returns:
      x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`.
      sample_shape: `Tensor` (1D, `int32`).
    """
    with self._name_scope(name, values=[x]):
      x = ops.convert_to_tensor(x, name="x")
      # x.shape: S+B+E
      sample_shape, batch_shape, event_shape = self.get_shape(x)
      event_shape = distribution_util.pick_vector(
          self._event_ndims_is_0, [1], event_shape)
      if expand_batch_dim:
        batch_shape = distribution_util.pick_vector(
            self._batch_ndims_is_0, [1], batch_shape)
      new_shape = array_ops.concat([[-1], batch_shape, event_shape], 0)
      x = array_ops.reshape(x, shape=new_shape)
      # x.shape: [prod(S)]+B_+E_
      x = distribution_util.rotate_transpose(x, shift=-1)
      # x.shape: B_+E_+[prod(S)]
      return x, sample_shape
Пример #9
0
  def undo_make_batch_of_event_sample_matrices(
      self, x, sample_shape, expand_batch_dim=True,
      name="undo_make_batch_of_event_sample_matrices"):
    """Reshapes/transposes `Distribution` `Tensor` from B_+E_+S_ to S+B+E.

    Where:
      - `B_ = B if B or not expand_batch_dim else [1]`,
      - `E_ = E if E else [1]`,
      - `S_ = [tf.reduce_prod(S)]`.

    This function "reverses" `make_batch_of_event_sample_matrices`.

    Args:
      x: `Tensor` of shape `B_+E_+S_`.
      sample_shape: `Tensor` (1D, `int32`).
      expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
        such that `batch_ndims>=1`.
      name: Python `str`. The name to give this op.

    Returns:
      x: `Tensor`. Input transposed/reshaped to `S+B+E`.
    """
    with self._name_scope(name, values=[x, sample_shape]):
      x = tf.convert_to_tensor(x, name="x")
      # x.shape: _B+_E+[prod(S)]
      sample_shape = tf.convert_to_tensor(sample_shape, name="sample_shape")
      x = distribution_util.rotate_transpose(x, shift=1)
      # x.shape: [prod(S)]+_B+_E
      if self._is_all_constant_helper(self.batch_ndims, self.event_ndims):
        if self._batch_ndims_is_0 or self._event_ndims_is_0:
          squeeze_dims = []
          if self._event_ndims_is_0:
            squeeze_dims += [-1]
          if self._batch_ndims_is_0 and expand_batch_dim:
            squeeze_dims += [1]
          if squeeze_dims:
            x = tf.squeeze(x, axis=squeeze_dims)
            # x.shape: [prod(S)]+B+E
        _, batch_shape, event_shape = self.get_shape(x)
      else:
        s = (
            x.get_shape().as_list()
            if x.get_shape().is_fully_defined() else tf.shape(x))
        batch_shape = s[1:1+self.batch_ndims]
        # Since sample_dims=1 and is left-most, we add 1 to the number of
        # batch_ndims to get the event start dim.
        event_start = tf.where(
            tf.logical_and(expand_batch_dim, self._batch_ndims_is_0), 2,
            1 + self.batch_ndims)
        event_shape = s[event_start:event_start+self.event_ndims]
      new_shape = tf.concat([sample_shape, batch_shape, event_shape], 0)
      x = tf.reshape(x, shape=new_shape)
      # x.shape: S+B+E
      return x
Пример #10
0
 def testRollDynamic(self):
     with self.test_session() as sess:
         x = array_ops.placeholder(dtypes.float32)
         shift = array_ops.placeholder(dtypes.int32)
         for x_value in (np.ones(1, dtype=x.dtype.as_numpy_dtype()),
                         np.ones((2, 1), dtype=x.dtype.as_numpy_dtype()),
                         np.ones((3, 2, 1),
                                 dtype=x.dtype.as_numpy_dtype())):
             for shift_value in np.arange(-5, 5):
                 self.assertAllEqual(
                     self._np_rotate_transpose(x_value, shift_value),
                     sess.run(distribution_util.rotate_transpose(x, shift),
                              feed_dict={
                                  x: x_value,
                                  shift: shift_value
                              }))
Пример #11
0
def auto_correlation(
    x,
    axis=-1,
    max_lags=None,
    center=True,
    normalize=True,
    name="auto_correlation"):
  """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider
      (in equation above).  If `max_lags >= x.shape[axis]`, we effectively
      re-set `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
  # Implementation details:
  # Extend length N / 2 1-D array x to length N by zero padding onto the end.
  # Then, set
  #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
  # It is not hard to see that
  #   F[x]_k Conj(F[x]_k) = F[R]_k, where
  #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
  # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

  # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
  # based version of estimating RXX.
  # Note that this is a special case of the Wiener-Khinchin Theorem.
  with ops.name_scope(name, values=[x]):
    x = ops.convert_to_tensor(x, name="x")

    # Rotate dimensions of x in order to put axis at the rightmost dim.
    # FFT op requires this.
    rank = util.prefer_static_rank(x)
    if axis < 0:
      axis = rank + axis
    shift = rank - 1 - axis
    # Suppose x.shape[axis] = T, so there are T "time" steps.
    #   ==> x_rotated.shape = B + [T],
    # where B is x_rotated's batch shape.
    x_rotated = util.rotate_transpose(x, shift)

    if center:
      x_rotated -= math_ops.reduce_mean(x_rotated, axis=-1, keepdims=True)

    # x_len = N / 2 from above explanation.  The length of x along axis.
    # Get a value for x_len that works in all cases.
    x_len = util.prefer_static_shape(x_rotated)[-1]

    # TODO (langmore) Investigate whether this zero padding helps or hurts.  At id:595 gh:596
    # the moment is is necessary so that all FFT implementations work.
    # Zero pad to the next power of 2 greater than 2 * x_len, which equals
    # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
    x_len_float64 = math_ops.cast(x_len, np.float64)
    target_length = math_ops.pow(
        np.float64(2.),
        math_ops.ceil(math_ops.log(x_len_float64 * 2) / np.log(2.)))
    pad_length = math_ops.cast(target_length - x_len_float64, np.int32)

    # We should have:
    # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
    #                     = B + [T + pad_length]
    x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length)

    dtype = x.dtype
    if not dtype.is_complex:
      if not dtype.is_floating:
        raise TypeError("Argument x must have either float or complex dtype"
                        " found: {}".format(dtype))
      x_rotated_pad = math_ops.complex(x_rotated_pad,
                                       dtype.real_dtype.as_numpy_dtype(0.))

    # Autocorrelation is IFFT of power-spectral density (up to some scaling).
    fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad)
    spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad)
    # shifted_product is R[m] from above detailed explanation.
    # It is the inner product sum_n X[n] * Conj(X[n - m]).
    shifted_product = spectral_ops.ifft(spectral_density)

    # Cast back to real-valued if x was real to begin with.
    shifted_product = math_ops.cast(shifted_product, dtype)

    # Figure out if we can deduce the final static shape, and set max_lags.
    # Use x_rotated as a reference, because it has the time dimension in the far
    # right, and was created before we performed all sorts of crazy shape
    # manipulations.
    know_static_shape = True
    if not x_rotated.shape.is_fully_defined():
      know_static_shape = False
    if max_lags is None:
      max_lags = x_len - 1
    else:
      max_lags = ops.convert_to_tensor(max_lags, name="max_lags")
      max_lags_ = tensor_util.constant_value(max_lags)
      if max_lags_ is None or not know_static_shape:
        know_static_shape = False
        max_lags = math_ops.minimum(x_len - 1, max_lags)
      else:
        max_lags = min(x_len - 1, max_lags_)

    # Chop off the padding.
    # We allow users to provide a huge max_lags, but cut it off here.
    # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
    shifted_product_chopped = shifted_product[..., :max_lags + 1]

    # If possible, set shape.
    if know_static_shape:
      chopped_shape = x_rotated.shape.as_list()
      chopped_shape[-1] = min(x_len, max_lags + 1)
      shifted_product_chopped.set_shape(chopped_shape)

    # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
    # other terms were zeros arising only due to zero padding.
    # `denominator = (N / 2 - m)` (defined below) is the proper term to
    # divide by by to make this an unbiased estimate of the expectation
    # E[X[n] Conj(X[n - m])].
    x_len = math_ops.cast(x_len, dtype.real_dtype)
    max_lags = math_ops.cast(max_lags, dtype.real_dtype)
    denominator = x_len - math_ops.range(0., max_lags + 1.)
    denominator = math_ops.cast(denominator, dtype)
    shifted_product_rotated = shifted_product_chopped / denominator

    if normalize:
      shifted_product_rotated /= shifted_product_rotated[..., :1]

    # Transpose dimensions back to those of x.
    return util.rotate_transpose(shifted_product_rotated, -shift)
Пример #12
0
def auto_correlation(
    x,
    axis=-1,
    max_lags=None,
    center=True,
    normalize=True,
    name="auto_correlation"):
  """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider
      (in equation above).  If `max_lags >= x.shape[axis]`, we effectively
      re-set `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
  # Implementation details:
  # Extend length N / 2 1-D array x to length N by zero padding onto the end.
  # Then, set
  #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
  # It is not hard to see that
  #   F[x]_k Conj(F[x]_k) = F[R]_k, where
  #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
  # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

  # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
  # based version of estimating RXX.
  # Note that this is a special case of the Wiener-Khinchin Theorem.
  with ops.name_scope(name, values=[x]):
    x = ops.convert_to_tensor(x, name="x")

    # Rotate dimensions of x in order to put axis at the rightmost dim.
    # FFT op requires this.
    rank = util.prefer_static_rank(x)
    if axis < 0:
      axis = rank + axis
    shift = rank - 1 - axis
    # Suppose x.shape[axis] = T, so there are T "time" steps.
    #   ==> x_rotated.shape = B + [T],
    # where B is x_rotated's batch shape.
    x_rotated = util.rotate_transpose(x, shift)

    if center:
      x_rotated -= math_ops.reduce_mean(x_rotated, axis=-1, keepdims=True)

    # x_len = N / 2 from above explanation.  The length of x along axis.
    # Get a value for x_len that works in all cases.
    x_len = util.prefer_static_shape(x_rotated)[-1]

    # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
    # the moment is is necessary so that all FFT implementations work.
    # Zero pad to the next power of 2 greater than 2 * x_len, which equals
    # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
    x_len_float64 = math_ops.cast(x_len, np.float64)
    target_length = math_ops.pow(
        np.float64(2.),
        math_ops.ceil(math_ops.log(x_len_float64 * 2) / np.log(2.)))
    pad_length = math_ops.cast(target_length - x_len_float64, np.int32)

    # We should have:
    # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
    #                     = B + [T + pad_length]
    x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length)

    dtype = x.dtype
    if not dtype.is_complex:
      if not dtype.is_floating:
        raise TypeError("Argument x must have either float or complex dtype"
                        " found: {}".format(dtype))
      x_rotated_pad = math_ops.complex(x_rotated_pad,
                                       dtype.real_dtype.as_numpy_dtype(0.))

    # Autocorrelation is IFFT of power-spectral density (up to some scaling).
    fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad)
    spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad)
    # shifted_product is R[m] from above detailed explanation.
    # It is the inner product sum_n X[n] * Conj(X[n - m]).
    shifted_product = spectral_ops.ifft(spectral_density)

    # Cast back to real-valued if x was real to begin with.
    shifted_product = math_ops.cast(shifted_product, dtype)

    # Figure out if we can deduce the final static shape, and set max_lags.
    # Use x_rotated as a reference, because it has the time dimension in the far
    # right, and was created before we performed all sorts of crazy shape
    # manipulations.
    know_static_shape = True
    if not x_rotated.shape.is_fully_defined():
      know_static_shape = False
    if max_lags is None:
      max_lags = x_len - 1
    else:
      max_lags = ops.convert_to_tensor(max_lags, name="max_lags")
      max_lags_ = tensor_util.constant_value(max_lags)
      if max_lags_ is None or not know_static_shape:
        know_static_shape = False
        max_lags = math_ops.minimum(x_len - 1, max_lags)
      else:
        max_lags = min(x_len - 1, max_lags_)

    # Chop off the padding.
    # We allow users to provide a huge max_lags, but cut it off here.
    # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
    shifted_product_chopped = shifted_product[..., :max_lags + 1]

    # If possible, set shape.
    if know_static_shape:
      chopped_shape = x_rotated.shape.as_list()
      chopped_shape[-1] = min(x_len, max_lags + 1)
      shifted_product_chopped.set_shape(chopped_shape)

    # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
    # other terms were zeros arising only due to zero padding.
    # `denominator = (N / 2 - m)` (defined below) is the proper term to
    # divide by by to make this an unbiased estimate of the expectation
    # E[X[n] Conj(X[n - m])].
    x_len = math_ops.cast(x_len, dtype.real_dtype)
    max_lags = math_ops.cast(max_lags, dtype.real_dtype)
    denominator = x_len - math_ops.range(0., max_lags + 1.)
    denominator = math_ops.cast(denominator, dtype)
    shifted_product_rotated = shifted_product_chopped / denominator

    if normalize:
      shifted_product_rotated /= shifted_product_rotated[..., :1]

    # Transpose dimensions back to those of x.
    return util.rotate_transpose(shifted_product_rotated, -shift)