예제 #1
0
            def body(i, exchanged_states):
                """Body of while loop for exchanging states."""
                # Propose exchange between replicas indexed by m and n.
                m, n = tf.unstack(exchange_proposed[i])

                # Construct log_accept_ratio:  -temp_diff * target_log_prob_diff.
                # Note target_log_prob_diff = -EnergyDiff (common definition is in terms
                # of energy).
                temp_diff = self.inverse_temperatures[
                    m] - self.inverse_temperatures[n]
                # Difference of target log probs may be +- Inf or NaN.  We want the
                # product of this with the temperature difference to have "alt value" of
                # -Inf.
                log_accept_ratio = mcmc_util.safe_sum([
                    -temp_diff * target_log_probs[m],
                    temp_diff * target_log_probs[n]
                ])

                is_exchange_accepted = log_uniforms[i] < log_accept_ratio

                for k in range(num_state_parts):
                    new_m, new_n = _swap(is_exchange_accepted,
                                         old_states[k].read(m),
                                         old_states[k].read(n))
                    exchanged_states[k] = exchanged_states[k].write(m, new_m)
                    exchanged_states[k] = exchanged_states[k].write(n, new_n)

                return i + 1, exchanged_states
  def one_step(self, current_state, previous_kernel_results):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'mh', 'one_step'),
        values=[current_state, previous_kernel_results]):
      # Take one inner step.
      [
          proposed_state,
          proposed_results,
      ] = self.inner_kernel.one_step(
          current_state,
          previous_kernel_results.accepted_results)

      if (not has_target_log_prob(proposed_results) or
          not has_target_log_prob(previous_kernel_results.accepted_results)):
        raise ValueError('"target_log_prob" must be a member of '
                         '`inner_kernel` results.')

      # Compute log(acceptance_ratio).
      to_sum = [proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob]
      try:
        if (not mcmc_util.is_list_like(
            proposed_results.log_acceptance_correction)
            or proposed_results.log_acceptance_correction):
          to_sum.append(proposed_results.log_acceptance_correction)
      except AttributeError:
        warnings.warn('Supplied inner `TransitionKernel` does not have a '
                      '`log_acceptance_correction`. Assuming its value is `0.`')
      log_accept_ratio = mcmc_util.safe_sum(
          to_sum, name='compute_log_accept_ratio')

      # If proposed state reduces likelihood: randomly accept.
      # If proposed state increases likelihood: always accept.
      # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
      #       ==> log(u) < log_accept_ratio
      log_uniform = tf.log(tf.random_uniform(
          shape=tf.shape(proposed_results.target_log_prob),
          dtype=proposed_results.target_log_prob.dtype.base_dtype,
          seed=self._seed_stream()))
      is_accepted = log_uniform < log_accept_ratio

      next_state = mcmc_util.choose(
          is_accepted,
          proposed_state,
          current_state,
          name='choose_next_state')

      kernel_results = MetropolisHastingsKernelResults(
          accepted_results=mcmc_util.choose(
              is_accepted,
              proposed_results,
              previous_kernel_results.accepted_results,
              name='choose_inner_results'),
          is_accepted=is_accepted,
          log_accept_ratio=log_accept_ratio,
          proposed_state=proposed_state,
          proposed_results=proposed_results,
          extra=[],
      )

      return next_state, kernel_results
