Exemple #1
0
def pad_batch_dimension_for_multiple_chains(
    observed_time_series, model, chain_batch_shape):
  """"Expand the observed time series with extra batch dimension(s)."""
  # Running with multiple chains introduces an extra batch dimension. In
  # general we also need to pad the observed time series with a matching batch
  # dimension.
  #
  # For example, suppose our model has batch shape [3, 4] and
  # the observed time series has shape `concat([[5], [3, 4], [100])`,
  # corresponding to `sample_shape`, `batch_shape`, and `num_timesteps`
  # respectively. The model will produce distributions with batch shape
  # `concat([chain_batch_shape, [3, 4]])`, so we pad `observed_time_series` to
  # have matching shape `[5, 1, 3, 4, 100]`, where the added `1` dimension
  # between the sample and batch shapes will broadcast to `chain_batch_shape`.

  observed_time_series = maybe_expand_trailing_dim(
      observed_time_series)  # Guarantee `event_ndims=2`

  event_ndims = 2  # event_shape = [num_timesteps, observation_size=1]

  model_batch_ndims = (
      model.batch_shape.ndims if model.batch_shape.ndims is not None else
      tf.shape(input=model.batch_shape_tensor())[0])

  # Compute ndims from chain_batch_shape.
  chain_batch_shape = tf.convert_to_tensor(
      value=chain_batch_shape, name='chain_batch_shape', dtype=tf.int32)
  if not chain_batch_shape.shape.is_fully_defined():
    raise ValueError('Batch shape must have static rank. (given: {})'.format(
        chain_batch_shape))
  if chain_batch_shape.shape.ndims == 0:  # expand int `k` to `[k]`.
    chain_batch_shape = chain_batch_shape[tf.newaxis]
  chain_batch_ndims = tf.compat.dimension_value(chain_batch_shape.shape[0])

  def do_padding(observed_time_series_tensor):
    current_sample_shape = tf.shape(
        input=observed_time_series_tensor)[:-(model_batch_ndims + event_ndims)]
    current_batch_and_event_shape = tf.shape(
        input=observed_time_series_tensor)[-(model_batch_ndims + event_ndims):]
    return tf.reshape(
        tensor=observed_time_series_tensor,
        shape=tf.concat([
            current_sample_shape,
            tf.ones([chain_batch_ndims], dtype=tf.int32),
            current_batch_and_event_shape], axis=0))

  # Padding is only needed if the observed time series has sample shape.
  observed_time_series = prefer_static.cond(
      (dist_util.prefer_static_rank(observed_time_series) >
       model_batch_ndims + event_ndims),
      lambda: do_padding(observed_time_series),
      lambda: observed_time_series)

  return observed_time_series
Exemple #2
0
def _batch_transpose(mat):
  """Transpose a possibly batched matrix.

  Args:
    mat: A `tf.Tensor` of shape `[..., n, m]`.

  Returns:
    A tensor of shape `[..., m, n]` with matching batch dimensions.
  """
  n = distribution_util.prefer_static_rank(mat)
  perm = tf.range(n)
  perm = tf.concat([perm[:-2], [perm[-1], perm[-2]]], axis=0)
  return tf.transpose(a=mat, perm=perm)
Exemple #3
0
def _mul_right(mat, vec):
    """Computes the product of a square matrix with a vector on the right.

  Note this accepts a generalized square matrix `M`, i.e. of shape `s + s`
  with `rank(s) >= 1`, a generalized vector `v` of shape `s`, and computes
  the product `M.v` (also of shape `s`).

  Furthermore, the shapes may be fully dynamic.

  Examples:

    v = tf.constant([0, 1])
    M = tf.constant([[0, 1], [2, 3]])
    _mul_right(M, v)
    # => [1, 3]

    v = tf.reshape(tf.range(6), shape=(2, 3))
    # => [[0, 1, 2],
    #     [3, 4, 5]]
    M = tf.reshape(tf.range(36), shape=(2, 3, 2, 3))
    _mul_right(M, v)
    # => [[ 55, 145, 235],
    #     [325, 415, 505]]

  Args:
    mat: A `tf.Tensor` of shape `s + s`.
    vec: A `tf.Tensor` of shape `s`.

  Returns:
    A tensor with the result of the product (also of shape `s`).
  """
    contraction_axes = tf.range(-distribution_util.prefer_static_rank(vec), 0)
    result = tf.tensordot(mat,
                          vec,
                          axes=tf.stack([contraction_axes, contraction_axes]))
    # This last reshape is needed to help with inference about the shape
    # information, otherwise a partially-known shape would become completely
    # unknown.
    return tf.reshape(result, distribution_util.prefer_static_shape(vec))
