def resample_with_target_distribution(self):
    particles = np.linspace(0., 500., num=2500, dtype=np.float32)
    log_weights = tfd.Poisson(20.).log_prob(particles)

    # Resample particles to target a Poisson(20.) distribution.
    new_particles, _, new_log_weights = resample(
        particles, log_weights,
        resample_fn=resample_systematic,
        seed=test_util.test_seed(sampler_type='stateless'))
    self.assertAllClose(tf.reduce_mean(new_particles), 20., atol=1e-2)
    self.assertAllClose(
        tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles),
        20.,
        atol=1e-2)

    # Reweight the resampled particles to target a Poisson(30.) distribution.
    new_particles, _, new_log_weights = resample(
        particles, log_weights,
        resample_fn=resample_systematic,
        target_log_weights=tfd.Poisson(30).log_prob(particles),
        seed=test_util.test_seed(sampler_type='stateless'))
    self.assertAllClose(tf.reduce_mean(new_particles), 20., atol=1e-2)
    self.assertAllClose(
        tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles),
        30., atol=1.)
Esempio n. 2
0
    def smc_body_fn(stage, state, smc_kernel_result):
      """Run one stage of SMC with constant temperature."""
      (
          new_marginal,
          new_inv_temperature,
          log_weights
      ) = update_weights_temperature(
          smc_kernel_result.inverse_temperature,
          smc_kernel_result.particle_info.likelihood_log_prob)
      # TODO(b/152412213) Use a tf.scan to better collect debug info.
      if PRINT_DEBUG:
        tf.print(
            'Stage:', stage,
            'Beta:', new_inv_temperature,
            'n_steps:', smc_kernel_result.num_steps,
            'accept:', tf.exp(reduce_logmeanexp(
                smc_kernel_result.particle_info.log_accept_prob, axis=0)),
            'scaling:', tf.exp(reduce_logmeanexp(
                smc_kernel_result.particle_info.log_scalings, axis=0))
            )
      (resampled_state,
       resampled_particle_info), _ = weighted_resampling.resample(
           particles=(state, smc_kernel_result.particle_info),
           log_weights=log_weights,
           resample_fn=resample_fn,
           seed=seed_stream)
      next_num_steps, next_log_scalings = tuning_fn(
          smc_kernel_result.num_steps,
          resampled_particle_info.log_scalings,
          resampled_particle_info.log_accept_prob)
      # Skip tuning at stage 0.
      next_num_steps = tf.where(stage == 0,
                                smc_kernel_result.num_steps,
                                next_num_steps)
      next_log_scalings = tf.where(stage == 0,
                                   resampled_particle_info.log_scalings,
                                   next_log_scalings)
      next_num_steps = tf.clip_by_value(
          next_num_steps, min_num_steps, max_num_steps)

      next_state, log_accept_prob, tempered_log_prob = mutate(
          resampled_state,
          next_log_scalings,
          next_num_steps,
          new_inv_temperature)
      next_pkr = SMCResults(
          num_steps=next_num_steps,
          inverse_temperature=new_inv_temperature,
          log_marginal_likelihood=(new_marginal +
                                   smc_kernel_result.log_marginal_likelihood),
          particle_info=ParticleInfo(
              log_accept_prob=log_accept_prob,
              log_scalings=next_log_scalings,
              tempered_log_prob=tempered_log_prob,
              likelihood_log_prob=likelihood_log_prob_fn(*next_state),
          ))
      return stage + 1, next_state, next_pkr
    def one_step(self, state, kernel_results, seed=None):
        """Takes one Sequential Monte Carlo inference step.

    Args:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        the current particles with (log) weights. The `log_weights` must be
        a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The
        `particles` may be any structure of `Tensor`s, each of which
        must have shape `concat([log_weights.shape, event_shape])` for some
        `event_shape`, which may vary across components.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results
        from a previous step.
      seed: Optional seed for reproducible sampling.

    Returns:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        new particles with (log) weights.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults`.
    """
        with tf.name_scope(self.name):
            with tf.name_scope('one_step'):
                seed = samplers.sanitize_seed(seed)
                proposal_seed, resample_seed = samplers.split_seed(seed)

                state = WeightedParticles(*state)  # Canonicalize.
                num_particles = ps.size0(state.log_weights)

                # Propose new particles and update weights for this step, unless it's
                # the initial step, in which case, use the user-provided initial
                # particles and weights.
                proposed_state = self.propose_and_update_log_weights_fn(
                    # Propose state[t] from state[t - 1].
                    ps.maximum(0, kernel_results.steps - 1),
                    state,
                    seed=proposal_seed)
                is_initial_step = ps.equal(kernel_results.steps, 0)
                # TODO(davmre): this `where` assumes the state size didn't change.
                state = tf.nest.map_structure(
                    lambda a, b: tf.where(is_initial_step, a, b), state,
                    proposed_state)

                normalized_log_weights = tf.nn.log_softmax(state.log_weights,
                                                           axis=0)
                # Every entry of `log_weights` differs from `normalized_log_weights`
                # by the same normalizing constant. We extract that constant by
                # examining an arbitrary entry.
                incremental_log_marginal_likelihood = (
                    state.log_weights[0] - normalized_log_weights[0])

                do_resample = self.resample_criterion_fn(state)

                # Some batch elements may require resampling and others not, so
                # we first do the resampling for all elements, then select whether to
                # use the resampled values for each batch element according to
                # `do_resample`. If there were no batching, we might prefer to use
                # `tf.cond` to avoid the resampling computation on steps where it's not
                # needed---but we're ultimately interested in adaptive resampling
                # for statistical (not computational) purposes, so this isn't a
                # dealbreaker.
                resampled_particles, resample_indices = weighted_resampling.resample(
                    state.particles,
                    state.log_weights,
                    self.resample_fn,
                    seed=resample_seed)
                uniform_weights = tf.fill(
                    ps.shape(state.log_weights),
                    value=-tf.math.log(
                        tf.cast(num_particles, state.log_weights.dtype)))
                (resampled_particles, resample_indices,
                 log_weights) = tf.nest.map_structure(
                     lambda r, p: ps.where(do_resample, r, p),
                     (resampled_particles, resample_indices, uniform_weights),
                     (state.particles, _dummy_indices_like(resample_indices),
                      normalized_log_weights))

            return (
                WeightedParticles(particles=resampled_particles,
                                  log_weights=log_weights),
                SequentialMonteCarloResults(
                    steps=kernel_results.steps + 1,
                    parent_indices=resample_indices,
                    incremental_log_marginal_likelihood=(
                        incremental_log_marginal_likelihood),
                    accumulated_log_marginal_likelihood=(
                        kernel_results.accumulated_log_marginal_likelihood +
                        incremental_log_marginal_likelihood),
                    seed=seed))
