def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'bootstrap_results')): if isinstance(self.initial_running_variance, sample_stats.RunningVariance): variance_parts = [self.initial_running_variance] else: variance_parts = self.initial_running_variance diags = [ variance_part.variance() for variance_part in variance_parts ] # Step inner results. inner_results = self.inner_kernel.bootstrap_results(init_state) # Set the momentum. batch_ndims = ps.rank( unnest.get_innermost(inner_results, 'target_log_prob')) init_state_parts = tf.nest.flatten(init_state) momentum_distribution = _make_momentum_distribution( diags, init_state_parts, batch_ndims) inner_results = self.momentum_distribution_setter_fn( inner_results, momentum_distribution) proposed = unnest.get_innermost(inner_results, 'proposed_results', default=None) if proposed is not None: proposed = proposed._replace( momentum_distribution=momentum_distribution) inner_results = unnest.replace_innermost( inner_results, proposed_results=proposed) return DiagonalMassMatrixAdaptationResults( inner_results=inner_results, running_variance=variance_parts)
def test_atypical_nesting(self): results = FakeAtypicalNestingResults( unique_atypical_nesting=0, atypical_inner_results=FakeResults(1)) self.assertTrue(unnest.has_nested(results, 'unique_atypical_nesting')) self.assertTrue(unnest.has_nested(results, 'unique_core')) self.assertFalse(hasattr(results, 'unique_core')) self.assertFalse(unnest.has_nested(results, 'foo')) self.assertEqual( results.unique_atypical_nesting, unnest.get_innermost(results, 'unique_atypical_nesting')) self.assertEqual( results.unique_atypical_nesting, unnest.get_outermost(results, 'unique_atypical_nesting')) self.assertEqual(results.atypical_inner_results.unique_core, unnest.get_innermost(results, 'unique_core')) self.assertEqual(results.atypical_inner_results.unique_core, unnest.get_outermost(results, 'unique_core')) self.assertRaises(AttributeError, lambda: unnest.get_innermost(results, 'foo')) self.assertRaises(AttributeError, lambda: unnest.get_outermost(results, 'foo')) self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON), SINGLETON) self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON), SINGLETON)
def hmc_like_proposed_velocity_getter_fn(kernel_results): """Getter for `proposed_velocity` so it can be inspected.""" # TODO(b/169898033): This only works with the standard kinetic energy term. proposed_velocity = unnest.get_innermost(kernel_results, 'final_momentum') proposed_state = unnest.get_innermost(kernel_results, 'proposed_state') # proposed_velocity has the wrong structure when state is a scalar. return tf.nest.pack_sequence_as(proposed_state, tf.nest.flatten(proposed_velocity))
def test_change_core_params(self): builder = kernel_builder.KernelBuilder.make(lambda x: x) builder = builder.hmc(num_leapfrog_steps=3) kernel = builder.build() found_steps = unnest.get_innermost(kernel, 'num_leapfrog_steps') self.assertEqual(found_steps, 3) builder = builder.hmc(num_leapfrog_steps=5) kernel = builder.build() found_steps = unnest.get_innermost(kernel, 'num_leapfrog_steps') self.assertEqual(found_steps, 5)
def trace_fn(_, pkr): energy_diff = pkr.inner_results.inner_results.inner_results.log_accept_ratio has_divergence = jnp.abs(energy_diff) > max_energy_diff return ( unnest.get_innermost(pkr, 'target_log_prob'), unnest.get_innermost(pkr, 'num_leapfrog_steps'), has_divergence, energy_diff, pkr.inner_results.inner_results.inner_results.log_accept_ratio, pkr.inner_results.inner_results.max_trajectory_length, unnest.get_innermost(pkr, 'step_size'), )
def trace_fn(_, pkr): return { 'step_size': unnest.get_innermost(pkr, 'step_size'), 'mean_trajectory_length': unnest.get_innermost(pkr, 'max_trajectory_length') / 2., 'principal_component': unnest.get_innermost(pkr, 'ema_principal_component'), 'variance': unnest.get_innermost(pkr, 'ema_variance'), 'num_leapfrog_steps': unnest.get_innermost(pkr, 'num_leapfrog_steps'), }
def test_flat(self): results = FakeResults(0) self.assertTrue(unnest.has_nested(results, 'unique_core')) self.assertFalse(unnest.has_nested(results, 'foo')) self.assertEqual(results.unique_core, unnest.get_innermost(results, 'unique_core')) self.assertEqual(results.unique_core, unnest.get_outermost(results, 'unique_core')) self.assertRaises( AttributeError, lambda: unnest.get_innermost(results, 'foo')) self.assertRaises( AttributeError, lambda: unnest.get_outermost(results, 'foo')) self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON), SINGLETON) self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON), SINGLETON)
def _get_field(kernel_results, field_name): try: return unnest.get_innermost(kernel_results, field_name) except AttributeError: msg = _kernel_result_not_implemented_message_template.format( kernel_results, field_name) raise REMCFieldNotFoundError(msg)
def test_supply_single_step_size(self): stream = test_util.test_seed_stream() jd_model = tfd.JointDistributionNamed({ 'a': tfd.Normal(0., 1.), 'b': tfd.MultivariateNormalDiag(loc=tf.zeros(3), scale_diag=tf.constant([1., 2., 3.])) }) init_step_size = 1. _, traced_step_size = self.evaluate( tfp.experimental.mcmc.windowed_adaptive_hmc( 1, jd_model, num_adaptation_steps=25, n_chains=20, init_step_size=init_step_size, num_leapfrog_steps=5, discard_tuning=False, trace_fn=lambda *args: unnest.get_innermost( args[-1], 'step_size'), seed=stream())) self.assertEqual((25 + 1, ), traced_step_size.shape) self.assertAllClose(1., traced_step_size[0])
def test_supply_full_step_size(self): stream = test_util.test_seed_stream() jd_model = tfd.JointDistributionNamed({ 'a': tfd.Normal(0., 1.), 'b': tfd.MultivariateNormalDiag(loc=tf.zeros(3), scale_diag=tf.constant([1., 2., 3.])) }) init_step_size = { 'a': tf.reshape(tf.linspace(1., 2., 20), (20, 1)), 'b': tf.reshape(tf.linspace(1., 2., 60), (20, 3)) } _, actual_step_size = tfp.experimental.mcmc.windowed_adaptive_hmc( 1, jd_model, num_adaptation_steps=100, n_chains=20, init_step_size=init_step_size, num_leapfrog_steps=5, discard_tuning=False, trace_fn=lambda *args: unnest.get_innermost(args[-1], 'step_size'), seed=stream(), ) # Gets a newaxis because step size needs to have an event dimension. self.assertAllCloseNested([init_step_size['a'], init_step_size['b']], [j[0] for j in actual_step_size])
def test_supply_single_step_size(self): stream = test_util.test_seed_stream() jd_model = tfd.JointDistributionNamed({ 'a': tfd.Normal(0., 1.), 'b': tfd.MultivariateNormalDiag(loc=tf.zeros(3), scale_diag=tf.constant([1., 2., 3.])) }) init_step_size = 1. _, actual_step_size = tfp.experimental.mcmc.windowed_adaptive_hmc( 1, jd_model, num_adaptation_steps=100, n_chains=20, init_step_size=init_step_size, num_leapfrog_steps=5, discard_tuning=False, trace_fn=lambda *args: unnest.get_innermost(args[-1], 'step_size'), seed=stream(), ) actual_step = [j[0] for j in actual_step_size] expected_step = [1., 1.] self.assertAllCloseNested(expected_step, actual_step)
def hmc_like_log_accept_prob_getter_fn(kernel_results): log_accept_ratio = unnest.get_innermost(kernel_results, 'log_accept_ratio') safe_accept_ratio = tf.where( tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=log_accept_ratio.dtype)) return tf.minimum(safe_accept_ratio, 0.)
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'bootstrap_results')): # Step inner results. inner_results = self.inner_kernel.bootstrap_results(init_state) # Bootstrap the results. results = self._bootstrap_from_inner_results( init_state, inner_results) if self.num_estimation_steps is not None: # We only update the momentum at the end of adaptation phase, # so we do not need to set the momentum here. return results # Set the momentum. diags = [ variance_part.variance() for variance_part in results.running_variance ] inner_results = results.inner_results batch_shape = ps.shape( unnest.get_innermost(inner_results, 'target_log_prob')) init_state_parts = tf.nest.flatten(init_state) momentum_distribution = preconditioning_utils.make_momentum_distribution( init_state_parts, batch_shape, diags, shard_axis_names=self.experimental_shard_axis_names) inner_results = self.momentum_distribution_setter_fn( inner_results, momentum_distribution) proposed = unnest.get_innermost(inner_results, 'proposed_results', default=None) if proposed is not None: proposed = proposed._replace( momentum_distribution=momentum_distribution) inner_results = unnest.replace_innermost( inner_results, proposed_results=proposed) results = results._replace(inner_results=inner_results) return results
def hmc_like_proposed_velocity_getter_fn(kernel_results): """Getter for `proposed_velocity` so it can be inspected.""" final_momentum = unnest.get_innermost(kernel_results, 'final_momentum') proposed_state = unnest.get_innermost(kernel_results, 'proposed_state') momentum_distribution = unnest.get_innermost(kernel_results, 'momentum_distribution', default=None) if momentum_distribution is None: proposed_velocity = final_momentum else: momentum_log_prob = getattr(momentum_distribution, '_log_prob_unnormalized', momentum_distribution.log_prob) kinetic_energy_fn = lambda *args: -momentum_log_prob(*args) _, proposed_velocity = mcmc_util.maybe_call_fn_and_grads( kinetic_energy_fn, final_momentum) # proposed_velocity has the wrong structure when state is a scalar. return tf.nest.pack_sequence_as(proposed_state, tf.nest.flatten(proposed_velocity))
def test_deeply_nested(self): results = _build_deeply_nested(0, 1, 2, 3, 4) self.assertTrue(unnest.has_nested(results, 'unique_nesting')) self.assertTrue(unnest.has_nested(results, 'unique_atypical_nesting')) self.assertTrue(unnest.has_nested(results, 'unique_core')) self.assertFalse(hasattr(self, 'unique_core')) self.assertFalse(unnest.has_nested(results, 'foo')) self.assertEqual(unnest.get_innermost(results, 'unique_nesting'), 3) self.assertEqual(unnest.get_outermost(results, 'unique_nesting'), 0) self.assertEqual( unnest.get_innermost(results, 'unique_atypical_nesting'), 2) self.assertEqual( unnest.get_outermost(results, 'unique_atypical_nesting'), 1) self.assertEqual(unnest.get_innermost(results, 'unique_core'), 4) self.assertEqual(unnest.get_outermost(results, 'unique_core'), 4) self.assertRaises( AttributeError, lambda: unnest.get_innermost(results, 'foo')) self.assertRaises( AttributeError, lambda: unnest.get_outermost(results, 'foo')) self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON), SINGLETON) self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON), SINGLETON)
def trace_results_fn(_, results): """Packs results into a dictionary""" results_dict = {} root_results = results.inner_results step_size = tf.convert_to_tensor( unnest.get_outermost(root_results[0], "step_size") ) results_dict["hmc"] = { "is_accepted": unnest.get_innermost(root_results[0], "is_accepted"), "target_log_prob": unnest.get_innermost( root_results[0], "target_log_prob" ), "step_size": step_size, } def get_move_results(results): return { "is_accepted": results.is_accepted, "target_log_prob": results.accepted_results.target_log_prob, "proposed_delta": tf.stack( [ results.accepted_results.m, results.accepted_results.t, results.accepted_results.delta_t, results.accepted_results.x_star, ] ), } res1 = root_results[1].inner_results results_dict["move/S->E"] = get_move_results(res1[0]) results_dict["move/E->I"] = get_move_results(res1[1]) results_dict["occult/S->E"] = get_move_results(res1[2]) results_dict["occult/E->I"] = get_move_results(res1[3]) return results_dict
def sample_trace_fn(_, pkr): return ( unnest.get_innermost(pkr, 'target_log_prob'), unnest.get_innermost(pkr, 'leapfrogs_taken'), unnest.get_innermost(pkr, 'has_divergence'), unnest.get_innermost(pkr, 'energy'), unnest.get_innermost(pkr, 'log_accept_ratio'), unnest.get_innermost(pkr, 'reach_max_depth'), )
def default_nuts_trace_fn(state, bijector, is_adapting, pkr): """Trace function for `windowed_adaptive_nuts` providing standard diagnostics. Specifically, these match up with a number of diagnostics used by ArviZ [1], to make diagnostics and analysis easier. The names used follow those used in TensorFlow Probability, and will need to be mapped to those used in the ArviZ schema. References: [1]: Kumar, R., Carroll, C., Hartikainen, A., & Martin, O. (2019). ArviZ a unified library for exploratory analysis of Bayesian models in Python. Journal of Open Source Software, 4(33), 1143. Args: state: tf.Tensor Current sampler state, flattened and unconstrained. bijector: tfb.Bijector This can be used to untransform the shape to something with the same shape as will be returned. is_adapting: bool Whether this is an adapting step, or may be treated as a valid MCMC draw. pkr: UncalibratedPreconditionedHamiltonianMonteCarloKernelResults Kernel results from this iteration. Returns: dict with sampler statistics. """ del state, bijector # Unused energy_diff = unnest.get_innermost(pkr, 'log_accept_ratio') return { 'step_size': unnest.get_innermost(pkr, 'step_size'), 'tune': is_adapting, 'target_log_prob': unnest.get_innermost(pkr, 'target_log_prob'), 'diverging': unnest.get_innermost(pkr, 'has_divergence'), 'accept_ratio': tf.minimum(tf.ones_like(energy_diff), tf.exp(energy_diff)), 'variance_scaling': unnest.get_innermost(pkr, 'momentum_distribution').variance(), 'n_steps': unnest.get_innermost(pkr, 'leapfrogs_taken'), 'is_accepted': unnest.get_innermost(pkr, 'is_accepted') }
def test_sequential_step_size(self): stream = test_util.test_seed_stream() jd_model = tfd.JointDistributionSequential( [tfd.HalfNormal(scale=1., name=f'dist_{idx}') for idx in range(4)]) init_step_size = [1., 2., 3.] _, actual_step_size = tfp.experimental.mcmc.windowed_adaptive_nuts( 1, jd_model, num_adaptation_steps=100, n_chains=20, init_step_size=init_step_size, discard_tuning=False, trace_fn=lambda *args: unnest.get_innermost(args[-1], 'step_size'), dist_3=tf.constant(1.), seed=stream(), ) self.assertAllCloseNested(init_step_size, [j[0] for j in actual_step_size])
def testTooFewChains(self, use_static_shape): state = tf.constant([[0.1, 0.2]]) def tlp_fn(x): return tf1.placeholder_with_default( tf.reduce_sum(x, -1), shape=[1] if use_static_shape else [None]) kernel = tfp.experimental.mcmc.SNAPERHamiltonianMonteCarlo( tlp_fn, step_size=0.1, num_adaptation_steps=2, num_mala_steps=100, validate_args=True, ) with self.assertRaisesRegex(Exception, 'SNAPERHMC requires at least 2 chains'): self.evaluate( unnest.get_innermost(kernel.bootstrap_results(state), 'target_log_prob'))
def hmc_like_step_size_getter_fn(kernel_results): """Getter for `step_size` so it can be inspected.""" return unnest.get_innermost(kernel_results, 'step_size')
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope( mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation', 'one_step')): variance_parts = previous_kernel_results.running_variance diags = [ variance_part.variance() for variance_part in variance_parts ] # Set the momentum. batch_ndims = ps.rank( unnest.get_innermost(previous_kernel_results, 'target_log_prob')) state_parts = tf.nest.flatten(current_state) new_momentum_distribution = _make_momentum_distribution( diags, state_parts, batch_ndims) inner_results = self.momentum_distribution_setter_fn( previous_kernel_results.inner_results, new_momentum_distribution) # Step the inner kernel. inner_kwargs = {} if seed is None else dict(seed=seed) new_state, new_inner_results = self.inner_kernel.one_step( current_state, inner_results, **inner_kwargs) new_state_parts = tf.nest.flatten(new_state) new_variance_parts = [] for variance_part, diag, state_part in zip(variance_parts, diags, new_state_parts): # Compute new variance for each variance part, accounting for partial # batching of the variance calculation across chains (ie, some, all, or # none of the chains may share the estimated mass matrix). # # For example, say # # state_part has shape [2, 3, 4] + [5, 6] (batch + event) # variance_part has shape [4] + [5, 6] # log_prob has shape [2, 3, 4] # # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass # matrices, each being shared across a [2, 3]-batch of chains. Note this # division is inferred from the shapes of the state part, the log_prob, # and the user-provided initial running variances. # # Until RunningVariance supports rank > 1 chunking, we need to flatten # the states that go into updating the variance estimates. In the above # example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and # fed to `RunningVariance.update(state_part, axis=0)`, recording # 6 new observations in the running variance calculation. # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and # the resulting momentum distribution will have batch shape of # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part. state_rank = ps.rank(state_part) variance_rank = ps.rank(diag) num_reduce_dims = state_rank - variance_rank state_part_shape = ps.shape(state_part) # This reshape adds a 1 when reduce_dims==0, and collapses all the lead # dimensions to a single one otherwise. reshaped_state = ps.reshape( state_part, ps.concat( [[ps.reduce_prod(state_part_shape[:num_reduce_dims])], state_part_shape[num_reduce_dims:]], axis=0)) # The `axis=0` here removes the leading dimension we got from the # reshape above, so the new_variance_parts have the correct shape again. new_variance_parts.append( variance_part.update(reshaped_state, axis=0)) new_kernel_results = previous_kernel_results._replace( inner_results=new_inner_results, running_variance=new_variance_parts) return new_state, new_kernel_results
def hmc_like_proposed_state_getter_fn(kernel_results): """Getter for `proposed_state` so it can be inspected.""" return unnest.get_innermost(kernel_results, 'proposed_state')
def hmc_like_num_leapfrog_steps_getter_fn(kernel_results): """Getter for `num_leapfrog_steps` so it can be inspected.""" return unnest.get_innermost(kernel_results, 'num_leapfrog_steps')
def hmc_like_momentum_distribution_getter_fn(kernel_results): """Getter for `momentum_distribution` so it can be updated.""" return unnest.get_innermost(kernel_results, 'momentum_distribution')
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope( mcmc_util.make_name(self.name, 'snaper_hamiltonian_monte_carlo', 'one_step')): inner_results = previous_kernel_results.inner_results batch_shape = ps.shape( unnest.get_innermost(previous_kernel_results, 'target_log_prob')) reduce_axes = ps.range(0, ps.size(batch_shape)) step = inner_results.step state_ema_points = previous_kernel_results.state_ema_points kernel = self._make_kernel( batch_shape=batch_shape, step=step, state_ema_points=state_ema_points, state=current_state, mean=previous_kernel_results.ema_mean, variance=previous_kernel_results.ema_variance, principal_component=previous_kernel_results. ema_principal_component, ) inner_results = unnest.replace_innermost( inner_results, momentum_distribution=( kernel.inner_kernel.parameters['momentum_distribution']), # pylint: disable=protected-access ) seed = samplers.sanitize_seed(seed) state_parts, inner_results = kernel.one_step( tf.nest.flatten(current_state), inner_results, seed=seed, ) state = tf.nest.pack_sequence_as(current_state, state_parts) state_ema_points, ema_mean, ema_variance = self._update_state_ema( reduce_axes=reduce_axes, state=state, step=step, state_ema_points=state_ema_points, ema_mean=previous_kernel_results.ema_mean, ema_variance=previous_kernel_results.ema_variance, ) (principal_component_ema_points, ema_principal_component) = self._update_principal_component_ema( reduce_axes=reduce_axes, state=state, step=step, principal_component_ema_points=( previous_kernel_results.principal_component_ema_points), ema_principal_component=( previous_kernel_results.ema_principal_component), ) kernel_results = previous_kernel_results._replace( inner_results=inner_results, ema_mean=ema_mean, ema_variance=ema_variance, state_ema_points=state_ema_points, ema_principal_component=ema_principal_component, principal_component_ema_points=principal_component_ema_points, seed=seed, ) return state, kernel_results