Exemple #4
0
def _mul_right(mat, vec):
  """Computes the product of a square matrix with a vector on the right.

  Note this accepts a generalized square matrix `M`, i.e. of shape `s + s`
  with `rank(s) >= 1`, a generalized vector `v` of shape `s`, and computes
  the product `M.v` (also of shape `s`).

  Furthermore, the shapes may be fully dynamic.

  Examples:

    v = tf.constant([0, 1])
    M = tf.constant([[0, 1], [2, 3]])
    _mul_right(M, v)
    # => [1, 3]

    v = tf.reshape(tf.range(6), shape=(2, 3))
    # => [[0, 1, 2],
    #     [3, 4, 5]]
    M = tf.reshape(tf.range(36), shape=(2, 3, 2, 3))
    _mul_right(M, v)
    # => [[ 55, 145, 235],
    #     [325, 415, 505]]

  Args:
    mat: A `tf.Tensor` of shape `s + s`.
    vec: A `tf.Tensor` of shape `s`.

  Returns:
    A tensor with the result of the product (also of shape `s`).
  """
  contraction_axes = tf.range(-distribution_util.prefer_static_rank(vec), 0)
  result = tf.tensordot(mat, vec, axes=tf.stack([contraction_axes,
                                                 contraction_axes]))
  # This last reshape is needed to help with inference about the shape
  # information, otherwise a partially-known shape would become completely
  # unknown.
  return tf.reshape(result, distribution_util.prefer_static_shape(vec))
