예제 #1
0
 def __init__(
     self,
     target_log_prob_fn,
     topology,
     cumulative_event_offset,
     nmax,
     t_range=None,
     seed=None,
     name=None,
 ):
     """An uncalibrated random walk for event times.
     :param target_log_prob_fn: the log density of the target distribution
     :param target_event_id: the position in the last dimension of the events
                             tensor that we wish to move
     :param t_range: a tuple containing earliest and latest times between which 
                     to update occults.
     :param seed: a random seed
     :param name: the name of the update step
     """
     self._seed_stream = SeedStream(seed, salt="UncalibratedOccultUpdate")
     self._name = name
     self._parameters = dict(
         target_log_prob_fn=target_log_prob_fn,
         topology=topology,
         nmax=nmax,
         t_range=t_range,
         seed=seed,
         name=name,
     )
     self.tx_topology = topology
     self.initial_state = cumulative_event_offset
예제 #2
0
    def _apply_variational_kernel(self, inputs):
        if (not isinstance(self.kernel_posterior, independent_lib.Independent)
                or not isinstance(self.kernel_posterior.distribution,
                                  normal_lib.Normal)):
            raise TypeError('`DenseFlipout` requires '
                            '`kernel_posterior_fn` produce an instance of '
                            '`tfd.Independent(tfd.Normal)` '
                            '(saw: \"{}\").'.format(
                                self.kernel_posterior.name))
        self.kernel_posterior_affine = normal_lib.Normal(
            loc=tf.zeros_like(self.kernel_posterior.distribution.loc),
            scale=self.kernel_posterior.distribution.scale)
        self.kernel_posterior_affine_tensor = (self.kernel_posterior_tensor_fn(
            self.kernel_posterior_affine))
        self.kernel_posterior_tensor = None

        input_shape = tf.shape(inputs)
        batch_shape = input_shape[:-1]

        seed_stream = SeedStream(self.seed, salt='DenseFlipout')

        sign_input = random_rademacher(input_shape,
                                       dtype=inputs.dtype,
                                       seed=seed_stream())
        sign_output = random_rademacher(tf.concat(
            [batch_shape, tf.expand_dims(self.units, 0)], 0),
                                        dtype=inputs.dtype,
                                        seed=seed_stream())
        perturbed_inputs = tf.matmul(
            inputs * sign_input,
            self.kernel_posterior_affine_tensor) * sign_output

        outputs = tf.matmul(inputs, self.kernel_posterior.distribution.loc)
        outputs += perturbed_inputs
        return outputs
  def iid_sample_fn(*args, **kwargs):
    """Draws iid samples from `fn`."""

    with tf.name_scope('iid_sample_fn'):

      seed = kwargs.pop('seed', None)
      if samplers.is_stateful_seed(seed):
        kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')())
        def pfor_loop_body(_):
          with tf.name_scope('iid_sample_fn_stateful_body'):
            return sample_fn(*args, **kwargs)
      else:
        # If a stateless seed arg is passed, split it into `n` different
        # stateless seeds, so that we don't just get a bunch of copies of the
        # same sample.
        if not JAX_MODE:
          warnings.warn(
              'Saw Tensor seed {}, implying stateless sampling. Autovectorized '
              'functions that use stateless sampling may be quite slow because '
              'the current implementation falls back to an explicit loop. This '
              'will be fixed in the future. For now, you will likely see '
              'better performance from stateful sampling, which you can invoke '
              'by passing a Python `int` seed.'.format(seed))
        seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless')
        def pfor_loop_body(i):
          with tf.name_scope('iid_sample_fn_stateless_body'):
            return sample_fn(*args, seed=tf.gather(seed, i), **kwargs)

      draws = parallel_for.pfor(pfor_loop_body, n)
      return tf.nest.map_structure(unflatten, draws, expand_composites=True)
예제 #4
0
def _filter_one_step(step,
                     observation,
                     previous_particles,
                     transition_fn,
                     observation_fn,
                     proposal_fn,
                     seed=None):
    """Advances the particle filter by a single time step."""
    with tf.name_scope('filter_one_step'):
        seed = SeedStream(seed, 'filter_one_step')

        proposed_particles, proposal_log_weights = _propose_with_log_weights(
            step=step - 1,
            particles=previous_particles,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            seed=seed())

        observation_log_weights = _compute_observation_log_weights(
            step, proposed_particles, observation, observation_fn)
        log_weights = proposal_log_weights + observation_log_weights

        resampled_particles, resample_indices = _resample(proposed_particles,
                                                          log_weights,
                                                          seed=seed())

        step_log_marginal_likelihood = tfp_math.reduce_logmeanexp(log_weights,
                                                                  axis=-1)

    return resampled_particles, resample_indices, step_log_marginal_likelihood
예제 #5
0
    def iid_sample_fn(*args, **kwargs):
        """Draws iid samples from `fn`."""

        pfor_loop_body = lambda _: sample_fn(*args, **kwargs)

        seed = kwargs.pop('seed', None)
        try:  # Assume that `seed` is a valid stateful seed (Python `int`).
            kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')())
            pfor_loop_body = lambda _: sample_fn(*args, **kwargs)
        except TypeError as e:
            # If a stateless seed arg is passed, split it into `n` different stateless
            # seeds, so that we don't just get a bunch of copies of the same sample.
            if TENSOR_SEED_MSG_PREFIX not in str(e):
                raise
            warnings.warn(
                'Saw non-`int` seed {}, implying stateless sampling. '
                'Autovectorized functions that use stateless sampling '
                'may be quite slow because the current implementation '
                'falls back to an explicit loop. This will be fixed in the '
                'future. For now, you will likely see better performance '
                'from stateful sampling, which you can invoke by passing a'
                'traditional Python `int` seed.'.format(seed))
            seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless')
            pfor_loop_body = (
                lambda i: sample_fn(*args, seed=tf.gather(seed, i), **kwargs))

        draws = parallel_for.pfor(pfor_loop_body, n)
        return tf.nest.map_structure(unflatten, draws, expand_composites=True)
