def body(i, exchanged_states):
                """Body of while loop for exchanging states."""
                # Propose exchange between replicas indexed by m and n.
                m, n = tf.unstack(exchange_proposed[i])

                # Construct log_accept_ratio:  -temp_diff * target_log_prob_diff.
                # Note target_log_prob_diff = -EnergyDiff (common definition is in terms
                # of energy).
                temp_diff = self.inverse_temperatures[
                    m] - self.inverse_temperatures[n]
                # Difference of target log probs may be +- Inf or NaN.  We want the
                # product of this with the temperature difference to have "alt value" of
                # -Inf.
                log_accept_ratio = mcmc_util.safe_sum([
                    -temp_diff * target_log_probs[m],
                    temp_diff * target_log_probs[n]
                ])

                is_exchange_accepted = log_uniforms[i] < log_accept_ratio

                for k in range(num_state_parts):
                    new_m, new_n = _swap(is_exchange_accepted,
                                         old_states[k].read(m),
                                         old_states[k].read(n))
                    exchanged_states[k] = exchanged_states[k].write(m, new_m)
                    exchanged_states[k] = exchanged_states[k].write(n, new_n)

                return i + 1, exchanged_states
Beispiel #2
0
def _compute_log_acceptance_correction(current_state_parts,
                                       proposed_state_parts,
                                       current_volatility_parts,
                                       proposed_volatility_parts,
                                       current_drift_parts,
                                       proposed_drift_parts,
                                       step_size_parts,
                                       independent_chain_ndims,
                                       experimental_shard_axis_names=None,
                                       name=None):
    r"""Helper to `kernel` which computes the log acceptance-correction.

  Computes `log_acceptance_correction` as described in `MetropolisHastings`
  class. The proposal density is normal. More specifically,

   ```none
  q(proposed_state | current_state) \sim N(current_state + current_drift,
  step_size * current_volatility**2)

  q(current_state | proposed_state) \sim N(proposed_state + proposed_drift,
  step_size * proposed_volatility**2)
  ```

  The `log_acceptance_correction` is then

  ```none
  log_acceptance_correctio = q(current_state | proposed_state)
  - q(proposed_state | current_state)
  ```

  Args:
    current_state_parts: Python `list` of `Tensor`s representing the value(s) of
      the current state of the chain.
    proposed_state_parts:  Python `list` of `Tensor`s representing the value(s)
      of the proposed state of the chain. Must broadcast with the shape of
      `current_state_parts`.
    current_volatility_parts: Python `list` of `Tensor`s representing the value
      of `volatility_fn(*current_volatility_parts)`. Must broadcast with the
      shape of `current_state_parts`.
    proposed_volatility_parts: Python `list` of `Tensor`s representing the value
      of `volatility_fn(*proposed_volatility_parts)`. Must broadcast with the
      shape of `current_state_parts`
    current_drift_parts: Python `list` of `Tensor`s representing value of the
      drift `_get_drift(*current_state_parts, ..)`. Must broadcast with the
      shape of `current_state_parts`.
    proposed_drift_parts: Python `list` of `Tensor`s representing value of the
      drift `_get_drift(*proposed_drift_parts, ..)`. Must broadcast with the
      shape of `current_state_parts`.
    step_size_parts: Python `list` of `Tensor`s representing the step size for
      Euler-Maruyama method. Must broadcast with the shape of
      `current_state_parts`.
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    experimental_shard_axis_names: A structure of string names indicating how
      members of the state are sharded.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """

    with tf.name_scope(name or 'compute_log_acceptance_correction'):

        proposed_log_density_parts = []
        dual_log_density_parts = []

        if experimental_shard_axis_names is None:
            experimental_shard_axis_names = [None] * len(current_state_parts)

        for [
                current_state, proposed_state, current_volatility,
                proposed_volatility, current_drift, proposed_drift, step_size,
                shard_axes
        ] in zip(current_state_parts, proposed_state_parts,
                 current_volatility_parts, proposed_volatility_parts,
                 current_drift_parts, proposed_drift_parts, step_size_parts,
                 experimental_shard_axis_names):
            axis = ps.range(independent_chain_ndims, ps.rank(current_state))

            state_diff = proposed_state - current_state

            current_volatility *= tf.sqrt(step_size)

            proposed_energy = (state_diff - current_drift) / current_volatility

            proposed_volatility *= tf.sqrt(step_size)

            # Compute part of `q(proposed_state | current_state)`
            def reduce_sum(shard_axes, x, axis=None):
                x = tf.reduce_sum(x, axis)
                if shard_axes is not None:
                    x = distribute_lib.psum(x, shard_axes)
                return x

            proposed_energy = (reduce_sum(shard_axes,
                                          mcmc_util.safe_sum([
                                              tf.math.log(current_volatility),
                                              0.5 * (proposed_energy**2)
                                          ]),
                                          axis=axis))
            proposed_log_density_parts.append(-proposed_energy)

            # Compute part of `q(current_state | proposed_state)`
            dual_energy = (state_diff + proposed_drift) / proposed_volatility
            dual_energy = (reduce_sum(shard_axes,
                                      mcmc_util.safe_sum([
                                          tf.math.log(proposed_volatility),
                                          0.5 * (dual_energy**2)
                                      ]),
                                      axis=axis))
            dual_log_density_parts.append(-dual_energy)

        # Compute `q(proposed_state | current_state)`
        proposed_log_density_reduce = tf.add_n(proposed_log_density_parts)
        # Compute `q(current_state | proposed_state)`
        dual_log_density_reduce = tf.add_n(dual_log_density_parts)

        return mcmc_util.safe_sum(
            [dual_log_density_reduce, -proposed_log_density_reduce])