Exemple #5
0
    def __init__(self,
                 skewness,
                 tailweight,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct Johnson's SU distributions.

    The distributions have shape parameteres `tailweight` and `skewness`,
    mean `loc`, and scale `scale`.

    The parameters `tailweight`, `skewness`, `loc`, and `scale` must be shaped
    in a way that supports broadcasting
    (e.g. `skewness + tailweight + loc + scale` is a valid operation).

    Args:
      skewness: Floating-point `Tensor`. Skewness of the distribution(s).
      tailweight: Floating-point `Tensor`. Tail weight of the
        distribution(s). `tailweight` must contain only positive values.
      loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
      scale: Floating-point `Tensor`. The scaling factor(s) for the
        distribution(s). Note that `scale` is not technically the standard
        deviation of this distribution but has semantics more similar to
        standard deviation than variance.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if any of skewness, tailweight, loc and scale are different
        dtypes.
    """
        parameters = dict(locals())
        with tf.name_scope(name or 'JohnsonSU') as name:
            dtype = dtype_util.common_dtype([skewness, tailweight, loc, scale],
                                            tf.float32)
            self._skewness = tensor_util.convert_nonref_to_tensor(
                skewness, name='skewness', dtype=dtype)
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name='tailweight', dtype=dtype)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name='loc',
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name='scale',
                                                               dtype=dtype)

            norm_shift = invert_bijector.Invert(
                shift_bijector.Shift(shift=self._skewness,
                                     validate_args=validate_args))

            norm_scale = invert_bijector.Invert(
                scale_bijector.Scale(scale=self._tailweight,
                                     validate_args=validate_args))

            sinh = sinh_bijector.Sinh(validate_args=validate_args)

            scale = scale_bijector.Scale(scale=self._scale,
                                         validate_args=validate_args)

            shift = shift_bijector.Shift(shift=self._loc,
                                         validate_args=validate_args)

            bijector = shift(scale(sinh(norm_scale(norm_shift))))

            batch_rank = ps.reduce_max([
                distribution_util.prefer_static_rank(x)
                for x in (self._skewness, self._tailweight, self._loc,
                          self._scale)
            ])

            super(JohnsonSU, self).__init__(
                # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden
                # `batch_shape` and `batch_shape_tensor` when
                # TransformedDistribution's bijector can modify its `batch_shape`.
                distribution=normal.Normal(loc=tf.zeros(ps.ones(
                    batch_rank, tf.int32),
                                                        dtype=dtype),
                                           scale=tf.ones([], dtype=dtype),
                                           validate_args=validate_args,
                                           allow_nan_stats=allow_nan_stats),
                bijector=bijector,
                validate_args=validate_args,
                parameters=parameters,
                name=name)
  def one_step(self, current_state, previous_kernel_results):
    """Runs one iteration of Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions
        index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      TypeError: if `not target_log_prob.dtype.is_floating`.
    """
    with tf.compat.v1.name_scope(
        name=mcmc_util.make_name(self.name, 'slice', 'one_step'),
        values=[
            self.step_size, self.max_doublings, self._seed_stream,
            current_state, previous_kernel_results.target_log_prob
        ]):
      with tf.compat.v1.name_scope('initialize'):
        [
            current_state_parts,
            step_sizes,
            current_target_log_prob
        ] = _prepare_args(
            self.target_log_prob_fn,
            current_state,
            self.step_size,
            previous_kernel_results.target_log_prob,
            maybe_expand=True)

        max_doublings = tf.convert_to_tensor(
            value=self.max_doublings, dtype=tf.int32, name='max_doublings')

      independent_chain_ndims = distribution_util.prefer_static_rank(
          current_target_log_prob)

      [
          next_state_parts,
          next_target_log_prob,
          bounds_satisfied,
          direction,
          upper_bounds,
          lower_bounds
      ] = _sample_next(
          self.target_log_prob_fn,
          current_state_parts,
          step_sizes,
          max_doublings,
          current_target_log_prob,
          independent_chain_ndims,
          seed=self._seed_stream()
      )

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      return [
          maybe_flatten(next_state_parts),
          SliceSamplerKernelResults(
              target_log_prob=next_target_log_prob,
              bounds_satisfied=bounds_satisfied,
              direction=direction,
              upper_bounds=upper_bounds,
              lower_bounds=lower_bounds
          ),
      ]
Exemple #7
0
  def one_step(self, current_state, previous_kernel_results):
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'hmc', 'one_step'),
        values=[self.step_size,
                self.num_leapfrog_steps,
                current_state,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob]):
      [
          current_state_parts,
          step_sizes,
          current_target_log_prob,
          current_target_log_prob_grad_parts,
      ] = _prepare_args(
          self.target_log_prob_fn,
          current_state,
          self.step_size,
          previous_kernel_results.target_log_prob,
          previous_kernel_results.grads_target_log_prob,
          maybe_expand=True,
          state_gradients_are_stopped=self.state_gradients_are_stopped)

      independent_chain_ndims = distribution_util.prefer_static_rank(
          current_target_log_prob)

      current_momentum_parts = []
      for x in current_state_parts:
        current_momentum_parts.append(tf.random_normal(
            shape=tf.shape(x),
            dtype=x.dtype.base_dtype,
            seed=self._seed_stream()))

      def _leapfrog_one_step(*args):
        """Closure representing computation done during each leapfrog step."""
        return _leapfrog_integrator_one_step(
            target_log_prob_fn=self.target_log_prob_fn,
            independent_chain_ndims=independent_chain_ndims,
            step_sizes=step_sizes,
            current_momentum_parts=args[0],
            current_state_parts=args[1],
            current_target_log_prob=args[2],
            current_target_log_prob_grad_parts=args[3],
            state_gradients_are_stopped=self.state_gradients_are_stopped)

      num_leapfrog_steps = tf.convert_to_tensor(
          self.num_leapfrog_steps, dtype=tf.int64, name='num_leapfrog_steps')

      [
          next_momentum_parts,
          next_state_parts,
          next_target_log_prob,
          next_target_log_prob_grad_parts,

      ] = tf.while_loop(
          cond=lambda i, *args: i < num_leapfrog_steps,
          body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args)),
          loop_vars=[
              tf.zeros([], tf.int64, name='iter'),
              current_momentum_parts,
              current_state_parts,
              current_target_log_prob,
              current_target_log_prob_grad_parts
          ])[1:]

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      return [
          maybe_flatten(next_state_parts),
          UncalibratedHamiltonianMonteCarloKernelResults(
              log_acceptance_correction=_compute_log_acceptance_correction(
                  current_momentum_parts,
                  next_momentum_parts,
                  independent_chain_ndims),
              target_log_prob=next_target_log_prob,
              grads_target_log_prob=next_target_log_prob_grad_parts,
          ),
      ]
Exemple #8
0
def auto_correlation(x,
                     axis=-1,
                     max_lags=None,
                     center=True,
                     normalize=True,
                     name='auto_correlation'):
  """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
  # Implementation details:
  # Extend length N / 2 1-D array x to length N by zero padding onto the end.
  # Then, set
  #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
  # It is not hard to see that
  #   F[x]_k Conj(F[x]_k) = F[R]_k, where
  #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
  # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

  # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
  # based version of estimating RXX.
  # Note that this is a special case of the Wiener-Khinchin Theorem.
  with tf.name_scope(name, values=[x]):
    x = tf.convert_to_tensor(value=x, name='x')

    # Rotate dimensions of x in order to put axis at the rightmost dim.
    # FFT op requires this.
    rank = util.prefer_static_rank(x)
    if axis < 0:
      axis = rank + axis
    shift = rank - 1 - axis
    # Suppose x.shape[axis] = T, so there are T 'time' steps.
    #   ==> x_rotated.shape = B + [T],
    # where B is x_rotated's batch shape.
    x_rotated = util.rotate_transpose(x, shift)

    if center:
      x_rotated -= tf.reduce_mean(
          input_tensor=x_rotated, axis=-1, keepdims=True)

    # x_len = N / 2 from above explanation.  The length of x along axis.
    # Get a value for x_len that works in all cases.
    x_len = util.prefer_static_shape(x_rotated)[-1]

    # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
    # the moment is necessary so that all FFT implementations work.
    # Zero pad to the next power of 2 greater than 2 * x_len, which equals
    # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
    x_len_float64 = tf.cast(x_len, np.float64)
    target_length = tf.pow(
        np.float64(2.), tf.math.ceil(
            tf.math.log(x_len_float64 * 2) / np.log(2.)))
    pad_length = tf.cast(target_length - x_len_float64, np.int32)

    # We should have:
    # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
    #                     = B + [T + pad_length]
    x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length)

    dtype = x.dtype
    if not dtype.is_complex:
      if not dtype.is_floating:
        raise TypeError('Argument x must have either float or complex dtype'
                        ' found: {}'.format(dtype))
      x_rotated_pad = tf.complex(x_rotated_pad,
                                 dtype.real_dtype.as_numpy_dtype(0.))

    # Autocorrelation is IFFT of power-spectral density (up to some scaling).
    fft_x_rotated_pad = tf.signal.fft(x_rotated_pad)
    spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad)
    # shifted_product is R[m] from above detailed explanation.
    # It is the inner product sum_n X[n] * Conj(X[n - m]).
    shifted_product = tf.signal.ifft(spectral_density)

    # Cast back to real-valued if x was real to begin with.
    shifted_product = tf.cast(shifted_product, dtype)

    # Figure out if we can deduce the final static shape, and set max_lags.
    # Use x_rotated as a reference, because it has the time dimension in the far
    # right, and was created before we performed all sorts of crazy shape
    # manipulations.
    know_static_shape = True
    if not x_rotated.shape.is_fully_defined():
      know_static_shape = False
    if max_lags is None:
      max_lags = x_len - 1
    else:
      max_lags = tf.convert_to_tensor(value=max_lags, name='max_lags')
      max_lags_ = tf.get_static_value(max_lags)
      if max_lags_ is None or not know_static_shape:
        know_static_shape = False
        max_lags = tf.minimum(x_len - 1, max_lags)
      else:
        max_lags = min(x_len - 1, max_lags_)

    # Chop off the padding.
    # We allow users to provide a huge max_lags, but cut it off here.
    # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
    shifted_product_chopped = shifted_product[..., :max_lags + 1]

    # If possible, set shape.
    if know_static_shape:
      chopped_shape = x_rotated.shape.as_list()
      chopped_shape[-1] = min(x_len, max_lags + 1)
      shifted_product_chopped.set_shape(chopped_shape)

    # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
    # other terms were zeros arising only due to zero padding.
    # `denominator = (N / 2 - m)` (defined below) is the proper term to
    # divide by to make this an unbiased estimate of the expectation
    # E[X[n] Conj(X[n - m])].
    x_len = tf.cast(x_len, dtype.real_dtype)
    max_lags = tf.cast(max_lags, dtype.real_dtype)
    denominator = x_len - tf.range(0., max_lags + 1.)
    denominator = tf.cast(denominator, dtype)
    shifted_product_rotated = shifted_product_chopped / denominator

    if normalize:
      shifted_product_rotated /= shifted_product_rotated[..., :1]

    # Transpose dimensions back to those of x.
    return util.rotate_transpose(shifted_product_rotated, -shift)