예제 #3
0
파일: hmc.py 프로젝트: xuxyang/probability
def _compute_log_acceptance_correction(current_momentums,
                                       proposed_momentums,
                                       independent_chain_ndims,
                                       name=None):
    """Helper to `kernel` which computes the log acceptance-correction.

  A sufficient but not necessary condition for the existence of a stationary
  distribution, `p(x)`, is "detailed balance", i.e.:

  ```none
  p(x'|x) p(x) = p(x|x') p(x')
  ```

  In the Metropolis-Hastings algorithm, a state is proposed according to
  `g(x'|x)` and accepted according to `a(x'|x)`, hence
  `p(x'|x) = g(x'|x) a(x'|x)`.

  Inserting this into the detailed balance equation implies:

  ```none
      g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x')
  ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)]    (*)
  ```

  One definition of `a(x'|x)` which satisfies (*) is:

  ```none
  a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)])
  ```

  (To see that this satisfies (*), notice that under this definition only at
  most one `a(x'|x)` and `a(x|x') can be other than one.)

  We call the bracketed term the "acceptance correction".

  In the case of UncalibratedHMC, the log acceptance-correction is not the log
  proposal-ratio. UncalibratedHMC augments the state-space with momentum, z.
  Assuming a standard Gaussian distribution for momentums, the chain eventually
  converges to:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  ```

  Relating this back to Metropolis-Hastings parlance, for HMC we have:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  g([x, z] | [x', z']) = g([x', z'] | [x, z])
  ```

  In other words, the MH bracketed term is `1`. However, because we desire to
  use a general MH framework, we can place the momentum probability ratio inside
  the metropolis-correction factor thus getting an acceptance probability:

  ```none
                       target_prob(x')
  accept_prob(x'|x) = -----------------  [exp(-0.5 z**2) / exp(-0.5 z'**2)]
                       target_prob(x)
  ```

  (Note: we actually need to handle the kinetic energy change at each leapfrog
  step, but this is the idea.)

  Args:
    current_momentums: `Tensor` representing the value(s) of the current
      momentum(s) of the state (parts).
    proposed_momentums: `Tensor` representing the value(s) of the proposed
      momentum(s) of the state (parts).
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """
    with tf.name_scope(
            name, 'compute_log_acceptance_correction',
        [independent_chain_ndims, current_momentums, proposed_momentums]):
        log_current_kinetic, log_proposed_kinetic = [], []
        for current_momentum, proposed_momentum in zip(current_momentums,
                                                       proposed_momentums):
            axis = tf.range(independent_chain_ndims, tf.rank(current_momentum))
            log_current_kinetic.append(_log_sum_sq(current_momentum, axis))
            log_proposed_kinetic.append(_log_sum_sq(proposed_momentum, axis))
        current_kinetic = 0.5 * tf.exp(
            tf.reduce_logsumexp(tf.stack(log_current_kinetic, axis=-1),
                                axis=-1))
        proposed_kinetic = 0.5 * tf.exp(
            tf.reduce_logsumexp(tf.stack(log_proposed_kinetic, axis=-1),
                                axis=-1))
        return mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
예제 #4
0
def _compute_log_acceptance_correction(current_state_parts,
                                       proposed_state_parts,
                                       current_volatility_parts,
                                       proposed_volatility_parts,
                                       current_drift_parts,
                                       proposed_drift_parts,
                                       step_size_parts,
                                       independent_chain_ndims,
                                       name=None):
    r"""Helper to `kernel` which computes the log acceptance-correction.

  Computes `log_acceptance_correction` as described in `MetropolisHastings`
  class. The proposal density is normal. More specifically,

   ```none
  q(proposed_state | current_state) \sim N(current_state + current_drift,
  step_size * current_volatility**2)

  q(current_state | proposed_state) \sim N(proposed_state + proposed_drift,
  step_size * proposed_volatility**2)
  ```

  The `log_acceptance_correction` is then

  ```none
  log_acceptance_correctio = q(current_state | proposed_state)
  - q(proposed_state | current_state)
  ```

  Args:
    current_state_parts: Python `list` of `Tensor`s representing the value(s) of
      the current state of the chain.
    proposed_state_parts:  Python `list` of `Tensor`s representing the value(s)
      of the proposed state of the chain. Must broadcast with the shape of
      `current_state_parts`.
    current_volatility_parts: Python `list` of `Tensor`s representing the value
      of `volatility_fn(*current_volatility_parts)`. Must broadcast with the
      shape of `current_state_parts`.
    proposed_volatility_parts: Python `list` of `Tensor`s representing the value
      of `volatility_fn(*proposed_volatility_parts)`. Must broadcast with the
      shape of `current_state_parts`
    current_drift_parts: Python `list` of `Tensor`s representing value of the
      drift `_get_drift(*current_state_parts, ..)`. Must broadcast with the
      shape of `current_state_parts`.
    proposed_drift_parts: Python `list` of `Tensor`s representing value of the
      drift `_get_drift(*proposed_drift_parts, ..)`. Must broadcast with the
      shape of `current_state_parts`.
    step_size_parts: Python `list` of `Tensor`s representing the step size for
      Euler-Maruyama method. Must broadcast with the shape of
      `current_state_parts`.
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """

    with tf.name_scope(name, 'compute_log_acceptance_correction', [
            current_state_parts, proposed_state_parts,
            current_volatility_parts, proposed_volatility_parts,
            current_drift_parts, proposed_drift_parts, step_size_parts,
            independent_chain_ndims
    ]):

        proposed_log_density_parts = []
        dual_log_density_parts = []

        for [
                current_state,
                proposed_state,
                current_volatility,
                proposed_volatility,
                current_drift,
                proposed_drift,
                step_size,
        ] in zip(
                current_state_parts,
                proposed_state_parts,
                current_volatility_parts,
                proposed_volatility_parts,
                current_drift_parts,
                proposed_drift_parts,
                step_size_parts,
        ):
            axis = tf.range(independent_chain_ndims, tf.rank(current_state))

            state_diff = proposed_state - current_state

            current_volatility *= tf.sqrt(step_size)

            proposed_energy = (state_diff - current_drift) / current_volatility

            proposed_volatility *= tf.sqrt(step_size)
            # Compute part of `q(proposed_state | current_state)`
            proposed_energy = (tf.reduce_sum(mcmc_util.safe_sum(
                [tf.log(current_volatility), 0.5 * (proposed_energy**2)]),
                                             axis=axis))
            proposed_log_density_parts.append(-proposed_energy)

            # Compute part of `q(current_state | proposed_state)`
            dual_energy = (state_diff + proposed_drift) / proposed_volatility
            dual_energy = (tf.reduce_sum(mcmc_util.safe_sum(
                [tf.log(proposed_volatility), 0.5 * (dual_energy**2)]),
                                         axis=axis))
            dual_log_density_parts.append(-dual_energy)

        # Compute `q(proposed_state | current_state)`
        proposed_log_density_reduce = tf.reduce_sum(tf.stack(
            proposed_log_density_parts, axis=-1),
                                                    axis=-1)
        # Compute `q(current_state | proposed_state)`
        dual_log_density_reduce = tf.reduce_sum(tf.stack(
            dual_log_density_parts, axis=-1),
                                                axis=-1)

        return mcmc_util.safe_sum(
            [dual_log_density_reduce, -proposed_log_density_reduce])
