예제 #1
0
파일: hmc.py 프로젝트: xuxyang/probability
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc',
                                                    'one_step'),
                           values=[
                               self.step_size, self.num_leapfrog_steps,
                               current_state,
                               previous_kernel_results.target_log_prob,
                               previous_kernel_results.grads_target_log_prob
                           ]):
            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                self.step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            independent_chain_ndims = distributions_util.prefer_static_rank(
                current_target_log_prob)

            current_momentum_parts = []
            for x in current_state_parts:
                current_momentum_parts.append(
                    tf.random_normal(shape=tf.shape(x),
                                     dtype=x.dtype.base_dtype,
                                     seed=self._seed_stream()))

            def _leapfrog_one_step(*args):
                """Closure representing computation done during each leapfrog step."""
                return _leapfrog_integrator_one_step(
                    target_log_prob_fn=self.target_log_prob_fn,
                    independent_chain_ndims=independent_chain_ndims,
                    step_sizes=step_sizes,
                    current_momentum_parts=args[0],
                    current_state_parts=args[1],
                    current_target_log_prob=args[2],
                    current_target_log_prob_grad_parts=args[3],
                    state_gradients_are_stopped=self.
                    state_gradients_are_stopped)

            # Do leapfrog integration.
            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = tf.while_loop(
                cond=lambda i, *args: i < self.num_leapfrog_steps,
                body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args)
                                                     ),
                loop_vars=[
                    tf.zeros([], tf.int32, name='iter'),
                    current_momentum_parts,
                    current_state_parts,
                    current_target_log_prob,
                    current_target_log_prob_grad_parts,
                ])[1:]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedHamiltonianMonteCarloKernelResults(
                    log_acceptance_correction=
                    _compute_log_acceptance_correction(
                        current_momentum_parts, next_momentum_parts,
                        independent_chain_ndims),
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_target_log_prob_grad_parts,
                ),
            ]