Exemple #9
0
 def testDynamicRankEndsUpBeingScalar(self):
     if tf.executing_eagerly(): return
     x = tf1.placeholder_with_default(np.array(1, dtype=np.int32),
                                      shape=None)
     rank = distribution_util.prefer_static_rank(x)
     self.assertAllEqual(0, self.evaluate(rank))
    def __init__(self,
                 loc,
                 scale,
                 skewness=None,
                 tailweight=None,
                 distribution=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='SinhArcsinh'):
        """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Must have a batch shape to which the shapes of `loc`, `scale`,
        `skewness`, and `tailweight` all broadcast. Default is
        `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted
        shape of the parameters. Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                            tf.float32)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name='loc',
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name='scale',
                                                               dtype=dtype)
            tailweight = 1. if tailweight is None else tailweight
            has_default_skewness = skewness is None
            skewness = 0. if has_default_skewness else skewness
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name='tailweight', dtype=dtype)
            self._skewness = tensor_util.convert_nonref_to_tensor(
                skewness, name='skewness', dtype=dtype)

            # Recall, with Z a random variable,
            #   Y := loc + scale * F(Z),
            #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C
            #   C := 2 / F_0(2)
            #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
            if distribution is None:
                batch_rank = tf.reduce_max([
                    distribution_util.prefer_static_rank(x)
                    for x in (self._skewness, self._tailweight, self._loc,
                              self._scale)
                ])
                # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden
                # `batch_shape` and `batch_shape_tensor` when
                # TransformedDistribution's bijector can modify its `batch_shape`.
                distribution = normal.Normal(loc=tf.zeros(tf.ones(
                    batch_rank, tf.int32),
                                                          dtype=dtype),
                                             scale=tf.ones([], dtype=dtype),
                                             allow_nan_stats=allow_nan_stats,
                                             validate_args=validate_args)

            # Make the SAS bijector, 'F'.
            f = sinh_arcsinh_bijector.SinhArcsinh(skewness=self._skewness,
                                                  tailweight=self._tailweight,
                                                  validate_args=validate_args)

            # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
            affine = affine_scalar_bijector.AffineScalar(
                shift=self._loc,
                scale=self._scale,
                validate_args=validate_args)

            bijector = chain_bijector.Chain([affine, f])

            super(SinhArcsinh, self).__init__(distribution=distribution,
                                              bijector=bijector,
                                              validate_args=validate_args,
                                              name=name)
            self._parameters = parameters
Exemple #11
0
 def testScalarTensor(self):
     x = tf.constant(1.)
     rank = distribution_util.prefer_static_rank(x)
     if not tf.executing_eagerly():
         self.assertIsInstance(rank, np.ndarray)
     self.assertEqual(0, rank)
Exemple #12
0
 def testNonEmptyConstantTensor(self):
     x = tf.zeros([2, 3, 4])
     rank = distribution_util.prefer_static_rank(x)
     if not tf.executing_eagerly():
         self.assertIsInstance(rank, np.ndarray)
     self.assertEqual(3, rank)
Exemple #13
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.compat.v2.name_scope(
                mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps
            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            self.restoreShapes = []
            for x in current_state_parts:
                n = 1
                shape = x.shape
                for m in shape:
                    n *= m
                self.restoreShapes.append([shape, n])
            current_state_parts = [
                tf.reshape(part, [-1]) for part in current_state_parts
            ]
            current_state_parts = tf.concat(current_state_parts, -1)
            temp = []
            #print(current_state_parts)
            for x in range(current_state_parts.shape[0]):
                temp.append(current_state_parts[x])
            current_state_parts = temp
            #print(current_state_parts)

            current_momentum_parts = []

            for x in current_state_parts:
                current_momentum_parts.append(
                    tf.random.normal(shape=tf.shape(input=x),
                                     dtype=self._momentum_dtype
                                     or x.dtype.base_dtype,
                                     seed=self._seed_stream()))

            next_state_parts, initial_kinetic, final_kinetic, final_target_log_prob = self.run_integrator(
                step_sizes, num_leapfrog_steps, current_momentum_parts,
                current_state_parts)

            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = distribution_util.prefer_static_rank(
                current_target_log_prob)

            next_state_parts = maybe_flatten(next_state_parts)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    initial_kinetic, final_kinetic, independent_chain_ndims),
                target_log_prob=final_target_log_prob)
            argv = next_state_parts  #[0]
            next_state_parts = []
            index = 0
            #print(self.restoreShapes)
            for info in self.restoreShapes:
                next_state_parts.append(
                    tf.reshape(argv[index:index + info[1]], info[0]))
                index += info[1]

            return next_state_parts, new_kernel_results
Exemple #14
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.compat.v2.name_scope(
                mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            current_momentum_parts = []
            for x in current_state_parts:
                current_momentum_parts.append(
                    tf.random.normal(shape=tf.shape(input=x),
                                     dtype=self._momentum_dtype
                                     or x.dtype.base_dtype,
                                     seed=self._seed_stream()))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = distribution_util.prefer_static_rank(
                current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts, next_momentum_parts,
                    independent_chain_ndims),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
def auto_correlation(x,
                     axis=-1,
                     max_lags=None,
                     center=True,
                     normalize=True,
                     name='auto_correlation'):
  """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
  # Implementation details:
  # Extend length N / 2 1-D array x to length N by zero padding onto the end.
  # Then, set
  #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
  # It is not hard to see that
  #   F[x]_k Conj(F[x]_k) = F[R]_k, where
  #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
  # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

  # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
  # based version of estimating RXX.
  # Note that this is a special case of the Wiener-Khinchin Theorem.
  with tf.name_scope(name, values=[x]):
    x = tf.convert_to_tensor(x, name='x')

    # Rotate dimensions of x in order to put axis at the rightmost dim.
    # FFT op requires this.
    rank = util.prefer_static_rank(x)
    if axis < 0:
      axis = rank + axis
    shift = rank - 1 - axis
    # Suppose x.shape[axis] = T, so there are T 'time' steps.
    #   ==> x_rotated.shape = B + [T],
    # where B is x_rotated's batch shape.
    x_rotated = util.rotate_transpose(x, shift)

    if center:
      x_rotated -= tf.reduce_mean(x_rotated, axis=-1, keepdims=True)

    # x_len = N / 2 from above explanation.  The length of x along axis.
    # Get a value for x_len that works in all cases.
    x_len = util.prefer_static_shape(x_rotated)[-1]

    # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
    # the moment is necessary so that all FFT implementations work.
    # Zero pad to the next power of 2 greater than 2 * x_len, which equals
    # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
    x_len_float64 = tf.cast(x_len, np.float64)
    target_length = tf.pow(
        np.float64(2.), tf.ceil(tf.log(x_len_float64 * 2) / np.log(2.)))
    pad_length = tf.cast(target_length - x_len_float64, np.int32)

    # We should have:
    # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
    #                     = B + [T + pad_length]
    x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length)

    dtype = x.dtype
    if not dtype.is_complex:
      if not dtype.is_floating:
        raise TypeError('Argument x must have either float or complex dtype'
                        ' found: {}'.format(dtype))
      x_rotated_pad = tf.complex(x_rotated_pad,
                                 dtype.real_dtype.as_numpy_dtype(0.))

    # Autocorrelation is IFFT of power-spectral density (up to some scaling).
    fft_x_rotated_pad = tf.fft(x_rotated_pad)
    spectral_density = fft_x_rotated_pad * tf.conj(fft_x_rotated_pad)
    # shifted_product is R[m] from above detailed explanation.
    # It is the inner product sum_n X[n] * Conj(X[n - m]).
    shifted_product = tf.ifft(spectral_density)

    # Cast back to real-valued if x was real to begin with.
    shifted_product = tf.cast(shifted_product, dtype)

    # Figure out if we can deduce the final static shape, and set max_lags.
    # Use x_rotated as a reference, because it has the time dimension in the far
    # right, and was created before we performed all sorts of crazy shape
    # manipulations.
    know_static_shape = True
    if not x_rotated.shape.is_fully_defined():
      know_static_shape = False
    if max_lags is None:
      max_lags = x_len - 1
    else:
      max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
      max_lags_ = tf.contrib.util.constant_value(max_lags)
      if max_lags_ is None or not know_static_shape:
        know_static_shape = False
        max_lags = tf.minimum(x_len - 1, max_lags)
      else:
        max_lags = min(x_len - 1, max_lags_)

    # Chop off the padding.
    # We allow users to provide a huge max_lags, but cut it off here.
    # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
    shifted_product_chopped = shifted_product[..., :max_lags + 1]

    # If possible, set shape.
    if know_static_shape:
      chopped_shape = x_rotated.shape.as_list()
      chopped_shape[-1] = min(x_len, max_lags + 1)
      shifted_product_chopped.set_shape(chopped_shape)

    # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
    # other terms were zeros arising only due to zero padding.
    # `denominator = (N / 2 - m)` (defined below) is the proper term to
    # divide by to make this an unbiased estimate of the expectation
    # E[X[n] Conj(X[n - m])].
    x_len = tf.cast(x_len, dtype.real_dtype)
    max_lags = tf.cast(max_lags, dtype.real_dtype)
    denominator = x_len - tf.range(0., max_lags + 1.)
    denominator = tf.cast(denominator, dtype)
    shifted_product_rotated = shifted_product_chopped / denominator

    if normalize:
      shifted_product_rotated /= shifted_product_rotated[..., :1]

    # Transpose dimensions back to those of x.
    return util.rotate_transpose(shifted_product_rotated, -shift)
