def test_flat_replace(self): results = FakeResults(0) self.assertEqual( unnest.replace_innermost(results, unique_core=1).unique_core, 1) self.assertEqual( unnest.replace_innermost(results, return_unused=True, unique_core=1, foo=2), (FakeResults(1), { 'foo': 2 })) self.assertEqual( unnest.replace_outermost(results, unique_core=2).unique_core, 2) self.assertEqual( unnest.replace_outermost(results, return_unused=True, unique_core=2, foo=3), (FakeResults(2), { 'foo': 3 })) self.assertRaises( ValueError, lambda: unnest.replace_innermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1)) self.assertRaises( ValueError, lambda: unnest.replace_outermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1))
def test_atypical_nesting_replace(self): def build(a, b): return FakeAtypicalNestingResults( unique_atypical_nesting=a, atypical_inner_results=FakeResults(b)) results = build(0, 1) self.assertEqual(unnest.replace_innermost( results, unique_atypical_nesting=2), build(2, 1)) self.assertEqual( unnest.replace_innermost(results, unique_core=2), build(0, 2)) self.assertEqual(unnest.replace_innermost( results, unique_atypical_nesting=2, unique_core=3), build(2, 3)) self.assertEqual(unnest.replace_innermost( results, return_unused=True, unique_atypical_nesting=2, unique_core=3, foo=4), (build(2, 3), {'foo': 4})) self.assertEqual(unnest.replace_outermost( results, unique_atypical_nesting=2), build(2, 1)) self.assertEqual( unnest.replace_outermost(results, unique_core=2), build(0, 2)) self.assertEqual(unnest.replace_outermost( results, unique_atypical_nesting=2, unique_core=3), build(2, 3)) self.assertEqual(unnest.replace_outermost( results, return_unused=True, unique_atypical_nesting=2, unique_core=3, foo=4), (build(2, 3), {'foo': 4})) self.assertRaises(ValueError, lambda: unnest.replace_innermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1)) self.assertRaises(ValueError, lambda: unnest.replace_outermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1))
def hmc_like_momentum_distribution_setter_fn(kernel_results, new_distribution): """Setter for `momentum_distribution` so it can be adapted.""" # Note that unnest.replace_innermost has a special path for going into # `accepted_results` preferentially, so this will set # `accepted_results.momentum_distribution`. return unnest.replace_innermost(kernel_results, momentum_distribution=new_distribution)
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'bootstrap_results')): if isinstance(self.initial_running_variance, sample_stats.RunningVariance): variance_parts = [self.initial_running_variance] else: variance_parts = self.initial_running_variance diags = [ variance_part.variance() for variance_part in variance_parts ] # Step inner results. inner_results = self.inner_kernel.bootstrap_results(init_state) # Set the momentum. batch_ndims = ps.rank( unnest.get_innermost(inner_results, 'target_log_prob')) init_state_parts = tf.nest.flatten(init_state) momentum_distribution = _make_momentum_distribution( diags, init_state_parts, batch_ndims) inner_results = self.momentum_distribution_setter_fn( inner_results, momentum_distribution) proposed = unnest.get_innermost(inner_results, 'proposed_results', default=None) if proposed is not None: proposed = proposed._replace( momentum_distribution=momentum_distribution) inner_results = unnest.replace_innermost( inner_results, proposed_results=proposed) return DiagonalMassMatrixAdaptationResults( inner_results=inner_results, running_variance=variance_parts)
def _update_field(kernel_results, field_name, value): try: return unnest.replace_innermost(kernel_results, **{field_name: value}) except ValueError: msg = _kernel_result_not_implemented_message_template.format( kernel_results, field_name) raise REMCFieldNotFoundError(msg)
def one_step(self, current_state, previous_kernel_results, seed=None): """Runs one iteration of NeuTra. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) seed: Optional seed for reproducible sampling. Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. """ step_size = previous_kernel_results.new_step_size previous_kernel_results = unnest.replace_innermost( previous_kernel_results, num_leapfrog_steps=self._num_leapfrog_steps(step_size)) new_state, kernel_results = self._kernel.one_step( self._flatten_state(current_state), previous_kernel_results, seed=seed) return self._unflatten_state(new_state), kernel_results
def test_sets_kinetic_energy(self): dist = tfd.MultivariateNormalDiag(scale_diag=tf.constant([0.1, 10.])) step_size = 0.1 kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo( target_log_prob_fn=dist.log_prob, step_size=step_size, num_leapfrog_steps=1, store_parameters_in_results=True) init_state = tf.constant([0.1, 0.1]) kr = kernel.bootstrap_results(init_state) # Manually set the momentum distribution. kr = unnest.replace_innermost(kr, momentum_distribution=dist) # Take one leapfrog step using the kernel. _, nkr = kernel.one_step(init_state, kr, seed=test_util.test_seed()) # Need to evaluate here for consistency in graph mode. (momentum_parts, target_grad_parts, proposed_state, final_momentum, target_log_prob, grads_target_log_prob) = self.evaluate([ nkr.proposed_results.initial_momentum, nkr.accepted_results.grads_target_log_prob, nkr.proposed_state, nkr.proposed_results.final_momentum, nkr.proposed_results.target_log_prob, nkr.proposed_results.grads_target_log_prob]) # Take one leapfrog step manually. leapfrog = tfp.mcmc.internal.leapfrog_integrator.SimpleLeapfrogIntegrator( target_fn=dist.log_prob, step_sizes=[step_size], num_steps=1) # Again, need to evaluate here for graph mode consistency. (next_momentum, next_state, next_target_log_prob, grads_next_target_log_prob) = self.evaluate(leapfrog( momentum_parts=momentum_parts, state_parts=[init_state], target=dist.log_prob(init_state), target_grad_parts=target_grad_parts, kinetic_energy_fn=lambda x: -dist.log_prob(x))) # Verify resulting states are the same self.assertAllClose(proposed_state, next_state[0]) self.assertAllClose(final_momentum, next_momentum) self.assertAllClose(target_log_prob, next_target_log_prob) self.assertAllClose(grads_target_log_prob, grads_next_target_log_prob)
def test_deeply_nested_replace(self): results = _build_deeply_nested(0, 1, 2, 3, 4) self.assertTrue(unnest.replace_innermost(results, unique_nesting=5), _build_deeply_nested(0, 1, 2, 5, 4)) self.assertTrue(unnest.replace_outermost(results, unique_nesting=5), _build_deeply_nested(5, 1, 2, 3, 4)) self.assertTrue( unnest.replace_innermost(results, unique_atypical_nesting=5), _build_deeply_nested(0, 1, 5, 3, 4)) self.assertTrue( unnest.replace_outermost(results, unique_atypical_nesting=5), _build_deeply_nested(0, 5, 2, 3, 4)) self.assertTrue(unnest.replace_innermost(results, unique_core=5), _build_deeply_nested(0, 1, 2, 3, 5)) self.assertTrue(unnest.replace_outermost(results, unique_core=5), _build_deeply_nested(0, 1, 2, 3, 5)) self.assertTrue( unnest.replace_innermost(results, unique_nesting=5, unique_atypical_nesting=6, unique_core=7), _build_deeply_nested(0, 1, 6, 5, 7)) self.assertTrue( unnest.replace_innermost(results, return_unused=True, unique_nesting=5, unique_atypical_nesting=6, unique_core=7, foo=8), (_build_deeply_nested(0, 1, 6, 5, 7), { 'foo': 8 })) self.assertTrue( unnest.replace_outermost(results, unique_nesting=5, unique_atypical_nesting=6, unique_core=7), _build_deeply_nested(5, 6, 2, 3, 7)) self.assertTrue( unnest.replace_outermost(results, return_unused=True, unique_nesting=5, unique_atypical_nesting=6, unique_core=7, foo=8), (_build_deeply_nested(5, 6, 2, 3, 7), { 'foo': 8 })) self.assertRaises( ValueError, lambda: unnest.replace_innermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1)) self.assertRaises( ValueError, lambda: unnest.replace_outermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1))
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'bootstrap_results')): # Step inner results. inner_results = self.inner_kernel.bootstrap_results(init_state) # Bootstrap the results. results = self._bootstrap_from_inner_results( init_state, inner_results) if self.num_estimation_steps is not None: # We only update the momentum at the end of adaptation phase, # so we do not need to set the momentum here. return results # Set the momentum. diags = [ variance_part.variance() for variance_part in results.running_variance ] inner_results = results.inner_results batch_shape = ps.shape( unnest.get_innermost(inner_results, 'target_log_prob')) init_state_parts = tf.nest.flatten(init_state) momentum_distribution = preconditioning_utils.make_momentum_distribution( init_state_parts, batch_shape, diags, shard_axis_names=self.experimental_shard_axis_names) inner_results = self.momentum_distribution_setter_fn( inner_results, momentum_distribution) proposed = unnest.get_innermost(inner_results, 'proposed_results', default=None) if proposed is not None: proposed = proposed._replace( momentum_distribution=momentum_distribution) inner_results = unnest.replace_innermost( inner_results, proposed_results=proposed) results = results._replace(inner_results=inner_results) return results
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope( mcmc_util.make_name(self.name, 'snaper_hamiltonian_monte_carlo', 'one_step')): inner_results = previous_kernel_results.inner_results batch_shape = ps.shape( unnest.get_innermost(previous_kernel_results, 'target_log_prob')) reduce_axes = ps.range(0, ps.size(batch_shape)) step = inner_results.step state_ema_points = previous_kernel_results.state_ema_points kernel = self._make_kernel( batch_shape=batch_shape, step=step, state_ema_points=state_ema_points, state=current_state, mean=previous_kernel_results.ema_mean, variance=previous_kernel_results.ema_variance, principal_component=previous_kernel_results. ema_principal_component, ) inner_results = unnest.replace_innermost( inner_results, momentum_distribution=( kernel.inner_kernel.parameters['momentum_distribution']), # pylint: disable=protected-access ) seed = samplers.sanitize_seed(seed) state_parts, inner_results = kernel.one_step( tf.nest.flatten(current_state), inner_results, seed=seed, ) state = tf.nest.pack_sequence_as(current_state, state_parts) state_ema_points, ema_mean, ema_variance = self._update_state_ema( reduce_axes=reduce_axes, state=state, step=step, state_ema_points=state_ema_points, ema_mean=previous_kernel_results.ema_mean, ema_variance=previous_kernel_results.ema_variance, ) (principal_component_ema_points, ema_principal_component) = self._update_principal_component_ema( reduce_axes=reduce_axes, state=state, step=step, principal_component_ema_points=( previous_kernel_results.principal_component_ema_points), ema_principal_component=( previous_kernel_results.ema_principal_component), ) kernel_results = previous_kernel_results._replace( inner_results=inner_results, ema_mean=ema_mean, ema_variance=ema_variance, state_ema_points=state_ema_points, ema_principal_component=ema_principal_component, principal_component_ema_points=principal_component_ema_points, seed=seed, ) return state, kernel_results
def hmc_like_step_size_setter_fn(kernel_results, new_step_size): """Setter for `step_size` so it can be adapted.""" return unnest.replace_innermost(kernel_results, step_size=new_step_size)
def hmc_like_num_leapfrog_steps_setter_fn(kernel_results, new_num_leapfrog_steps): """Setter for `num_leapfrog_steps` so it can be adapted.""" return unnest.replace_innermost( kernel_results, num_leapfrog_steps=new_num_leapfrog_steps)