def resample(particles, log_weights, resample_fn, seed=None): """Resamples the current particles according to provided weights. Args: particles: Nested structure of `Tensor`s each of shape `[num_particles, b1, ..., bN, ...]`, where `b1, ..., bN` are optional batch dimensions. log_weights: float `Tensor` of shape `[num_particles, b1, ..., bN]`, where `b1, ..., bN` are optional batch dimensions. resample_fn: choose the function used for resampling. Use 'resample_independent' for independent resamples. Use 'resample_stratified' for stratified resampling. Use 'resample_systematic' for systematic resampling. seed: Python `int` random seed. Returns: resampled_particles: Nested structure of `Tensor`s, matching `particles`. resample_indices: int `Tensor` of shape `[num_particles, b1, ..., bN]`. """ with tf.name_scope('resample'): num_particles = ps.size0(log_weights) log_probs = tf.math.log_softmax(log_weights, axis=0) resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed) resampled_particles = tf.nest.map_structure( lambda x: mcmc_util.index_remapping_gather( # pylint: disable=g-long-lambda x, resampled_indices, axis=0), particles) return resampled_particles, resampled_indices
def resample_particle_and_info(particles, log_weights, seed=None): """Resamples the current particles according to provided weights. Args: particles: Nested structure of `Tensor`s each of shape `[num_particles, b1, ..., bN, ...]`, where `b1, ..., bN` are optional batch dimensions. log_weights: float `Tensor` of shape `[num_particles, b1, ..., bN]`, where `b1, ..., bN` are optional batch dimensions. seed: Python `int` random seed. Returns: resampled_particles: Nested structure of `Tensor`s, matching `particles`. resample_indices: int `Tensor` of shape `[num_particles, b1, ..., bN]`. """ with tf.name_scope('resample'): num_particles = ps.size0(log_weights) log_probs = tf.math.log_softmax(log_weights, axis=0) # TODO(junpenglao): use an `axis` specifiable categorical sampler to avoid # transpose below. resample_indices = categorical.Categorical( logits=dist_util.move_dimension(log_probs, 0, -1)).sample( num_particles, seed=seed) resampled_particles = tf.nest.map_structure( lambda x: mcmc_util.index_remapping_gather(x, resample_indices), particles) return resampled_particles, resample_indices
def reconstruct_trajectories(particles, parent_indices, name=None): """Reconstructs the ancestor trajectory that generated each final particle.""" with tf.name_scope(name or 'reconstruct_trajectories'): # Walk backwards to compute the ancestor of each final particle at time t. final_indices = smc_kernel._dummy_indices_like(parent_indices[-1]) # pylint: disable=protected-access ancestor_indices = tf.scan( fn=lambda ancestor, parent: mcmc_util.index_remapping_gather( # pylint: disable=g-long-lambda parent, ancestor, axis=0), elems=parent_indices[1:], initializer=final_indices, reverse=True) ancestor_indices = tf.concat([ancestor_indices, [final_indices]], axis=0) return tf.nest.map_structure( lambda part: mcmc_util.index_remapping_gather( # pylint: disable=g-long-lambda part, ancestor_indices, axis=1, indices_axis=1), particles)
def test_rank_1_same_as_gather(self): params = [10, 11, 12, 13] indices = [3, 2, 0] expected = [13, 12, 10] result = util.index_remapping_gather(params, indices) self.assertAllEqual(np.asarray(indices).shape, result.shape) self.assertAllEqual(expected, self.evaluate(result))
def test_rank_2_and_axis_0(self): params = [[95, 46, 17], [46, 29, 55]] indices = [[0, 0, 1], [1, 0, 1]] expected = [[95, 46, 55], [46, 46, 55]] result = util.index_remapping_gather(params, indices) self.assertAllEqual(np.asarray(params).shape, result.shape) self.assertAllEqual(expected, self.evaluate(result))
def _resample(self, particles, log_weights, seed=None): """Chooses one out of `importance_sample_size` many weighted proposals.""" sampled_indices = categorical.Categorical( logits=distribution_util.move_dimension( log_weights, 0, -1)).sample(sample_shape=[1], seed=seed) return tf.nest.map_structure( lambda x: ( # pylint: disable=g-long-lambda mcmc_util.index_remapping_gather(x, sampled_indices, axis=0)[ 0, ...]), particles)
def resample(particles, log_weights, resample_fn, target_log_weights=None, seed=None): """Resamples the current particles according to provided weights. Args: particles: Nested structure of `Tensor`s each of shape `[num_particles, b1, ..., bN, ...]`, where `b1, ..., bN` are optional batch dimensions. log_weights: float `Tensor` of shape `[num_particles, b1, ..., bN]`, where `b1, ..., bN` are optional batch dimensions. resample_fn: choose the function used for resampling. Use 'resample_independent' for independent resamples. Use 'resample_stratified' for stratified resampling. Use 'resample_systematic' for systematic resampling. target_log_weights: optional float `Tensor` of the same shape and dtype as `log_weights`, specifying the target measure on `particles` if this is different from that implied by normalizing `log_weights`. The returned `log_weights_after_resampling` will represent this measure. If `None`, the target measure is implicitly taken to be the normalized log weights (`log_weights - tf.reduce_logsumexp(log_weights, axis=0)`). Default value: `None`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: resampled_particles: Nested structure of `Tensor`s, matching `particles`. resample_indices: int `Tensor` of shape `[num_particles, b1, ..., bN]`. log_weights_after_resampling: float `Tensor` of same shape and dtype as `log_weights`, such that weighted sums of the resampled particles are equal (in expectation over the resampling step) to weighted sums of the original particles: `E [ exp(log_weights_after_resampling) * some_fn(resampled_particles) ] == exp(target_log_weights) * some_fn(particles)`. If no `target_log_weights` was specified, the log weights after resampling are uniformly equal to `-log(num_particles)`. """ with tf.name_scope('resample'): num_particles = ps.size0(log_weights) log_num_particles = tf.math.log(tf.cast(num_particles, log_weights.dtype)) # Normalize the weights and sample the ancestral indices. log_probs = tf.math.log_softmax(log_weights, axis=0) resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed) gather_ancestors = lambda x: ( # pylint: disable=g-long-lambda mcmc_util.index_remapping_gather(x, resampled_indices, axis=0)) resampled_particles = tf.nest.map_structure(gather_ancestors, particles) if target_log_weights is None: log_weights_after_resampling = tf.fill(ps.shape(log_weights), -log_num_particles) else: importance_weights = target_log_weights - log_probs - log_num_particles log_weights_after_resampling = tf.nest.map_structure( gather_ancestors, importance_weights) return resampled_particles, resampled_indices, log_weights_after_resampling
def test_params_rank3_indices_rank2_axis_0(self): axis = 0 params = np.random.randint(10, 100, size=(4, 5, 2)) indices = np.random.randint(0, params.shape[axis], size=(6, 5)) result = util.index_remapping_gather(params, indices) self.assertAllEqual(indices.shape[:axis + 1] + params.shape[axis + 1:], result.shape) result_ = self.evaluate(result) for i in range(indices.shape[0]): for j in range(params.shape[1]): for k in range(params.shape[2]): self.assertEqual(params[indices[i, j], j, k], result_[i, j, k])
def _maybe_embed_swaps_validation(swaps, null_swaps, validate_args): """Return `swaps`, possibly with embedded "once only" assertion.""" if not validate_args: return swaps assertions = [ assert_util.assert_equal( null_swaps, mcmc_util.index_remapping_gather(swaps, swaps), message=('Proposed replica swaps must be consist of "once only ' 'swaps," i.e., be a self-inverse permutation, ' '`range(swaps.shape[0]) == gather(swaps, swaps).')), ] with tf.control_dependencies(assertions): return tf.identity(swaps)
def test_params_rank3_indices_rank1_axis_1(self): axis = 1 params = np.random.randint(10, 100, size=[4, 5, 2]) indices = np.random.randint(0, params.shape[axis], size=[6]) result = util.index_remapping_gather(params, indices, axis=axis) self.assertAllEqual( params.shape[:axis] + indices.shape[:1] + params.shape[axis + 1:], result.shape) result_ = self.evaluate(result) for i in range(params.shape[0]): for j in range(indices.shape[0]): for k in range(params.shape[2]): self.assertEqual(params[i, indices[j], k], result_[i, j, k])
def test_params_rank5_indices_rank3_axis_2_iaxis_1(self): axis = 2 indices_axis = 1 params = np.random.randint(10, 100, size=[4, 5, 2, 3, 4]) indices = np.random.randint(0, params.shape[axis], size=[5, 6, 3]) result = util.index_remapping_gather( params, indices, axis=axis, indices_axis=indices_axis) self.assertAllEqual( params.shape[:axis] + indices.shape[indices_axis:indices_axis + 1] + params.shape[axis + 1:], result.shape) result_ = self.evaluate(result) for i in range(params.shape[0]): for j in range(params.shape[1]): for k in range(indices.shape[1]): for l in range(params.shape[3]): for m in range(params.shape[4]): self.assertEqual(params[i, j, indices[j, k, l], l, m], result_[i, j, k, l, m])
def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x)
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( target_log_prob_fn=self.target_log_prob_fn, inverse_temperatures=inverse_temperatures, untempered_log_prob_fn=self.untempered_log_prob_fn, tempered_log_prob_fn=self.tempered_log_prob_fn, ) # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise raise TypeError( '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` ' 'argument. `TransitionKernel` instances now receive seeds via ' '`one_step`.') seed = samplers.sanitize_seed(seed) # Retain for diagnostics. inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3) # Step the inner TransitionKernel. [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results, seed=inner_seed) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob) num_replica = ps.size0(inverse_temperatures) inverse_temperatures = bu.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. try: swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed, step_count=previous_kernel_results.step_count), dtype=tf.int32) except TypeError as e: if 'step_count' not in str(e): raise warnings.warn( 'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept ' 'the `step_count` argument. Falling back to omitting the ' 'argument. This fallback will be removed after 24-Oct-2020.' ) swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed), dtype=tf.int32) null_swaps = bu.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs for use in the swap acceptance ratio. if self.tempered_log_prob_fn is None: # Efficient way of re-evaluating target_log_prob_fn on the # pre_swap_replica_states. untempered_negative_energy_ignoring_ulp = ( # Since untempered_log_prob_fn is None, we may assume # inverse_temperatures > 0 (else the target is improper). pre_swap_replica_target_log_prob / inverse_temperatures) else: # The untempered_log_prob_fn does not factor into the acceptance ratio. # Proof: Suppose the tempered target is # p_k(x) = f(x)^{beta_k} g(x), # So f(x) is tempered, and g(x) is not. Then, the acceptance ratio for # a 1 <--> 2 swap is... # (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2)) # which depends only on f(x), since terms involving g(x) cancel. untempered_negative_energy_ignoring_ulp = self.tempered_log_prob_fn( *pre_swap_replica_states) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. # Note: The untempered_log_prob_fn (if provided) is not included in # untempered_pre_swap_replica_target_log_prob, and hence does not factor # into energy_diff. Why? Because, it cancels out in the acceptance ratio. energy_diff = (untempered_negative_energy_ignoring_ulp - mcmc_util.index_remapping_gather( untempered_negative_energy_ignoring_ulp, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * bu.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=logu_seed)) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = bu.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, bu.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) if self._state_includes_replicas: post_swap_states = post_swap_replica_states else: post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _set_swapped_fields_to_nan( _swap_log_prob_and_maybe_grads(pre_swap_replica_results, post_swap_replica_states, inner_kernel)) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, step_count=previous_kernel_results.step_count + 1, seed=seed, potential_energy=-untempered_negative_energy_ignoring_ulp, ) return states, post_swap_kernel_results
def infer_trajectories(observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=ess_below_threshold, rejuvenation_kernel_fn=None, num_transitions_per_observation=1, seed=None, name=None): # pylint: disable=g-doc-args """Use particle filtering to sample from the posterior over trajectories. ${particle_filter_arg_str} seed: Python `int` seed for random ops. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'infer_trajectories'`). Returns: trajectories: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing unbiased samples from the posterior distribution `p(latent_states | observations)`. incremental_log_marginal_likelihoods: float `Tensor` of shape `[num_observation_steps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each timestep `t`. Note that (by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. #### Examples **Tracking unknown position and velocity**: Let's consider tracking an object moving in a one-dimensional space. We'll define a dynamical system by specifying an `initial_state_prior`, a `transition_fn`, and `observation_fn`. The structure of the latent state space is determined by the prior distribution. Here, we'll define a state space that includes the object's current position and velocity: ```python initial_state_prior = tfd.JointDistributionNamed({ 'position': tfd.Normal(loc=0., scale=1.), 'velocity': tfd.Normal(loc=0., scale=0.1)}) ``` The `transition_fn` specifies the evolution of the system. It should return a distribution over latent states of the same structure as the prior. Here, we'll assume that the position evolves according to the velocity, with a small random drift, and the velocity also changes slowly, following a random drift: ```python def transition_fn(_, previous_state): return tfd.JointDistributionNamed({ 'position': tfd.Normal( loc=previous_state['position'] + previous_state['velocity'], scale=0.1), 'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)}) ``` The `observation_fn` specifies the process by which the system is observed at each time step. Let's suppose we observe only a noisy version of the = current position. ```python def observation_fn(_, state): return tfd.Normal(loc=state['position'], scale=0.1) ``` Now let's track our object. Suppose we've been given observations corresponding to an initial position of `0.4` and constant velocity of `0.01`: ```python # Generate simulated observations. observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01), scale=0.1).sample() # Run particle filtering to sample plausible trajectories. (trajectories, # {'position': [40, 1000], 'velocity': [40, 1000]} lps) = tfp.experimental.mcmc.infer_trajectories( observations=observed_positions, initial_state_prior=initial_state_prior, transition_fn=transition_fn, observation_fn=observation_fn, num_particles=1000) ``` For all `i`, `trajectories['position'][:, i]` is a sample from the posterior over position sequences, given the observations: `p(state[0:T] | observations[0:T])`. Often, the sampled trajectories will be highly redundant in their earlier timesteps, because most of the initial particles have been discarded through resampling (this problem is known as 'particle degeneracy'; see section 3.5 of [Doucet and Johansen][1]). In such cases it may be useful to also consider the series of *filtering* distributions `p(state[t] | observations[:t])`, in which each latent state is inferred conditioned only on observations up to that point in time; these may be computed using `tfp.mcmc.experimental.particle_filter`. #### References [1] Arnaud Doucet and Adam M. Johansen. A tutorial on particle filtering and smoothing: Fifteen years later. _Handbook of nonlinear filtering_, 12(656-704), 2009. https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf """ with tf.name_scope(name or 'infer_trajectories') as name: seed = SeedStream(seed, 'infer_trajectories') (particles, log_weights, parent_indices, incremental_log_marginal_likelihoods) = particle_filter( observations=observations, initial_state_prior=initial_state_prior, transition_fn=transition_fn, observation_fn=observation_fn, num_particles=num_particles, initial_state_proposal=initial_state_proposal, proposal_fn=proposal_fn, resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, rejuvenation_kernel_fn=rejuvenation_kernel_fn, num_transitions_per_observation=num_transitions_per_observation, trace_fn=_default_trace_fn, seed=seed, name=name) weighted_trajectories = reconstruct_trajectories( particles, parent_indices) # Resample all steps of the trajectories using the final weights. resample_indices = resample_fn(log_probs=log_weights[-1], event_size=num_particles, sample_shape=(), seed=seed) trajectories = tf.nest.map_structure( lambda x: mcmc_util.index_remapping_gather( x, # pylint: disable=g-long-lambda resample_indices, axis=1), weighted_trajectories) return trajectories, incremental_log_marginal_likelihoods
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( self.target_log_prob_fn, inverse_temperatures) # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.make_kernel_fn`, and no seed # expected by `kernel.one_step`. # In other words: # - We try `make_kernel_fn` without a seed first; this is the future. The # kernel will receive a seed later, as part of `one_step`. # - If the user code doesn't like that (Python complains about a missing # required argument), we warn and fall back to the previous behavior. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is ' 'deprecated. `TransitionKernel` instances now receive seeds via ' '`one_step`.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, self._seed_stream()) # Now that we've constructed the TransitionKernel instance: # - If we were given a seed, we sanitize it to stateless and pass along # to `kernel.one_step`. If it doesn't like that, we crash and propagate # the error. Rationale: The contract is stateless sampling given # seed, and doing otherwise would not meet it. # - If not given a seed, we don't pass one along. This avoids breaking # underlying kernels lacking a `seed` arg on `one_step`. # TODO(b/159636942): Clean up after 2020-09-20. if seed is not None: seed = samplers.sanitize_seed(seed) inner_seed, swap_seed, logu_seed = samplers.split_seed( seed, n=3, salt='remc_one_step') inner_kwargs = dict(seed=inner_seed) else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) inner_kwargs = {} swap_seed, logu_seed = samplers.split_seed(self._seed_stream()) [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results, **inner_kwargs) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob) num_replica = ps.size0(inverse_temperatures) inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. try: swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed, step_count=previous_kernel_results.step_count), dtype=tf.int32) except TypeError as e: if 'step_count' not in str(e): raise warnings.warn( 'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept ' 'the `step_count` argument. Falling back to omitting the ' 'argument. This fallback will be removed after 24-Oct-2020.' ) swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed), dtype=tf.int32) null_swaps = mcmc_util.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs. E.g., for replica k, at point x_k, this is # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k. untempered_pre_swap_replica_target_log_prob = ( pre_swap_replica_target_log_prob / inverse_temperatures) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. energy_diff = (untempered_pre_swap_replica_target_log_prob - mcmc_util.index_remapping_gather( untempered_pre_swap_replica_target_log_prob, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * mcmc_util.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce Log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=logu_seed)) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = mcmc_util.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, mcmc_util.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) if self._state_includes_replicas: post_swap_states = post_swap_replica_states else: post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _make_post_swap_replica_results( pre_swap_replica_results, inverse_temperatures, swapped_inverse_temperatures, is_swap_accepted_mask, _swap_tensor) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, step_count=previous_kernel_results.step_count + 1, seed=samplers.zeros_seed() if seed is None else seed, ) return states, post_swap_kernel_results
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable _make_replica_target_log_prob_fn(self.target_log_prob_fn, inverse_temperatures), self._seed_stream()) [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = prefer_static.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = prefer_static.rank( pre_swap_replica_target_log_prob) num_replica = prefer_static.size0(inverse_temperatures) inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=self._seed_stream()), dtype=tf.int32) null_swaps = mcmc_util.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs. E.g., for replica k, at point x_k, this is # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k. untempered_pre_swap_replica_target_log_prob = ( pre_swap_replica_target_log_prob / inverse_temperatures) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. energy_diff = (untempered_pre_swap_replica_target_log_prob - mcmc_util.index_remapping_gather( untempered_pre_swap_replica_target_log_prob, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * mcmc_util.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce Log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( tf.random.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=self._seed_stream())) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = mcmc_util.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, mcmc_util.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _make_post_swap_replica_results( pre_swap_replica_results, inverse_temperatures, swapped_inverse_temperatures, is_swap_accepted_mask, _swap_tensor) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, ) return states, post_swap_kernel_results