예제 #6
0
def _filter_one_step(step,
                     observation,
                     previous_particles,
                     log_weights,
                     transition_fn,
                     observation_fn,
                     proposal_fn,
                     resample_criterion_fn,
                     seed=None):
    """Advances the particle filter by a single time step."""
    with tf.name_scope('filter_one_step'):
        seed = SeedStream(seed, 'filter_one_step')
        num_particles = prefer_static.shape(log_weights)[-1]

        proposed_particles, proposal_log_weights = _propose_with_log_weights(
            step=step - 1,
            particles=previous_particles,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            seed=seed)

        observation_log_weights = _compute_observation_log_weights(
            step, proposed_particles, observation, observation_fn)
        unnormalized_log_weights = (log_weights + proposal_log_weights +
                                    observation_log_weights)
        step_log_marginal_likelihood = tf.math.reduce_logsumexp(
            unnormalized_log_weights, axis=-1)
        log_weights = (unnormalized_log_weights -
                       step_log_marginal_likelihood[..., tf.newaxis])

        # Adaptive resampling: resample particles iff the specified criterion.
        do_resample = tf.convert_to_tensor(
            resample_criterion_fn(unnormalized_log_weights))[
                ..., tf.newaxis]  # Broadcast over particles.

        # Some batch elements may require resampling and others not, so
        # we first do the resampling for all elements, then select whether to use
        # the resampled values for each batch element according to
        # `do_resample`. If there were no batching, we might prefer to use
        # `tf.cond` to avoid the resampling computation on steps where it's not
        # needed---but we're ultimately interested in adaptive resampling
        # for statistical (not computational) purposes, so this isn't a dealbreaker.
        resampled_particles, resample_indices = _resample(proposed_particles,
                                                          log_weights,
                                                          seed=seed)
        dummy_indices = tf.broadcast_to(prefer_static.range(num_particles),
                                        prefer_static.shape(resample_indices))
        uniform_weights = (prefer_static.zeros_like(log_weights) -
                           prefer_static.log(num_particles))
        (resampled_particles, resample_indices,
         log_weights) = tf.nest.map_structure(
             lambda r, p: prefer_static.where(do_resample, r, p),
             (resampled_particles, resample_indices, uniform_weights),
             (proposed_particles, dummy_indices, log_weights))

    return ParticleFilterStepResults(
        particles=resampled_particles,
        log_weights=log_weights,
        parent_indices=resample_indices,
        step_log_marginal_likelihood=step_log_marginal_likelihood)
예제 #7
0
 def __init__(
     self,
     target_log_prob_fn,
     target_event_id,
     prev_event_id,
     next_event_id,
     initial_state,
     dmax,
     mmax,
     nmax,
     seed=None,
     name=None,
 ):
     """An uncalibrated random walk for event times.
     :param target_log_prob_fn: the log density of the target distribution
     :param target_event_id: the position in the first dimension of the events
                             tensor that we wish to move
     :param prev_event_id: the position of the previous event in the events tensor
     :param next_event_id: the position of the next event in the events tensor
     :param initial_state: the initial state tensor
     :param seed: a random seed
     :param name: the name of the update step
     """
     self._seed_stream = SeedStream(seed,
                                    salt="UncalibratedEventTimesUpdate")
     self._name = name
     self._parameters = dict(
         target_log_prob_fn=target_log_prob_fn,
         target_event_id=target_event_id,
         prev_event_id=prev_event_id,
         next_event_id=next_event_id,
         initial_state=initial_state,
         dmax=dmax,
         mmax=mmax,
         nmax=nmax,
         seed=seed,
         name=name,
     )
     self.tx_topology = TransitionTopology(prev_event_id, target_event_id,
                                           next_event_id)
     self.time_offsets = tf.range(self.parameters["dmax"])
예제 #8
0
def _filter_one_step(step,
                     observation,
                     previous_particles,
                     log_weights,
                     transition_fn,
                     observation_fn,
                     proposal_fn,
                     resample_criterion_fn,
                     has_observation=True,
                     seed=None):
    """Advances the particle filter by a single time step."""
    with tf.name_scope('filter_one_step'):
        seed = SeedStream(seed, 'filter_one_step')
        num_particles = prefer_static.shape(log_weights)[0]

        proposed_particles, proposal_log_weights = _propose_with_log_weights(
            step=step - 1,
            particles=previous_particles,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            seed=seed)
        log_weights = tf.nn.log_softmax(proposal_log_weights + log_weights,
                                        axis=-1)

        # If this step has an observation, compute its weights and marginal
        # likelihood (and otherwise, leave weights unchanged).
        observation_log_weights = prefer_static.cond(
            has_observation,
            lambda: prefer_static.broadcast_to(  # pylint: disable=g-long-lambda
                _compute_observation_log_weights(step, proposed_particles,
                                                 observation, observation_fn),
                prefer_static.shape(log_weights)),
            lambda: tf.zeros_like(log_weights))

        unnormalized_log_weights = log_weights + observation_log_weights
        step_log_marginal_likelihood = tf.math.reduce_logsumexp(
            unnormalized_log_weights, axis=0)
        log_weights = (unnormalized_log_weights - step_log_marginal_likelihood)

        # Adaptive resampling: resample particles iff the specified criterion.
        do_resample = resample_criterion_fn(unnormalized_log_weights)

        # Some batch elements may require resampling and others not, so
        # we first do the resampling for all elements, then select whether to use
        # the resampled values for each batch element according to
        # `do_resample`. If there were no batching, we might prefer to use
        # `tf.cond` to avoid the resampling computation on steps where it's not
        # needed---but we're ultimately interested in adaptive resampling
        # for statistical (not computational) purposes, so this isn't a dealbreaker.
        resampled_particles, resample_indices = _resample(proposed_particles,
                                                          log_weights,
                                                          resample_independent,
                                                          seed=seed)

        uniform_weights = (prefer_static.zeros_like(log_weights) -
                           prefer_static.log(num_particles))
        (resampled_particles, resample_indices,
         log_weights) = tf.nest.map_structure(
             lambda r, p: prefer_static.where(do_resample, r, p),
             (resampled_particles, resample_indices, uniform_weights),
             (proposed_particles, _dummy_indices_like(resample_indices),
              log_weights))

    return ParticleFilterStepResults(
        particles=resampled_particles,
        log_weights=log_weights,
        parent_indices=resample_indices,
        step_log_marginal_likelihood=step_log_marginal_likelihood)
