def body(i, exchanged_states): """Body of while loop for exchanging states.""" # Propose exchange between replicas indexed by m and n. m, n = tf.unstack(exchange_proposed[i]) # Construct log_accept_ratio: -temp_diff * target_log_prob_diff. # Note target_log_prob_diff = -EnergyDiff (common definition is in terms # of energy). temp_diff = self.inverse_temperatures[ m] - self.inverse_temperatures[n] # Difference of target log probs may be +- Inf or NaN. We want the # product of this with the temperature difference to have "alt value" of # -Inf. log_accept_ratio = mcmc_util.safe_sum([ -temp_diff * target_log_probs[m], temp_diff * target_log_probs[n] ]) is_exchange_accepted = log_uniforms[i] < log_accept_ratio for k in range(num_state_parts): new_m, new_n = _swap(is_exchange_accepted, old_states[k].read(m), old_states[k].read(n)) exchanged_states[k] = exchanged_states[k].write(m, new_m) exchanged_states[k] = exchanged_states[k].write(n, new_n) return i + 1, exchanged_states
def _compute_log_acceptance_correction(current_state_parts, proposed_state_parts, current_volatility_parts, proposed_volatility_parts, current_drift_parts, proposed_drift_parts, step_size_parts, independent_chain_ndims, experimental_shard_axis_names=None, name=None): r"""Helper to `kernel` which computes the log acceptance-correction. Computes `log_acceptance_correction` as described in `MetropolisHastings` class. The proposal density is normal. More specifically, ```none q(proposed_state | current_state) \sim N(current_state + current_drift, step_size * current_volatility**2) q(current_state | proposed_state) \sim N(proposed_state + proposed_drift, step_size * proposed_volatility**2) ``` The `log_acceptance_correction` is then ```none log_acceptance_correctio = q(current_state | proposed_state) - q(proposed_state | current_state) ``` Args: current_state_parts: Python `list` of `Tensor`s representing the value(s) of the current state of the chain. proposed_state_parts: Python `list` of `Tensor`s representing the value(s) of the proposed state of the chain. Must broadcast with the shape of `current_state_parts`. current_volatility_parts: Python `list` of `Tensor`s representing the value of `volatility_fn(*current_volatility_parts)`. Must broadcast with the shape of `current_state_parts`. proposed_volatility_parts: Python `list` of `Tensor`s representing the value of `volatility_fn(*proposed_volatility_parts)`. Must broadcast with the shape of `current_state_parts` current_drift_parts: Python `list` of `Tensor`s representing value of the drift `_get_drift(*current_state_parts, ..)`. Must broadcast with the shape of `current_state_parts`. proposed_drift_parts: Python `list` of `Tensor`s representing value of the drift `_get_drift(*proposed_drift_parts, ..)`. Must broadcast with the shape of `current_state_parts`. step_size_parts: Python `list` of `Tensor`s representing the step size for Euler-Maruyama method. Must broadcast with the shape of `current_state_parts`. independent_chain_ndims: Scalar `int` `Tensor` representing the number of leftmost `Tensor` dimensions which index independent chains. experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'compute_log_acceptance_correction'). Returns: log_acceptance_correction: `Tensor` representing the `log` acceptance-correction. (See docstring for mathematical definition.) """ with tf.name_scope(name or 'compute_log_acceptance_correction'): proposed_log_density_parts = [] dual_log_density_parts = [] if experimental_shard_axis_names is None: experimental_shard_axis_names = [None] * len(current_state_parts) for [ current_state, proposed_state, current_volatility, proposed_volatility, current_drift, proposed_drift, step_size, shard_axes ] in zip(current_state_parts, proposed_state_parts, current_volatility_parts, proposed_volatility_parts, current_drift_parts, proposed_drift_parts, step_size_parts, experimental_shard_axis_names): axis = ps.range(independent_chain_ndims, ps.rank(current_state)) state_diff = proposed_state - current_state current_volatility *= tf.sqrt(step_size) proposed_energy = (state_diff - current_drift) / current_volatility proposed_volatility *= tf.sqrt(step_size) # Compute part of `q(proposed_state | current_state)` def reduce_sum(shard_axes, x, axis=None): x = tf.reduce_sum(x, axis) if shard_axes is not None: x = distribute_lib.psum(x, shard_axes) return x proposed_energy = (reduce_sum(shard_axes, mcmc_util.safe_sum([ tf.math.log(current_volatility), 0.5 * (proposed_energy**2) ]), axis=axis)) proposed_log_density_parts.append(-proposed_energy) # Compute part of `q(current_state | proposed_state)` dual_energy = (state_diff + proposed_drift) / proposed_volatility dual_energy = (reduce_sum(shard_axes, mcmc_util.safe_sum([ tf.math.log(proposed_volatility), 0.5 * (dual_energy**2) ]), axis=axis)) dual_log_density_parts.append(-dual_energy) # Compute `q(proposed_state | current_state)` proposed_log_density_reduce = tf.add_n(proposed_log_density_parts) # Compute `q(current_state | proposed_state)` dual_log_density_reduce = tf.add_n(dual_log_density_parts) return mcmc_util.safe_sum( [dual_log_density_reduce, -proposed_log_density_reduce])
def _compute_log_acceptance_correction(current_momentums, proposed_momentums, independent_chain_ndims, shard_axis_names=None, name=None): """Helper to `kernel` which computes the log acceptance-correction. A sufficient but not necessary condition for the existence of a stationary distribution, `p(x)`, is "detailed balance", i.e.: ```none p(x'|x) p(x) = p(x|x') p(x') ``` In the Metropolis-Hastings algorithm, a state is proposed according to `g(x'|x)` and accepted according to `a(x'|x)`, hence `p(x'|x) = g(x'|x) a(x'|x)`. Inserting this into the detailed balance equation implies: ```none g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x') ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)] (*) ``` One definition of `a(x'|x)` which satisfies (*) is: ```none a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)]) ``` (To see that this satisfies (*), notice that under this definition only at most one `a(x'|x)` and `a(x|x') can be other than one.) We call the bracketed term the "acceptance correction". In the case of UncalibratedHMC, the log acceptance-correction is not the log proposal-ratio. UncalibratedHMC augments the state-space with momentum, z. Assuming a standard Gaussian distribution for momentums, the chain eventually converges to: ```none p([x, z]) propto= target_prob(x) exp(-0.5 z**2) ``` Relating this back to Metropolis-Hastings parlance, for HMC we have: ```none p([x, z]) propto= target_prob(x) exp(-0.5 z**2) g([x, z] | [x', z']) = g([x', z'] | [x, z]) ``` In other words, the MH bracketed term is `1`. However, because we desire to use a general MH framework, we can place the momentum probability ratio inside the metropolis-correction factor thus getting an acceptance probability: ```none target_prob(x') accept_prob(x'|x) = ----------------- [exp(-0.5 z**2) / exp(-0.5 z'**2)] target_prob(x) ``` (Note: we actually need to handle the kinetic energy change at each leapfrog step, but this is the idea.) Args: current_momentums: `Tensor` representing the value(s) of the current momentum(s) of the state (parts). proposed_momentums: `Tensor` representing the value(s) of the proposed momentum(s) of the state (parts). independent_chain_ndims: Scalar `int` `Tensor` representing the number of leftmost `Tensor` dimensions which index independent chains. shard_axis_names: A structure of string names indicating how members of the state are sharded. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'compute_log_acceptance_correction'). Returns: log_acceptance_correction: `Tensor` representing the `log` acceptance-correction. (See docstring for mathematical definition.) """ with tf.name_scope(name or 'compute_log_acceptance_correction'): def compute_sum_sq(v, shard_axes): sum_sq = tf.reduce_sum(v**2., axis=ps.range(independent_chain_ndims, ps.rank(v))) if shard_axes is not None: sum_sq = distribute_lib.psum(sum_sq, shard_axes) return sum_sq shard_axis_names = (shard_axis_names or ([None] * len(current_momentums))) current_kinetic = tf.add_n([ compute_sum_sq(v, axes) for v, axes in zip(current_momentums, shard_axis_names) ]) proposed_kinetic = tf.add_n([ compute_sum_sq(v, axes) for v, axes in zip(proposed_momentums, shard_axis_names) ]) return 0.5 * mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
def _compute_log_acceptance_correction(kinetic_energy_fn, current_momentums, proposed_momentums, name=None): """Helper to `kernel` which computes the log acceptance-correction. A sufficient but not necessary condition for the existence of a stationary distribution, `p(x)`, is "detailed balance", i.e.: ```none p(x'|x) p(x) = p(x|x') p(x') ``` In the Metropolis-Hastings algorithm, a state is proposed according to `g(x'|x)` and accepted according to `a(x'|x)`, hence `p(x'|x) = g(x'|x) a(x'|x)`. Inserting this into the detailed balance equation implies: ```none g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x') ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)] (*) ``` One definition of `a(x'|x)` which satisfies (*) is: ```none a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)]) ``` (To see that this satisfies (*), notice that under this definition only at most one `a(x'|x)` and `a(x|x') can be other than one.) We call the bracketed term the "acceptance correction". In the case of UncalibratedHMC, the log acceptance-correction is not the log proposal-ratio. UncalibratedHMC augments the state-space with momentum, z. Given a probability density of `m(z)` for momentums, the chain eventually converges to: ```none p([x, z]) propto= target_prob(x) m(z) ``` Relating this back to Metropolis-Hastings parlance, for HMC we have: ```none p([x, z]) propto= target_prob(x) m(z) g([x, z] | [x', z']) = g([x', z'] | [x, z]) ``` In other words, the MH bracketed term is `1`. However, because we desire to use a general MH framework, we can place the momentum probability ratio inside the metropolis-correction factor thus getting an acceptance probability: ```none target_prob(x') accept_prob(x'|x) = ----------------- [m(z') / m(z)] target_prob(x) ``` (Note: we actually need to handle the kinetic energy change at each leapfrog step, but this is the idea.) For consistency, we compute this correction in log space, using the kinetic energy function, `K(z)`, which is the negative log probability of the momentum distribution. So the log acceptance probability is ```none log(correction) = log(m(z')) - log(m(z)) = K(z) - K(z') ``` Note that this is equality, since the normalization constants on `m` cancel out. Args: kinetic_energy_fn: Python callable that can evaluate the kinetic energy of the given momentum. This is typically the negative log probability of the distribution over the momentum. current_momentums: (List of) `Tensor`s representing the value(s) of the current momentum(s) of the state (parts). proposed_momentums: (List of) `Tensor`s representing the value(s) of the proposed momentum(s) of the state (parts). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'compute_log_acceptance_correction'). Returns: log_acceptance_correction: `Tensor` representing the `log` acceptance-correction. (See docstring for mathematical definition.) """ with tf.name_scope(name or 'compute_log_acceptance_correction'): current_kinetic = kinetic_energy_fn(current_momentums) proposed_kinetic = kinetic_energy_fn(proposed_momentums) return mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ is_seeded = seed is not None seed = samplers.sanitize_seed(seed) # Retain for diagnostics. proposal_seed, acceptance_seed = samplers.split_seed(seed) with tf.name_scope(mcmc_util.make_name(self.name, 'mh', 'one_step')): # Take one inner step. inner_kwargs = dict(seed=proposal_seed) if is_seeded else {} [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results, **inner_kwargs) if mcmc_util.is_list_like(current_state): proposed_state = tf.nest.pack_sequence_as( current_state, proposed_state) if (not has_target_log_prob(proposed_results) or not has_target_log_prob( previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [ proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob ] try: if (not mcmc_util.is_list_like( proposed_results.log_acceptance_correction) or proposed_results.log_acceptance_correction): to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn( 'Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum( to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio log_uniform = tf.math.log( samplers.uniform(shape=prefer_static.shape( proposed_results.target_log_prob), dtype=dtype_util.base_dtype( proposed_results.target_log_prob.dtype), seed=acceptance_seed)) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose(is_accepted, proposed_state, current_state, name='choose_next_state') kernel_results = MetropolisHastingsKernelResults( accepted_results=mcmc_util.choose( is_accepted, # We strip seeds when populating `accepted_results` because unlike # other kernel result fields, seeds are not a per-chain value. # Thus it is impossible to choose between a previously accepted # seed value and a proposed seed, since said choice would need to # be made on a per-chain basis. mcmc_util.strip_seeds(proposed_results), previous_kernel_results.accepted_results, name='choose_inner_results'), is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, extra=[], seed=seed, ) return next_state, kernel_results
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ with tf1.name_scope(name=mcmc_util.make_name(self.name, 'mh', 'one_step'), values=[current_state, previous_kernel_results]): # Take one inner step. [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results) if (not has_target_log_prob(proposed_results) or not has_target_log_prob( previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [ proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob ] try: if (not mcmc_util.is_list_like( proposed_results.log_acceptance_correction) or proposed_results.log_acceptance_correction): to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn( 'Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum( to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio log_uniform = tf.math.log( tf.random.uniform( shape=tf.shape(input=proposed_results.target_log_prob), dtype=proposed_results.target_log_prob.dtype.base_dtype, seed=self._seed_stream())) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose(is_accepted, proposed_state, current_state, name='choose_next_state') kernel_results = MetropolisHastingsKernelResults( accepted_results=mcmc_util.choose( is_accepted, proposed_results, previous_kernel_results.accepted_results, name='choose_inner_results'), is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, extra=[], ) return next_state, kernel_results