예제 #2
0
def auto_correlation(
    x,
    axis=-1,
    max_lags=None,
    center=True,
    normalize=True,
    name="auto_correlation"):
  """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider
      (in equation above).  If `max_lags >= x.shape[axis]`, we effectively
      re-set `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
  # Implementation details:
  # Extend length N / 2 1-D array x to length N by zero padding onto the end.
  # Then, set
  #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
  # It is not hard to see that
  #   F[x]_k Conj(F[x]_k) = F[R]_k, where
  #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
  # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

  # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
  # based version of estimating RXX.
  # Note that this is a special case of the Wiener-Khinchin Theorem.
  with ops.name_scope(name, values=[x]):
    x = ops.convert_to_tensor(x, name="x")

    # Rotate dimensions of x in order to put axis at the rightmost dim.
    # FFT op requires this.
    rank = util.prefer_static_rank(x)
    if axis < 0:
      axis = rank + axis
    shift = rank - 1 - axis
    # Suppose x.shape[axis] = T, so there are T "time" steps.
    #   ==> x_rotated.shape = B + [T],
    # where B is x_rotated's batch shape.
    x_rotated = util.rotate_transpose(x, shift)

    if center:
      x_rotated -= math_ops.reduce_mean(x_rotated, axis=-1, keepdims=True)

    # x_len = N / 2 from above explanation.  The length of x along axis.
    # Get a value for x_len that works in all cases.
    x_len = util.prefer_static_shape(x_rotated)[-1]

    # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
    # the moment is is necessary so that all FFT implementations work.
    # Zero pad to the next power of 2 greater than 2 * x_len, which equals
    # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
    x_len_float64 = math_ops.cast(x_len, np.float64)
    target_length = math_ops.pow(
        np.float64(2.),
        math_ops.ceil(math_ops.log(x_len_float64 * 2) / np.log(2.)))
    pad_length = math_ops.cast(target_length - x_len_float64, np.int32)

    # We should have:
    # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
    #                     = B + [T + pad_length]
    x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length)

    dtype = x.dtype
    if not dtype.is_complex:
      if not dtype.is_floating:
        raise TypeError("Argument x must have either float or complex dtype"
                        " found: {}".format(dtype))
      x_rotated_pad = math_ops.complex(x_rotated_pad,
                                       dtype.real_dtype.as_numpy_dtype(0.))

    # Autocorrelation is IFFT of power-spectral density (up to some scaling).
    fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad)
    spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad)
    # shifted_product is R[m] from above detailed explanation.
    # It is the inner product sum_n X[n] * Conj(X[n - m]).
    shifted_product = spectral_ops.ifft(spectral_density)

    # Cast back to real-valued if x was real to begin with.
    shifted_product = math_ops.cast(shifted_product, dtype)

    # Figure out if we can deduce the final static shape, and set max_lags.
    # Use x_rotated as a reference, because it has the time dimension in the far
    # right, and was created before we performed all sorts of crazy shape
    # manipulations.
    know_static_shape = True
    if not x_rotated.shape.is_fully_defined():
      know_static_shape = False
    if max_lags is None:
      max_lags = x_len - 1
    else:
      max_lags = ops.convert_to_tensor(max_lags, name="max_lags")
      max_lags_ = tensor_util.constant_value(max_lags)
      if max_lags_ is None or not know_static_shape:
        know_static_shape = False
        max_lags = math_ops.minimum(x_len - 1, max_lags)
      else:
        max_lags = min(x_len - 1, max_lags_)

    # Chop off the padding.
    # We allow users to provide a huge max_lags, but cut it off here.
    # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
    shifted_product_chopped = shifted_product[..., :max_lags + 1]

    # If possible, set shape.
    if know_static_shape:
      chopped_shape = x_rotated.shape.as_list()
      chopped_shape[-1] = min(x_len, max_lags + 1)
      shifted_product_chopped.set_shape(chopped_shape)

    # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
    # other terms were zeros arising only due to zero padding.
    # `denominator = (N / 2 - m)` (defined below) is the proper term to
    # divide by by to make this an unbiased estimate of the expectation
    # E[X[n] Conj(X[n - m])].
    x_len = math_ops.cast(x_len, dtype.real_dtype)
    max_lags = math_ops.cast(max_lags, dtype.real_dtype)
    denominator = x_len - math_ops.range(0., max_lags + 1.)
    denominator = math_ops.cast(denominator, dtype)
    shifted_product_rotated = shifted_product_chopped / denominator

    if normalize:
      shifted_product_rotated /= shifted_product_rotated[..., :1]

    # Transpose dimensions back to those of x.
    return util.rotate_transpose(shifted_product_rotated, -shift)
예제 #3
0
def auto_correlation(
    x,
    axis=-1,
    max_lags=None,
    center=True,
    normalize=True,
    name="auto_correlation"):
  """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider
      (in equation above).  If `max_lags >= x.shape[axis]`, we effectively
      re-set `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
  # Implementation details:
  # Extend length N / 2 1-D array x to length N by zero padding onto the end.
  # Then, set
  #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
  # It is not hard to see that
  #   F[x]_k Conj(F[x]_k) = F[R]_k, where
  #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
  # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

  # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
  # based version of estimating RXX.
  # Note that this is a special case of the Wiener-Khinchin Theorem.
  with ops.name_scope(name, values=[x]):
    x = ops.convert_to_tensor(x, name="x")

    # Rotate dimensions of x in order to put axis at the rightmost dim.
    # FFT op requires this.
    rank = util.prefer_static_rank(x)
    if axis < 0:
      axis = rank + axis
    shift = rank - 1 - axis
    # Suppose x.shape[axis] = T, so there are T "time" steps.
    #   ==> x_rotated.shape = B + [T],
    # where B is x_rotated's batch shape.
    x_rotated = util.rotate_transpose(x, shift)

    if center:
      x_rotated -= math_ops.reduce_mean(x_rotated, axis=-1, keepdims=True)

    # x_len = N / 2 from above explanation.  The length of x along axis.
    # Get a value for x_len that works in all cases.
    x_len = util.prefer_static_shape(x_rotated)[-1]

    # TODO (langmore) Investigate whether this zero padding helps or hurts.  At id:595 gh:596
    # the moment is is necessary so that all FFT implementations work.
    # Zero pad to the next power of 2 greater than 2 * x_len, which equals
    # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
    x_len_float64 = math_ops.cast(x_len, np.float64)
    target_length = math_ops.pow(
        np.float64(2.),
        math_ops.ceil(math_ops.log(x_len_float64 * 2) / np.log(2.)))
    pad_length = math_ops.cast(target_length - x_len_float64, np.int32)

    # We should have:
    # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
    #                     = B + [T + pad_length]
    x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length)

    dtype = x.dtype
    if not dtype.is_complex:
      if not dtype.is_floating:
        raise TypeError("Argument x must have either float or complex dtype"
                        " found: {}".format(dtype))
      x_rotated_pad = math_ops.complex(x_rotated_pad,
                                       dtype.real_dtype.as_numpy_dtype(0.))

    # Autocorrelation is IFFT of power-spectral density (up to some scaling).
    fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad)
    spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad)
    # shifted_product is R[m] from above detailed explanation.
    # It is the inner product sum_n X[n] * Conj(X[n - m]).
    shifted_product = spectral_ops.ifft(spectral_density)

    # Cast back to real-valued if x was real to begin with.
    shifted_product = math_ops.cast(shifted_product, dtype)

    # Figure out if we can deduce the final static shape, and set max_lags.
    # Use x_rotated as a reference, because it has the time dimension in the far
    # right, and was created before we performed all sorts of crazy shape
    # manipulations.
    know_static_shape = True
    if not x_rotated.shape.is_fully_defined():
      know_static_shape = False
    if max_lags is None:
      max_lags = x_len - 1
    else:
      max_lags = ops.convert_to_tensor(max_lags, name="max_lags")
      max_lags_ = tensor_util.constant_value(max_lags)
      if max_lags_ is None or not know_static_shape:
        know_static_shape = False
        max_lags = math_ops.minimum(x_len - 1, max_lags)
      else:
        max_lags = min(x_len - 1, max_lags_)

    # Chop off the padding.
    # We allow users to provide a huge max_lags, but cut it off here.
    # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
    shifted_product_chopped = shifted_product[..., :max_lags + 1]

    # If possible, set shape.
    if know_static_shape:
      chopped_shape = x_rotated.shape.as_list()
      chopped_shape[-1] = min(x_len, max_lags + 1)
      shifted_product_chopped.set_shape(chopped_shape)

    # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
    # other terms were zeros arising only due to zero padding.
    # `denominator = (N / 2 - m)` (defined below) is the proper term to
    # divide by by to make this an unbiased estimate of the expectation
    # E[X[n] Conj(X[n - m])].
    x_len = math_ops.cast(x_len, dtype.real_dtype)
    max_lags = math_ops.cast(max_lags, dtype.real_dtype)
    denominator = x_len - math_ops.range(0., max_lags + 1.)
    denominator = math_ops.cast(denominator, dtype)
    shifted_product_rotated = shifted_product_chopped / denominator

    if normalize:
      shifted_product_rotated /= shifted_product_rotated[..., :1]

    # Transpose dimensions back to those of x.
    return util.rotate_transpose(shifted_product_rotated, -shift)
예제 #4
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        # Take one inner step.
        [
            proposed_state,
            proposed_results,
        ] = self.inner_kernel.one_step(
            current_state, previous_kernel_results.accepted_results)

        if (not has_target_log_prob(proposed_results)
                or not has_target_log_prob(
                    previous_kernel_results.accepted_results)):
            raise ValueError('"target_log_prob" must be a member of '
                             '`inner_kernel` results.')

        # Compute log(acceptance_ratio).
        to_sum = [
            proposed_results.target_log_prob,
            -previous_kernel_results.accepted_results.target_log_prob
        ]
        try:
            to_sum.append(proposed_results.log_acceptance_correction)
        except AttributeError:
            warnings.warn(
                'Supplied inner `TransitionKernel` does not have a '
                '`log_acceptance_correction`. Assuming its value is `0.`')
        log_accept_ratio = mcmc_util.safe_sum(to_sum,
                                              name='compute_log_accept_ratio')

        # If proposed state reduces likelihood: randomly accept.
        # If proposed state increases likelihood: always accept.
        # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
        #       ==> log(u) < log_accept_ratio
        # Note:
        # - We mutate seed state so subsequent calls are not correlated.
        # - We mutate seed BEFORE using it just in case users supplied the
        #   same seed to the inner kernel.
        self._seed = distributions_util.gen_new_seed(
            self.seed, salt='metropolis_hastings_one_step')
        log_uniform = tf.log(
            tf.random_uniform(
                shape=tf.shape(proposed_results.target_log_prob),
                dtype=proposed_results.target_log_prob.dtype.base_dtype,
                seed=self.seed))
        is_accepted = log_uniform < log_accept_ratio

        independent_chain_ndims = distributions_util.prefer_static_rank(
            proposed_results.target_log_prob)

        next_state = mcmc_util.choose(is_accepted, proposed_state,
                                      current_state, independent_chain_ndims)

        accepted_results = type(proposed_results)(
            **dict([(fn,
                     mcmc_util.choose(
                         is_accepted, getattr(proposed_results, fn),
                         getattr(previous_kernel_results.accepted_results, fn),
                         independent_chain_ndims))
                    for fn in proposed_results._fields]))

        return [
            next_state,
            MetropolisHastingsKernelResults(
                accepted_results=accepted_results,
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
            )
        ]
def move_dimension(x, source_idx, dest_idx):
    """Move a single tensor dimension within its shape.

  This is a special case of `tf.transpose()`, which applies
  arbitrary permutations to tensor dimensions.

  Args:
    x: Tensor of rank `ndims`.
    source_idx: Integer index into `x.shape` (negative indexing is
      supported).
    dest_idx: Integer index into `x.shape` (negative indexing is
      supported).

  Returns:
    x_perm: Tensor of rank `ndims`, in which the dimension at original
     index `source_idx` has been moved to new index `dest_idx`, with
     all other dimensions retained in their original order.

  Example:

  ```python
  x = tf.compat.v1.placeholder(shape=[200, 30, 4, 1, 6])
  x_perm = _move_dimension(x, 1, 1) # no-op
  x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6]
  x_perm = _move_dimension(x, 0, -2) # equivalent to previous
  x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1]
  ```
  """
    ndims = util.prefer_static_rank(x)
    if isinstance(source_idx, int):
        dtype = dtypes.int32
    else:
        dtype = dtypes.as_dtype(source_idx.dtype)

    # Handle negative indexing. Since ndims might be dynamic, this makes
    # source_idx and dest_idx also possibly dynamic.
    if source_idx < 0:
        source_idx = ndims + source_idx
    if dest_idx < 0:
        dest_idx = ndims + dest_idx

    # Construct the appropriate permutation of dimensions, depending
    # whether the source is before or after the destination.
    def move_left_permutation():
        return util.prefer_static_value(
            array_ops.concat([
                math_ops.range(0, dest_idx, dtype=dtype), [source_idx],
                math_ops.range(dest_idx, source_idx, dtype=dtype),
                math_ops.range(source_idx + 1, ndims, dtype=dtype)
            ],
                             axis=0))

    def move_right_permutation():
        return util.prefer_static_value(
            array_ops.concat([
                math_ops.range(0, source_idx, dtype=dtype),
                math_ops.range(source_idx + 1, dest_idx + 1, dtype=dtype),
                [source_idx],
                math_ops.range(dest_idx + 1, ndims, dtype=dtype)
            ],
                             axis=0))

    def x_permuted():
        return array_ops.transpose(x,
                                   perm=smart_cond.smart_cond(
                                       source_idx < dest_idx,
                                       move_right_permutation,
                                       move_left_permutation))

    # One final conditional to handle the special case where source
    # and destination indices are equal.
    return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx),
                                 lambda: x, x_permuted)
예제 #6
0
def move_dimension(x, source_idx, dest_idx):
  """Move a single tensor dimension within its shape.

  This is a special case of `tf.transpose()`, which applies
  arbitrary permutations to tensor dimensions.

  Args:
    x: Tensor of rank `ndims`.
    source_idx: Integer index into `x.shape` (negative indexing is
      supported).
    dest_idx: Integer index into `x.shape` (negative indexing is
      supported).

  Returns:
    x_perm: Tensor of rank `ndims`, in which the dimension at original
     index `source_idx` has been moved to new index `dest_idx`, with
     all other dimensions retained in their original order.

  Example:

  ```python
  x = tf.placeholder(shape=[200, 30, 4, 1, 6])
  x_perm = _move_dimension(x, 1, 1) # no-op
  x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6]
  x_perm = _move_dimension(x, 0, -2) # equivalent to previous
  x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1]
  ```
  """
  ndims = util.prefer_static_rank(x)
  if isinstance(source_idx, int):
    dtype = tf.int32
  else:
    dtype = tf.as_dtype(source_idx.dtype)

  # Handle negative indexing. Since ndims might be dynamic, this makes
  # source_idx and dest_idx also possibly dynamic.
  if source_idx < 0:
    source_idx = ndims + source_idx
  if dest_idx < 0:
    dest_idx = ndims + dest_idx

  # Construct the appropriate permutation of dimensions, depending
  # whether the source is before or after the destination.
  def move_left_permutation():
    return util.prefer_static_value(
        tf.concat(
            [
                tf.range(0, dest_idx, dtype=dtype), [source_idx],
                tf.range(dest_idx, source_idx, dtype=dtype),
                tf.range(source_idx + 1, ndims, dtype=dtype)
            ],
            axis=0))

  def move_right_permutation():
    return util.prefer_static_value(
        tf.concat(
            [
                tf.range(0, source_idx, dtype=dtype),
                tf.range(source_idx + 1, dest_idx + 1, dtype=dtype),
                [source_idx],
                tf.range(dest_idx + 1, ndims, dtype=dtype)
            ],
            axis=0))

  def x_permuted():
    return tf.transpose(
        x,
        perm=smart_cond.smart_cond(source_idx < dest_idx,
                                   move_right_permutation,
                                   move_left_permutation))

  # One final conditional to handle the special case where source
  # and destination indices are equal.
  return smart_cond.smart_cond(
      tf.equal(source_idx, dest_idx), lambda: x, x_permuted)
예제 #7
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(self.name, 'hmc_kernel', [
                self.step_size, self.num_leapfrog_steps, self.seed,
                current_state, previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob
        ]):
            with tf.name_scope('initialize'):
                [
                    current_state_parts,
                    step_sizes,
                    current_target_log_prob,
                    current_grads_target_log_prob,
                ] = _prepare_args(
                    self.target_log_prob_fn,
                    current_state,
                    self.step_size,
                    previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    maybe_expand=True)

                current_momentums = []
                for s in current_state_parts:
                    # Note:
                    # - We mutate seed state so subsequent calls are not correlated.
                    # - We mutate seed BEFORE using it just in case users supplied the
                    #   same seed to an outer kernel, e.g., `MetropolisHastings`.
                    self._seed = distributions_util.gen_new_seed(
                        self.seed, salt='hmc_kernel_momentums')
                    current_momentums.append(
                        tf.random_normal(shape=tf.shape(s),
                                         dtype=s.dtype.base_dtype,
                                         seed=self.seed))

                num_leapfrog_steps = tf.convert_to_tensor(
                    self.num_leapfrog_steps,
                    dtype=tf.int32,
                    name='num_leapfrog_steps')

            independent_chain_ndims = distributions_util.prefer_static_rank(
                current_target_log_prob)

            [
                next_momentums,
                next_state_parts,
                next_target_log_prob,
                next_grads_target_log_prob,
            ] = _leapfrog_integrator(current_momentums,
                                     self.target_log_prob_fn,
                                     current_state_parts, step_sizes,
                                     num_leapfrog_steps,
                                     current_target_log_prob,
                                     current_grads_target_log_prob)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedHamiltonianMonteCarloKernelResults(
                    log_acceptance_correction=
                    _compute_log_acceptance_correction(
                        current_momentums, next_momentums,
                        independent_chain_ndims),
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_grads_target_log_prob,
                ),
            ]
예제 #8
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'mala',
                                                    'one_step'),
                           values=[
                               self.step_size, current_state,
                               previous_kernel_results.target_log_prob,
                               previous_kernel_results.grads_target_log_prob,
                               previous_kernel_results.volatility,
                               previous_kernel_results.diffusion_drift
                           ]):
            with tf.name_scope('initialize'):
                # Prepare input arguments to be passed to `_euler_method`.
                [
                    current_state_parts,
                    step_size_parts,
                    current_target_log_prob,
                    _,  # grads_target_log_prob
                    current_volatility_parts,
                    _,  # grads_volatility
                    current_drift_parts,
                ] = _prepare_args(
                    self.target_log_prob_fn, self.volatility_fn, current_state,
                    self.step_size, previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    previous_kernel_results.volatility,
                    previous_kernel_results.grads_volatility)

                random_draw_parts = []
                for s in current_state_parts:
                    random_draw_parts.append(
                        tf.random_normal(shape=tf.shape(s),
                                         dtype=s.dtype.base_dtype,
                                         seed=self._seed_stream()))

            # Number of independent chains run by the algorithm.
            independent_chain_ndims = distributions_util.prefer_static_rank(
                current_target_log_prob)

            # Generate the next state of the algorithm using Euler-Maruyama method.
            next_state_parts = _euler_method(random_draw_parts,
                                             current_state_parts,
                                             current_drift_parts,
                                             step_size_parts,
                                             current_volatility_parts)

            # Compute helper `UncalibratedLangevinKernelResults` to be processed by
            # `_compute_log_acceptance_correction` and in the next iteration of
            # `one_step` function.
            [
                _,  # state_parts
                _,  # step_sizes
                next_target_log_prob,
                next_grads_target_log_prob,
                next_volatility_parts,
                next_grads_volatility,
                next_drift_parts,
            ] = _prepare_args(self.target_log_prob_fn, self.volatility_fn,
                              next_state_parts, step_size_parts)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedLangevinKernelResults(
                    log_acceptance_correction=
                    _compute_log_acceptance_correction(
                        current_state_parts, next_state_parts,
                        current_volatility_parts, next_volatility_parts,
                        current_drift_parts, next_drift_parts, step_size_parts,
                        independent_chain_ndims),
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_grads_target_log_prob,
                    volatility=maybe_flatten(next_volatility_parts),
                    grads_volatility=next_grads_volatility,
                    diffusion_drift=next_drift_parts),
            ]
예제 #9
0
def kernel(target_log_prob_fn,
           current_state,
           step_size,
           num_leapfrog_steps,
           seed=None,
           current_target_log_prob=None,
           current_grads_target_log_prob=None,
           name=None):
    """Runs one iteration of Hamiltonian Monte Carlo.

  Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC)
  algorithm that takes a series of gradient-informed steps to produce
  a Metropolis proposal. This function applies one step of HMC to
  randomly update the variable `x`.

  This function can update multiple chains in parallel. It assumes that all
  leftmost dimensions of `current_state` index independent chain states (and are
  therefore updated independently). The output of `target_log_prob_fn()` should
  sum log-probabilities across all event dimensions. Slices along the rightmost
  dimensions may have different target distributions; for example,
  `current_state[0, :]` could have a different target distribution from
  `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of
  independent chains is `tf.size(target_log_prob_fn(*current_state))`.)

  #### Examples:

  ##### Simple chain with warm-up.

  ```python
  tfd = tf.contrib.distributions

  # Tuning acceptance rates:
  dtype = np.float32
  target_accept_rate = 0.631
  num_warmup_iter = 500
  num_chain_iter = 500

  x = tf.get_variable(name="x", initializer=dtype(1))
  step_size = tf.get_variable(name="step_size", initializer=dtype(1))

  target = tfd.Normal(loc=dtype(0), scale=dtype(1))

  new_x, other_results = hmc.kernel(
      target_log_prob_fn=target.log_prob,
      current_state=x,
      step_size=step_size,
      num_leapfrog_steps=3)[:4]

  x_update = x.assign(new_x)

  step_size_update = step_size.assign_add(
      step_size * tf.where(
        other_results.acceptance_probs > target_accept_rate,
        0.01, -0.01))

  warmup = tf.group([x_update, step_size_update])

  tf.global_variables_initializer().run()

  sess.graph.finalize()  # No more graph building.

  # Warm up the sampler and adapt the step size
  for _ in xrange(num_warmup_iter):
    sess.run(warmup)

  # Collect samples without adapting step size
  samples = np.zeros([num_chain_iter])
  for i in xrange(num_chain_iter):
    _, x_, target_log_prob_, grad_ = sess.run([
        x_update,
        x,
        other_results.target_log_prob,
        other_results.grads_target_log_prob])
    samples[i] = x_

  print(samples.mean(), samples.std())
  ```

  ##### Sample from more complicated posterior.

  I.e.,

  ```none
    W ~ MVN(loc=0, scale=sigma * eye(dims))
    for i=1...num_samples:
        X[i] ~ MVN(loc=0, scale=eye(dims))
      eps[i] ~ Normal(loc=0, scale=1)
        Y[i] = X[i].T * W + eps[i]
  ```

  ```python
  tfd = tf.contrib.distributions

  def make_training_data(num_samples, dims, sigma):
    dt = np.asarray(sigma).dtype
    zeros = tf.zeros(dims, dtype=dt)
    x = tfd.MultivariateNormalDiag(
        loc=zeros).sample(num_samples, seed=1)
    w = tfd.MultivariateNormalDiag(
        loc=zeros,
        scale_identity_multiplier=sigma).sample(seed=2)
    noise = tfd.Normal(
        loc=dt(0),
        scale=dt(1)).sample(num_samples, seed=3)
    y = tf.tensordot(x, w, axes=[[1], [0]]) + noise
    return y, x, w

  def make_prior(sigma, dims):
    # p(w | sigma)
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros([dims], dtype=sigma.dtype),
        scale_identity_multiplier=sigma)

  def make_likelihood(x, w):
    # p(y | x, w)
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(x, w, axes=[[1], [0]]))

  # Setup assumptions.
  dtype = np.float32
  num_samples = 150
  dims = 10
  num_iters = int(5e3)

  true_sigma = dtype(0.5)
  y, x, true_weights = make_training_data(num_samples, dims, true_sigma)

  # Estimate of `log(true_sigma)`.
  log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0))
  sigma = tf.exp(log_sigma)

  # State of the Markov chain.
  weights = tf.get_variable(
      name="weights",
      initializer=np.random.randn(dims).astype(dtype))

  prior = make_prior(sigma, dims)

  def joint_log_prob_fn(w):
    # f(w) = log p(w, y | x)
    return prior.log_prob(w) + make_likelihood(x, w).log_prob(y)

  weights_update = weights.assign(
      hmc.kernel(target_log_prob_fn=joint_log_prob,
                 current_state=weights,
                 step_size=0.1,
                 num_leapfrog_steps=5)[0])

  with tf.control_dependencies([weights_update]):
    loss = -prior.log_prob(weights)

  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
  log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma])

  sess.graph.finalize()  # No more graph building.

  tf.global_variables_initializer().run()

  sigma_history = np.zeros(num_iters, dtype)
  weights_history = np.zeros([num_iters, dims], dtype)

  for i in xrange(num_iters):
    _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights])
    weights_history[i, :] = weights_
    sigma_history[i] = sigma_

  true_weights_ = sess.run(true_weights)

  # Should converge to something close to true_sigma.
  plt.plot(sigma_history);
  plt.ylabel("sigma");
  plt.xlabel("iteration");
  ```

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
      for the leapfrog integrator. Must broadcast with the shape of
      `current_state`. Larger step sizes lead to faster progress, but too-large
      step sizes make rejection exponentially more likely. When possible, it's
      often helpful to match per-variable step sizes to the standard deviations
      of the target distribution in each variable.
    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
      for. Total progress per HMC step is roughly proportional to `step_size *
      num_leapfrog_steps`.
    seed: Python integer to seed the random number generator.
    current_target_log_prob: (Optional) `Tensor` representing the value of
      `target_log_prob_fn` at the `current_state`. The only reason to
      specify this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
      representing gradient of `current_target_log_prob` at the `current_state`
      and wrt the `current_state`. Must have same shape as `current_state`. The
      only reason to specify this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "hmc_kernel").

  Returns:
    accepted_state: Tensor or Python list of `Tensor`s representing the state(s)
      of the Markov chain(s) at each result step. Has same shape as
      `current_state`.
    acceptance_probs: Tensor with the acceptance probabilities for each
      iteration. Has shape matching `target_log_prob_fn(current_state)`.
    accepted_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn` at `accepted_state`.
    accepted_grads_target_log_prob: Python `list` of `Tensor`s representing the
      gradient of `accepted_target_log_prob` wrt each `accepted_state`.

  Raises:
    ValueError: if there isn't one `step_size` or a list with same length as
      `current_state`.
  """
    with ops.name_scope(name, "hmc_kernel", [
            current_state, step_size, num_leapfrog_steps, seed,
            current_target_log_prob, current_grads_target_log_prob
    ]):
        with ops.name_scope("initialize"):
            [
                current_state_parts, step_sizes, current_target_log_prob,
                current_grads_target_log_prob
            ] = _prepare_args(target_log_prob_fn,
                              current_state,
                              step_size,
                              current_target_log_prob,
                              current_grads_target_log_prob,
                              maybe_expand=True)
            independent_chain_ndims = distributions_util.prefer_static_rank(
                current_target_log_prob)

            def init_momentum(s):
                return random_ops.random_normal(
                    shape=array_ops.shape(s),
                    dtype=s.dtype.base_dtype,
                    seed=distributions_util.gen_new_seed(
                        seed, salt="hmc_kernel_momentums"))

            current_momentums = [init_momentum(s) for s in current_state_parts]

        [
            proposed_momentums,
            proposed_state_parts,
            proposed_target_log_prob,
            proposed_grads_target_log_prob,
        ] = _leapfrog_integrator(current_momentums, target_log_prob_fn,
                                 current_state_parts, step_sizes,
                                 num_leapfrog_steps, current_target_log_prob,
                                 current_grads_target_log_prob)

        energy_change = _compute_energy_change(current_target_log_prob,
                                               current_momentums,
                                               proposed_target_log_prob,
                                               proposed_momentums,
                                               independent_chain_ndims)

        # u < exp(min(-energy, 0)),  where u~Uniform[0,1)
        # ==> -log(u) >= max(e, 0)
        # ==> -log(u) >= e
        # (Perhaps surprisingly, we don't have a better way to obtain a random
        # uniform from positive reals, i.e., `tf.random_uniform(minval=0,
        # maxval=np.inf)` won't work.)
        random_uniform = random_ops.random_uniform(
            shape=array_ops.shape(energy_change),
            dtype=energy_change.dtype,
            seed=seed)
        random_positive = -math_ops.log(random_uniform)
        is_accepted = random_positive >= energy_change

        accepted_target_log_prob = array_ops.where(is_accepted,
                                                   proposed_target_log_prob,
                                                   current_target_log_prob)

        accepted_state_parts = [
            _choose(is_accepted, proposed_state_part, current_state_part,
                    independent_chain_ndims)
            for current_state_part, proposed_state_part in zip(
                current_state_parts, proposed_state_parts)
        ]

        accepted_grads_target_log_prob = [
            _choose(is_accepted, proposed_grad, grad, independent_chain_ndims)
            for proposed_grad, grad in zip(proposed_grads_target_log_prob,
                                           current_grads_target_log_prob)
        ]

        maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0]
        return [
            maybe_flatten(accepted_state_parts),
            KernelResults(
                acceptance_probs=math_ops.exp(
                    math_ops.minimum(-energy_change, 0.)),
                current_grads_target_log_prob=accepted_grads_target_log_prob,
                current_target_log_prob=accepted_target_log_prob,
                energy_change=energy_change,
                is_accepted=is_accepted,
                proposed_grads_target_log_prob=proposed_grads_target_log_prob,
                proposed_state=maybe_flatten(proposed_state_parts),
                proposed_target_log_prob=proposed_target_log_prob,
                random_positive=random_positive,
            ),
        ]