Esempio n. 4
0
def _filter_one_step(step,
                     observation,
                     previous_step_results,
                     transition_fn,
                     observation_fn,
                     proposal_fn,
                     resample_fn,
                     resample_criterion_fn,
                     has_observation=True,
                     seed=None):
    """Advances the particle filter by a single time step."""
    with tf.name_scope('filter_one_step'):
        seed = SeedStream(seed, 'filter_one_step')
        num_particles = ps.size0(previous_step_results.log_weights)

        proposed_particles, proposal_log_weights = _propose_with_log_weights(
            step=step - 1,
            particles=previous_step_results.particles,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            seed=seed)
        log_weights = tf.nn.log_softmax(proposal_log_weights +
                                        previous_step_results.log_weights,
                                        axis=-1)

        # If this step has an observation, compute its weights and marginal
        # likelihood (and otherwise, leave weights unchanged).
        observation_log_weights = ps.cond(
            has_observation,
            lambda: ps.broadcast_to(  # pylint: disable=g-long-lambda
                _compute_observation_log_weights(step, proposed_particles,
                                                 observation, observation_fn),
                ps.shape(log_weights)),
            lambda: tf.zeros_like(log_weights))

        unnormalized_log_weights = log_weights + observation_log_weights
        log_weights = tf.nn.log_softmax(unnormalized_log_weights, axis=0)
        # Every entry of `log_weights` differs from `unnormalized_log_weights`
        # by the same normalizing constant. We extract that constant by examining
        # an arbitrary entry.
        incremental_log_marginal_likelihood = (unnormalized_log_weights[0] -
                                               log_weights[0])

        # Adaptive resampling: resample particles iff the specified criterion.
        do_resample = resample_criterion_fn(unnormalized_log_weights)

        # Some batch elements may require resampling and others not, so
        # we first do the resampling for all elements, then select whether to use
        # the resampled values for each batch element according to
        # `do_resample`. If there were no batching, we might prefer to use
        # `tf.cond` to avoid the resampling computation on steps where it's not
        # needed---but we're ultimately interested in adaptive resampling
        # for statistical (not computational) purposes, so this isn't a dealbreaker.
        resampled_particles, resample_indices = weighted_resampling.resample(
            proposed_particles, log_weights, resample_fn, seed=seed)

        uniform_weights = (ps.zeros_like(log_weights) - ps.log(num_particles))
        (resampled_particles, resample_indices,
         log_weights) = tf.nest.map_structure(
             lambda r, p: ps.where(do_resample, r, p),
             (resampled_particles, resample_indices, uniform_weights),
             (proposed_particles, _dummy_indices_like(resample_indices),
              log_weights))

    return ParticleFilterStepResults(
        particles=resampled_particles,
        log_weights=log_weights,
        parent_indices=resample_indices,
        incremental_log_marginal_likelihood=incremental_log_marginal_likelihood,
        accumulated_log_marginal_likelihood=(
            previous_step_results.accumulated_log_marginal_likelihood +
            incremental_log_marginal_likelihood))