예제 #5
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        # Take one inner step.
        [
            proposed_state,
            proposed_results,
        ] = self.inner_kernel.one_step(
            current_state, previous_kernel_results.accepted_results)

        if (not has_target_log_prob(proposed_results)
                or not has_target_log_prob(
                    previous_kernel_results.accepted_results)):
            raise ValueError('"target_log_prob" must be a member of '
                             '`inner_kernel` results.')

        # Compute log(acceptance_ratio).
        to_sum = [
            proposed_results.target_log_prob,
            -previous_kernel_results.accepted_results.target_log_prob
        ]
        try:
            to_sum.append(proposed_results.log_acceptance_correction)
        except AttributeError:
            warnings.warn(
                'Supplied inner `TransitionKernel` does not have a '
                '`log_acceptance_correction`. Assuming its value is `0.`')
        log_accept_ratio = mcmc_util.safe_sum(to_sum,
                                              name='compute_log_accept_ratio')

        # If proposed state reduces likelihood: randomly accept.
        # If proposed state increases likelihood: always accept.
        # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
        #       ==> log(u) < log_accept_ratio
        # Note:
        # - We mutate seed state so subsequent calls are not correlated.
        # - We mutate seed BEFORE using it just in case users supplied the
        #   same seed to the inner kernel.
        self._seed = distributions_util.gen_new_seed(
            self.seed, salt='metropolis_hastings_one_step')
        log_uniform = tf.log(
            tf.random_uniform(
                shape=tf.shape(proposed_results.target_log_prob),
                dtype=proposed_results.target_log_prob.dtype.base_dtype,
                seed=self.seed))
        is_accepted = log_uniform < log_accept_ratio

        independent_chain_ndims = distributions_util.prefer_static_rank(
            proposed_results.target_log_prob)

        next_state = mcmc_util.choose(is_accepted, proposed_state,
                                      current_state, independent_chain_ndims)

        accepted_results = type(proposed_results)(
            **dict([(fn,
                     mcmc_util.choose(
                         is_accepted, getattr(proposed_results, fn),
                         getattr(previous_kernel_results.accepted_results, fn),
                         independent_chain_ndims))
                    for fn in proposed_results._fields]))

        return [
            next_state,
            MetropolisHastingsKernelResults(
                accepted_results=accepted_results,
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
            )
        ]
