def test_diagonal_mass_matrix_no_distribute(self): """Nothing distributed. Make sure EchoKernel works.""" kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation( EchoKernel(), tfp.experimental.stats.RunningVariance.from_stats( num_samples=10., mean=tf.zeros(3), variance=tf.ones(3))) state = tf.zeros(3) pkr = kernel.bootstrap_results(state) draws = np.random.randn(10, 3).astype(np.float32) def body(pkr_seed, draw): pkr, seed = pkr_seed seed, kernel_seed = samplers.split_seed(seed) _, pkr = kernel.one_step(draw, pkr, seed=kernel_seed) return (pkr, seed) (pkr, _), _ = mcmc_util.trace_scan(body, (pkr, samplers.sanitize_seed(self.key)), draws, lambda _: ()) running_variance = pkr.running_variance[0] emp_mean = draws.sum(axis=0) / 20. emp_squared_residuals = (np.sum( (draws - emp_mean)**2, axis=0) + 10 * emp_mean**2 + 10) self.assertAllClose(emp_mean, running_variance.mean) self.assertAllClose(emp_squared_residuals, running_variance.sum_squared_residuals)
def testComposite(self): auto_normal = auto_composite_tensor.auto_composite_tensor( tfd.Normal, omit_kwargs=('name', )) def _loop_fn(state, element): return state + element def _trace_fn(state): return [state, 2 * state, auto_normal(state, 0.1)] final_state, trace = util.trace_scan(loop_fn=_loop_fn, initial_state=0., elems=[1., 2.], trace_fn=_trace_fn) self.assertAllClose([], tensorshape_util.as_list(final_state.shape)) self.assertAllClose([2], tensorshape_util.as_list(trace[0].shape)) self.assertAllClose([2], tensorshape_util.as_list(trace[1].shape)) self.assertAllClose(3, final_state) self.assertAllClose([1, 3], trace[0]) self.assertAllClose([2, 6], trace[1]) self.assertIsInstance(trace[2], tfd.Normal) self.assertAllClose([1., 3.], trace[2].loc) self.assertAllClose([0.1, 0.1], trace[2].scale)
def testTraceCriterion(self, static_length): final_state, trace = self.evaluate( util.trace_scan( loop_fn=lambda state, element: state + element, initial_state=0, elems=[1, 2, 3, 4, 5, 6, 7], trace_fn=lambda state: state / 2, trace_criterion_fn=lambda state: tf.equal(state % 2, 0), static_trace_allocation_size=3 if static_length else None)) self.assertAllClose(7 + 6 + 5 + 4 + 3 + 2 + 1, final_state) self.assertAllClose([3, 5, 14], trace)
def testConditionFn(self, static_length): final_state, trace = self.evaluate( util.trace_scan( loop_fn=lambda state, element: state + element, initial_state=0, elems=[1, 2, 3, 4, 5, 6, 7], trace_fn=lambda state: state / 2, condition_fn=lambda step, state, num_traced, trace: state < 9, static_trace_allocation_size=4 if static_length else None)) self.assertAllClose(10, final_state) self.assertAllClose([.5, 1.5, 3, 5], trace)
def testBasic(self): def _loop_fn(state, element): return state + element def _trace_fn(state): return [state, state * 2] final_state, trace = util.trace_scan( loop_fn=_loop_fn, initial_state=0, elems=[1, 2], trace_fn=_trace_fn) self.assertAllClose([], tensorshape_util.as_list(final_state.shape)) self.assertAllClose([2], tensorshape_util.as_list(trace[0].shape)) self.assertAllClose([2], tensorshape_util.as_list(trace[1].shape)) final_state, trace = self.evaluate([final_state, trace]) self.assertAllClose(3, final_state) self.assertAllClose([1, 3], trace[0]) self.assertAllClose([2, 6], trace[1])
def run(seed): dist_seed, *seeds = samplers.split_seed(seed, 11) dist = tfp_dist.Sharded(tfd.Sample(tfd.Normal(0., 1.), 3), shard_axis_name=self.axis_name) state = dist.sample(seed=dist_seed) kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation( EchoKernel(), tfp.experimental.stats.RunningVariance.from_stats( num_samples=10., mean=tf.zeros(3), variance=tf.ones(3))) pkr = kernel.bootstrap_results(state) def body(draw_pkr, seed): _, pkr = draw_pkr draw_seed, step_seed = samplers.split_seed(seed) draw = dist.sample(seed=draw_seed) _, pkr = kernel.one_step(draw, pkr, seed=step_seed) return draw, pkr (_, pkr), draws = mcmc_util.trace_scan( body, (tf.zeros(dist.event_shape), pkr), seeds, lambda v: v[0]) return draws, pkr
def sample_chain( num_results, current_state, previous_kernel_results=None, kernel=None, num_burnin_steps=0, num_steps_between_results=0, trace_fn=lambda current_state, kernel_results: kernel_results, return_final_kernel_results=False, parallel_iterations=10, name=None, ): """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps. This function samples from an Markov chain at `current_state` and whose stationary distribution is governed by the supplied `TransitionKernel` instance (`kernel`). This function can sample from multiple chains, in parallel. (Whether or not there are multiple chains is dictated by the `kernel`.) The `current_state` can be represented as a single `Tensor` or a `list` of `Tensors` which collectively represent the current state. Since MCMC states are correlated, it is sometimes desirable to produce additional intermediate states, and then discard them, ending up with a set of states with decreased autocorrelation. See [Owen (2017)][1]. Such "thinning" is made possible by setting `num_steps_between_results > 0`. The chain then takes `num_steps_between_results` extra steps between the steps that make it into the results. The extra steps are never materialized (in calls to `sess.run`), and thus do not increase memory requirements. Warning: when setting a `seed` in the `kernel`, ensure that `sample_chain`'s `parallel_iterations=1`, otherwise results will not be reproducible. In addition to returning the chain state, this function supports tracing of auxiliary variables used by the kernel. The traced values are selected by specifying `trace_fn`. By default, all kernel results are traced but in the future the default will be changed to no results being traced, so plan accordingly. See below for some examples of this feature. Args: num_results: Integer number of Markov chain draws. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. num_burnin_steps: Integer number of chain steps to take before starting to collect results. Default value: 0 (i.e., no burn-in). num_steps_between_results: Integer number of chain steps between collecting a result. Only one out of every `num_steps_between_samples + 1` steps is included in the returned results. The number of returned chain states is still equal to `num_results`. Default value: 0 (i.e., no thinning). trace_fn: A callable that takes in the current chain state and the previous kernel results and return a `Tensor` or a nested collection of `Tensor`s that is then traced along with the chain state. return_final_kernel_results: If `True`, then the final kernel results are returned alongside the chain state and the trace specified by the `trace_fn`. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "mcmc_sample_chain"). Returns: checkpointable_states_and_trace: if `return_final_kernel_results` is `True`. The return value is an instance of `CheckpointableStatesAndTrace`. all_states: if `return_final_kernel_results` is `False` and `trace_fn` is `None`. The return value is a `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state` but with a prepended `num_results`-size dimension. states_and_trace: if `return_final_kernel_results` is `False` and `trace_fn` is not `None`. The return value is an instance of `StatesAndTrace`. #### Examples ##### Sample from a diagonal-variance Gaussian. I.e., ```none for i=1..n: x[i] ~ MultivariateNormal(loc=0, scale=diag(true_stddev)) # likelihood ``` ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions dims = 10 true_stddev = np.sqrt(np.linspace(1., 3., dims)) likelihood = tfd.MultivariateNormalDiag(loc=0., scale_diag=true_stddev) states = tfp.mcmc.sample_chain( num_results=1000, num_burnin_steps=500, current_state=tf.zeros(dims), kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=likelihood.log_prob, step_size=0.5, num_leapfrog_steps=2), trace_fn=None) sample_mean = tf.reduce_mean(states, axis=0) # ==> approx all zeros sample_stddev = tf.sqrt(tf.reduce_mean( tf.squared_difference(states, sample_mean), axis=0)) # ==> approx equal true_stddev ``` ##### Sampling from factor-analysis posteriors with known factors. I.e., ```none # prior w ~ MultivariateNormal(loc=0, scale=eye(d)) for i=1..n: # likelihood x[i] ~ Normal(loc=w^T F[i], scale=1) ``` where `F` denotes factors. ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions # Specify model. def make_prior(dims): return tfd.MultivariateNormalDiag( loc=tf.zeros(dims)) def make_likelihood(weights, factors): return tfd.MultivariateNormalDiag( loc=tf.matmul(weights, factors, adjoint_b=True)) def joint_log_prob(num_weights, factors, x, w): return (make_prior(num_weights).log_prob(w) + make_likelihood(w, factors).log_prob(x)) def unnormalized_log_posterior(w): # Posterior is proportional to: `p(W, X=x | factors)`. return joint_log_prob(num_weights, factors, x, w) # Setup data. num_weights = 10 # == d num_factors = 40 # == n num_chains = 100 weights = make_prior(num_weights).sample(1) factors = tf.random_normal([num_factors, num_weights]) x = make_likelihood(weights, factors).sample() # Sample from Hamiltonian Monte Carlo Markov Chain. # Get `num_results` samples from `num_chains` independent chains. chains_states, kernels_results = tfp.mcmc.sample_chain( num_results=1000, num_burnin_steps=500, current_state=tf.zeros([num_chains, num_weights], name='init_weights'), kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=unnormalized_log_posterior, step_size=0.1, num_leapfrog_steps=2)) # Compute sample stats. sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) # ==> approx equal to weights sample_var = tf.reduce_mean( tf.squared_difference(chains_states, sample_mean), axis=[0, 1]) # ==> less than 1 ``` ##### Custom tracing functions. ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions likelihood = tfd.Normal(loc=0., scale=1.) def sample_chain(trace_fn): return tfp.mcmc.sample_chain( num_results=1000, num_burnin_steps=500, current_state=0., kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=likelihood.log_prob, step_size=0.5, num_leapfrog_steps=2), trace_fn=trace_fn) def trace_log_accept_ratio(states, previous_kernel_results): return previous_kernel_results.log_accept_ratio def trace_everything(states, previous_kernel_results): return previous_kernel_results _, log_accept_ratio = sample_chain(trace_fn=trace_log_accept_ratio) _, kernel_results = sample_chain(trace_fn=trace_everything) acceptance_prob = tf.math.exp(tf.minimum(log_accept_ratio_, 0.)) # Equivalent to, but more efficient than: acceptance_prob = tf.math.exp(tf.minimum( kernel_results.log_accept_ratio_, 0.)) ``` #### References [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler. _Technical Report_, 2017. http://statweb.stanford.edu/~owen/reports/bestthinning.pdf """ if not kernel.is_calibrated: warnings.warn( "supplied `TransitionKernel` is not calibrated. Markov " "chain may not converge to intended target distribution.") with tf.name_scope(name or "mcmc_sample_chain"): num_results = tf.convert_to_tensor(num_results, dtype=tf.int32, name="num_results") num_burnin_steps = tf.convert_to_tensor(num_burnin_steps, dtype=tf.int32, name="num_burnin_steps") num_steps_between_results = tf.convert_to_tensor( num_steps_between_results, dtype=tf.int32, name="num_steps_between_results") current_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name="current_state"), current_state) if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) if trace_fn is None: # It simplifies the logic to use a dummy function here. trace_fn = lambda *args: () no_trace = True else: no_trace = False if trace_fn is sample_chain.__defaults__[4]: warnings.warn( "Tracing all kernel results by default is deprecated. Set " "the `trace_fn` argument to None (the future default " "value) or an explicit callback that traces the values " "you are interested in.") def _trace_scan_fn(state_and_results, num_steps): next_state, current_kernel_results = mcmc_util.smart_for_loop( loop_num_iter=num_steps, body_fn=kernel.one_step, initial_loop_vars=list(state_and_results), parallel_iterations=parallel_iterations) return next_state, current_kernel_results (_, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan( loop_fn=_trace_scan_fn, initial_state=(current_state, previous_kernel_results), elems=tf.one_hot(indices=0, depth=num_results, on_value=1 + num_burnin_steps, off_value=1 + num_steps_between_results, dtype=tf.int32), # pylint: disable=g-long-lambda trace_fn=lambda state_and_results: (state_and_results[0], trace_fn(*state_and_results)), # pylint: enable=g-long-lambda parallel_iterations=parallel_iterations) if return_final_kernel_results: return CheckpointableStatesAndTrace( all_states=all_states, trace=trace, final_kernel_results=final_kernel_results) else: if no_trace: return all_states else: return StatesAndTrace(all_states=all_states, trace=trace)
def particle_filter( observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=smc_kernel.ess_below_threshold, unbiased_gradients=True, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_transitions_per_observation=1, trace_fn=_default_trace_fn, trace_criterion_fn=_always_trace, static_trace_allocation_size=None, parallel_iterations=1, seed=None, name=None): # pylint: disable=g-doc-args """Samples a series of particles representing filtered latent states. The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all observations *up to that time*. Because particles may be resampled, a particle at time `t` may be different from the particle with the same index at time `t + 1`. To reconstruct trajectories by tracing back through the resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} trace_fn: Python `callable` defining the values to be traced at each step, with signature `traced_values = trace_fn(weighted_particles, results)` in which the first argument is an instance of `tfp.experimental.mcmc.WeightedParticles` and the second an instance of `SequentialMonteCarloResults` tuple, and the return value is a structure of `Tensor`s. Default value: `lambda s, r: (s.particles, s.log_weights, r.parent_indices, r.incremental_log_marginal_likelihood)` trace_criterion_fn: optional Python `callable` with signature `trace_this_step = trace_criterion_fn(weighted_particles, results)` taking the same arguments as `trace_fn` and returning a boolean `Tensor`. If `None`, only values from the final step are returned. Default value: `lambda *_: True` (trace every step). static_trace_allocation_size: Optional Python `int` size of trace to allocate statically. This should be an upper bound on the number of steps traced and is used only when the length cannot be statically inferred (for example, if a `trace_criterion_fn` is specified). It is primarily intended for contexts where static shapes are required, such as in XLA-compiled code. Default value: `None`. parallel_iterations: Passed to the internal `tf.while_loop`. Default value: `1`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'particle_filter'`). Returns: traced_results: A structure of Tensors as returned by `trace_fn`. If `trace_criterion_fn==None`, this is computed from the final step; otherwise, each Tensor will have initial dimension `num_steps_traced` and stacks the traced results across all steps. #### References [1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle Filtering without Modifying the Forward Pass. _arXiv preprint arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314 """ init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) num_timesteps = (1 + num_transitions_per_observation * (num_observation_steps - 1)) # If trace criterion is `None`, we'll return only the final results. never_trace = lambda *_: False if trace_criterion_fn is None: static_trace_allocation_size = 0 trace_criterion_fn = never_trace initial_weighted_particles = _particle_filter_initial_weighted_particles( observations=observations, observation_fn=observation_fn, initial_state_prior=initial_state_prior, initial_state_proposal=initial_state_proposal, num_particles=num_particles, seed=init_seed) propose_and_update_log_weights_fn = ( _particle_filter_propose_and_update_log_weights_fn( observations=observations, transition_fn=transition_fn, proposal_fn=proposal_fn, observation_fn=observation_fn, num_transitions_per_observation=num_transitions_per_observation )) kernel = smc_kernel.SequentialMonteCarlo( propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, unbiased_gradients=unbiased_gradients) # Use `trace_scan` rather than `sample_chain` directly because the latter # would force us to trace the state history (with or without thinning), # which is not always appropriate. def seeded_one_step(seed_state_results, _): seed, state, results = seed_state_results one_step_seed, next_seed = samplers.split_seed(seed) next_state, next_results = kernel.one_step(state, results, seed=one_step_seed) return next_seed, next_state, next_results final_seed_state_result, traced_results = mcmc_util.trace_scan( loop_fn=seeded_one_step, initial_state=( loop_seed, initial_weighted_particles, kernel.bootstrap_results(initial_weighted_particles)), elems=tf.ones([num_timesteps]), trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[ 1:]), trace_criterion_fn=( lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda *seed_state_results[1:])), static_trace_allocation_size=static_trace_allocation_size, parallel_iterations=parallel_iterations) if trace_criterion_fn is never_trace: # Return results from just the final step. traced_results = trace_fn(*final_seed_state_result[1:]) return traced_results