예제 #9
0
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        resample_criterion_fn=ess_below_threshold,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        num_transitions_per_observation=1,
        num_steps_state_history_to_pass=None,
        num_steps_observation_history_to_pass=None,
        seed=None,
        name=None):  # pylint: disable=g-doc-args
    """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
  Returns:
    particles: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing (possibly weighted) samples from the series of filtering
      distributions `p(latent_states[t] | observations[:t])`.
    log_weights: `float` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`, such that
      `log_weights[t, :]` are the logarithms of normalized importance weights
      (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of
      the particles at time `t`. These may be used in conjunction with
      `particles` to compute expectations under the series of filtering
      distributions.
    parent_indices: `int` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`,
      such that `parent_indices[t, k]` gives the index of the particle at
      time `t - 1` that the `k`th particle at time `t` is immediately descended
      from. See also
      `tfp.experimental.mcmc.reconstruct_trajectories`.
    step_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each observed timestep `t`.
      Note that (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  ${non_markovian_specification_str}
  """
    seed = SeedStream(seed, 'particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_observation_steps = prefer_static.shape(
            tf.nest.flatten(observations)[0])[0]
        num_timesteps = (1 + num_transitions_per_observation *
                         (num_observation_steps - 1))

        # If no criterion is specified, default is to resample at every step.
        if not resample_criterion_fn:
            resample_criterion_fn = lambda _: True

        # Dress up the prior and prior proposal as a fake `transition_fn` and
        # `proposal_fn` respectively.
        prior_fn = lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
            initial_state_prior, num_particles)
        prior_proposal_fn = (
            None if initial_state_proposal is None else
            lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
                initial_state_proposal, num_particles))

        # Initially the particles all have the same weight, `1. / num_particles`.
        broadcast_batch_shape = tf.convert_to_tensor(functools.reduce(
            prefer_static.broadcast_shape,
            tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []),
                                                     dtype=tf.int32)
        log_uniform_weights = prefer_static.zeros(
            prefer_static.concat([[num_particles], broadcast_batch_shape],
                                 axis=0),
            dtype=tf.float32) - prefer_static.log(num_particles)

        # Initialize from the prior, and incorporate the first observation.
        initial_step_results = _filter_one_step(
            step=0,
            # `previous_particles` at the first step is a dummy quantity, used only
            # to convey state structure and num_particles to an optional
            # proposal fn.
            previous_particles=prior_fn(0, []).sample(),
            log_weights=log_uniform_weights,
            observation=tf.nest.map_structure(lambda x: tf.gather(x, 0),
                                              observations),
            transition_fn=prior_fn,
            observation_fn=observation_fn,
            proposal_fn=prior_proposal_fn,
            resample_criterion_fn=resample_criterion_fn,
            seed=seed)

        def _loop_body(step, previous_step_results, accumulated_step_results,
                       state_history):
            """Take one step in dynamics and accumulate marginal likelihood."""

            step_has_observation = (
                # The second of these conditions subsumes the first, but both are
                # useful because the first can often be evaluated statically.
                prefer_static.equal(num_transitions_per_observation, 1) |
                prefer_static.equal(step % num_transitions_per_observation, 0))
            observation_idx = step // num_transitions_per_observation
            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, observation_idx),
                observations)

            history_to_pass_into_fns = {}
            if num_steps_observation_history_to_pass:
                history_to_pass_into_fns[
                    'observation_history'] = _gather_history(
                        observations, observation_idx,
                        num_steps_observation_history_to_pass)
            if num_steps_state_history_to_pass:
                history_to_pass_into_fns['state_history'] = state_history

            new_step_results = _filter_one_step(
                step=step,
                previous_particles=previous_step_results.particles,
                log_weights=previous_step_results.log_weights,
                observation=current_observation,
                transition_fn=functools.partial(transition_fn,
                                                **history_to_pass_into_fns),
                observation_fn=functools.partial(observation_fn,
                                                 **history_to_pass_into_fns),
                proposal_fn=(None
                             if proposal_fn is None else functools.partial(
                                 proposal_fn, **history_to_pass_into_fns)),
                resample_criterion_fn=resample_criterion_fn,
                has_observation=step_has_observation,
                seed=seed)

            return _update_loop_variables(step, new_step_results,
                                          accumulated_step_results,
                                          state_history)

        loop_results = tf.while_loop(
            cond=lambda step, *_: step < num_timesteps,
            body=_loop_body,
            loop_vars=_initialize_loop_variables(
                initial_step_results, num_steps_state_history_to_pass,
                num_timesteps))

        results = tf.nest.map_structure(lambda ta: ta.stack(),
                                        loop_results.accumulated_step_results)
        if num_transitions_per_observation != 1:
            # Return a log-prob for each observed step.
            observed_steps = prefer_static.range(
                0, num_timesteps, num_transitions_per_observation)
            results = results._replace(step_log_marginal_likelihood=tf.gather(
                results.step_log_marginal_likelihood, observed_steps))
        return results