예제 #6
0
def _compute_log_acceptance_correction(current_momentums,
                                       proposed_momentums,
                                       independent_chain_ndims,
                                       name=None):
  """Helper to `kernel` which computes the log acceptance-correction.

  A sufficient but not necessary condition for the existence of a stationary
  distribution, `p(x)`, is "detailed balance", i.e.:

  ```none
  p(x'|x) p(x) = p(x|x') p(x')
  ```

  In the Metropolis-Hastings algorithm, a state is proposed according to
  `g(x'|x)` and accepted according to `a(x'|x)`, hence
  `p(x'|x) = g(x'|x) a(x'|x)`.

  Inserting this into the detailed balance equation implies:

  ```none
      g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x')
  ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)]    (*)
  ```

  One definition of `a(x'|x)` which satisfies (*) is:

  ```none
  a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)])
  ```

  (To see that this satisfies (*), notice that under this definition only at
  most one `a(x'|x)` and `a(x|x') can be other than one.)

  We call the bracketed term the "acceptance correction".

  In the case of UncalibratedHMC, the log acceptance-correction is not the log
  proposal-ratio. UncalibratedHMC augments the state-space with momentum, z.
  Assuming a standard Gaussian distribution for momentums, the chain eventually
  converges to:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  ```

  Relating this back to Metropolis-Hastings parlance, for HMC we have:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  g([x, z] | [x', z']) = g([x', z'] | [x, z])
  ```

  In other words, the MH bracketed term is `1`. However, because we desire to
  use a general MH framework, we can place the momentum probability ratio inside
  the metropolis-correction factor thus getting an acceptance probability:

  ```none
                       target_prob(x')
  accept_prob(x'|x) = -----------------  [exp(-0.5 z**2) / exp(-0.5 z'**2)]
                       target_prob(x)
  ```

  (Note: we actually need to handle the kinetic energy change at each leapfrog
  step, but this is the idea.)

  Args:
    current_momentums: `Tensor` representing the value(s) of the current
      momentum(s) of the state (parts).
    proposed_momentums: `Tensor` representing the value(s) of the proposed
      momentum(s) of the state (parts).
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """
  with tf.name_scope(
      name, 'compute_log_acceptance_correction',
      [independent_chain_ndims, current_momentums, proposed_momentums]):
    log_current_kinetic, log_proposed_kinetic = [], []
    for current_momentum, proposed_momentum in zip(
        current_momentums, proposed_momentums):
      axis = tf.range(independent_chain_ndims, tf.rank(current_momentum))
      log_current_kinetic.append(_log_sum_sq(current_momentum, axis))
      log_proposed_kinetic.append(_log_sum_sq(proposed_momentum, axis))
    current_kinetic = 0.5 * tf.exp(
        tf.reduce_logsumexp(tf.stack(log_current_kinetic, axis=-1), axis=-1))
    proposed_kinetic = 0.5 * tf.exp(
        tf.reduce_logsumexp(tf.stack(log_proposed_kinetic, axis=-1), axis=-1))
    return mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
예제 #7
0
  def one_step(self, current_state, previous_kernel_results):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'mh', 'one_step'),
        values=[current_state, previous_kernel_results]):
      # Take one inner step.
      [
          proposed_state,
          proposed_results,
      ] = self.inner_kernel.one_step(
          current_state,
          previous_kernel_results.accepted_results)

      if (not has_target_log_prob(proposed_results) or
          not has_target_log_prob(previous_kernel_results.accepted_results)):
        raise ValueError('"target_log_prob" must be a member of '
                         '`inner_kernel` results.')

      # Compute log(acceptance_ratio).
      to_sum = [proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob]
      try:
        if (not mcmc_util.is_list_like(
            proposed_results.log_acceptance_correction)
            or proposed_results.log_acceptance_correction):
          to_sum.append(proposed_results.log_acceptance_correction)
      except AttributeError:
        warnings.warn('Supplied inner `TransitionKernel` does not have a '
                      '`log_acceptance_correction`. Assuming its value is `0.`')
      log_accept_ratio = mcmc_util.safe_sum(
          to_sum, name='compute_log_accept_ratio')

      # If proposed state reduces likelihood: randomly accept.
      # If proposed state increases likelihood: always accept.
      # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
      #       ==> log(u) < log_accept_ratio
      log_uniform = tf.log(tf.random_uniform(
          shape=tf.shape(proposed_results.target_log_prob),
          dtype=proposed_results.target_log_prob.dtype.base_dtype,
          seed=self._seed_stream()))
      is_accepted = log_uniform < log_accept_ratio

      next_state = mcmc_util.choose(
          is_accepted,
          proposed_state,
          current_state,
          name='choose_next_state')

      kernel_results = MetropolisHastingsKernelResults(
          accepted_results=mcmc_util.choose(
              is_accepted,
              proposed_results,
              previous_kernel_results.accepted_results,
              name='choose_inner_results'),
          is_accepted=is_accepted,
          log_accept_ratio=log_accept_ratio,
          proposed_state=proposed_state,
          proposed_results=proposed_results,
          extra=[],
      )

      return next_state, kernel_results