Exemple #16
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'mala',
                                                    'one_step'),
                           values=[
                               self.step_size, current_state,
                               previous_kernel_results.target_log_prob,
                               previous_kernel_results.grads_target_log_prob,
                               previous_kernel_results.volatility,
                               previous_kernel_results.diffusion_drift
                           ]):
            with tf.name_scope('initialize'):
                # Prepare input arguments to be passed to `_euler_method`.
                [
                    current_state_parts,
                    step_size_parts,
                    current_target_log_prob,
                    _,  # grads_target_log_prob
                    current_volatility_parts,
                    _,  # grads_volatility
                    current_drift_parts,
                ] = _prepare_args(
                    self.target_log_prob_fn, self.volatility_fn, current_state,
                    self.step_size, previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    previous_kernel_results.volatility,
                    previous_kernel_results.grads_volatility,
                    previous_kernel_results.diffusion_drift,
                    self.parallel_iterations)

                random_draw_parts = []
                for s in current_state_parts:
                    random_draw_parts.append(
                        tf.random_normal(shape=tf.shape(s),
                                         dtype=s.dtype.base_dtype,
                                         seed=self._seed_stream()))

            # Number of independent chains run by the algorithm.
            independent_chain_ndims = distribution_util.prefer_static_rank(
                current_target_log_prob)

            # Generate the next state of the algorithm using Euler-Maruyama method.
            next_state_parts = _euler_method(random_draw_parts,
                                             current_state_parts,
                                             current_drift_parts,
                                             step_size_parts,
                                             current_volatility_parts)

            # Compute helper `UncalibratedLangevinKernelResults` to be processed by
            # `_compute_log_acceptance_correction` and in the next iteration of
            # `one_step` function.
            [
                _,  # state_parts
                _,  # step_sizes
                next_target_log_prob,
                next_grads_target_log_prob,
                next_volatility_parts,
                next_grads_volatility,
                next_drift_parts,
            ] = _prepare_args(self.target_log_prob_fn,
                              self.volatility_fn,
                              next_state_parts,
                              step_size_parts,
                              parallel_iterations=self.parallel_iterations)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            # Decide whether to compute the acceptance ratio
            log_acceptance_correction_compute = _compute_log_acceptance_correction(
                current_state_parts, next_state_parts,
                current_volatility_parts, next_volatility_parts,
                current_drift_parts, next_drift_parts, step_size_parts,
                independent_chain_ndims)
            log_acceptance_correction_skip = tf.zeros_like(
                next_target_log_prob)

            log_acceptance_correction = tf.cond(
                self.compute_acceptance,
                lambda: log_acceptance_correction_compute,
                lambda: log_acceptance_correction_skip)

            return [
                maybe_flatten(next_state_parts),
                UncalibratedLangevinKernelResults(
                    log_acceptance_correction=log_acceptance_correction,
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_grads_target_log_prob,
                    volatility=maybe_flatten(next_volatility_parts),
                    grads_volatility=next_grads_volatility,
                    diffusion_drift=next_drift_parts),
            ]