예제 #10
0
    def one_step(self, current_state, previous_kernel_results):
        """Runs one iteration of Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions
        index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      TypeError: if `not target_log_prob.dtype.is_floating`.
    """
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'slice',
                                                    'one_step'),
                           values=[
                               self.step_size, self.max_doublings,
                               self._seed_stream, current_state,
                               previous_kernel_results.target_log_prob
                           ]):
            with tf.name_scope('initialize'):
                [current_state_parts, step_sizes, current_target_log_prob
                 ] = _prepare_args(self.target_log_prob_fn,
                                   current_state,
                                   self.step_size,
                                   previous_kernel_results.target_log_prob,
                                   maybe_expand=True)

                max_doublings = tf.convert_to_tensor(self.max_doublings,
                                                     dtype=tf.int32,
                                                     name='max_doublings')

            independent_chain_ndims = distributions_util.prefer_static_rank(
                current_target_log_prob)

            [
                next_state_parts, next_target_log_prob, bounds_satisfied,
                direction, upper_bounds, lower_bounds
            ] = _sample_next(self.target_log_prob_fn,
                             current_state_parts,
                             step_sizes,
                             max_doublings,
                             current_target_log_prob,
                             independent_chain_ndims,
                             seed=self._seed_stream())

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                SliceSamplerKernelResults(target_log_prob=next_target_log_prob,
                                          bounds_satisfied=bounds_satisfied,
                                          direction=direction,
                                          upper_bounds=upper_bounds,
                                          lower_bounds=lower_bounds),
            ]
