def __init__( self, target_log_prob_fn, topology, cumulative_event_offset, nmax, t_range=None, seed=None, name=None, ): """An uncalibrated random walk for event times. :param target_log_prob_fn: the log density of the target distribution :param target_event_id: the position in the last dimension of the events tensor that we wish to move :param t_range: a tuple containing earliest and latest times between which to update occults. :param seed: a random seed :param name: the name of the update step """ self._seed_stream = SeedStream(seed, salt="UncalibratedOccultUpdate") self._name = name self._parameters = dict( target_log_prob_fn=target_log_prob_fn, topology=topology, nmax=nmax, t_range=t_range, seed=seed, name=name, ) self.tx_topology = topology self.initial_state = cumulative_event_offset
def _apply_variational_kernel(self, inputs): if (not isinstance(self.kernel_posterior, independent_lib.Independent) or not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): raise TypeError('`DenseFlipout` requires ' '`kernel_posterior_fn` produce an instance of ' '`tfd.Independent(tfd.Normal)` ' '(saw: \"{}\").'.format( self.kernel_posterior.name)) self.kernel_posterior_affine = normal_lib.Normal( loc=tf.zeros_like(self.kernel_posterior.distribution.loc), scale=self.kernel_posterior.distribution.scale) self.kernel_posterior_affine_tensor = (self.kernel_posterior_tensor_fn( self.kernel_posterior_affine)) self.kernel_posterior_tensor = None input_shape = tf.shape(inputs) batch_shape = input_shape[:-1] seed_stream = SeedStream(self.seed, salt='DenseFlipout') sign_input = random_rademacher(input_shape, dtype=inputs.dtype, seed=seed_stream()) sign_output = random_rademacher(tf.concat( [batch_shape, tf.expand_dims(self.units, 0)], 0), dtype=inputs.dtype, seed=seed_stream()) perturbed_inputs = tf.matmul( inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output outputs = tf.matmul(inputs, self.kernel_posterior.distribution.loc) outputs += perturbed_inputs return outputs
def iid_sample_fn(*args, **kwargs): """Draws iid samples from `fn`.""" with tf.name_scope('iid_sample_fn'): seed = kwargs.pop('seed', None) if samplers.is_stateful_seed(seed): kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')()) def pfor_loop_body(_): with tf.name_scope('iid_sample_fn_stateful_body'): return sample_fn(*args, **kwargs) else: # If a stateless seed arg is passed, split it into `n` different # stateless seeds, so that we don't just get a bunch of copies of the # same sample. if not JAX_MODE: warnings.warn( 'Saw Tensor seed {}, implying stateless sampling. Autovectorized ' 'functions that use stateless sampling may be quite slow because ' 'the current implementation falls back to an explicit loop. This ' 'will be fixed in the future. For now, you will likely see ' 'better performance from stateful sampling, which you can invoke ' 'by passing a Python `int` seed.'.format(seed)) seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless') def pfor_loop_body(i): with tf.name_scope('iid_sample_fn_stateless_body'): return sample_fn(*args, seed=tf.gather(seed, i), **kwargs) draws = parallel_for.pfor(pfor_loop_body, n) return tf.nest.map_structure(unflatten, draws, expand_composites=True)
def _filter_one_step(step, observation, previous_particles, transition_fn, observation_fn, proposal_fn, seed=None): """Advances the particle filter by a single time step.""" with tf.name_scope('filter_one_step'): seed = SeedStream(seed, 'filter_one_step') proposed_particles, proposal_log_weights = _propose_with_log_weights( step=step - 1, particles=previous_particles, transition_fn=transition_fn, proposal_fn=proposal_fn, seed=seed()) observation_log_weights = _compute_observation_log_weights( step, proposed_particles, observation, observation_fn) log_weights = proposal_log_weights + observation_log_weights resampled_particles, resample_indices = _resample(proposed_particles, log_weights, seed=seed()) step_log_marginal_likelihood = tfp_math.reduce_logmeanexp(log_weights, axis=-1) return resampled_particles, resample_indices, step_log_marginal_likelihood
def iid_sample_fn(*args, **kwargs): """Draws iid samples from `fn`.""" pfor_loop_body = lambda _: sample_fn(*args, **kwargs) seed = kwargs.pop('seed', None) try: # Assume that `seed` is a valid stateful seed (Python `int`). kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')()) pfor_loop_body = lambda _: sample_fn(*args, **kwargs) except TypeError as e: # If a stateless seed arg is passed, split it into `n` different stateless # seeds, so that we don't just get a bunch of copies of the same sample. if TENSOR_SEED_MSG_PREFIX not in str(e): raise warnings.warn( 'Saw non-`int` seed {}, implying stateless sampling. ' 'Autovectorized functions that use stateless sampling ' 'may be quite slow because the current implementation ' 'falls back to an explicit loop. This will be fixed in the ' 'future. For now, you will likely see better performance ' 'from stateful sampling, which you can invoke by passing a' 'traditional Python `int` seed.'.format(seed)) seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless') pfor_loop_body = ( lambda i: sample_fn(*args, seed=tf.gather(seed, i), **kwargs)) draws = parallel_for.pfor(pfor_loop_body, n) return tf.nest.map_structure(unflatten, draws, expand_composites=True)
def _filter_one_step(step, observation, previous_particles, log_weights, transition_fn, observation_fn, proposal_fn, resample_criterion_fn, seed=None): """Advances the particle filter by a single time step.""" with tf.name_scope('filter_one_step'): seed = SeedStream(seed, 'filter_one_step') num_particles = prefer_static.shape(log_weights)[-1] proposed_particles, proposal_log_weights = _propose_with_log_weights( step=step - 1, particles=previous_particles, transition_fn=transition_fn, proposal_fn=proposal_fn, seed=seed) observation_log_weights = _compute_observation_log_weights( step, proposed_particles, observation, observation_fn) unnormalized_log_weights = (log_weights + proposal_log_weights + observation_log_weights) step_log_marginal_likelihood = tf.math.reduce_logsumexp( unnormalized_log_weights, axis=-1) log_weights = (unnormalized_log_weights - step_log_marginal_likelihood[..., tf.newaxis]) # Adaptive resampling: resample particles iff the specified criterion. do_resample = tf.convert_to_tensor( resample_criterion_fn(unnormalized_log_weights))[ ..., tf.newaxis] # Broadcast over particles. # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to use # the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a dealbreaker. resampled_particles, resample_indices = _resample(proposed_particles, log_weights, seed=seed) dummy_indices = tf.broadcast_to(prefer_static.range(num_particles), prefer_static.shape(resample_indices)) uniform_weights = (prefer_static.zeros_like(log_weights) - prefer_static.log(num_particles)) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: prefer_static.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (proposed_particles, dummy_indices, log_weights)) return ParticleFilterStepResults( particles=resampled_particles, log_weights=log_weights, parent_indices=resample_indices, step_log_marginal_likelihood=step_log_marginal_likelihood)
def __init__( self, target_log_prob_fn, target_event_id, prev_event_id, next_event_id, initial_state, dmax, mmax, nmax, seed=None, name=None, ): """An uncalibrated random walk for event times. :param target_log_prob_fn: the log density of the target distribution :param target_event_id: the position in the first dimension of the events tensor that we wish to move :param prev_event_id: the position of the previous event in the events tensor :param next_event_id: the position of the next event in the events tensor :param initial_state: the initial state tensor :param seed: a random seed :param name: the name of the update step """ self._seed_stream = SeedStream(seed, salt="UncalibratedEventTimesUpdate") self._name = name self._parameters = dict( target_log_prob_fn=target_log_prob_fn, target_event_id=target_event_id, prev_event_id=prev_event_id, next_event_id=next_event_id, initial_state=initial_state, dmax=dmax, mmax=mmax, nmax=nmax, seed=seed, name=name, ) self.tx_topology = TransitionTopology(prev_event_id, target_event_id, next_event_id) self.time_offsets = tf.range(self.parameters["dmax"])
def _filter_one_step(step, observation, previous_particles, log_weights, transition_fn, observation_fn, proposal_fn, resample_criterion_fn, has_observation=True, seed=None): """Advances the particle filter by a single time step.""" with tf.name_scope('filter_one_step'): seed = SeedStream(seed, 'filter_one_step') num_particles = prefer_static.shape(log_weights)[0] proposed_particles, proposal_log_weights = _propose_with_log_weights( step=step - 1, particles=previous_particles, transition_fn=transition_fn, proposal_fn=proposal_fn, seed=seed) log_weights = tf.nn.log_softmax(proposal_log_weights + log_weights, axis=-1) # If this step has an observation, compute its weights and marginal # likelihood (and otherwise, leave weights unchanged). observation_log_weights = prefer_static.cond( has_observation, lambda: prefer_static.broadcast_to( # pylint: disable=g-long-lambda _compute_observation_log_weights(step, proposed_particles, observation, observation_fn), prefer_static.shape(log_weights)), lambda: tf.zeros_like(log_weights)) unnormalized_log_weights = log_weights + observation_log_weights step_log_marginal_likelihood = tf.math.reduce_logsumexp( unnormalized_log_weights, axis=0) log_weights = (unnormalized_log_weights - step_log_marginal_likelihood) # Adaptive resampling: resample particles iff the specified criterion. do_resample = resample_criterion_fn(unnormalized_log_weights) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to use # the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a dealbreaker. resampled_particles, resample_indices = _resample(proposed_particles, log_weights, resample_independent, seed=seed) uniform_weights = (prefer_static.zeros_like(log_weights) - prefer_static.log(num_particles)) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: prefer_static.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (proposed_particles, _dummy_indices_like(resample_indices), log_weights)) return ParticleFilterStepResults( particles=resampled_particles, log_weights=log_weights, parent_indices=resample_indices, step_log_marginal_likelihood=step_log_marginal_likelihood)
def particle_filter( observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_criterion_fn=ess_below_threshold, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_transitions_per_observation=1, num_steps_state_history_to_pass=None, num_steps_observation_history_to_pass=None, seed=None, name=None): # pylint: disable=g-doc-args """Samples a series of particles representing filtered latent states. The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all observations *up to that time*. Because particles may be resampled, a particle at time `t` may be different from the particle with the same index at time `t + 1`. To reconstruct trajectories by tracing back through the resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} Returns: particles: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing (possibly weighted) samples from the series of filtering distributions `p(latent_states[t] | observations[:t])`. log_weights: `float` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `log_weights[t, :]` are the logarithms of normalized importance weights (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of the particles at time `t`. These may be used in conjunction with `particles` to compute expectations under the series of filtering distributions. parent_indices: `int` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `parent_indices[t, k]` gives the index of the particle at time `t - 1` that the `k`th particle at time `t` is immediately descended from. See also `tfp.experimental.mcmc.reconstruct_trajectories`. step_log_marginal_likelihoods: float `Tensor` of shape `[num_observation_steps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each observed timestep `t`. Note that (by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. ${non_markovian_specification_str} """ seed = SeedStream(seed, 'particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = prefer_static.shape( tf.nest.flatten(observations)[0])[0] num_timesteps = (1 + num_transitions_per_observation * (num_observation_steps - 1)) # If no criterion is specified, default is to resample at every step. if not resample_criterion_fn: resample_criterion_fn = lambda _: True # Dress up the prior and prior proposal as a fake `transition_fn` and # `proposal_fn` respectively. prior_fn = lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_prior, num_particles) prior_proposal_fn = ( None if initial_state_proposal is None else lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_proposal, num_particles)) # Initially the particles all have the same weight, `1. / num_particles`. broadcast_batch_shape = tf.convert_to_tensor(functools.reduce( prefer_static.broadcast_shape, tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []), dtype=tf.int32) log_uniform_weights = prefer_static.zeros( prefer_static.concat([[num_particles], broadcast_batch_shape], axis=0), dtype=tf.float32) - prefer_static.log(num_particles) # Initialize from the prior, and incorporate the first observation. initial_step_results = _filter_one_step( step=0, # `previous_particles` at the first step is a dummy quantity, used only # to convey state structure and num_particles to an optional # proposal fn. previous_particles=prior_fn(0, []).sample(), log_weights=log_uniform_weights, observation=tf.nest.map_structure(lambda x: tf.gather(x, 0), observations), transition_fn=prior_fn, observation_fn=observation_fn, proposal_fn=prior_proposal_fn, resample_criterion_fn=resample_criterion_fn, seed=seed) def _loop_body(step, previous_step_results, accumulated_step_results, state_history): """Take one step in dynamics and accumulate marginal likelihood.""" step_has_observation = ( # The second of these conditions subsumes the first, but both are # useful because the first can often be evaluated statically. prefer_static.equal(num_transitions_per_observation, 1) | prefer_static.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation current_observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) history_to_pass_into_fns = {} if num_steps_observation_history_to_pass: history_to_pass_into_fns[ 'observation_history'] = _gather_history( observations, observation_idx, num_steps_observation_history_to_pass) if num_steps_state_history_to_pass: history_to_pass_into_fns['state_history'] = state_history new_step_results = _filter_one_step( step=step, previous_particles=previous_step_results.particles, log_weights=previous_step_results.log_weights, observation=current_observation, transition_fn=functools.partial(transition_fn, **history_to_pass_into_fns), observation_fn=functools.partial(observation_fn, **history_to_pass_into_fns), proposal_fn=(None if proposal_fn is None else functools.partial( proposal_fn, **history_to_pass_into_fns)), resample_criterion_fn=resample_criterion_fn, has_observation=step_has_observation, seed=seed) return _update_loop_variables(step, new_step_results, accumulated_step_results, state_history) loop_results = tf.while_loop( cond=lambda step, *_: step < num_timesteps, body=_loop_body, loop_vars=_initialize_loop_variables( initial_step_results, num_steps_state_history_to_pass, num_timesteps)) results = tf.nest.map_structure(lambda ta: ta.stack(), loop_results.accumulated_step_results) if num_transitions_per_observation != 1: # Return a log-prob for each observed step. observed_steps = prefer_static.range( 0, num_timesteps, num_transitions_per_observation) results = results._replace(step_log_marginal_likelihood=tf.gather( results.step_log_marginal_likelihood, observed_steps)) return results
def infer_trajectories(observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_criterion_fn=ess_below_threshold, rejuvenation_kernel_fn=None, num_transitions_per_observation=1, num_steps_state_history_to_pass=None, num_steps_observation_history_to_pass=None, seed=None, name=None): # pylint: disable=g-doc-args """Use particle filtering to sample from the posterior over trajectories. ${particle_filter_arg_str} Returns: trajectories: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing unbiased samples from the posterior distribution `p(latent_states | observations)`. step_log_marginal_likelihoods: float `Tensor` of shape `[num_observation_steps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each timestep `t`. Note that (by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. ${non_markovian_specification_str} #### Examples **Tracking unknown position and velocity**: Let's consider tracking an object moving in a one-dimensional space. We'll define a dynamical system by specifying an `initial_state_prior`, a `transition_fn`, and `observation_fn`. The structure of the latent state space is determined by the prior distribution. Here, we'll define a state space that includes the object's current position and velocity: ```python initial_state_prior = tfd.JointDistributionNamed({ 'position': tfd.Normal(loc=0., scale=1.), 'velocity': tfd.Normal(loc=0., scale=0.1)}) ``` The `transition_fn` specifies the evolution of the system. It should return a distribution over latent states of the same structure as the prior. Here, we'll assume that the position evolves according to the velocity, with a small random drift, and the velocity also changes slowly, following a random drift: ```python def transition_fn(_, previous_state): return tfd.JointDistributionNamed({ 'position': tfd.Normal( loc=previous_state['position'] + previous_state['velocity'], scale=0.1), 'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)}) ``` The `observation_fn` specifies the process by which the system is observed at each time step. Let's suppose we observe only a noisy version of the = current position. ```python def observation_fn(_, state): return tfd.Normal(loc=state['position'], scale=0.1) ``` Now let's track our object. Suppose we've been given observations corresponding to an initial position of `0.4` and constant velocity of `0.01`: ```python # Generate simulated observations. observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01), scale=0.1).sample() # Run particle filtering to sample plausible trajectories. (trajectories, # {'position': [40, 1000], 'velocity': [40, 1000]} lps) = tfp.experimental.mcmc.infer_trajectories( observations=observed_positions, initial_state_prior=initial_state_prior, transition_fn=transition_fn, observation_fn=observation_fn, num_particles=1000) ``` For all `i`, `trajectories['position'][:, i]` is a sample from the posterior over position sequences, given the observations: `p(state[0:T] | observations[0:T])`. Often, the sampled trajectories will be highly redundant in their earlier timesteps, because most of the initial particles have been discarded through resampling (this problem is known as 'particle degeneracy'; see section 3.5 of [Doucet and Johansen][1]). In such cases it may be useful to also consider the series of *filtering* distributions `p(state[t] | observations[:t])`, in which each latent state is inferred conditioned only on observations up to that point in time; these may be computed using `tfp.mcmc.experimental.particle_filter`. #### References [1] Arnaud Doucet and Adam M. Johansen. A tutorial on particle filtering and smoothing: Fifteen years later. _Handbook of nonlinear filtering_, 12(656-704), 2009. https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf """ with tf.name_scope(name or 'infer_trajectories') as name: seed = SeedStream(seed, 'infer_trajectories') (particles, log_weights, parent_indices, step_log_marginal_likelihoods) = particle_filter( observations=observations, initial_state_prior=initial_state_prior, transition_fn=transition_fn, observation_fn=observation_fn, num_particles=num_particles, initial_state_proposal=initial_state_proposal, proposal_fn=proposal_fn, resample_criterion_fn=resample_criterion_fn, rejuvenation_kernel_fn=rejuvenation_kernel_fn, num_transitions_per_observation=num_transitions_per_observation, num_steps_state_history_to_pass=num_steps_state_history_to_pass, num_steps_observation_history_to_pass=( num_steps_observation_history_to_pass), seed=seed, name=name) weighted_trajectories = reconstruct_trajectories( particles, parent_indices) # Resample all steps of the trajectories using the final weights. resample_indices = categorical.Categorical( dist_util.move_dimension(log_weights[-1, ...], source_idx=0, dest_idx=-1)).sample(num_particles, seed=seed) trajectories = tf.nest.map_structure( lambda x: _batch_gather(x, resample_indices, axis=1), weighted_trajectories) return trajectories, step_log_marginal_likelihoods
def particle_filter( observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=ess_below_threshold, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_transitions_per_observation=1, trace_fn=_default_trace_fn, step_indices_to_trace=None, seed=None, name=None): # pylint: disable=g-doc-args """Samples a series of particles representing filtered latent states. The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all observations *up to that time*. Because particles may be resampled, a particle at time `t` may be different from the particle with the same index at time `t + 1`. To reconstruct trajectories by tracing back through the resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} trace_fn: Python `callable` defining the values to be traced at each step. It takes a `ParticleFilterStepResults` tuple and returns a structure of `Tensor`s. The default function returns `(particles, log_weights, parent_indices, step_log_likelihood)`. step_indices_to_trace: optional `int` `Tensor` listing, in increasing order, the indices of steps at which to record the values traced by `trace_fn`. If `None`, the default behavior is to trace at every timestep, equivalent to specifying `step_indices_to_trace=tf.range(num_timsteps)`. seed: Python `int` seed for random ops. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'particle_filter'`). Returns: particles: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing (possibly weighted) samples from the series of filtering distributions `p(latent_states[t] | observations[:t])`. log_weights: `float` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `log_weights[t, :]` are the logarithms of normalized importance weights (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of the particles at time `t`. These may be used in conjunction with `particles` to compute expectations under the series of filtering distributions. parent_indices: `int` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `parent_indices[t, k]` gives the index of the particle at time `t - 1` that the `k`th particle at time `t` is immediately descended from. See also `tfp.experimental.mcmc.reconstruct_trajectories`. incremental_log_marginal_likelihoods: float `Tensor` of shape `[num_observation_steps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each observed timestep `t`. Note that (by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. """ seed = SeedStream(seed, 'particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) num_timesteps = (1 + num_transitions_per_observation * (num_observation_steps - 1)) # If no criterion is specified, default is to resample at every step. if not resample_criterion_fn: resample_criterion_fn = lambda _: True # Canonicalize the list of steps to trace as a rank-1 tensor of (sorted) # positive integers. E.g., `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`. if step_indices_to_trace is not None: (step_indices_to_trace, traced_steps_have_rank_zero) = _canonicalize_steps_to_trace( step_indices_to_trace, num_timesteps) # Dress up the prior and prior proposal as a fake `transition_fn` and # `proposal_fn` respectively. prior_fn = lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_prior, num_particles) prior_proposal_fn = ( None if initial_state_proposal is None else lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_proposal, num_particles)) # Initially the particles all have the same weight, `1. / num_particles`. broadcast_batch_shape = tf.convert_to_tensor(functools.reduce( ps.broadcast_shape, tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []), dtype=tf.int32) log_uniform_weights = ps.zeros( ps.concat([[num_particles], broadcast_batch_shape], axis=0), dtype=tf.float32) - ps.log(num_particles) # Initialize from the prior and incorporate the first observation. dummy_previous_step = ParticleFilterStepResults( particles=prior_fn(0, []).sample(), log_weights=log_uniform_weights, parent_indices=None, incremental_log_marginal_likelihood=0., accumulated_log_marginal_likelihood=0.) initial_step_results = _filter_one_step( step=0, # `previous_particles` at the first step is a dummy quantity, used only # to convey state structure and num_particles to an optional # proposal fn. previous_step_results=dummy_previous_step, observation=tf.nest.map_structure(lambda x: tf.gather(x, 0), observations), transition_fn=prior_fn, observation_fn=observation_fn, proposal_fn=prior_proposal_fn, resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, seed=seed) def _loop_body(step, previous_step_results, accumulated_traced_results, num_steps_traced): """Take one step in dynamics and accumulate marginal likelihood.""" step_has_observation = ( # The second of these conditions subsumes the first, but both are # useful because the first can often be evaluated statically. ps.equal(num_transitions_per_observation, 1) | ps.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation current_observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) new_step_results = _filter_one_step( step=step, previous_step_results=previous_step_results, observation=current_observation, transition_fn=transition_fn, observation_fn=observation_fn, proposal_fn=proposal_fn, resample_criterion_fn=resample_criterion_fn, resample_fn=resample_fn, has_observation=step_has_observation, seed=seed) return _update_loop_variables( step=step, current_step_results=new_step_results, accumulated_traced_results=accumulated_traced_results, trace_fn=trace_fn, step_indices_to_trace=step_indices_to_trace, num_steps_traced=num_steps_traced) loop_results = tf.while_loop( cond=lambda step, *_: step < num_timesteps, body=_loop_body, loop_vars=_initialize_loop_variables( initial_step_results=initial_step_results, num_timesteps=num_timesteps, trace_fn=trace_fn, step_indices_to_trace=step_indices_to_trace)) results = tf.nest.map_structure( lambda ta: ta.stack(), loop_results.accumulated_traced_results) if step_indices_to_trace is not None: # If we were passed a rank-0 (single scalar) step to trace, don't # return a time axis in the returned results. results = ps.cond( traced_steps_have_rank_zero, lambda: tf.nest.map_structure(lambda x: x[0, ...], results), lambda: results) return results
def particle_filter( observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument seed=None, name=None): """Samples a series of particles representing filtered latent states. Each latent state is a `Tensor` or nested structure of `Tensor`s, as defined by the `initial_state_prior`. Each of the `transition_fn`, `observation_fn`, and `proposal_fn` args, if specified, takes arguments `(step, state)`, where `state` represents the latent state at timestep `step`. Args: observations: a (structure of) Tensors, each of shape `concat([[num_timesteps, b1, ..., bN], event_shape])` with optional batch dimensions `b1, ..., bN`. initial_state_prior: a (joint) distribution over the initial latent state, with optional batch shape `[b1, ..., bN]`. transition_fn: callable returning a (joint) distribution over the next latent state. observation_fn: callable returning a (joint) distribution over the current observation. num_particles: `int` `Tensor` number of particles. initial_state_proposal: a (joint) distribution over the initial latent state, with optional batch shape `[b1, ..., bN]`. If `None`, the initial particles are proposed from the `initial_state_prior`. Default value: `None`. proposal_fn: callable returning a (joint) proposal distribution over the next latent state. If `None`, the dynamics model is used ( `proposal_fn == transition_fn`). Default value: `None`. rejuvenation_kernel_fn: optional Python `callable` with signature `transition_kernel = rejuvenation_kernel_fn(target_log_prob_fn)` where `target_log_prob_fn` is a provided callable evaluating `p(x[t] | y[t], x[t-1])` at each step `t`, and `transition_kernel` should be an instance of `tfp.mcmc.TransitionKernel`. Default value: `None`. # TODO(davmre): not yet supported. seed: Python `int` seed for random ops. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'particle_filter'`). Returns: particles: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, b1, ..., bN, num_particles], event_shape])`, representing unbiased samples from the series of (filtering) distributions `p(latent_states[t] | observations[:t])`. parent_indices: `int` `Tensor` of shape `[num_timesteps, b1, ..., bN, num_particles]`, such that `parent_indices[t, k]` gives the index of the particle at time `t - 1` that the `k`th particle at time `t` is immediately descended from. See also `tfp.experimental.mcmc.reconstruct_trajectories`. step_log_marginal_likelihoods: float `Tensor` of shape `[num_timesteps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each timestep `t`. Note that ( by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. #### Examples **Tracking unknown position and velocity**: Let's consider tracking an object moving in a one-dimensional space. We'll define a dynamical system by specifying an `initial_state_prior`, a `transition_fn`, and `observation_fn`. The structure of the latent state space is determined by the prior distribution. Here, we'll define a state space that includes the object's current position and velocity: ```python initial_state_prior = tfd.JointDistributionNamed({ 'position': tfd.Normal(loc=0., scale=1.), 'velocity': tfd.Normal(loc=0., scale=0.1)}) ``` The `transition_fn` specifies the evolution of the system. It should return a distribution over latent states of the same structure as the prior. Here, we'll assume that the position evolves according to the velocity, with a small random drift, and the velocity also changes slowly, following a random drift: ```python def transition_fn(_, previous_state): return tfd.JointDistributionNamed({ 'position': tfd.Normal( loc=previous_state['position'] + previous_state['velocity'], scale=0.1), 'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)}) ``` The `observation_fn` specifies the process by which the system is observed at each time step. Let's suppose we observe only a noisy version of the = current position. ```python def observation_fn(_, state): return tfd.Normal(loc=state['position'], scale=0.1) ``` Now let's track our object. Suppose we've been given observations corresponding to an initial position of `0.4` and constant velocity of `0.01`: ```python # Generate simulated observations. observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01), scale=0.1).sample() # Run particle filtering. (particles, # {'position': [40, 1000], 'velocity': [40, 1000]} parent_indices, # [40, 1000] _) = tfp.experimental.mcmc.particle_filter( observations=observed_positions, initial_state_prior=initial_state_prior, transition_fn=transition_fn, observation_fn=observation_fn, num_particles=1000) ``` The particle filter samples from the "filtering" distribution over latent states: at each point in time, this is the distribution conditioned on all observations *up to that time*. For example, `particles['position'][t]` contains `num_particles` samples from the distribution `p(position[t] | observed_positions[:t])`. Because particles may be resampled, there is no relationship between a particle at time `t` and the particle with the same index at time `t + 1`. We may, however, trace back through the resampling steps to reconstruct samples of entire latent trajectories. ```python trajectories = tfp.experimental.mcmc.reconstruct_trajectories( particles, parent_indices) ``` Here, `trajectories['position'][:, i]` contains the history of positions sampled for what became the `i`th particle at the final timestep. These are samples from the 'smoothed' posterior over trajectories, given all observations: `p(position[0:T] | observed_position[0:T])`. """ seed = SeedStream(seed, 'particle_filter') with tf.name_scope(name or 'particle_filter'): num_timesteps = prefer_static.shape( tf.nest.flatten(observations)[0])[0] # Dress up the prior and prior proposal as a fake `transition_fn` and # `proposal_fn` respectively. prior_fn = lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_prior, num_particles) prior_proposal_fn = ( None if initial_state_proposal is None else lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_proposal, num_particles)) # Initialize from the prior, and incorporate the first observation. ( initial_resampled_particles, initial_parent_indices, initial_step_log_marginal_likelihood ) = _filter_one_step( step=0, # `previous_particles` at the first step is a dummy quantity, used only # to convey state structure and num_particles to an optional # proposal fn. previous_particles=prior_fn(0, []).sample(), observation=tf.nest.map_structure(lambda x: tf.gather(x, 0), observations), transition_fn=prior_fn, observation_fn=observation_fn, proposal_fn=prior_proposal_fn, seed=seed()) # Initialize the loop state with the initial values. all_resampled_particles = tf.nest.map_structure( lambda x: tf.TensorArray(dtype=x.dtype, size=num_timesteps).write( 0, x), initial_resampled_particles) all_parent_indices = tf.TensorArray(dtype=tf.int32, size=num_timesteps).write( 0, initial_parent_indices) all_step_log_marginal_likelihoods = tf.TensorArray( dtype=initial_step_log_marginal_likelihood.dtype, size=num_timesteps).write(0, initial_step_log_marginal_likelihood) def _loop_body(step, resampled_particles, all_resampled_particles, all_parent_indices, all_step_log_marginal_likelihoods): """Take one step in dynamics and accumulate marginal likelihood.""" current_observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, step), observations) (resampled_particles, parent_indices, step_log_marginal_likelihood) = _filter_one_step( step=step, previous_particles=resampled_particles, observation=current_observation, transition_fn=transition_fn, observation_fn=observation_fn, proposal_fn=proposal_fn, seed=seed()) all_resampled_particles = tf.nest.map_structure( lambda x, y: x.write(step, y), all_resampled_particles, resampled_particles) all_parent_indices = all_parent_indices.write(step, parent_indices) all_step_log_marginal_likelihoods = ( all_step_log_marginal_likelihoods.write( step, step_log_marginal_likelihood)) return (step + 1, resampled_particles, all_resampled_particles, all_parent_indices, all_step_log_marginal_likelihoods) # This loop could (and perhaps should) be written as a tf.scan, rather than # an explicit while_loop. It is written as an explicit while_loop to allow # for anticipated future changes that may not fit the form of a scan loop. (_, _, all_resampled_particles, all_parent_indices, all_step_log_marginal_likelihoods) = tf.while_loop( cond=lambda step, *_: step < num_timesteps, body=_loop_body, loop_vars=(1, initial_resampled_particles, all_resampled_particles, all_parent_indices, all_step_log_marginal_likelihoods)) return (tf.nest.map_structure(lambda ta: ta.stack(), all_resampled_particles), all_parent_indices.stack(), all_step_log_marginal_likelihoods.stack())
def estimate_parameters(self, observations, num_iterations, num_particles, initial_perturbation_scale, cooling_schedule, seed=None, name=None, **kwargs): """Runs multiple iterations of filtering following a cooling schedule. Args: observations: observed `Tensor` value(s) on which to condition the parameter estimate. num_iterations: `int `Tensor` number of filtering iterations to run. num_particles: scalar int `Tensor` number of particles to use. initial_perturbation_scale: scalar float `Tensor`, or any structure of float `Tensor`s broadcasting to the same shape as the (unconstrained) parameters, specifying the scale (standard deviation) of Gaussian perturbations to each parameter at the first timestep. cooling_schedule: callable with signature `cooling_factor = cooling_schedule(iteration)` for `iteration` in `[0, ..., num_iterations - 1]`. The filter is invoked with perturbations of scale `initial_perturbation_scale * cooling_schedule(iteration)`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: `str` name for ops constructed by this method. **kwargs: additional keyword arguments passed to `tfp.experimental.mcmc.infer_trajectories`. Returns: final_parameter_particles: structure of `Tensor`s matching `self.parameter_prior`, each with batch shape `[num_iterations, num_particles]`. These are the populations of particles representing the parameter estimate after each iteration of filtering. """ seed = SeedStream(seed, 'iterated_filter_estimate_parameters') with self._name_scope(name or 'estimate_parameters'): initial_perturbation_scale = tf.convert_to_tensor( initial_perturbation_scale, name='initial_perturbation_scale') # Get initial parameter particles from the first filtering iteration. initial_unconstrained_parameters = self.one_step( observations=observations, num_particles=num_particles, perturbation_scale=initial_perturbation_scale, seed=seed, **kwargs) # Run the remaining iterations and accumulate the results. @tf.function(autograph=False) def loop_body(unconstrained_parameters, cooling_fraction): return self.one_step( observations=observations, num_particles=num_particles, perturbation_scale=tf.nest.map_structure( lambda s: cooling_fraction * s, initial_perturbation_scale), initial_unconstrained_parameters=unconstrained_parameters, seed=seed, **kwargs) estimated_unconstrained_parameters = tf.scan( fn=loop_body, elems=cooling_schedule(tf.range(1, num_iterations)), initializer=initial_unconstrained_parameters) return self.parameter_constraining_bijector.forward( estimated_unconstrained_parameters)
def particle_filter( observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=smc_kernel.ess_below_threshold, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_transitions_per_observation=1, trace_fn=_default_trace_fn, trace_criterion_fn=_always_trace, static_trace_allocation_size=None, parallel_iterations=1, seed=None, name=None): # pylint: disable=g-doc-args """Samples a series of particles representing filtered latent states. The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all observations *up to that time*. Because particles may be resampled, a particle at time `t` may be different from the particle with the same index at time `t + 1`. To reconstruct trajectories by tracing back through the resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} trace_fn: Python `callable` defining the values to be traced at each step, with signature `traced_values = trace_fn(weighted_particles, results)` in which the first argument is an instance of `tfp.experimental.mcmc.WeightedParticles` and the second an instance of `SequentialMonteCarloResults` tuple, and the return value is a structure of `Tensor`s. Default value: `lambda s, r: (s.particles, s.log_weights, r.parent_indices, r.incremental_log_marginal_likelihood)` trace_criterion_fn: optional Python `callable` with signature `trace_this_step = trace_criterion_fn(weighted_particles, results)` taking the same arguments as `trace_fn` and returning a boolean `Tensor`. If `None`, only values from the final step are returned. Default value: `lambda *_: True` (trace every step). static_trace_allocation_size: Optional Python `int` size of trace to allocate statically. This should be an upper bound on the number of steps traced and is used only when the length cannot be statically inferred (for example, if a `trace_criterion_fn` is specified). It is primarily intended for contexts where static shapes are required, such as in XLA-compiled code. Default value: `None`. parallel_iterations: Passed to the internal `tf.while_loop`. Default value: `1`. seed: Python `int` seed for random ops. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'particle_filter'`). Returns: traced_results: A structure of Tensors as returned by `trace_fn`. If `trace_criterion_fn==None`, this is computed from the final step; otherwise, each Tensor will have initial dimension `num_steps_traced` and stacks the traced results across all steps. """ seed = SeedStream(seed, 'particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) num_timesteps = (1 + num_transitions_per_observation * (num_observation_steps - 1)) # If trace criterion is `None`, we'll return only the final results. never_trace = lambda *_: False if trace_criterion_fn is None: static_trace_allocation_size = 0 trace_criterion_fn = never_trace initial_weighted_particles = _particle_filter_initial_weighted_particles( observations=observations, observation_fn=observation_fn, initial_state_prior=initial_state_prior, initial_state_proposal=initial_state_proposal, num_particles=num_particles, seed=seed) propose_and_update_log_weights_fn = _particle_filter_propose_and_update_log_weights_fn( observations=observations, transition_fn=transition_fn, proposal_fn=proposal_fn, observation_fn=observation_fn, num_transitions_per_observation=num_transitions_per_observation) kernel = smc_kernel.SequentialMonteCarlo( propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn) # Use `trace_scan` rather than `sample_chain` directly because the latter # would force us to trace the state history (with or without thinning), # which is not always appropriate. final_result, traced_results = mcmc_util.trace_scan( loop_fn=(lambda state_and_results, _: kernel.one_step( *state_and_results)), initial_state=( initial_weighted_particles, kernel.bootstrap_results(initial_weighted_particles)), elems=tf.ones([num_timesteps]), trace_fn=lambda state_and_results: trace_fn(*state_and_results), trace_criterion_fn=(lambda state_and_results: trace_criterion_fn( *state_and_results)), static_trace_allocation_size=static_trace_allocation_size, parallel_iterations=parallel_iterations) if trace_criterion_fn is never_trace: # Return results from just the final step. traced_results = trace_fn(*final_result) return traced_results
def one_step(self, state, kernel_results, seed=None): """Takes one Sequential Monte Carlo inference step. Args: state: instance of `tfp.experimental.mcmc.WeightedParticles` representing the current particles with (log) weights. The `log_weights` must be a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The `particles` may be any structure of `Tensor`s, each of which must have shape `concat([log_weights.shape, event_shape])` for some `event_shape`, which may vary across components. kernel_results: instance of `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results from a previous step. seed: Optional Python integer to seed the random number generator. If provided, overrides the class-level seed set in `__init__`. Returns: state: instance of `tfp.experimental.mcmc.WeightedParticles` representing new particles with (log) weights. kernel_results: instance of `tfp.experimental.mcmc.SequentialMonteCarloResults`. """ with tf.name_scope(self.name): with tf.name_scope('one_step'): seed = SeedStream(seed if seed else self.seed, 'smc_one_step') state = WeightedParticles(*state) # Canonicalize. num_particles = ps.size0(state.log_weights) # Propose new particles and update weights for this step, unless it's # the initial step, in which case, use the user-provided initial # particles and weights. proposed_state = self.propose_and_update_log_weights_fn( # Propose state[t] from state[t - 1]. ps.maximum(0, kernel_results.steps - 1), state, seed=seed()) is_initial_step = ps.equal(kernel_results.steps, 0) # TODO(davmre): this `where` assumes the state size didn't change. state = tf.nest.map_structure( lambda a, b: ps.where(is_initial_step, a, b), state, proposed_state) normalized_log_weights = tf.nn.log_softmax(state.log_weights, axis=0) # Every entry of `log_weights` differs from `normalized_log_weights` # by the same normalizing constant. We extract that constant by # examining an arbitrary entry. incremental_log_marginal_likelihood = ( state.log_weights[0] - normalized_log_weights[0]) do_resample = self.resample_criterion_fn(state) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to # use the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a # dealbreaker. resampled_particles, resample_indices = weighted_resampling.resample( state.particles, state.log_weights, self.resample_fn, seed=seed) uniform_weights = tf.fill( tf.shape(state.log_weights), value=-tf.math.log( tf.cast(num_particles, state.log_weights.dtype))) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: ps.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (state.particles, _dummy_indices_like(resample_indices), normalized_log_weights)) return ( WeightedParticles(particles=resampled_particles, log_weights=log_weights), SequentialMonteCarloResults( steps=kernel_results.steps + 1, parent_indices=resample_indices, incremental_log_marginal_likelihood=( incremental_log_marginal_likelihood), accumulated_log_marginal_likelihood=( kernel_results.accumulated_log_marginal_likelihood + incremental_log_marginal_likelihood)))
def particle_filter(observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_steps_state_history_to_pass=None, num_steps_observation_history_to_pass=None, seed=None, name=None): # pylint: disable=g-doc-args """Samples a series of particles representing filtered latent states. The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all observations *up to that time*. Because particles may be resampled, a particle at time `t` may be different from the particle with the same index at time `t + 1`. To reconstruct trajectories by tracing back through the resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} Returns: particles: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, b1, ..., bN, num_particles], event_shape])`, representing unbiased samples from the series of (filtering) distributions `p(latent_states[t] | observations[:t])`. parent_indices: `int` `Tensor` of shape `[num_timesteps, b1, ..., bN, num_particles]`, such that `parent_indices[t, k]` gives the index of the particle at time `t - 1` that the `k`th particle at time `t` is immediately descended from. See also `tfp.experimental.mcmc.reconstruct_trajectories`. step_log_marginal_likelihoods: float `Tensor` of shape `[num_timesteps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each timestep `t`. Note that ( by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. ${non_markovian_specification_str} """ seed = SeedStream(seed, 'particle_filter') with tf.name_scope(name or 'particle_filter'): num_timesteps = prefer_static.shape( tf.nest.flatten(observations)[0])[0] # Dress up the prior and prior proposal as a fake `transition_fn` and # `proposal_fn` respectively. prior_fn = lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_prior, num_particles) prior_proposal_fn = ( None if initial_state_proposal is None else lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_proposal, num_particles)) # Initialize from the prior, and incorporate the first observation. (initial_resampled_particles, initial_parent_indices, initial_step_log_marginal_likelihood) = _filter_one_step( step=0, # `previous_particles` at the first step is a dummy quantity, used only # to convey state structure and num_particles to an optional # proposal fn. previous_particles=prior_fn(0, []).sample(), observation=tf.nest.map_structure( lambda x: tf.gather(x, 0), observations), transition_fn=prior_fn, observation_fn=observation_fn, proposal_fn=prior_proposal_fn, seed=seed()) loop_vars = _initialize_accumulated_quantities( initial_resampled_particles, initial_parent_indices, initial_step_log_marginal_likelihood, num_steps_state_history_to_pass, num_timesteps) def _loop_body(step, resampled_particles, accumulated_quantities): """Take one step in dynamics and accumulate marginal likelihood.""" current_observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, step), observations) history_to_pass_into_fns = {} if num_steps_observation_history_to_pass: history_to_pass_into_fns['observation_history'] = _gather_history( observations, step, num_steps_observation_history_to_pass) if num_steps_state_history_to_pass: history_to_pass_into_fns['state_history'] = ( accumulated_quantities.state_history) (resampled_particles, parent_indices, step_log_marginal_likelihood) = _filter_one_step( step=step, previous_particles=resampled_particles, observation=current_observation, transition_fn=functools.partial( transition_fn, **history_to_pass_into_fns), observation_fn=functools.partial( observation_fn, **history_to_pass_into_fns), proposal_fn=( None if proposal_fn is None else functools.partial(proposal_fn, **history_to_pass_into_fns)), seed=seed()) new_accumulated_quantities = _write_accumulated_quantities( step, accumulated_quantities, resampled_particles, parent_indices, step_log_marginal_likelihood) return step + 1, resampled_particles, new_accumulated_quantities (_, _, loop_results) = tf.while_loop( cond=lambda step, *_: step < num_timesteps, body=_loop_body, loop_vars=(1, initial_resampled_particles, loop_vars)) return (tf.nest.map_structure(lambda ta: ta.stack(), loop_results.all_resampled_particles), loop_results.all_parent_indices.stack(), loop_results.all_step_log_marginal_likelihoods.stack())