def adjacent_swaps(num_replica, batch_shape=(), step_count=None, seed=None): """Make random shuffle using only one time swaps.""" del step_count # Unused for this function. with tf.name_scope(name or 'adjacent_swaps'): parity_seed, proposal_seed = samplers.split_seed(seed) # u selects parity. E.g., # u==False ==> [1, 0, 3, 2, 4] even parity swaps # u==True ==> [0, 2, 1, 4, 3] odd parity swaps # If there are only 2 replicas, then the "True" swaps are null # swaps...which would contradict the user provided `prob_swap`. # So special case num_replica==2, forcing u==False in this case. u_shape = ps.concat( (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)), axis=0) u = samplers.uniform(u_shape, seed=parity_seed) < 0.5 u = tf.where(num_replica > 2, u, False) x = mcmc_util.left_justified_expand_dims_to(ps.range( num_replica, dtype=tf.int64), rank=ps.size(u_shape)) y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1, x - 1) y = tf.clip_by_value(y, 0, num_replica - 1) # TODO(b/142689785): Consider using tf.cond and returning an empty list # then in REMC consider using a tf.cond for short-circuiting. return tf.where( samplers.uniform(batch_shape, seed=proposal_seed) < prob_swap, y, x)
def adjacent_swaps(num_replica, batch_shape=(), seed=None): """Make random shuffle using only one time swaps.""" with tf.name_scope(name or 'adjacent_swaps'): seed = SeedStream(seed, salt='random_adjacent_shuffle') # u selects parity. E.g., # u==True ==> [0, 2, 1, 4, 3] like swaps # u==False ==> [1, 0, 3, 2, 4] like swaps # If there are only 2 replicas, then the "True" swaps are null # swaps...which would contradict the user provided `prob_swap`. # So special case num_replica==2, forcing u==False in this case. u_shape = prefer_static.concat( (tf.ones(1, dtype=tf.int32), tf.cast(batch_shape, tf.int32)), axis=0) u = tf.random.uniform(u_shape, seed=seed()) < 0.5 u = tf.where(num_replica > 2, u, False) x = mcmc_util.left_justified_expand_dims_to( tf.range(num_replica, dtype=tf.int64), rank=prefer_static.size(u_shape)) y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1, x - 1) y = tf.clip_by_value(y, 0, num_replica - 1) # TODO(b/142689785): Consider using tf.cond and returning an empty list # then in REMC consider using a tf.cond for short-circuiting. return tf.where( tf.random.uniform(batch_shape, seed=seed()) < prob_swap, y, x)
def even_odd_swaps(num_replica, batch_shape=(), step_count=None, seed=None): """Make deterministic even_odd one time swaps.""" if step_count is None: raise ValueError('`step_count` must be supplied. Found `None`.') del seed # Unused for this function. with tf.name_scope(name or 'even_odd_swaps'): # Period is 1 / frequency, and we want period = Inf if frequency = 0. # safe_swap_period is the correct swap period in case swap_frequency > 0. # If swap_frequency == 0, safe_swap_period is set to 1 (to avoid integer # div by zero below). We will hard-set this case to "null swap." swap_freq = tf.convert_to_tensor(swap_frequency, name='swap_frequency') safe_swap_period = tf.cast( tf.where(swap_freq > 0, tf.math.ceil(tf.math.reciprocal_no_nan(swap_freq)), 1), # Although period = 1 / frequency may have roundoff error, and result # in a period different than what the user intended, the # user will end up with a single integer period, and thus well defined # deterministic swaps. tf.int32, ) # u selects parity. E.g., # u==False ==> [1, 0, 3, 2, 4] even parity swaps # u==True ==> [0, 2, 1, 4, 3] odd parity swaps # If there are 2 replicas, then the "True" swaps are null # swaps...which would contradict the user provided `swap_frequency`. # So special case num_replica==2, forcing u==False in this case. u_shape = ps.concat( (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)), axis=0) u = tf.fill(u_shape, tf.cast((step_count // safe_swap_period) % 2, tf.bool)) u = tf.where(num_replica > 2, u, False) x = mcmc_util.left_justified_expand_dims_to(tf.range( num_replica, dtype=tf.int64), rank=ps.size(u_shape)) y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1, x - 1) y = tf.clip_by_value(y, 0, num_replica - 1) # TODO(b/142689785): Consider using tf.cond and returning an empty list # then in REMC consider using a tf.cond for short-circuiting. return tf.where( (tf.cast(step_count % safe_swap_period, tf.bool) | tf.math.equal(swap_freq, 0)), x, # Don't swap y, # Swap )
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( target_log_prob_fn=self.target_log_prob_fn, inverse_temperatures=inverse_temperatures, untempered_log_prob_fn=self.untempered_log_prob_fn, tempered_log_prob_fn=self.tempered_log_prob_fn, ) # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise raise TypeError( '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` ' 'argument. `TransitionKernel` instances now receive seeds via ' '`one_step`.') seed = samplers.sanitize_seed(seed) # Retain for diagnostics. inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3) # Step the inner TransitionKernel. [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results, seed=inner_seed) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob) num_replica = ps.size0(inverse_temperatures) inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. try: swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed, step_count=previous_kernel_results.step_count), dtype=tf.int32) except TypeError as e: if 'step_count' not in str(e): raise warnings.warn( 'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept ' 'the `step_count` argument. Falling back to omitting the ' 'argument. This fallback will be removed after 24-Oct-2020.' ) swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed), dtype=tf.int32) null_swaps = mcmc_util.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs for use in the swap acceptance ratio. if self.tempered_log_prob_fn is None: # Efficient way of re-evaluating target_log_prob_fn on the # pre_swap_replica_states. untempered_energy_ignoring_ulp = ( # Since untempered_log_prob_fn is None, we may assume # inverse_temperatures > 0 (else the target is improper). pre_swap_replica_target_log_prob / inverse_temperatures) else: # The untempered_log_prob_fn does not factor into the acceptance ratio. # Proof: Suppose the tempered target is # p_k(x) = f(x)^{beta_k} g(x), # So f(x) is tempered, and g(x) is not. Then, the acceptance ratio for # a 1 <--> 2 swap is... # (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2)) # which depends only on f(x), since terms involving g(x) cancel. untempered_energy_ignoring_ulp = self.tempered_log_prob_fn( *pre_swap_replica_states) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. # Note: The untempered_log_prob_fn (if provided) is not included in # untempered_pre_swap_replica_target_log_prob, and hence does not factor # into energy_diff. Why? Because, it cancels out in the acceptance ratio. energy_diff = (untempered_energy_ignoring_ulp - mcmc_util.index_remapping_gather( untempered_energy_ignoring_ulp, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * mcmc_util.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=logu_seed)) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = mcmc_util.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, mcmc_util.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) if self._state_includes_replicas: post_swap_states = post_swap_replica_states else: post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _set_swapped_fields_to_nan( _swap_log_prob_and_maybe_grads(pre_swap_replica_results, post_swap_replica_states, inner_kernel)) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, step_count=previous_kernel_results.step_count + 1, seed=seed, ) return states, post_swap_kernel_results
def _make_post_swap_replica_results(pre_swap_replica_results, inverse_temperatures, swapped_inverse_temperatures, is_swap_accepted_mask, swap_tensor_fn): """Return Kernel results, valid for post-swap states. Fields will be removed if they cannot be updated in an unambiguous manner. Args: pre_swap_replica_results: Kernel results obtained by running inner_kernel.one_step before swapping. inverse_temperatures: Tensor of inverse temperatures. swapped_inverse_temperatures: Tensor of inverse temperatures, permuted by swaps. is_swap_accepted_mask: Shape [num_replica] + batch_shape boolean Tensor telling which swaps were accepted. Returns Kernel results of same type as pre_swap_replica_results. swap_tensor_fn: Callable. For `x.shape = [num_replica] + batch_shape`, swap_tensor_fn(x) performs swaps where they are accepted, and does not swap otherwise. Returns: new_replica_results: Same type as pre_swap_replica_results. Raises: NotImplementedError: If type of [nested] results is not handled. """ if not isinstance(pre_swap_replica_results, metropolis_hastings.MetropolisHastingsKernelResults): # TODO(b/143702650) Handle other kernels. raise NotImplementedError( '`pre_swap_replica_results` currently works only for ' 'MetropolisHastingsKernelResults. Found: {}. ' 'Please file a request with the TensorFlow Probability team.'. format(type(pre_swap_replica_results))) kr = pre_swap_replica_results dtype = swapped_inverse_temperatures.dtype # Hard to modify proposed_results in an um-ambiguous manner. # ...we also don't need to. kr = kr._replace( proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype), proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype), ) replica_and_batch_rank = ps.rank(kr.log_accept_ratio) # After using swap_tensor_fn on "values", values will be multiplied by the # swapped_inverse_temperatures. We need it to be multiplied instead by the # inverse temperature corresponding to its index. it_ratio_raw = inverse_temperatures / swapped_inverse_temperatures it_ratio = tf.where( is_swap_accepted_mask, mcmc_util.left_justified_expand_dims_to(it_ratio_raw, replica_and_batch_rank), tf.convert_to_tensor(1.0, dtype=dtype)) def _swap_then_retemper(x): x, is_multipart = mcmc_util.prepare_state_parts(x) it_ratio_ = mcmc_util.left_justified_expand_dims_like(it_ratio, x[0]) x = [swap_tensor_fn(x_part) * it_ratio_ for x_part in x] if not is_multipart: x = x[0] return x if isinstance(kr.accepted_results, hmc.UncalibratedHamiltonianMonteCarloKernelResults): kr = kr._replace(accepted_results=kr.accepted_results._replace( target_log_prob=_swap_then_retemper( kr.accepted_results.target_log_prob), grads_target_log_prob=_swap_then_retemper( kr.accepted_results.grads_target_log_prob))) elif isinstance(kr.accepted_results, random_walk_metropolis.UncalibratedRandomWalkResults): kr = kr._replace(accepted_results=kr.accepted_results._replace( target_log_prob=_swap_then_retemper( kr.accepted_results.target_log_prob))) else: # TODO(b/143702650) Handle other kernels. raise NotImplementedError( 'Only HMC and RWMH Kernels are handled at this time. Please file a ' 'request with the TensorFlow Probability team.') return kr
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( self.target_log_prob_fn, inverse_temperatures) # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.make_kernel_fn`, and no seed # expected by `kernel.one_step`. # In other words: # - We try `make_kernel_fn` without a seed first; this is the future. The # kernel will receive a seed later, as part of `one_step`. # - If the user code doesn't like that (Python complains about a missing # required argument), we warn and fall back to the previous behavior. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is ' 'deprecated. `TransitionKernel` instances now receive seeds via ' '`one_step`.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, self._seed_stream()) # Now that we've constructed the TransitionKernel instance: # - If we were given a seed, we sanitize it to stateless and pass along # to `kernel.one_step`. If it doesn't like that, we crash and propagate # the error. Rationale: The contract is stateless sampling given # seed, and doing otherwise would not meet it. # - If not given a seed, we don't pass one along. This avoids breaking # underlying kernels lacking a `seed` arg on `one_step`. # TODO(b/159636942): Clean up after 2020-09-20. if seed is not None: seed = samplers.sanitize_seed(seed) inner_seed, swap_seed, logu_seed = samplers.split_seed( seed, n=3, salt='remc_one_step') inner_kwargs = dict(seed=inner_seed) else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) inner_kwargs = {} swap_seed, logu_seed = samplers.split_seed(self._seed_stream()) [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results, **inner_kwargs) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob) num_replica = ps.size0(inverse_temperatures) inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. try: swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed, step_count=previous_kernel_results.step_count), dtype=tf.int32) except TypeError as e: if 'step_count' not in str(e): raise warnings.warn( 'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept ' 'the `step_count` argument. Falling back to omitting the ' 'argument. This fallback will be removed after 24-Oct-2020.' ) swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed), dtype=tf.int32) null_swaps = mcmc_util.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs. E.g., for replica k, at point x_k, this is # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k. untempered_pre_swap_replica_target_log_prob = ( pre_swap_replica_target_log_prob / inverse_temperatures) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. energy_diff = (untempered_pre_swap_replica_target_log_prob - mcmc_util.index_remapping_gather( untempered_pre_swap_replica_target_log_prob, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * mcmc_util.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce Log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=logu_seed)) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = mcmc_util.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, mcmc_util.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) if self._state_includes_replicas: post_swap_states = post_swap_replica_states else: post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _make_post_swap_replica_results( pre_swap_replica_results, inverse_temperatures, swapped_inverse_temperatures, is_swap_accepted_mask, _swap_tensor) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, step_count=previous_kernel_results.step_count + 1, seed=samplers.zeros_seed() if seed is None else seed, ) return states, post_swap_kernel_results
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable _make_replica_target_log_prob_fn(self.target_log_prob_fn, inverse_temperatures), self._seed_stream()) [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = prefer_static.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = prefer_static.rank( pre_swap_replica_target_log_prob) num_replica = prefer_static.size0(inverse_temperatures) inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=self._seed_stream()), dtype=tf.int32) null_swaps = mcmc_util.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs. E.g., for replica k, at point x_k, this is # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k. untempered_pre_swap_replica_target_log_prob = ( pre_swap_replica_target_log_prob / inverse_temperatures) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. energy_diff = (untempered_pre_swap_replica_target_log_prob - mcmc_util.index_remapping_gather( untempered_pre_swap_replica_target_log_prob, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * mcmc_util.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce Log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( tf.random.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=self._seed_stream())) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = mcmc_util.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, mcmc_util.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _make_post_swap_replica_results( pre_swap_replica_results, inverse_temperatures, swapped_inverse_temperatures, is_swap_accepted_mask, _swap_tensor) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, ) return states, post_swap_kernel_results
def _make_post_swap_replica_results(pre_swap_replica_results, inverse_temperatures, swapped_inverse_temperatures, is_swap_accepted_mask, swap_tensor_fn): """Return Kernel results, valid for post-swap states. Fields will be removed if they cannot be updated in an unambiguous manner. Args: pre_swap_replica_results: Kernel results obtained by running inner_kernel.one_step before swapping. inverse_temperatures: Tensor of inverse temperatures. swapped_inverse_temperatures: Tensor of inverse temperatures, permuted by swaps. is_swap_accepted_mask: Shape [num_replica] + batch_shape boolean Tensor telling which swaps were accepted. Returns Kernel results of same type as pre_swap_replica_results. swap_tensor_fn: Callable. For `x.shape = [num_replica] + batch_shape`, swap_tensor_fn(x) performs swaps where they are accepted, and does not swap otherwise. Returns: new_replica_results: Same type as pre_swap_replica_results. Raises: NotImplementedError: If type of [nested] results is not handled. """ kr = pre_swap_replica_results dtype = swapped_inverse_temperatures.dtype # Hard to modify proposed_results in an um-ambiguous manner. # ...we also don't need to. kr = _update_field(kr, 'proposed_results', tf.convert_to_tensor(np.nan, dtype=dtype)) kr = _update_field(kr, 'proposed_state', tf.convert_to_tensor(np.nan, dtype=dtype)) replica_and_batch_rank = ps.rank(_get_field(kr, 'log_accept_ratio')) # After using swap_tensor_fn on "values", values will be multiplied by the # swapped_inverse_temperatures. We need it to be multiplied instead by the # inverse temperature corresponding to its index. it_ratio_raw = inverse_temperatures / swapped_inverse_temperatures it_ratio = tf.where( is_swap_accepted_mask, mcmc_util.left_justified_expand_dims_to(it_ratio_raw, replica_and_batch_rank), tf.convert_to_tensor(1.0, dtype=dtype)) def _swap_then_retemper(x): x, is_multipart = mcmc_util.prepare_state_parts(x) it_ratio_ = mcmc_util.left_justified_expand_dims_like(it_ratio, x[0]) x = [swap_tensor_fn(x_part) * it_ratio_ for x_part in x] if not is_multipart: x = x[0] return x kr = _update_field(kr, 'target_log_prob', _swap_then_retemper(_get_field(kr, 'target_log_prob'))) try: new_grads_target_log_prob = _swap_then_retemper( _get_field(kr, 'grads_target_log_prob')) kr = _update_field(kr, 'grads_target_log_prob', new_grads_target_log_prob) # For transition kernels not involving the gradient of the log-probability, # grads_target_log_prob will not exist in the (possibly multiply wrapped) # kernel results and that's perfectly fine. But _get_field() / _update_field() # will throw a REMCFieldNotFoundError, which we thus catch silently. except REMCFieldNotFoundError: pass return kr