예제 #11
0
def kernel(target_log_prob_fn,
           current_state,
           step_size,
           num_leapfrog_steps,
           seed=None,
           current_target_log_prob=None,
           current_grads_target_log_prob=None,
           name=None):
  """Runs one iteration of Hamiltonian Monte Carlo.

  Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC)
  algorithm that takes a series of gradient-informed steps to produce
  a Metropolis proposal. This function applies one step of HMC to
  randomly update the variable `x`.

  This function can update multiple chains in parallel. It assumes that all
  leftmost dimensions of `current_state` index independent chain states (and are
  therefore updated independently). The output of `target_log_prob_fn()` should
  sum log-probabilities across all event dimensions. Slices along the rightmost
  dimensions may have different target distributions; for example,
  `current_state[0, :]` could have a different target distribution from
  `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of
  independent chains is `tf.size(target_log_prob_fn(*current_state))`.)

  #### Examples:

  ##### Simple chain with warm-up.

  ```python
  tfd = tf.contrib.distributions

  # Tuning acceptance rates:
  dtype = np.float32
  target_accept_rate = 0.631
  num_warmup_iter = 500
  num_chain_iter = 500

  x = tf.get_variable(name="x", initializer=dtype(1))
  step_size = tf.get_variable(name="step_size", initializer=dtype(1))

  target = tfd.Normal(loc=dtype(0), scale=dtype(1))

  new_x, other_results = hmc.kernel(
      target_log_prob_fn=target.log_prob,
      current_state=x,
      step_size=step_size,
      num_leapfrog_steps=3)[:4]

  x_update = x.assign(new_x)

  step_size_update = step_size.assign_add(
      step_size * tf.where(
        other_results.acceptance_probs > target_accept_rate,
        0.01, -0.01))

  warmup = tf.group([x_update, step_size_update])

  tf.global_variables_initializer().run()

  sess.graph.finalize()  # No more graph building.

  # Warm up the sampler and adapt the step size
  for _ in xrange(num_warmup_iter):
    sess.run(warmup)

  # Collect samples without adapting step size
  samples = np.zeros([num_chain_iter])
  for i in xrange(num_chain_iter):
    _, x_, target_log_prob_, grad_ = sess.run([
        x_update,
        x,
        other_results.target_log_prob,
        other_results.grads_target_log_prob])
    samples[i] = x_

  print(samples.mean(), samples.std())
  ```

  ##### Sample from more complicated posterior.

  I.e.,

  ```none
    W ~ MVN(loc=0, scale=sigma * eye(dims))
    for i=1...num_samples:
        X[i] ~ MVN(loc=0, scale=eye(dims))
      eps[i] ~ Normal(loc=0, scale=1)
        Y[i] = X[i].T * W + eps[i]
  ```

  ```python
  tfd = tf.contrib.distributions

  def make_training_data(num_samples, dims, sigma):
    dt = np.asarray(sigma).dtype
    zeros = tf.zeros(dims, dtype=dt)
    x = tfd.MultivariateNormalDiag(
        loc=zeros).sample(num_samples, seed=1)
    w = tfd.MultivariateNormalDiag(
        loc=zeros,
        scale_identity_multiplier=sigma).sample(seed=2)
    noise = tfd.Normal(
        loc=dt(0),
        scale=dt(1)).sample(num_samples, seed=3)
    y = tf.tensordot(x, w, axes=[[1], [0]]) + noise
    return y, x, w

  def make_prior(sigma, dims):
    # p(w | sigma)
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros([dims], dtype=sigma.dtype),
        scale_identity_multiplier=sigma)

  def make_likelihood(x, w):
    # p(y | x, w)
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(x, w, axes=[[1], [0]]))

  # Setup assumptions.
  dtype = np.float32
  num_samples = 150
  dims = 10
  num_iters = int(5e3)

  true_sigma = dtype(0.5)
  y, x, true_weights = make_training_data(num_samples, dims, true_sigma)

  # Estimate of `log(true_sigma)`.
  log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0))
  sigma = tf.exp(log_sigma)

  # State of the Markov chain.
  weights = tf.get_variable(
      name="weights",
      initializer=np.random.randn(dims).astype(dtype))

  prior = make_prior(sigma, dims)

  def joint_log_prob_fn(w):
    # f(w) = log p(w, y | x)
    return prior.log_prob(w) + make_likelihood(x, w).log_prob(y)

  weights_update = weights.assign(
      hmc.kernel(target_log_prob_fn=joint_log_prob,
                 current_state=weights,
                 step_size=0.1,
                 num_leapfrog_steps=5)[0])

  with tf.control_dependencies([weights_update]):
    loss = -prior.log_prob(weights)

  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
  log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma])

  sess.graph.finalize()  # No more graph building.

  tf.global_variables_initializer().run()

  sigma_history = np.zeros(num_iters, dtype)
  weights_history = np.zeros([num_iters, dims], dtype)

  for i in xrange(num_iters):
    _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights])
    weights_history[i, :] = weights_
    sigma_history[i] = sigma_

  true_weights_ = sess.run(true_weights)

  # Should converge to something close to true_sigma.
  plt.plot(sigma_history);
  plt.ylabel("sigma");
  plt.xlabel("iteration");
  ```

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
      for the leapfrog integrator. Must broadcast with the shape of
      `current_state`. Larger step sizes lead to faster progress, but too-large
      step sizes make rejection exponentially more likely. When possible, it's
      often helpful to match per-variable step sizes to the standard deviations
      of the target distribution in each variable.
    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
      for. Total progress per HMC step is roughly proportional to `step_size *
      num_leapfrog_steps`.
    seed: Python integer to seed the random number generator.
    current_target_log_prob: (Optional) `Tensor` representing the value of
      `target_log_prob_fn` at the `current_state`. The only reason to
      specify this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
      representing gradient of `current_target_log_prob` at the `current_state`
      and wrt the `current_state`. Must have same shape as `current_state`. The
      only reason to specify this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "hmc_kernel").

  Returns:
    accepted_state: Tensor or Python list of `Tensor`s representing the state(s)
      of the Markov chain(s) at each result step. Has same shape as
      `current_state`.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.

  Raises:
    ValueError: if there isn't one `step_size` or a list with same length as
      `current_state`.
  """
  with ops.name_scope(
      name, "hmc_kernel",
      [current_state, step_size, num_leapfrog_steps, seed,
       current_target_log_prob, current_grads_target_log_prob]):
    with ops.name_scope("initialize"):
      [current_state_parts, step_sizes, current_target_log_prob,
       current_grads_target_log_prob] = _prepare_args(
           target_log_prob_fn, current_state, step_size,
           current_target_log_prob, current_grads_target_log_prob,
           maybe_expand=True)
      independent_chain_ndims = distributions_util.prefer_static_rank(
          current_target_log_prob)
      current_momentums = []
      for s in current_state_parts:
        current_momentums.append(random_ops.random_normal(
            shape=array_ops.shape(s),
            dtype=s.dtype.base_dtype,
            seed=seed))
        seed = distributions_util.gen_new_seed(
            seed, salt="hmc_kernel_momentums")

      num_leapfrog_steps = ops.convert_to_tensor(
          num_leapfrog_steps,
          dtype=dtypes.int32,
          name="num_leapfrog_steps")
    [
        proposed_momentums,
        proposed_state_parts,
        proposed_target_log_prob,
        proposed_grads_target_log_prob,
    ] = _leapfrog_integrator(current_momentums,
                             target_log_prob_fn,
                             current_state_parts,
                             step_sizes,
                             num_leapfrog_steps,
                             current_target_log_prob,
                             current_grads_target_log_prob)

    energy_change = _compute_energy_change(current_target_log_prob,
                                           current_momentums,
                                           proposed_target_log_prob,
                                           proposed_momentums,
                                           independent_chain_ndims)

    # u < exp(min(-energy, 0)),  where u~Uniform[0,1)
    # ==> -log(u) >= max(e, 0)
    # ==> -log(u) >= e
    # (Perhaps surprisingly, we don't have a better way to obtain a random
    # uniform from positive reals, i.e., `tf.random_uniform(minval=0,
    # maxval=np.inf)` won't work.)
    random_uniform = random_ops.random_uniform(
        shape=array_ops.shape(energy_change),
        dtype=energy_change.dtype,
        seed=seed)
    random_positive = -math_ops.log(random_uniform)
    is_accepted = random_positive >= energy_change

    accepted_target_log_prob = array_ops.where(is_accepted,
                                               proposed_target_log_prob,
                                               current_target_log_prob)

    accepted_state_parts = [_choose(is_accepted,
                                    proposed_state_part,
                                    current_state_part,
                                    independent_chain_ndims)
                            for current_state_part, proposed_state_part
                            in zip(current_state_parts, proposed_state_parts)]

    accepted_grads_target_log_prob = [
        _choose(is_accepted,
                proposed_grad,
                grad,
                independent_chain_ndims)
        for proposed_grad, grad
        in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)]

    maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0]
    return [
        maybe_flatten(accepted_state_parts),
        KernelResults(
            acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)),
            current_grads_target_log_prob=accepted_grads_target_log_prob,
            current_target_log_prob=accepted_target_log_prob,
            energy_change=energy_change,
            is_accepted=is_accepted,
            proposed_grads_target_log_prob=proposed_grads_target_log_prob,
            proposed_state=maybe_flatten(proposed_state_parts),
            proposed_target_log_prob=proposed_target_log_prob,
            random_positive=random_positive,
        ),
    ]