Beispiel #3
0
def _compute_log_acceptance_correction(current_momentums,
                                       proposed_momentums,
                                       independent_chain_ndims,
                                       shard_axis_names=None,
                                       name=None):
    """Helper to `kernel` which computes the log acceptance-correction.

  A sufficient but not necessary condition for the existence of a stationary
  distribution, `p(x)`, is "detailed balance", i.e.:

  ```none
  p(x'|x) p(x) = p(x|x') p(x')
  ```

  In the Metropolis-Hastings algorithm, a state is proposed according to
  `g(x'|x)` and accepted according to `a(x'|x)`, hence
  `p(x'|x) = g(x'|x) a(x'|x)`.

  Inserting this into the detailed balance equation implies:

  ```none
      g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x')
  ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)]    (*)
  ```

  One definition of `a(x'|x)` which satisfies (*) is:

  ```none
  a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)])
  ```

  (To see that this satisfies (*), notice that under this definition only at
  most one `a(x'|x)` and `a(x|x') can be other than one.)

  We call the bracketed term the "acceptance correction".

  In the case of UncalibratedHMC, the log acceptance-correction is not the log
  proposal-ratio. UncalibratedHMC augments the state-space with momentum, z.
  Assuming a standard Gaussian distribution for momentums, the chain eventually
  converges to:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  ```

  Relating this back to Metropolis-Hastings parlance, for HMC we have:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  g([x, z] | [x', z']) = g([x', z'] | [x, z])
  ```

  In other words, the MH bracketed term is `1`. However, because we desire to
  use a general MH framework, we can place the momentum probability ratio inside
  the metropolis-correction factor thus getting an acceptance probability:

  ```none
                       target_prob(x')
  accept_prob(x'|x) = -----------------  [exp(-0.5 z**2) / exp(-0.5 z'**2)]
                       target_prob(x)
  ```

  (Note: we actually need to handle the kinetic energy change at each leapfrog
  step, but this is the idea.)

  Args:
    current_momentums: `Tensor` representing the value(s) of the current
      momentum(s) of the state (parts).
    proposed_momentums: `Tensor` representing the value(s) of the proposed
      momentum(s) of the state (parts).
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    shard_axis_names: A structure of string names indicating how
      members of the state are sharded.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """
    with tf.name_scope(name or 'compute_log_acceptance_correction'):

        def compute_sum_sq(v, shard_axes):
            sum_sq = tf.reduce_sum(v**2.,
                                   axis=ps.range(independent_chain_ndims,
                                                 ps.rank(v)))
            if shard_axes is not None:
                sum_sq = distribute_lib.psum(sum_sq, shard_axes)
            return sum_sq

        shard_axis_names = (shard_axis_names
                            or ([None] * len(current_momentums)))
        current_kinetic = tf.add_n([
            compute_sum_sq(v, axes)
            for v, axes in zip(current_momentums, shard_axis_names)
        ])
        proposed_kinetic = tf.add_n([
            compute_sum_sq(v, axes)
            for v, axes in zip(proposed_momentums, shard_axis_names)
        ])
        return 0.5 * mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
