def body(i, exchanged_states): """Body of while loop for exchanging states.""" # Propose exchange between replicas indexed by m and n. m, n = tf.unstack(exchange_proposed[i]) # Construct log_accept_ratio: -temp_diff * target_log_prob_diff. # Note target_log_prob_diff = -EnergyDiff (common definition is in terms # of energy). temp_diff = self.inverse_temperatures[ m] - self.inverse_temperatures[n] # Difference of target log probs may be +- Inf or NaN. We want the # product of this with the temperature difference to have "alt value" of # -Inf. log_accept_ratio = mcmc_util.safe_sum([ -temp_diff * target_log_probs[m], temp_diff * target_log_probs[n] ]) is_exchange_accepted = log_uniforms[i] < log_accept_ratio for k in range(num_state_parts): new_m, new_n = _swap(is_exchange_accepted, old_states[k].read(m), old_states[k].read(n)) exchanged_states[k] = exchanged_states[k].write(m, new_m) exchanged_states[k] = exchanged_states[k].write(n, new_n) return i + 1, exchanged_states
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ with tf.name_scope( name=mcmc_util.make_name(self.name, 'mh', 'one_step'), values=[current_state, previous_kernel_results]): # Take one inner step. [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results) if (not has_target_log_prob(proposed_results) or not has_target_log_prob(previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob] try: if (not mcmc_util.is_list_like( proposed_results.log_acceptance_correction) or proposed_results.log_acceptance_correction): to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn('Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum( to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio log_uniform = tf.log(tf.random_uniform( shape=tf.shape(proposed_results.target_log_prob), dtype=proposed_results.target_log_prob.dtype.base_dtype, seed=self._seed_stream())) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose( is_accepted, proposed_state, current_state, name='choose_next_state') kernel_results = MetropolisHastingsKernelResults( accepted_results=mcmc_util.choose( is_accepted, proposed_results, previous_kernel_results.accepted_results, name='choose_inner_results'), is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, extra=[], ) return next_state, kernel_results
def _compute_log_acceptance_correction(current_momentums, proposed_momentums, independent_chain_ndims, name=None): """Helper to `kernel` which computes the log acceptance-correction. A sufficient but not necessary condition for the existence of a stationary distribution, `p(x)`, is "detailed balance", i.e.: ```none p(x'|x) p(x) = p(x|x') p(x') ``` In the Metropolis-Hastings algorithm, a state is proposed according to `g(x'|x)` and accepted according to `a(x'|x)`, hence `p(x'|x) = g(x'|x) a(x'|x)`. Inserting this into the detailed balance equation implies: ```none g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x') ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)] (*) ``` One definition of `a(x'|x)` which satisfies (*) is: ```none a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)]) ``` (To see that this satisfies (*), notice that under this definition only at most one `a(x'|x)` and `a(x|x') can be other than one.) We call the bracketed term the "acceptance correction". In the case of UncalibratedHMC, the log acceptance-correction is not the log proposal-ratio. UncalibratedHMC augments the state-space with momentum, z. Assuming a standard Gaussian distribution for momentums, the chain eventually converges to: ```none p([x, z]) propto= target_prob(x) exp(-0.5 z**2) ``` Relating this back to Metropolis-Hastings parlance, for HMC we have: ```none p([x, z]) propto= target_prob(x) exp(-0.5 z**2) g([x, z] | [x', z']) = g([x', z'] | [x, z]) ``` In other words, the MH bracketed term is `1`. However, because we desire to use a general MH framework, we can place the momentum probability ratio inside the metropolis-correction factor thus getting an acceptance probability: ```none target_prob(x') accept_prob(x'|x) = ----------------- [exp(-0.5 z**2) / exp(-0.5 z'**2)] target_prob(x) ``` (Note: we actually need to handle the kinetic energy change at each leapfrog step, but this is the idea.) Args: current_momentums: `Tensor` representing the value(s) of the current momentum(s) of the state (parts). proposed_momentums: `Tensor` representing the value(s) of the proposed momentum(s) of the state (parts). independent_chain_ndims: Scalar `int` `Tensor` representing the number of leftmost `Tensor` dimensions which index independent chains. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'compute_log_acceptance_correction'). Returns: log_acceptance_correction: `Tensor` representing the `log` acceptance-correction. (See docstring for mathematical definition.) """ with tf.name_scope( name, 'compute_log_acceptance_correction', [independent_chain_ndims, current_momentums, proposed_momentums]): log_current_kinetic, log_proposed_kinetic = [], [] for current_momentum, proposed_momentum in zip(current_momentums, proposed_momentums): axis = tf.range(independent_chain_ndims, tf.rank(current_momentum)) log_current_kinetic.append(_log_sum_sq(current_momentum, axis)) log_proposed_kinetic.append(_log_sum_sq(proposed_momentum, axis)) current_kinetic = 0.5 * tf.exp( tf.reduce_logsumexp(tf.stack(log_current_kinetic, axis=-1), axis=-1)) proposed_kinetic = 0.5 * tf.exp( tf.reduce_logsumexp(tf.stack(log_proposed_kinetic, axis=-1), axis=-1)) return mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
def _compute_log_acceptance_correction(current_state_parts, proposed_state_parts, current_volatility_parts, proposed_volatility_parts, current_drift_parts, proposed_drift_parts, step_size_parts, independent_chain_ndims, name=None): r"""Helper to `kernel` which computes the log acceptance-correction. Computes `log_acceptance_correction` as described in `MetropolisHastings` class. The proposal density is normal. More specifically, ```none q(proposed_state | current_state) \sim N(current_state + current_drift, step_size * current_volatility**2) q(current_state | proposed_state) \sim N(proposed_state + proposed_drift, step_size * proposed_volatility**2) ``` The `log_acceptance_correction` is then ```none log_acceptance_correctio = q(current_state | proposed_state) - q(proposed_state | current_state) ``` Args: current_state_parts: Python `list` of `Tensor`s representing the value(s) of the current state of the chain. proposed_state_parts: Python `list` of `Tensor`s representing the value(s) of the proposed state of the chain. Must broadcast with the shape of `current_state_parts`. current_volatility_parts: Python `list` of `Tensor`s representing the value of `volatility_fn(*current_volatility_parts)`. Must broadcast with the shape of `current_state_parts`. proposed_volatility_parts: Python `list` of `Tensor`s representing the value of `volatility_fn(*proposed_volatility_parts)`. Must broadcast with the shape of `current_state_parts` current_drift_parts: Python `list` of `Tensor`s representing value of the drift `_get_drift(*current_state_parts, ..)`. Must broadcast with the shape of `current_state_parts`. proposed_drift_parts: Python `list` of `Tensor`s representing value of the drift `_get_drift(*proposed_drift_parts, ..)`. Must broadcast with the shape of `current_state_parts`. step_size_parts: Python `list` of `Tensor`s representing the step size for Euler-Maruyama method. Must broadcast with the shape of `current_state_parts`. independent_chain_ndims: Scalar `int` `Tensor` representing the number of leftmost `Tensor` dimensions which index independent chains. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'compute_log_acceptance_correction'). Returns: log_acceptance_correction: `Tensor` representing the `log` acceptance-correction. (See docstring for mathematical definition.) """ with tf.name_scope(name, 'compute_log_acceptance_correction', [ current_state_parts, proposed_state_parts, current_volatility_parts, proposed_volatility_parts, current_drift_parts, proposed_drift_parts, step_size_parts, independent_chain_ndims ]): proposed_log_density_parts = [] dual_log_density_parts = [] for [ current_state, proposed_state, current_volatility, proposed_volatility, current_drift, proposed_drift, step_size, ] in zip( current_state_parts, proposed_state_parts, current_volatility_parts, proposed_volatility_parts, current_drift_parts, proposed_drift_parts, step_size_parts, ): axis = tf.range(independent_chain_ndims, tf.rank(current_state)) state_diff = proposed_state - current_state current_volatility *= tf.sqrt(step_size) proposed_energy = (state_diff - current_drift) / current_volatility proposed_volatility *= tf.sqrt(step_size) # Compute part of `q(proposed_state | current_state)` proposed_energy = (tf.reduce_sum(mcmc_util.safe_sum( [tf.log(current_volatility), 0.5 * (proposed_energy**2)]), axis=axis)) proposed_log_density_parts.append(-proposed_energy) # Compute part of `q(current_state | proposed_state)` dual_energy = (state_diff + proposed_drift) / proposed_volatility dual_energy = (tf.reduce_sum(mcmc_util.safe_sum( [tf.log(proposed_volatility), 0.5 * (dual_energy**2)]), axis=axis)) dual_log_density_parts.append(-dual_energy) # Compute `q(proposed_state | current_state)` proposed_log_density_reduce = tf.reduce_sum(tf.stack( proposed_log_density_parts, axis=-1), axis=-1) # Compute `q(current_state | proposed_state)` dual_log_density_reduce = tf.reduce_sum(tf.stack( dual_log_density_parts, axis=-1), axis=-1) return mcmc_util.safe_sum( [dual_log_density_reduce, -proposed_log_density_reduce])
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ # Take one inner step. [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results) if (not has_target_log_prob(proposed_results) or not has_target_log_prob( previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [ proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob ] try: to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn( 'Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum(to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio # Note: # - We mutate seed state so subsequent calls are not correlated. # - We mutate seed BEFORE using it just in case users supplied the # same seed to the inner kernel. self._seed = distributions_util.gen_new_seed( self.seed, salt='metropolis_hastings_one_step') log_uniform = tf.log( tf.random_uniform( shape=tf.shape(proposed_results.target_log_prob), dtype=proposed_results.target_log_prob.dtype.base_dtype, seed=self.seed)) is_accepted = log_uniform < log_accept_ratio independent_chain_ndims = distributions_util.prefer_static_rank( proposed_results.target_log_prob) next_state = mcmc_util.choose(is_accepted, proposed_state, current_state, independent_chain_ndims) accepted_results = type(proposed_results)( **dict([(fn, mcmc_util.choose( is_accepted, getattr(proposed_results, fn), getattr(previous_kernel_results.accepted_results, fn), independent_chain_ndims)) for fn in proposed_results._fields])) return [ next_state, MetropolisHastingsKernelResults( accepted_results=accepted_results, is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, ) ]
def _compute_log_acceptance_correction(current_momentums, proposed_momentums, independent_chain_ndims, name=None): """Helper to `kernel` which computes the log acceptance-correction. A sufficient but not necessary condition for the existence of a stationary distribution, `p(x)`, is "detailed balance", i.e.: ```none p(x'|x) p(x) = p(x|x') p(x') ``` In the Metropolis-Hastings algorithm, a state is proposed according to `g(x'|x)` and accepted according to `a(x'|x)`, hence `p(x'|x) = g(x'|x) a(x'|x)`. Inserting this into the detailed balance equation implies: ```none g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x') ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)] (*) ``` One definition of `a(x'|x)` which satisfies (*) is: ```none a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)]) ``` (To see that this satisfies (*), notice that under this definition only at most one `a(x'|x)` and `a(x|x') can be other than one.) We call the bracketed term the "acceptance correction". In the case of UncalibratedHMC, the log acceptance-correction is not the log proposal-ratio. UncalibratedHMC augments the state-space with momentum, z. Assuming a standard Gaussian distribution for momentums, the chain eventually converges to: ```none p([x, z]) propto= target_prob(x) exp(-0.5 z**2) ``` Relating this back to Metropolis-Hastings parlance, for HMC we have: ```none p([x, z]) propto= target_prob(x) exp(-0.5 z**2) g([x, z] | [x', z']) = g([x', z'] | [x, z]) ``` In other words, the MH bracketed term is `1`. However, because we desire to use a general MH framework, we can place the momentum probability ratio inside the metropolis-correction factor thus getting an acceptance probability: ```none target_prob(x') accept_prob(x'|x) = ----------------- [exp(-0.5 z**2) / exp(-0.5 z'**2)] target_prob(x) ``` (Note: we actually need to handle the kinetic energy change at each leapfrog step, but this is the idea.) Args: current_momentums: `Tensor` representing the value(s) of the current momentum(s) of the state (parts). proposed_momentums: `Tensor` representing the value(s) of the proposed momentum(s) of the state (parts). independent_chain_ndims: Scalar `int` `Tensor` representing the number of leftmost `Tensor` dimensions which index independent chains. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'compute_log_acceptance_correction'). Returns: log_acceptance_correction: `Tensor` representing the `log` acceptance-correction. (See docstring for mathematical definition.) """ with tf.name_scope( name, 'compute_log_acceptance_correction', [independent_chain_ndims, current_momentums, proposed_momentums]): log_current_kinetic, log_proposed_kinetic = [], [] for current_momentum, proposed_momentum in zip( current_momentums, proposed_momentums): axis = tf.range(independent_chain_ndims, tf.rank(current_momentum)) log_current_kinetic.append(_log_sum_sq(current_momentum, axis)) log_proposed_kinetic.append(_log_sum_sq(proposed_momentum, axis)) current_kinetic = 0.5 * tf.exp( tf.reduce_logsumexp(tf.stack(log_current_kinetic, axis=-1), axis=-1)) proposed_kinetic = 0.5 * tf.exp( tf.reduce_logsumexp(tf.stack(log_proposed_kinetic, axis=-1), axis=-1)) return mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])