def test_selects_batch_members_from_list_of_arrays(self): # Shape of each array: [2, 3] = [batch_size, event_size] # This test verifies that is_accepted selects batch members, despite the # "usual" broadcasting being applied on the right first (event first). zeros_states = [np.zeros((2, 3))] ones_states = [np.ones((2, 3))] chosen = util.choose( tf.constant([True, False]), zeros_states, ones_states) chosen_ = self.evaluate(chosen) # Make sure outer list wasn't interpreted as a dimenion of an array. self.assertIsInstance(chosen_, list) expected_array = np.array([ [0., 0., 0.], # zeros_states selected for first batch [1., 1., 1.], # ones_states selected for second ]) expected = [expected_array] self.assertAllEqual(expected, chosen_)
def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x)
def one_step(self, current_state, previous_kernel_results): with tf.name_scope( mcmc_util.make_name(self.name, 'simple_step_size_adaptation', 'one_step')): # Set the step_size. inner_results = self.step_size_setter_fn( previous_kernel_results.inner_results, previous_kernel_results.new_step_size) # Step the inner kernel. new_state, new_inner_results = self.inner_kernel.one_step( current_state, inner_results) # Get the new step size. log_accept_prob = self.log_accept_prob_getter_fn(new_inner_results) log_target_accept_prob = tf.math.log( tf.cast(previous_kernel_results.target_accept_prob, dtype=log_accept_prob.dtype)) state_parts = tf.nest.flatten(current_state) step_size = self.step_size_getter_fn(new_inner_results) step_size_parts = tf.nest.flatten(step_size) log_accept_prob_rank = prefer_static.rank(log_accept_prob) new_step_size_parts = [] for step_size_part, state_part in zip(step_size_parts, state_parts): # Compute new step sizes for each step size part. If step size part has # smaller rank than the corresponding state part, then the difference is # averaged away in the log accept prob. # # Example: # # state_part has shape [2, 3, 4, 5] # step_size_part has shape [1, 4, 1] # log_accept_prob has shape [2, 3, 4] # # Since step size has 1 rank fewer than the state, we reduce away the # leading dimension of log_accept_prob to get a Tensor with shape [3, # 4]. Next, since log_accept_prob must broadcast into step_size_part on # the left, we reduce the dimensions where their shapes differ, to get a # Tensor with shape [1, 4], which now is compatible with the leading # dimensions of step_size_part. # # There is a subtlety here in that step_size_parts might be a length-1 # list, which means that we'll be "structure-broadcasting" it for all # the state parts (see logic in, e.g., hmc.py). In this case we must # assume that that the lone step size provided broadcasts with the event # dims of each state part. This means that either step size has no # dimensions corresponding to chain dimensions, or all states are of the # same shape. For the former, we want to reduce over all chain # dimensions. For the later, we want to use the same logic as in the # non-structure-broadcasted case. # # It turns out we can compute the reduction dimensions for both cases # uniformly by taking the rank of any state part. This obviously works # in the second case (where all state ranks are the same). In the first # case, all state parts have the rank L + D_i + B, where L is the rank # of log_accept_prob, D_i is the non-shared dimensions amongst all # states, and B are the shared dimensions of all the states, which are # equal to the step size. When we subtract B, we will always get a # number >= L, which means we'll get the full reduction we want. num_reduce_dims = prefer_static.minimum( log_accept_prob_rank, prefer_static.rank(state_part) - prefer_static.rank(step_size_part)) reduced_log_accept_prob = reduce_logmeanexp( log_accept_prob, axis=prefer_static.range(num_reduce_dims)) # reduced_log_accept_prob must broadcast into step_size_part on the # left, so we do an additional reduction over dimensions where their # shapes differ. reduce_indices = get_differing_dims(reduced_log_accept_prob, step_size_part) reduced_log_accept_prob = reduce_logmeanexp( reduced_log_accept_prob, axis=reduce_indices, keepdims=True) one_plus_adaptation_rate = 1. + tf.cast( previous_kernel_results.adaptation_rate, dtype=step_size_part.dtype) new_step_size_part = mcmc_util.choose( reduced_log_accept_prob > log_target_accept_prob, step_size_part * one_plus_adaptation_rate, step_size_part / one_plus_adaptation_rate) new_step_size_parts.append( tf.where( previous_kernel_results.step < self.num_adaptation_steps, new_step_size_part, step_size_part)) new_step_size = tf.nest.pack_sequence_as(step_size, new_step_size_parts) return new_state, previous_kernel_results._replace( inner_results=new_inner_results, step=1 + previous_kernel_results.step, new_step_size=new_step_size)
def _swap(is_exchange_accepted, x, y): """Swap batches of x, y where accepted.""" with tf.compat.v1.name_scope('swap_where_exchange_accepted'): new_x = mcmc_util.choose(is_exchange_accepted, y, x) new_y = mcmc_util.choose(is_exchange_accepted, x, y) return new_x, new_y
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ is_seeded = seed is not None seed = samplers.sanitize_seed(seed) # Retain for diagnostics. proposal_seed, acceptance_seed = samplers.split_seed(seed) with tf.name_scope(mcmc_util.make_name(self.name, 'mh', 'one_step')): # Take one inner step. inner_kwargs = dict(seed=proposal_seed) if is_seeded else {} [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results, **inner_kwargs) if mcmc_util.is_list_like(current_state): proposed_state = tf.nest.pack_sequence_as( current_state, proposed_state) if (not has_target_log_prob(proposed_results) or not has_target_log_prob( previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [ proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob ] try: if (not mcmc_util.is_list_like( proposed_results.log_acceptance_correction) or proposed_results.log_acceptance_correction): to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn( 'Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum( to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio log_uniform = tf.math.log( samplers.uniform(shape=prefer_static.shape( proposed_results.target_log_prob), dtype=dtype_util.base_dtype( proposed_results.target_log_prob.dtype), seed=acceptance_seed)) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose(is_accepted, proposed_state, current_state, name='choose_next_state') kernel_results = MetropolisHastingsKernelResults( accepted_results=mcmc_util.choose( is_accepted, # We strip seeds when populating `accepted_results` because unlike # other kernel result fields, seeds are not a per-chain value. # Thus it is impossible to choose between a previously accepted # seed value and a proposed seed, since said choice would need to # be made on a per-chain basis. mcmc_util.strip_seeds(proposed_results), previous_kernel_results.accepted_results, name='choose_inner_results'), is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, extra=[], seed=seed, ) return next_state, kernel_results
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ with tf1.name_scope(name=mcmc_util.make_name(self.name, 'mh', 'one_step'), values=[current_state, previous_kernel_results]): # Take one inner step. [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results) if (not has_target_log_prob(proposed_results) or not has_target_log_prob( previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [ proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob ] try: if (not mcmc_util.is_list_like( proposed_results.log_acceptance_correction) or proposed_results.log_acceptance_correction): to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn( 'Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum( to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio log_uniform = tf.math.log( tf.random.uniform( shape=tf.shape(input=proposed_results.target_log_prob), dtype=proposed_results.target_log_prob.dtype.base_dtype, seed=self._seed_stream())) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose(is_accepted, proposed_state, current_state, name='choose_next_state') kernel_results = MetropolisHastingsKernelResults( accepted_results=mcmc_util.choose( is_accepted, proposed_results, previous_kernel_results.accepted_results, name='choose_inner_results'), is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, extra=[], ) return next_state, kernel_results
def _do_flip(state, i): new_state = sampler._flip_feature(state, tf.gather(flip_idxs, i)) return mcmc_util.choose(tf.gather(should_flip, i), new_state, state)
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope( mcmc_util.make_name(self.name, 'gradient_based_trajectory_length_adaptation', 'one_step')): jitter_seed, inner_seed = samplers.split_seed(seed) dtype = previous_kernel_results.adaptation_rate.dtype current_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, dtype=dtype), current_state) step_f = tf.cast(previous_kernel_results.step, dtype) if self.use_halton_sequence_jitter: trajectory_jitter = _halton_sequence(step_f) else: trajectory_jitter = samplers.uniform((), seed=jitter_seed, dtype=dtype) jitter_amount = previous_kernel_results.jitter_amount trajectory_jitter = ( trajectory_jitter * jitter_amount + (1. - jitter_amount)) adapting = previous_kernel_results.step < self.num_adaptation_steps max_trajectory_length = tf.where( adapting, previous_kernel_results.max_trajectory_length, previous_kernel_results.averaged_max_trajectory_length) jittered_trajectory_length = (max_trajectory_length * trajectory_jitter) step_size = _ensure_step_size_is_scalar( self.step_size_getter_fn(previous_kernel_results), self.validate_args) num_leapfrog_steps = tf.cast( tf.maximum( tf.ones([], dtype), tf.math.ceil(jittered_trajectory_length / step_size)), tf.int32) previous_kernel_results_with_jitter = self.num_leapfrog_steps_setter_fn( previous_kernel_results, num_leapfrog_steps) new_state, new_inner_results = self.inner_kernel.one_step( current_state, previous_kernel_results_with_jitter.inner_results, inner_seed) proposed_state = self.proposed_state_getter_fn(new_inner_results) proposed_velocity = self.proposed_velocity_getter_fn(new_inner_results) accept_prob = tf.exp(self.log_accept_prob_getter_fn(new_inner_results)) new_kernel_results = _update_trajectory_grad( previous_kernel_results_with_jitter, previous_state=current_state, proposed_state=proposed_state, proposed_velocity=proposed_velocity, trajectory_jitter=trajectory_jitter, accept_prob=accept_prob, step_size=step_size, criterion_fn=self.criterion_fn, max_leapfrog_steps=self.max_leapfrog_steps) # Undo the effect of adaptation if we're not in the burnin phase. We keep # the criterion, however, as that's a diagnostic. We also keep the # leapfrog steps setting, as that's an effect of jitter (and also doubles # as a diagnostic). criterion = new_kernel_results.criterion new_kernel_results = mcmc_util.choose( adapting, new_kernel_results, previous_kernel_results_with_jitter) new_kernel_results = new_kernel_results._replace( inner_results=new_inner_results, step=previous_kernel_results.step + 1, criterion=criterion) return new_state, new_kernel_results
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'one_step')): variance_parts = previous_kernel_results.running_variance inner_results = previous_kernel_results.inner_results # Step the inner kernel. inner_kwargs = {} if seed is None else dict(seed=seed) new_state, new_inner_results = self.inner_kernel.one_step( current_state, inner_results, **inner_kwargs) def update_running_variance(): diags = [ variance_part.variance() for variance_part in variance_parts ] new_state_parts = tf.nest.flatten(new_state) new_variance_parts = [] for variance_part, diag, state_part in zip( variance_parts, diags, new_state_parts): # Compute new variance for each variance part, accounting for partial # batching of the variance calculation across chains (ie, some, all, # or none of the chains may share the estimated mass matrix). # # For example, say # # state_part has shape [2, 3, 4] + [5, 6] (batch + event) # variance_part has shape [4] + [5, 6] # log_prob has shape [2, 3, 4] # # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass # matrices, each being shared across a [2, 3]-batch of chains. Note # this division is inferred from the shapes of the state part, the # log_prob, and the user-provided initial running variances. # # Until RunningVariance supports rank > 1 chunking, we need to flatten # the states that go into updating the variance estimates. In the # above example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and # fed to `RunningVariance.update(state_part, axis=0)`, recording # 6 new observations in the running variance calculation. # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and # the resulting momentum distribution will have batch shape of # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part. state_rank = ps.rank(state_part) variance_rank = ps.rank(diag) num_reduce_dims = state_rank - variance_rank state_part_shape = ps.shape(state_part) # This reshape adds a 1 when reduce_dims==0, and collapses all the # lead dimensions to a single one otherwise. reshaped_state = ps.reshape( state_part, ps.concat([[ ps.reduce_prod(state_part_shape[:num_reduce_dims]) ], state_part_shape[num_reduce_dims:]], axis=0)) # The `axis=0` here removes the leading dimension we got from the # reshape above, so the new_variance_parts have the correct shape # again. new_variance_parts.append( variance_part.update(reshaped_state, axis=0)) return new_variance_parts def update_momentum(): diags = [ variance_part.variance() for variance_part in new_variance_parts ] # Update the momentum. prev_momentum_distribution = self.momentum_distribution_getter_fn( new_inner_results) new_momentum_distribution = ( preconditioning_utils.update_momentum_distribution( prev_momentum_distribution, diags)) updated_new_inner_results = self.momentum_distribution_setter_fn( new_inner_results, new_momentum_distribution) return updated_new_inner_results step = previous_kernel_results.step + 1 if self.num_estimation_steps is None: new_variance_parts = update_running_variance() new_inner_results = update_momentum() else: new_variance_parts = mcmc_util.choose( step <= previous_kernel_results.num_estimation_steps, update_running_variance(), variance_parts) new_inner_results = mcmc_util.choose( tf.equal(step, previous_kernel_results.num_estimation_steps), update_momentum(), new_inner_results) new_kernel_results = previous_kernel_results._replace( inner_results=new_inner_results, running_variance=new_variance_parts, step=step) return new_state, new_kernel_results