예제 #12
0
  def one_step(self, current_state, previous_kernel_results):
    """Runs one iteration of Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s of fully defined
        static shape representing the current state(s) of the Markov chain(s).
        The first `r` dimensions index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      ValueError: if `current_state` does not have a fully defined static shape.
      TypeError: if `not target_log_prob.dtype.is_floating`.
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'slice', 'one_step'),
        values=[self.step_size, self.max_doublings, self._seed_stream,
                current_state,
                previous_kernel_results.target_log_prob]):
      with tf.name_scope('initialize'):
        [
            current_state_parts,
            step_sizes,
            current_target_log_prob
        ] = _prepare_args(
            self.target_log_prob_fn,
            current_state,
            self.step_size,
            previous_kernel_results.target_log_prob,
            maybe_expand=True)

        max_doublings = tf.convert_to_tensor(
            self.max_doublings,
            dtype=tf.int32,
            name='max_doublings')

      independent_chain_ndims = distributions_util.prefer_static_rank(
          current_target_log_prob)

      [
          next_state_parts,
          next_target_log_prob,
          bounds_satisfied,
          direction,
          upper_bounds,
          lower_bounds
      ] = _sample_next(
          self.target_log_prob_fn,
          current_state_parts,
          step_sizes,
          max_doublings,
          current_target_log_prob,
          independent_chain_ndims,
          seed=self._seed_stream()
      )

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      return [
          maybe_flatten(next_state_parts),
          SliceSamplerKernelResults(
              target_log_prob=next_target_log_prob,
              bounds_satisfied=bounds_satisfied,
              direction=direction,
              upper_bounds=upper_bounds,
              lower_bounds=lower_bounds
          ),
      ]