def _compute_log_acceptance_correction(kinetic_energy_fn,
                                       current_momentums,
                                       proposed_momentums,
                                       name=None):
    """Helper to `kernel` which computes the log acceptance-correction.

  A sufficient but not necessary condition for the existence of a stationary
  distribution, `p(x)`, is "detailed balance", i.e.:

  ```none
  p(x'|x) p(x) = p(x|x') p(x')
  ```

  In the Metropolis-Hastings algorithm, a state is proposed according to
  `g(x'|x)` and accepted according to `a(x'|x)`, hence
  `p(x'|x) = g(x'|x) a(x'|x)`.

  Inserting this into the detailed balance equation implies:

  ```none
      g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x')
  ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)]    (*)
  ```

  One definition of `a(x'|x)` which satisfies (*) is:

  ```none
  a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)])
  ```

  (To see that this satisfies (*), notice that under this definition only at
  most one `a(x'|x)` and `a(x|x') can be other than one.)

  We call the bracketed term the "acceptance correction".

  In the case of UncalibratedHMC, the log acceptance-correction is not the log
  proposal-ratio. UncalibratedHMC augments the state-space with momentum, z.
  Given a probability density of `m(z)` for momentums, the chain eventually
  converges to:

  ```none
  p([x, z]) propto= target_prob(x) m(z)
  ```

  Relating this back to Metropolis-Hastings parlance, for HMC we have:

  ```none
  p([x, z]) propto= target_prob(x) m(z)
  g([x, z] | [x', z']) = g([x', z'] | [x, z])
  ```

  In other words, the MH bracketed term is `1`. However, because we desire to
  use a general MH framework, we can place the momentum probability ratio inside
  the metropolis-correction factor thus getting an acceptance probability:

  ```none
                       target_prob(x')
  accept_prob(x'|x) = -----------------  [m(z') / m(z)]
                       target_prob(x)
  ```
  (Note: we actually need to handle the kinetic energy change at each leapfrog
  step, but this is the idea.)

  For consistency, we compute this correction in log space, using the kinetic
  energy function, `K(z)`, which is the negative log probability of the momentum
  distribution. So the log acceptance probability is

  ```none
  log(correction) = log(m(z')) - log(m(z))
                  = K(z) - K(z')
  ```

  Note that this is equality, since the normalization constants on `m` cancel
  out.


  Args:
    kinetic_energy_fn: Python callable that can evaluate the kinetic energy
      of the given momentum. This is typically the negative log probability of
      the distribution over the momentum.
    current_momentums: (List of) `Tensor`s representing the value(s) of the
      current momentum(s) of the state (parts).
    proposed_momentums: (List of) `Tensor`s representing the value(s) of the
      proposed momentum(s) of the state (parts).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """
    with tf.name_scope(name or 'compute_log_acceptance_correction'):
        current_kinetic = kinetic_energy_fn(current_momentums)
        proposed_kinetic = kinetic_energy_fn(proposed_momentums)
        return mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        is_seeded = seed is not None
        seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
        proposal_seed, acceptance_seed = samplers.split_seed(seed)

        with tf.name_scope(mcmc_util.make_name(self.name, 'mh', 'one_step')):
            # Take one inner step.
            inner_kwargs = dict(seed=proposal_seed) if is_seeded else {}
            [
                proposed_state,
                proposed_results,
            ] = self.inner_kernel.one_step(
                current_state, previous_kernel_results.accepted_results,
                **inner_kwargs)
            if mcmc_util.is_list_like(current_state):
                proposed_state = tf.nest.pack_sequence_as(
                    current_state, proposed_state)

            if (not has_target_log_prob(proposed_results)
                    or not has_target_log_prob(
                        previous_kernel_results.accepted_results)):
                raise ValueError('"target_log_prob" must be a member of '
                                 '`inner_kernel` results.')

            # Compute log(acceptance_ratio).
            to_sum = [
                proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob
            ]
            try:
                if (not mcmc_util.is_list_like(
                        proposed_results.log_acceptance_correction)
                        or proposed_results.log_acceptance_correction):
                    to_sum.append(proposed_results.log_acceptance_correction)
            except AttributeError:
                warnings.warn(
                    'Supplied inner `TransitionKernel` does not have a '
                    '`log_acceptance_correction`. Assuming its value is `0.`')
            log_accept_ratio = mcmc_util.safe_sum(
                to_sum, name='compute_log_accept_ratio')

            # If proposed state reduces likelihood: randomly accept.
            # If proposed state increases likelihood: always accept.
            # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
            #       ==> log(u) < log_accept_ratio
            log_uniform = tf.math.log(
                samplers.uniform(shape=prefer_static.shape(
                    proposed_results.target_log_prob),
                                 dtype=dtype_util.base_dtype(
                                     proposed_results.target_log_prob.dtype),
                                 seed=acceptance_seed))
            is_accepted = log_uniform < log_accept_ratio

            next_state = mcmc_util.choose(is_accepted,
                                          proposed_state,
                                          current_state,
                                          name='choose_next_state')

            kernel_results = MetropolisHastingsKernelResults(
                accepted_results=mcmc_util.choose(
                    is_accepted,
                    # We strip seeds when populating `accepted_results` because unlike
                    # other kernel result fields, seeds are not a per-chain value.
                    # Thus it is impossible to choose between a previously accepted
                    # seed value and a proposed seed, since said choice would need to
                    # be made on a per-chain basis.
                    mcmc_util.strip_seeds(proposed_results),
                    previous_kernel_results.accepted_results,
                    name='choose_inner_results'),
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
                extra=[],
                seed=seed,
            )

            return next_state, kernel_results