예제 #10
0
def infer_trajectories(observations,
                       initial_state_prior,
                       transition_fn,
                       observation_fn,
                       num_particles,
                       initial_state_proposal=None,
                       proposal_fn=None,
                       resample_criterion_fn=ess_below_threshold,
                       rejuvenation_kernel_fn=None,
                       num_transitions_per_observation=1,
                       num_steps_state_history_to_pass=None,
                       num_steps_observation_history_to_pass=None,
                       seed=None,
                       name=None):  # pylint: disable=g-doc-args
    """Use particle filtering to sample from the posterior over trajectories.

  ${particle_filter_arg_str}
  Returns:
    trajectories: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing unbiased samples from the posterior distribution
      `p(latent_states | observations)`.
    step_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each timestep `t`. Note that
      (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  ${non_markovian_specification_str}

  #### Examples

  **Tracking unknown position and velocity**: Let's consider tracking an object
  moving in a one-dimensional space. We'll define a dynamical system
  by specifying an `initial_state_prior`, a `transition_fn`,
  and `observation_fn`.

  The structure of the latent state space is determined by the prior
  distribution. Here, we'll define a state space that includes the object's
  current position and velocity:

  ```python
  initial_state_prior = tfd.JointDistributionNamed({
      'position': tfd.Normal(loc=0., scale=1.),
      'velocity': tfd.Normal(loc=0., scale=0.1)})
  ```

  The `transition_fn` specifies the evolution of the system. It should
  return a distribution over latent states of the same structure as the prior.
  Here, we'll assume that the position evolves according to the velocity,
  with a small random drift, and the velocity also changes slowly, following
  a random drift:

  ```python
  def transition_fn(_, previous_state):
    return tfd.JointDistributionNamed({
        'position': tfd.Normal(
            loc=previous_state['position'] + previous_state['velocity'],
            scale=0.1),
        'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)})
  ```

  The `observation_fn` specifies the process by which the system is observed
  at each time step. Let's suppose we observe only a noisy version of the =
  current position.

  ```python
    def observation_fn(_, state):
      return tfd.Normal(loc=state['position'], scale=0.1)
  ```

  Now let's track our object. Suppose we've been given observations
  corresponding to an initial position of `0.4` and constant velocity of `0.01`:

  ```python
  # Generate simulated observations.
  observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01),
                                  scale=0.1).sample()

  # Run particle filtering to sample plausible trajectories.
  (trajectories,  # {'position': [40, 1000], 'velocity': [40, 1000]}
   lps) = tfp.experimental.mcmc.infer_trajectories(
            observations=observed_positions,
            initial_state_prior=initial_state_prior,
            transition_fn=transition_fn,
            observation_fn=observation_fn,
            num_particles=1000)
  ```

  For all `i`, `trajectories['position'][:, i]` is a sample from the
  posterior over position sequences, given the observations:
  `p(state[0:T] | observations[0:T])`. Often, the sampled trajectories
  will be highly redundant in their earlier timesteps, because most
  of the initial particles have been discarded through resampling
  (this problem is known as 'particle degeneracy'; see section 3.5 of
  [Doucet and Johansen][1]).
  In such cases it may be useful to also consider the series of *filtering*
  distributions `p(state[t] | observations[:t])`, in which each latent state
  is inferred conditioned only on observations up to that point in time; these
  may be computed using `tfp.mcmc.experimental.particle_filter`.

  #### References

  [1] Arnaud Doucet and Adam M. Johansen. A tutorial on particle
      filtering and smoothing: Fifteen years later.
      _Handbook of nonlinear filtering_, 12(656-704), 2009.
      https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf

  """
    with tf.name_scope(name or 'infer_trajectories') as name:
        seed = SeedStream(seed, 'infer_trajectories')
        (particles, log_weights, parent_indices,
         step_log_marginal_likelihoods) = particle_filter(
             observations=observations,
             initial_state_prior=initial_state_prior,
             transition_fn=transition_fn,
             observation_fn=observation_fn,
             num_particles=num_particles,
             initial_state_proposal=initial_state_proposal,
             proposal_fn=proposal_fn,
             resample_criterion_fn=resample_criterion_fn,
             rejuvenation_kernel_fn=rejuvenation_kernel_fn,
             num_transitions_per_observation=num_transitions_per_observation,
             num_steps_state_history_to_pass=num_steps_state_history_to_pass,
             num_steps_observation_history_to_pass=(
                 num_steps_observation_history_to_pass),
             seed=seed,
             name=name)
        weighted_trajectories = reconstruct_trajectories(
            particles, parent_indices)

        # Resample all steps of the trajectories using the final weights.
        resample_indices = categorical.Categorical(
            dist_util.move_dimension(log_weights[-1, ...],
                                     source_idx=0,
                                     dest_idx=-1)).sample(num_particles,
                                                          seed=seed)
        trajectories = tf.nest.map_structure(
            lambda x: _batch_gather(x, resample_indices, axis=1),
            weighted_trajectories)

        return trajectories, step_log_marginal_likelihoods
예제 #11
0
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        resample_fn=weighted_resampling.resample_systematic,
        resample_criterion_fn=ess_below_threshold,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        num_transitions_per_observation=1,
        trace_fn=_default_trace_fn,
        step_indices_to_trace=None,
        seed=None,
        name=None):  # pylint: disable=g-doc-args
    """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
    trace_fn: Python `callable` defining the values to be traced at each step.
      It takes a `ParticleFilterStepResults` tuple and returns a structure of
      `Tensor`s. The default function returns
      `(particles, log_weights, parent_indices, step_log_likelihood)`.
    step_indices_to_trace: optional `int` `Tensor` listing, in increasing order,
      the indices of steps at which to record the values traced by `trace_fn`.
      If `None`, the default behavior is to trace at every timestep,
      equivalent to specifying `step_indices_to_trace=tf.range(num_timsteps)`.
    seed: Python `int` seed for random ops.
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'particle_filter'`).
  Returns:
    particles: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing (possibly weighted) samples from the series of filtering
      distributions `p(latent_states[t] | observations[:t])`.
    log_weights: `float` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`, such that
      `log_weights[t, :]` are the logarithms of normalized importance weights
      (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of
      the particles at time `t`. These may be used in conjunction with
      `particles` to compute expectations under the series of filtering
      distributions.
    parent_indices: `int` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`,
      such that `parent_indices[t, k]` gives the index of the particle at
      time `t - 1` that the `k`th particle at time `t` is immediately descended
      from. See also
      `tfp.experimental.mcmc.reconstruct_trajectories`.
    incremental_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each observed timestep `t`.
      Note that (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  """
    seed = SeedStream(seed, 'particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_observation_steps = ps.size0(tf.nest.flatten(observations)[0])
        num_timesteps = (1 + num_transitions_per_observation *
                         (num_observation_steps - 1))

        # If no criterion is specified, default is to resample at every step.
        if not resample_criterion_fn:
            resample_criterion_fn = lambda _: True

        # Canonicalize the list of steps to trace as a rank-1 tensor of (sorted)
        # positive integers. E.g., `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`.
        if step_indices_to_trace is not None:
            (step_indices_to_trace,
             traced_steps_have_rank_zero) = _canonicalize_steps_to_trace(
                 step_indices_to_trace, num_timesteps)

        # Dress up the prior and prior proposal as a fake `transition_fn` and
        # `proposal_fn` respectively.
        prior_fn = lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
            initial_state_prior, num_particles)
        prior_proposal_fn = (
            None if initial_state_proposal is None else
            lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
                initial_state_proposal, num_particles))

        # Initially the particles all have the same weight, `1. / num_particles`.
        broadcast_batch_shape = tf.convert_to_tensor(functools.reduce(
            ps.broadcast_shape,
            tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []),
                                                     dtype=tf.int32)
        log_uniform_weights = ps.zeros(
            ps.concat([[num_particles], broadcast_batch_shape], axis=0),
            dtype=tf.float32) - ps.log(num_particles)

        # Initialize from the prior and incorporate the first observation.
        dummy_previous_step = ParticleFilterStepResults(
            particles=prior_fn(0, []).sample(),
            log_weights=log_uniform_weights,
            parent_indices=None,
            incremental_log_marginal_likelihood=0.,
            accumulated_log_marginal_likelihood=0.)
        initial_step_results = _filter_one_step(
            step=0,
            # `previous_particles` at the first step is a dummy quantity, used only
            # to convey state structure and num_particles to an optional
            # proposal fn.
            previous_step_results=dummy_previous_step,
            observation=tf.nest.map_structure(lambda x: tf.gather(x, 0),
                                              observations),
            transition_fn=prior_fn,
            observation_fn=observation_fn,
            proposal_fn=prior_proposal_fn,
            resample_fn=resample_fn,
            resample_criterion_fn=resample_criterion_fn,
            seed=seed)

        def _loop_body(step, previous_step_results, accumulated_traced_results,
                       num_steps_traced):
            """Take one step in dynamics and accumulate marginal likelihood."""

            step_has_observation = (
                # The second of these conditions subsumes the first, but both are
                # useful because the first can often be evaluated statically.
                ps.equal(num_transitions_per_observation, 1)
                | ps.equal(step % num_transitions_per_observation, 0))
            observation_idx = step // num_transitions_per_observation
            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, observation_idx),
                observations)

            new_step_results = _filter_one_step(
                step=step,
                previous_step_results=previous_step_results,
                observation=current_observation,
                transition_fn=transition_fn,
                observation_fn=observation_fn,
                proposal_fn=proposal_fn,
                resample_criterion_fn=resample_criterion_fn,
                resample_fn=resample_fn,
                has_observation=step_has_observation,
                seed=seed)

            return _update_loop_variables(
                step=step,
                current_step_results=new_step_results,
                accumulated_traced_results=accumulated_traced_results,
                trace_fn=trace_fn,
                step_indices_to_trace=step_indices_to_trace,
                num_steps_traced=num_steps_traced)

        loop_results = tf.while_loop(
            cond=lambda step, *_: step < num_timesteps,
            body=_loop_body,
            loop_vars=_initialize_loop_variables(
                initial_step_results=initial_step_results,
                num_timesteps=num_timesteps,
                trace_fn=trace_fn,
                step_indices_to_trace=step_indices_to_trace))

        results = tf.nest.map_structure(
            lambda ta: ta.stack(), loop_results.accumulated_traced_results)
        if step_indices_to_trace is not None:
            # If we were passed a rank-0 (single scalar) step to trace, don't
            # return a time axis in the returned results.
            results = ps.cond(
                traced_steps_have_rank_zero,
                lambda: tf.nest.map_structure(lambda x: x[0, ...], results),
                lambda: results)

        return results