Exemple #17
0
  def one_step(self, current_state, previous_kernel_results):
    with tf.compat.v1.name_scope(
        name=mcmc_util.make_name(self.name, 'hmc', 'one_step'),
        values=[
            self.step_size, self.num_leapfrog_steps, current_state,
            previous_kernel_results.target_log_prob,
            previous_kernel_results.grads_target_log_prob
        ]):
      if self._store_parameters_in_results:
        step_size = previous_kernel_results.step_size
        num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
      else:
        step_size = self.step_size
        num_leapfrog_steps = self.num_leapfrog_steps

      [
          current_state_parts,
          step_sizes,
          current_target_log_prob,
          current_target_log_prob_grad_parts,
      ] = _prepare_args(
          self.target_log_prob_fn,
          current_state,
          step_size,
          previous_kernel_results.target_log_prob,
          previous_kernel_results.grads_target_log_prob,
          maybe_expand=True,
          state_gradients_are_stopped=self.state_gradients_are_stopped)

      independent_chain_ndims = distribution_util.prefer_static_rank(
          current_target_log_prob)

      current_momentum_parts = []
      for x in current_state_parts:
        current_momentum_parts.append(
            tf.random.normal(
                shape=tf.shape(input=x),
                dtype=x.dtype.base_dtype,
                seed=self._seed_stream()))

      def _leapfrog_one_step(*args):
        """Closure representing computation done during each leapfrog step."""
        return _leapfrog_integrator_one_step(
            target_log_prob_fn=self.target_log_prob_fn,
            independent_chain_ndims=independent_chain_ndims,
            step_sizes=step_sizes,
            current_momentum_parts=args[0],
            current_state_parts=args[1],
            current_target_log_prob=args[2],
            current_target_log_prob_grad_parts=args[3],
            state_gradients_are_stopped=self.state_gradients_are_stopped)

      num_leapfrog_steps = tf.convert_to_tensor(
          value=self.num_leapfrog_steps,
          dtype=tf.int32,
          name='num_leapfrog_steps')

      [
          next_momentum_parts,
          next_state_parts,
          next_target_log_prob,
          next_target_log_prob_grad_parts,

      ] = tf.while_loop(
          cond=lambda i, *args: i < num_leapfrog_steps,
          body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args)),
          loop_vars=[
              tf.zeros([], tf.int32, name='iter'),
              current_momentum_parts,
              current_state_parts,
              current_target_log_prob,
              current_target_log_prob_grad_parts
          ])[1:]

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      new_kernel_results = previous_kernel_results._replace(
          log_acceptance_correction=_compute_log_acceptance_correction(
              current_momentum_parts, next_momentum_parts,
              independent_chain_ndims),
          target_log_prob=next_target_log_prob,
          grads_target_log_prob=next_target_log_prob_grad_parts,
      )

      return maybe_flatten(next_state_parts), new_kernel_results
Exemple #18
0
 def testDynamicRankEndsUpBeingNonEmpty(self):
     if tf.executing_eagerly(): return
     x = tf1.placeholder_with_default(np.zeros([2, 3], dtype=np.float64),
                                      shape=None)
     rank = distribution_util.prefer_static_rank(x)
     self.assertAllEqual(2, self.evaluate(rank))