Beispiel #6
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        with tf1.name_scope(name=mcmc_util.make_name(self.name, 'mh',
                                                     'one_step'),
                            values=[current_state, previous_kernel_results]):
            # Take one inner step.
            [
                proposed_state,
                proposed_results,
            ] = self.inner_kernel.one_step(
                current_state, previous_kernel_results.accepted_results)

            if (not has_target_log_prob(proposed_results)
                    or not has_target_log_prob(
                        previous_kernel_results.accepted_results)):
                raise ValueError('"target_log_prob" must be a member of '
                                 '`inner_kernel` results.')

            # Compute log(acceptance_ratio).
            to_sum = [
                proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob
            ]
            try:
                if (not mcmc_util.is_list_like(
                        proposed_results.log_acceptance_correction)
                        or proposed_results.log_acceptance_correction):
                    to_sum.append(proposed_results.log_acceptance_correction)
            except AttributeError:
                warnings.warn(
                    'Supplied inner `TransitionKernel` does not have a '
                    '`log_acceptance_correction`. Assuming its value is `0.`')
            log_accept_ratio = mcmc_util.safe_sum(
                to_sum, name='compute_log_accept_ratio')

            # If proposed state reduces likelihood: randomly accept.
            # If proposed state increases likelihood: always accept.
            # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
            #       ==> log(u) < log_accept_ratio
            log_uniform = tf.math.log(
                tf.random.uniform(
                    shape=tf.shape(input=proposed_results.target_log_prob),
                    dtype=proposed_results.target_log_prob.dtype.base_dtype,
                    seed=self._seed_stream()))
            is_accepted = log_uniform < log_accept_ratio

            next_state = mcmc_util.choose(is_accepted,
                                          proposed_state,
                                          current_state,
                                          name='choose_next_state')

            kernel_results = MetropolisHastingsKernelResults(
                accepted_results=mcmc_util.choose(
                    is_accepted,
                    proposed_results,
                    previous_kernel_results.accepted_results,
                    name='choose_inner_results'),
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
                extra=[],
            )

            return next_state, kernel_results