예제 #12
0
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        seed=None,
        name=None):
    """Samples a series of particles representing filtered latent states.

  Each latent state is a `Tensor` or nested structure of `Tensor`s, as defined
  by the `initial_state_prior`.

  Each of the `transition_fn`, `observation_fn`, and `proposal_fn` args,
  if specified, takes arguments `(step, state)`, where `state` represents
  the latent state at timestep `step`.

  Args:
    observations: a (structure of) Tensors, each of shape
      `concat([[num_timesteps, b1, ..., bN], event_shape])` with optional
      batch dimensions `b1, ..., bN`.
    initial_state_prior: a (joint) distribution over the initial latent state,
      with optional batch shape `[b1, ..., bN]`.
    transition_fn: callable returning a (joint) distribution over the next
      latent state.
    observation_fn: callable returning a (joint) distribution over the current
      observation.
    num_particles: `int` `Tensor` number of particles.
    initial_state_proposal: a (joint) distribution over the initial latent
      state, with optional batch shape `[b1, ..., bN]`. If `None`, the initial
      particles are proposed from the `initial_state_prior`.
      Default value: `None`.
    proposal_fn: callable returning a (joint) proposal distribution over the
      next latent state. If `None`, the dynamics model is used (
      `proposal_fn == transition_fn`).
      Default value: `None`.
    rejuvenation_kernel_fn: optional Python `callable` with signature
      `transition_kernel = rejuvenation_kernel_fn(target_log_prob_fn)`
      where `target_log_prob_fn` is a provided callable evaluating
      `p(x[t] | y[t], x[t-1])` at each step `t`, and `transition_kernel`
      should be an instance of `tfp.mcmc.TransitionKernel`.
      Default value: `None`.  # TODO(davmre): not yet supported.
    seed: Python `int` seed for random ops.
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'particle_filter'`).
  Returns:
    particles: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, b1, ..., bN, num_particles], event_shape])`,
      representing unbiased samples from the series of (filtering) distributions
      `p(latent_states[t] | observations[:t])`.
    parent_indices: `int` `Tensor` of shape
      `[num_timesteps, b1, ..., bN, num_particles]`,
      such that `parent_indices[t, k]` gives the index of the particle at
      time `t - 1` that the `k`th particle at time `t` is immediately descended
      from. See also
      `tfp.experimental.mcmc.reconstruct_trajectories`.
    step_log_marginal_likelihoods: float `Tensor` of shape
      `[num_timesteps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each timestep `t`. Note that (
      by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  #### Examples

  **Tracking unknown position and velocity**: Let's consider tracking an object
  moving in a one-dimensional space. We'll define a dynamical system
  by specifying an `initial_state_prior`, a `transition_fn`,
  and `observation_fn`.

  The structure of the latent state space is determined by the prior
  distribution. Here, we'll define a state space that includes the object's
  current position and velocity:

  ```python
  initial_state_prior = tfd.JointDistributionNamed({
      'position': tfd.Normal(loc=0., scale=1.),
      'velocity': tfd.Normal(loc=0., scale=0.1)})
  ```

  The `transition_fn` specifies the evolution of the system. It should
  return a distribution over latent states of the same structure as the prior.
  Here, we'll assume that the position evolves according to the velocity,
  with a small random drift, and the velocity also changes slowly, following
  a random drift:

  ```python
  def transition_fn(_, previous_state):
    return tfd.JointDistributionNamed({
        'position': tfd.Normal(
            loc=previous_state['position'] + previous_state['velocity'],
            scale=0.1),
        'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)})
  ```

  The `observation_fn` specifies the process by which the system is observed
  at each time step. Let's suppose we observe only a noisy version of the =
  current position.

  ```python
    def observation_fn(_, state):
      return tfd.Normal(loc=state['position'], scale=0.1)
  ```

  Now let's track our object. Suppose we've been given observations
  corresponding to an initial position of `0.4` and constant velocity of `0.01`:

  ```python
  # Generate simulated observations.
  observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01),
                                  scale=0.1).sample()

  # Run particle filtering.
  (particles,       # {'position': [40, 1000], 'velocity': [40, 1000]}
   parent_indices,  #  [40, 1000]
   _) = tfp.experimental.mcmc.particle_filter(
          observations=observed_positions,
          initial_state_prior=initial_state_prior,
          transition_fn=transition_fn,
          observation_fn=observation_fn,
          num_particles=1000)
   ```

   The particle filter samples from the "filtering" distribution over latent
   states: at each point in time, this is the distribution conditioned on all
   observations *up to that time*. For example,
   `particles['position'][t]` contains `num_particles` samples from the
   distribution `p(position[t] | observed_positions[:t])`. Because
   particles may be resampled, there is no relationship between a particle
   at time `t` and the particle with the same index at time `t + 1`.

   We may, however, trace back through the resampling steps to reconstruct
   samples of entire latent trajectories.

  ```python
   trajectories = tfp.experimental.mcmc.reconstruct_trajectories(
        particles, parent_indices)
   ```

   Here, `trajectories['position'][:, i]` contains the history of positions
   sampled for what became the `i`th particle at the final timestep. These
   are samples from the 'smoothed' posterior over trajectories, given all
   observations: `p(position[0:T] | observed_position[0:T])`.

  """
    seed = SeedStream(seed, 'particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_timesteps = prefer_static.shape(
            tf.nest.flatten(observations)[0])[0]

        # Dress up the prior and prior proposal as a fake `transition_fn` and
        # `proposal_fn` respectively.
        prior_fn = lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
            initial_state_prior, num_particles)
        prior_proposal_fn = (
            None if initial_state_proposal is None else
            lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
                initial_state_proposal, num_particles))

        # Initialize from the prior, and incorporate the first observation.
        (
            initial_resampled_particles, initial_parent_indices,
            initial_step_log_marginal_likelihood
        ) = _filter_one_step(
            step=0,
            # `previous_particles` at the first step is a dummy quantity, used only
            # to convey state structure and num_particles to an optional
            # proposal fn.
            previous_particles=prior_fn(0, []).sample(),
            observation=tf.nest.map_structure(lambda x: tf.gather(x, 0),
                                              observations),
            transition_fn=prior_fn,
            observation_fn=observation_fn,
            proposal_fn=prior_proposal_fn,
            seed=seed())

        # Initialize the loop state with the initial values.
        all_resampled_particles = tf.nest.map_structure(
            lambda x: tf.TensorArray(dtype=x.dtype, size=num_timesteps).write(
                0, x), initial_resampled_particles)
        all_parent_indices = tf.TensorArray(dtype=tf.int32,
                                            size=num_timesteps).write(
                                                0, initial_parent_indices)
        all_step_log_marginal_likelihoods = tf.TensorArray(
            dtype=initial_step_log_marginal_likelihood.dtype,
            size=num_timesteps).write(0, initial_step_log_marginal_likelihood)

        def _loop_body(step, resampled_particles, all_resampled_particles,
                       all_parent_indices, all_step_log_marginal_likelihoods):
            """Take one step in dynamics and accumulate marginal likelihood."""

            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, step), observations)
            (resampled_particles, parent_indices,
             step_log_marginal_likelihood) = _filter_one_step(
                 step=step,
                 previous_particles=resampled_particles,
                 observation=current_observation,
                 transition_fn=transition_fn,
                 observation_fn=observation_fn,
                 proposal_fn=proposal_fn,
                 seed=seed())

            all_resampled_particles = tf.nest.map_structure(
                lambda x, y: x.write(step, y), all_resampled_particles,
                resampled_particles)
            all_parent_indices = all_parent_indices.write(step, parent_indices)
            all_step_log_marginal_likelihoods = (
                all_step_log_marginal_likelihoods.write(
                    step, step_log_marginal_likelihood))
            return (step + 1, resampled_particles, all_resampled_particles,
                    all_parent_indices, all_step_log_marginal_likelihoods)

        # This loop could (and perhaps should) be written as a tf.scan, rather than
        # an explicit while_loop. It is written as an explicit while_loop to allow
        # for anticipated future changes that may not fit the form of a scan loop.
        (_, _, all_resampled_particles, all_parent_indices,
         all_step_log_marginal_likelihoods) = tf.while_loop(
             cond=lambda step, *_: step < num_timesteps,
             body=_loop_body,
             loop_vars=(1, initial_resampled_particles,
                        all_resampled_particles, all_parent_indices,
                        all_step_log_marginal_likelihoods))

        return (tf.nest.map_structure(lambda ta: ta.stack(),
                                      all_resampled_particles),
                all_parent_indices.stack(),
                all_step_log_marginal_likelihoods.stack())
예제 #13
0
    def estimate_parameters(self,
                            observations,
                            num_iterations,
                            num_particles,
                            initial_perturbation_scale,
                            cooling_schedule,
                            seed=None,
                            name=None,
                            **kwargs):
        """Runs multiple iterations of filtering following a cooling schedule.

    Args:
      observations: observed `Tensor` value(s) on which to condition the
        parameter estimate.
      num_iterations: `int `Tensor` number of filtering iterations to run.
      num_particles: scalar int `Tensor` number of particles to use.
      initial_perturbation_scale: scalar float `Tensor`, or any structure of
        float `Tensor`s broadcasting to the same shape as the (unconstrained)
        parameters, specifying the scale (standard deviation) of Gaussian
        perturbations to each parameter at the first timestep.
      cooling_schedule: callable with signature
        `cooling_factor = cooling_schedule(iteration)` for `iteration` in
        `[0, ..., num_iterations - 1]`. The filter is
        invoked with perturbations of scale
        `initial_perturbation_scale * cooling_schedule(iteration)`.
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
      name: `str` name for ops constructed by this method.
      **kwargs: additional keyword arguments passed to
        `tfp.experimental.mcmc.infer_trajectories`.
    Returns:
      final_parameter_particles: structure of `Tensor`s matching
        `self.parameter_prior`, each with batch shape
        `[num_iterations, num_particles]`. These are the populations
        of particles representing the parameter estimate after each iteration
        of filtering.
    """
        seed = SeedStream(seed, 'iterated_filter_estimate_parameters')
        with self._name_scope(name or 'estimate_parameters'):

            initial_perturbation_scale = tf.convert_to_tensor(
                initial_perturbation_scale, name='initial_perturbation_scale')

            # Get initial parameter particles from the first filtering iteration.
            initial_unconstrained_parameters = self.one_step(
                observations=observations,
                num_particles=num_particles,
                perturbation_scale=initial_perturbation_scale,
                seed=seed,
                **kwargs)

            # Run the remaining iterations and accumulate the results.
            @tf.function(autograph=False)
            def loop_body(unconstrained_parameters, cooling_fraction):
                return self.one_step(
                    observations=observations,
                    num_particles=num_particles,
                    perturbation_scale=tf.nest.map_structure(
                        lambda s: cooling_fraction * s,
                        initial_perturbation_scale),
                    initial_unconstrained_parameters=unconstrained_parameters,
                    seed=seed,
                    **kwargs)

            estimated_unconstrained_parameters = tf.scan(
                fn=loop_body,
                elems=cooling_schedule(tf.range(1, num_iterations)),
                initializer=initial_unconstrained_parameters)

            return self.parameter_constraining_bijector.forward(
                estimated_unconstrained_parameters)
예제 #14
0
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        resample_fn=weighted_resampling.resample_systematic,
        resample_criterion_fn=smc_kernel.ess_below_threshold,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        num_transitions_per_observation=1,
        trace_fn=_default_trace_fn,
        trace_criterion_fn=_always_trace,
        static_trace_allocation_size=None,
        parallel_iterations=1,
        seed=None,
        name=None):  # pylint: disable=g-doc-args
    """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
    trace_fn: Python `callable` defining the values to be traced at each step,
      with signature `traced_values = trace_fn(weighted_particles, results)`
      in which the first argument is an instance of
      `tfp.experimental.mcmc.WeightedParticles` and the second an instance of
      `SequentialMonteCarloResults` tuple, and the return value is a structure
      of `Tensor`s.
      Default value: `lambda s, r: (s.particles, s.log_weights,
      r.parent_indices, r.incremental_log_marginal_likelihood)`
    trace_criterion_fn: optional Python `callable` with signature
      `trace_this_step = trace_criterion_fn(weighted_particles, results)` taking
      the same arguments as `trace_fn` and returning a boolean `Tensor`. If
      `None`, only values from the final step are returned.
      Default value: `lambda *_: True` (trace every step).
    static_trace_allocation_size: Optional Python `int` size of trace to
      allocate statically. This should be an upper bound on the number of steps
      traced and is used only when the length cannot be
      statically inferred (for example, if a `trace_criterion_fn` is specified).
      It is primarily intended for contexts where static shapes are required,
      such as in XLA-compiled code.
      Default value: `None`.
    parallel_iterations: Passed to the internal `tf.while_loop`.
      Default value: `1`.
    seed: Python `int` seed for random ops.
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'particle_filter'`).
  Returns:
    traced_results: A structure of Tensors as returned by `trace_fn`. If
      `trace_criterion_fn==None`, this is computed from the final step;
      otherwise, each Tensor will have initial dimension `num_steps_traced`
      and stacks the traced results across all steps.
  """

    seed = SeedStream(seed, 'particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_observation_steps = ps.size0(tf.nest.flatten(observations)[0])
        num_timesteps = (1 + num_transitions_per_observation *
                         (num_observation_steps - 1))

        # If trace criterion is `None`, we'll return only the final results.
        never_trace = lambda *_: False
        if trace_criterion_fn is None:
            static_trace_allocation_size = 0
            trace_criterion_fn = never_trace

        initial_weighted_particles = _particle_filter_initial_weighted_particles(
            observations=observations,
            observation_fn=observation_fn,
            initial_state_prior=initial_state_prior,
            initial_state_proposal=initial_state_proposal,
            num_particles=num_particles,
            seed=seed)
        propose_and_update_log_weights_fn = _particle_filter_propose_and_update_log_weights_fn(
            observations=observations,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            observation_fn=observation_fn,
            num_transitions_per_observation=num_transitions_per_observation)

        kernel = smc_kernel.SequentialMonteCarlo(
            propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
            resample_fn=resample_fn,
            resample_criterion_fn=resample_criterion_fn)

        # Use `trace_scan` rather than `sample_chain` directly because the latter
        # would force us to trace the state history (with or without thinning),
        # which is not always appropriate.
        final_result, traced_results = mcmc_util.trace_scan(
            loop_fn=(lambda state_and_results, _: kernel.one_step(
                *state_and_results)),
            initial_state=(
                initial_weighted_particles,
                kernel.bootstrap_results(initial_weighted_particles)),
            elems=tf.ones([num_timesteps]),
            trace_fn=lambda state_and_results: trace_fn(*state_and_results),
            trace_criterion_fn=(lambda state_and_results: trace_criterion_fn(
                *state_and_results)),
            static_trace_allocation_size=static_trace_allocation_size,
            parallel_iterations=parallel_iterations)

        if trace_criterion_fn is never_trace:
            # Return results from just the final step.
            traced_results = trace_fn(*final_result)

        return traced_results
예제 #15
0
    def one_step(self, state, kernel_results, seed=None):
        """Takes one Sequential Monte Carlo inference step.

    Args:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        the current particles with (log) weights. The `log_weights` must be
        a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The
        `particles` may be any structure of `Tensor`s, each of which
        must have shape `concat([log_weights.shape, event_shape])` for some
        `event_shape`, which may vary across components.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results
        from a previous step.
      seed: Optional Python integer to seed the random number generator.
        If provided, overrides the class-level seed set in `__init__`.
    Returns:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        new particles with (log) weights.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults`.
    """
        with tf.name_scope(self.name):
            with tf.name_scope('one_step'):
                seed = SeedStream(seed if seed else self.seed, 'smc_one_step')

                state = WeightedParticles(*state)  # Canonicalize.
                num_particles = ps.size0(state.log_weights)

                # Propose new particles and update weights for this step, unless it's
                # the initial step, in which case, use the user-provided initial
                # particles and weights.
                proposed_state = self.propose_and_update_log_weights_fn(
                    # Propose state[t] from state[t - 1].
                    ps.maximum(0, kernel_results.steps - 1),
                    state,
                    seed=seed())
                is_initial_step = ps.equal(kernel_results.steps, 0)
                # TODO(davmre): this `where` assumes the state size didn't change.
                state = tf.nest.map_structure(
                    lambda a, b: ps.where(is_initial_step, a, b), state,
                    proposed_state)

                normalized_log_weights = tf.nn.log_softmax(state.log_weights,
                                                           axis=0)
                # Every entry of `log_weights` differs from `normalized_log_weights`
                # by the same normalizing constant. We extract that constant by
                # examining an arbitrary entry.
                incremental_log_marginal_likelihood = (
                    state.log_weights[0] - normalized_log_weights[0])

                do_resample = self.resample_criterion_fn(state)

                # Some batch elements may require resampling and others not, so
                # we first do the resampling for all elements, then select whether to
                # use the resampled values for each batch element according to
                # `do_resample`. If there were no batching, we might prefer to use
                # `tf.cond` to avoid the resampling computation on steps where it's not
                # needed---but we're ultimately interested in adaptive resampling
                # for statistical (not computational) purposes, so this isn't a
                # dealbreaker.
                resampled_particles, resample_indices = weighted_resampling.resample(
                    state.particles,
                    state.log_weights,
                    self.resample_fn,
                    seed=seed)
                uniform_weights = tf.fill(
                    tf.shape(state.log_weights),
                    value=-tf.math.log(
                        tf.cast(num_particles, state.log_weights.dtype)))
                (resampled_particles, resample_indices,
                 log_weights) = tf.nest.map_structure(
                     lambda r, p: ps.where(do_resample, r, p),
                     (resampled_particles, resample_indices, uniform_weights),
                     (state.particles, _dummy_indices_like(resample_indices),
                      normalized_log_weights))

            return (
                WeightedParticles(particles=resampled_particles,
                                  log_weights=log_weights),
                SequentialMonteCarloResults(
                    steps=kernel_results.steps + 1,
                    parent_indices=resample_indices,
                    incremental_log_marginal_likelihood=(
                        incremental_log_marginal_likelihood),
                    accumulated_log_marginal_likelihood=(
                        kernel_results.accumulated_log_marginal_likelihood +
                        incremental_log_marginal_likelihood)))
예제 #16
0
def particle_filter(observations,
                    initial_state_prior,
                    transition_fn,
                    observation_fn,
                    num_particles,
                    initial_state_proposal=None,
                    proposal_fn=None,
                    rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
                    num_steps_state_history_to_pass=None,
                    num_steps_observation_history_to_pass=None,
                    seed=None,
                    name=None):  # pylint: disable=g-doc-args
  """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
  Returns:
    particles: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, b1, ..., bN, num_particles], event_shape])`,
      representing unbiased samples from the series of (filtering) distributions
      `p(latent_states[t] | observations[:t])`.
    parent_indices: `int` `Tensor` of shape
      `[num_timesteps, b1, ..., bN, num_particles]`,
      such that `parent_indices[t, k]` gives the index of the particle at
      time `t - 1` that the `k`th particle at time `t` is immediately descended
      from. See also
      `tfp.experimental.mcmc.reconstruct_trajectories`.
    step_log_marginal_likelihoods: float `Tensor` of shape
      `[num_timesteps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each timestep `t`. Note that (
      by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  ${non_markovian_specification_str}
  """
  seed = SeedStream(seed, 'particle_filter')
  with tf.name_scope(name or 'particle_filter'):
    num_timesteps = prefer_static.shape(
        tf.nest.flatten(observations)[0])[0]

    # Dress up the prior and prior proposal as a fake `transition_fn` and
    # `proposal_fn` respectively.
    prior_fn = lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
        initial_state_prior, num_particles)
    prior_proposal_fn = (
        None if initial_state_proposal is None
        else lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
            initial_state_proposal, num_particles))

    # Initialize from the prior, and incorporate the first observation.
    (initial_resampled_particles,
     initial_parent_indices,
     initial_step_log_marginal_likelihood) = _filter_one_step(
         step=0,
         # `previous_particles` at the first step is a dummy quantity, used only
         # to convey state structure and num_particles to an optional
         # proposal fn.
         previous_particles=prior_fn(0, []).sample(),
         observation=tf.nest.map_structure(
             lambda x: tf.gather(x, 0), observations),
         transition_fn=prior_fn,
         observation_fn=observation_fn,
         proposal_fn=prior_proposal_fn,
         seed=seed())

    loop_vars = _initialize_accumulated_quantities(
        initial_resampled_particles,
        initial_parent_indices,
        initial_step_log_marginal_likelihood,
        num_steps_state_history_to_pass,
        num_timesteps)

    def _loop_body(step, resampled_particles, accumulated_quantities):
      """Take one step in dynamics and accumulate marginal likelihood."""

      current_observation = tf.nest.map_structure(
          lambda x, step=step: tf.gather(x, step), observations)

      history_to_pass_into_fns = {}
      if num_steps_observation_history_to_pass:
        history_to_pass_into_fns['observation_history'] = _gather_history(
            observations, step, num_steps_observation_history_to_pass)
      if num_steps_state_history_to_pass:
        history_to_pass_into_fns['state_history'] = (
            accumulated_quantities.state_history)

      (resampled_particles,
       parent_indices,
       step_log_marginal_likelihood) = _filter_one_step(
           step=step,
           previous_particles=resampled_particles,
           observation=current_observation,
           transition_fn=functools.partial(
               transition_fn, **history_to_pass_into_fns),
           observation_fn=functools.partial(
               observation_fn, **history_to_pass_into_fns),
           proposal_fn=(
               None if proposal_fn is None else
               functools.partial(proposal_fn, **history_to_pass_into_fns)),
           seed=seed())

      new_accumulated_quantities = _write_accumulated_quantities(
          step,
          accumulated_quantities,
          resampled_particles,
          parent_indices,
          step_log_marginal_likelihood)

      return step + 1, resampled_particles, new_accumulated_quantities

    (_,
     _,
     loop_results) = tf.while_loop(
         cond=lambda step, *_: step < num_timesteps,
         body=_loop_body,
         loop_vars=(1, initial_resampled_particles, loop_vars))

    return (tf.nest.map_structure(lambda ta: ta.stack(),
                                  loop_results.all_resampled_particles),
            loop_results.all_parent_indices.stack(),
            loop_results.all_step_log_marginal_likelihoods.stack())