Example #1
0
def ess_below_threshold(unnormalized_log_weights, threshold=0.5):
    """Determines if the effective sample size is much less than num_particles."""
    with tf.name_scope('ess_below_threshold'):
        num_particles = ps.size0(unnormalized_log_weights)
        log_weights = tf.math.log_softmax(unnormalized_log_weights, axis=0)
        log_ess = -tf.math.reduce_logsumexp(2 * log_weights, axis=0)
        return log_ess < (ps.log(num_particles) + ps.log(threshold))
Example #2
0
def _filter_one_step(step,
                     observation,
                     previous_particles,
                     log_weights,
                     transition_fn,
                     observation_fn,
                     proposal_fn,
                     resample_criterion_fn,
                     seed=None):
    """Advances the particle filter by a single time step."""
    with tf.name_scope('filter_one_step'):
        seed = SeedStream(seed, 'filter_one_step')
        num_particles = prefer_static.shape(log_weights)[-1]

        proposed_particles, proposal_log_weights = _propose_with_log_weights(
            step=step - 1,
            particles=previous_particles,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            seed=seed)

        observation_log_weights = _compute_observation_log_weights(
            step, proposed_particles, observation, observation_fn)
        unnormalized_log_weights = (log_weights + proposal_log_weights +
                                    observation_log_weights)
        step_log_marginal_likelihood = tf.math.reduce_logsumexp(
            unnormalized_log_weights, axis=-1)
        log_weights = (unnormalized_log_weights -
                       step_log_marginal_likelihood[..., tf.newaxis])

        # Adaptive resampling: resample particles iff the specified criterion.
        do_resample = tf.convert_to_tensor(
            resample_criterion_fn(unnormalized_log_weights))[
                ..., tf.newaxis]  # Broadcast over particles.

        # Some batch elements may require resampling and others not, so
        # we first do the resampling for all elements, then select whether to use
        # the resampled values for each batch element according to
        # `do_resample`. If there were no batching, we might prefer to use
        # `tf.cond` to avoid the resampling computation on steps where it's not
        # needed---but we're ultimately interested in adaptive resampling
        # for statistical (not computational) purposes, so this isn't a dealbreaker.
        resampled_particles, resample_indices = _resample(proposed_particles,
                                                          log_weights,
                                                          seed=seed)
        dummy_indices = tf.broadcast_to(prefer_static.range(num_particles),
                                        prefer_static.shape(resample_indices))
        uniform_weights = (prefer_static.zeros_like(log_weights) -
                           prefer_static.log(num_particles))
        (resampled_particles, resample_indices,
         log_weights) = tf.nest.map_structure(
             lambda r, p: prefer_static.where(do_resample, r, p),
             (resampled_particles, resample_indices, uniform_weights),
             (proposed_particles, dummy_indices, log_weights))

    return ParticleFilterStepResults(
        particles=resampled_particles,
        log_weights=log_weights,
        parent_indices=resample_indices,
        step_log_marginal_likelihood=step_log_marginal_likelihood)
 def _forward_log_det_jacobian(self, x):
     # This code is similar to tf.math.log_softmax but different because we have
     # an implicit zero column to handle. I.e., instead of:
     #   reduce_sum(logits - reduce_sum(exp(logits), dim))
     # we must do:
     #   log_normalization = 1 + reduce_sum(exp(logits))
     #   -log_normalization + reduce_sum(logits - log_normalization)
     np1 = prefer_static.cast(1 + prefer_static.shape(x)[-1], dtype=x.dtype)
     return (0.5 * prefer_static.log(np1) + tf.reduce_sum(x, axis=-1) -
             np1 * tf.math.softplus(tf.reduce_logsumexp(x, axis=-1)))
Example #4
0
def _log_soosum_exp_impl(logx, axis, keepdims, compute_mean):
    """Implementation for `*soosum*` functions."""
    with tf.name_scope('log_soosum_exp_impl'):
        logx = tf.convert_to_tensor(logx, name='logx')
        log_loosum_x, log_sum_x, n = _log_loosum_exp_impl(logx,
                                                          axis,
                                                          keepdims,
                                                          compute_mean=False)
        # The swap-one-out-sum ('soosum') is n different sums, each of which
        # replaces the i-th item with the i-th-left-out average (or the user
        # specified value), i.e.,
        # soo_sum_x[i] = [exp(logx) - exp(logx[i])] + exp(mean(logx[!=i]))
        #              =  exp(log_loosum_x[i])      + exp(loo_log_swap_in[i])
        loo_log_swap_in = (
            (tf.reduce_sum(logx, axis=axis, keepdims=True) - logx) / (n - 1.))
        log_soosum_x = log_add_exp(log_loosum_x, loo_log_swap_in)
        if not compute_mean:
            return log_soosum_x, log_sum_x
        log_n = prefer_static.log(n)
        return log_soosum_x - log_n, log_sum_x - log_n
 def _inverse_log_det_jacobian(self, y):
   # Let B be the forward map defined by the bijector. Consider the map
   # F : R^n -> R^n where the image of B in R^{n+1} is restricted to the first
   # n coordinates.
   #
   # Claim: det{ dF(X)/dX } = prod(Y) where Y = B(X).
   # Proof: WLOG, in vector notation:
   #     X = log(Y[:-1]) - log(Y[-1])
   #   where,
   #     Y[-1] = 1 - sum(Y[:-1]).
   #   We have:
   #     det{dF} = 1 / det{ dX/dF(X} }                                      (1)
   #             = 1 / det{ diag(1 / Y[:-1]) + 1 / Y[-1] }
   #             = 1 / det{ inv{ diag(Y[:-1]) - Y[:-1]' Y[:-1] } }
   #             = det{ diag(Y[:-1]) - Y[:-1]' Y[:-1] }
   #             = (1 + Y[:-1]' inv{diag(Y[:-1])} Y[:-1]) det{diag(Y[:-1])} (2)
   #             = Y[-1] prod(Y[:-1])
   #             = prod(Y)
   #
   # Let P be the image of R^n under F. Define the lift G, from P to R^{n+1},
   # which appends the last coordinate, Y[-1] := 1 - \sum_k Y_k. G is linear,
   # so its Jacobian is constant.
   #
   # The differential of G, DG, is eye(n) with a row of -1s appended to the
   # bottom. To compute the Jacobian sqrt{det{(DG)^T(DG)}}, one can see that
   # (DG)^T(DG) = A + eye(n), where A is the n x n matrix of 1s. This has
   # eigenvalues (n + 1, 1,...,1), so the determinant is (n + 1). Hence, the
   # Jacobian of G is sqrt{n + 1} everywhere.
   #
   # Putting it all together, the forward bijective map B can be written as
   # B(X) = G(F(X)) and has Jacobian sqrt{n + 1} * prod(F(X)).
   #
   # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula
   #       or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector
   #       docstring "Tip".
   # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma
   np1 = ps.cast(ps.shape(y)[-1], dtype=y.dtype)
   return -(0.5 * ps.log(np1) +
            tf.reduce_sum(tf.math.log(y), axis=-1))
Example #6
0
def _particle_filter_initial_weighted_particles(observations,
                                                observation_fn,
                                                initial_state_prior,
                                                initial_state_proposal,
                                                num_particles,
                                                seed=None):
  """Initialize a set of weighted particles including the first observation."""
  # Initial particles all have the same weight, `1. / num_particles`.
  broadcast_batch_shape = tf.convert_to_tensor(
      functools.reduce(
          ps.broadcast_shape,
          tf.nest.flatten(initial_state_prior.batch_shape_tensor()),
          []), dtype=tf.int32)
  initial_log_weights = ps.zeros(
      ps.concat([[num_particles], broadcast_batch_shape], axis=0),
      dtype=tf.float32) - ps.log(num_particles)

  # Propose an initial state.
  if initial_state_proposal is None:
    initial_state = initial_state_prior.sample(num_particles, seed=seed)
  else:
    initial_state = initial_state_proposal.sample(num_particles, seed=seed)
    initial_log_weights += (initial_state_prior.log_prob(initial_state) -
                            initial_state_proposal.log_prob(initial_state))
    # The initial proposal weights are normalized in expectation, but actually
    # normalizing them reduces variance in the initial marginal
    # likelihood.
    initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0)

  # Return particles weighted by the initial observation.
  return smc_kernel.WeightedParticles(
      particles=initial_state,
      log_weights=initial_log_weights + _compute_observation_log_weights(
          step=0,
          particles=initial_state,
          observations=observations,
          observation_fn=observation_fn))
Example #7
0
def auto_correlation(x,
                     axis=-1,
                     max_lags=None,
                     center=True,
                     normalize=True,
                     name='auto_correlation'):
    """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
    # Implementation details:
    # Extend length N / 2 1-D array x to length N by zero padding onto the end.
    # Then, set
    #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
    # It is not hard to see that
    #   F[x]_k Conj(F[x]_k) = F[R]_k, where
    #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
    # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

    # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
    # based version of estimating RXX.
    # Note that this is a special case of the Wiener-Khinchin Theorem.
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name='x')

        # Rotate dimensions of x in order to put axis at the rightmost dim.
        # FFT op requires this.
        rank = ps.rank(x)
        if axis < 0:
            axis = rank + axis
        shift = rank - 1 - axis
        # Suppose x.shape[axis] = T, so there are T 'time' steps.
        #   ==> x_rotated.shape = B + [T],
        # where B is x_rotated's batch shape.
        x_rotated = distribution_util.rotate_transpose(x, shift)

        if center:
            x_rotated = x_rotated - tf.reduce_mean(
                x_rotated, axis=-1, keepdims=True)

        # x_len = N / 2 from above explanation.  The length of x along axis.
        # Get a value for x_len that works in all cases.
        x_len = ps.shape(x_rotated)[-1]

        # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
        # the moment is necessary so that all FFT implementations work.
        # Zero pad to the next power of 2 greater than 2 * x_len, which equals
        # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
        x_len_float64 = ps.cast(x_len, np.float64)
        target_length = ps.pow(np.float64(2.),
                               ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.)))
        pad_length = ps.cast(target_length - x_len_float64, np.int32)

        # We should have:
        # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
        #                     = B + [T + pad_length]
        x_rotated_pad = distribution_util.pad(x_rotated,
                                              axis=-1,
                                              back=True,
                                              count=pad_length)

        dtype = x.dtype
        if not dtype_util.is_complex(dtype):
            if not dtype_util.is_floating(dtype):
                raise TypeError(
                    'Argument x must have either float or complex dtype'
                    ' found: {}'.format(dtype))
            x_rotated_pad = tf.complex(
                x_rotated_pad,
                dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.))

        # Autocorrelation is IFFT of power-spectral density (up to some scaling).
        fft_x_rotated_pad = tf.signal.fft(x_rotated_pad)
        spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad)
        # shifted_product is R[m] from above detailed explanation.
        # It is the inner product sum_n X[n] * Conj(X[n - m]).
        shifted_product = tf.signal.ifft(spectral_density)

        # Cast back to real-valued if x was real to begin with.
        shifted_product = tf.cast(shifted_product, dtype)

        # Figure out if we can deduce the final static shape, and set max_lags.
        # Use x_rotated as a reference, because it has the time dimension in the far
        # right, and was created before we performed all sorts of crazy shape
        # manipulations.
        know_static_shape = True
        if not tensorshape_util.is_fully_defined(x_rotated.shape):
            know_static_shape = False
        if max_lags is None:
            max_lags = x_len - 1
        else:
            max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
            max_lags_ = tf.get_static_value(max_lags)
            if max_lags_ is None or not know_static_shape:
                know_static_shape = False
                max_lags = tf.minimum(x_len - 1, max_lags)
            else:
                max_lags = min(x_len - 1, max_lags_)

        # Chop off the padding.
        # We allow users to provide a huge max_lags, but cut it off here.
        # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
        shifted_product_chopped = shifted_product[..., :max_lags + 1]

        # If possible, set shape.
        if know_static_shape:
            chopped_shape = tensorshape_util.as_list(x_rotated.shape)
            chopped_shape[-1] = min(x_len, max_lags + 1)
            tensorshape_util.set_shape(shifted_product_chopped, chopped_shape)

        # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
        # other terms were zeros arising only due to zero padding.
        # `denominator = (N / 2 - m)` (defined below) is the proper term to
        # divide by to make this an unbiased estimate of the expectation
        # E[X[n] Conj(X[n - m])].
        x_len = ps.cast(x_len, dtype_util.real_dtype(dtype))
        max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype))
        denominator = x_len - ps.range(0., max_lags + 1.)
        denominator = ps.cast(denominator, dtype)
        shifted_product_rotated = shifted_product_chopped / denominator

        if normalize:
            shifted_product_rotated /= shifted_product_rotated[..., :1]

        # Transpose dimensions back to those of x.
        return distribution_util.rotate_transpose(shifted_product_rotated,
                                                  -shift)
Example #8
0
def _filter_one_step(step,
                     observation,
                     previous_particles,
                     log_weights,
                     transition_fn,
                     observation_fn,
                     proposal_fn,
                     resample_criterion_fn,
                     has_observation=True,
                     seed=None):
    """Advances the particle filter by a single time step."""
    with tf.name_scope('filter_one_step'):
        seed = SeedStream(seed, 'filter_one_step')
        num_particles = prefer_static.shape(log_weights)[0]

        proposed_particles, proposal_log_weights = _propose_with_log_weights(
            step=step - 1,
            particles=previous_particles,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            seed=seed)
        log_weights = tf.nn.log_softmax(proposal_log_weights + log_weights,
                                        axis=-1)

        # If this step has an observation, compute its weights and marginal
        # likelihood (and otherwise, leave weights unchanged).
        observation_log_weights = prefer_static.cond(
            has_observation,
            lambda: prefer_static.broadcast_to(  # pylint: disable=g-long-lambda
                _compute_observation_log_weights(step, proposed_particles,
                                                 observation, observation_fn),
                prefer_static.shape(log_weights)),
            lambda: tf.zeros_like(log_weights))

        unnormalized_log_weights = log_weights + observation_log_weights
        step_log_marginal_likelihood = tf.math.reduce_logsumexp(
            unnormalized_log_weights, axis=0)
        log_weights = (unnormalized_log_weights - step_log_marginal_likelihood)

        # Adaptive resampling: resample particles iff the specified criterion.
        do_resample = resample_criterion_fn(unnormalized_log_weights)

        # Some batch elements may require resampling and others not, so
        # we first do the resampling for all elements, then select whether to use
        # the resampled values for each batch element according to
        # `do_resample`. If there were no batching, we might prefer to use
        # `tf.cond` to avoid the resampling computation on steps where it's not
        # needed---but we're ultimately interested in adaptive resampling
        # for statistical (not computational) purposes, so this isn't a dealbreaker.
        resampled_particles, resample_indices = _resample(proposed_particles,
                                                          log_weights,
                                                          resample_independent,
                                                          seed=seed)

        uniform_weights = (prefer_static.zeros_like(log_weights) -
                           prefer_static.log(num_particles))
        (resampled_particles, resample_indices,
         log_weights) = tf.nest.map_structure(
             lambda r, p: prefer_static.where(do_resample, r, p),
             (resampled_particles, resample_indices, uniform_weights),
             (proposed_particles, _dummy_indices_like(resample_indices),
              log_weights))

    return ParticleFilterStepResults(
        particles=resampled_particles,
        log_weights=log_weights,
        parent_indices=resample_indices,
        step_log_marginal_likelihood=step_log_marginal_likelihood)
Example #9
0
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        resample_criterion_fn=ess_below_threshold,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        num_transitions_per_observation=1,
        num_steps_state_history_to_pass=None,
        num_steps_observation_history_to_pass=None,
        seed=None,
        name=None):  # pylint: disable=g-doc-args
    """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
  Returns:
    particles: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing (possibly weighted) samples from the series of filtering
      distributions `p(latent_states[t] | observations[:t])`.
    log_weights: `float` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`, such that
      `log_weights[t, :]` are the logarithms of normalized importance weights
      (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of
      the particles at time `t`. These may be used in conjunction with
      `particles` to compute expectations under the series of filtering
      distributions.
    parent_indices: `int` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`,
      such that `parent_indices[t, k]` gives the index of the particle at
      time `t - 1` that the `k`th particle at time `t` is immediately descended
      from. See also
      `tfp.experimental.mcmc.reconstruct_trajectories`.
    step_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each observed timestep `t`.
      Note that (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  ${non_markovian_specification_str}
  """
    seed = SeedStream(seed, 'particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_observation_steps = prefer_static.shape(
            tf.nest.flatten(observations)[0])[0]
        num_timesteps = (1 + num_transitions_per_observation *
                         (num_observation_steps - 1))

        # If no criterion is specified, default is to resample at every step.
        if not resample_criterion_fn:
            resample_criterion_fn = lambda _: True

        # Dress up the prior and prior proposal as a fake `transition_fn` and
        # `proposal_fn` respectively.
        prior_fn = lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
            initial_state_prior, num_particles)
        prior_proposal_fn = (
            None if initial_state_proposal is None else
            lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
                initial_state_proposal, num_particles))

        # Initially the particles all have the same weight, `1. / num_particles`.
        broadcast_batch_shape = tf.convert_to_tensor(functools.reduce(
            prefer_static.broadcast_shape,
            tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []),
                                                     dtype=tf.int32)
        log_uniform_weights = prefer_static.zeros(
            prefer_static.concat([[num_particles], broadcast_batch_shape],
                                 axis=0),
            dtype=tf.float32) - prefer_static.log(num_particles)

        # Initialize from the prior, and incorporate the first observation.
        initial_step_results = _filter_one_step(
            step=0,
            # `previous_particles` at the first step is a dummy quantity, used only
            # to convey state structure and num_particles to an optional
            # proposal fn.
            previous_particles=prior_fn(0, []).sample(),
            log_weights=log_uniform_weights,
            observation=tf.nest.map_structure(lambda x: tf.gather(x, 0),
                                              observations),
            transition_fn=prior_fn,
            observation_fn=observation_fn,
            proposal_fn=prior_proposal_fn,
            resample_criterion_fn=resample_criterion_fn,
            seed=seed)

        def _loop_body(step, previous_step_results, accumulated_step_results,
                       state_history):
            """Take one step in dynamics and accumulate marginal likelihood."""

            step_has_observation = (
                # The second of these conditions subsumes the first, but both are
                # useful because the first can often be evaluated statically.
                prefer_static.equal(num_transitions_per_observation, 1) |
                prefer_static.equal(step % num_transitions_per_observation, 0))
            observation_idx = step // num_transitions_per_observation
            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, observation_idx),
                observations)

            history_to_pass_into_fns = {}
            if num_steps_observation_history_to_pass:
                history_to_pass_into_fns[
                    'observation_history'] = _gather_history(
                        observations, observation_idx,
                        num_steps_observation_history_to_pass)
            if num_steps_state_history_to_pass:
                history_to_pass_into_fns['state_history'] = state_history

            new_step_results = _filter_one_step(
                step=step,
                previous_particles=previous_step_results.particles,
                log_weights=previous_step_results.log_weights,
                observation=current_observation,
                transition_fn=functools.partial(transition_fn,
                                                **history_to_pass_into_fns),
                observation_fn=functools.partial(observation_fn,
                                                 **history_to_pass_into_fns),
                proposal_fn=(None
                             if proposal_fn is None else functools.partial(
                                 proposal_fn, **history_to_pass_into_fns)),
                resample_criterion_fn=resample_criterion_fn,
                has_observation=step_has_observation,
                seed=seed)

            return _update_loop_variables(step, new_step_results,
                                          accumulated_step_results,
                                          state_history)

        loop_results = tf.while_loop(
            cond=lambda step, *_: step < num_timesteps,
            body=_loop_body,
            loop_vars=_initialize_loop_variables(
                initial_step_results, num_steps_state_history_to_pass,
                num_timesteps))

        results = tf.nest.map_structure(lambda ta: ta.stack(),
                                        loop_results.accumulated_step_results)
        if num_transitions_per_observation != 1:
            # Return a log-prob for each observed step.
            observed_steps = prefer_static.range(
                0, num_timesteps, num_transitions_per_observation)
            results = results._replace(step_log_marginal_likelihood=tf.gather(
                results.step_log_marginal_likelihood, observed_steps))
        return results
Example #10
0
def _log_loosum_exp_impl(logx, axis, keepdims, compute_mean):
    """Implementation for `*loosum*` functions."""
    with tf.name_scope('log_loosum_exp_impl'):
        logx = tf.convert_to_tensor(logx, name='logx')
        dtype = dtype_util.as_numpy_dtype(logx.dtype)

        if axis is not None:
            x = np.array(axis)
            axis = (tf.convert_to_tensor(
                axis, name='axis', dtype_hint=tf.int32)
                    if x.dtype is np.object else x.astype(np.int32))

        log_sum_x = tf.reduce_logsumexp(logx, axis=axis, keepdims=True)

        # Later we'll want to compute the mean from a sum so we calculate the number
        # of reduced elements, n.
        n = prefer_static.size(logx) // prefer_static.size(log_sum_x)
        n = prefer_static.cast(n, dtype)

        # log_loosum_x[i] =
        # = logsumexp(logx[j] : j != i)
        # = log( exp(logsumexp(logx)) - exp(logx[i]) )
        # = log( exp(logsumexp(logx - logx[i])) exp(logx[i])  - exp(logx[i]))
        # = logx[i] + log(exp(logsumexp(logx - logx[i])) - 1)
        # = logx[i] + log(exp(logsumexp(logx) - logx[i]) - 1)
        # = logx[i] + softplus_inverse(logsumexp(logx) - logx[i])
        d = log_sum_x - logx
        # We use `d != 0` rather than `d > 0.` because `d < 0.` should never happen;
        # if it does we want to complain loudly (which `softplus_inverse` will).
        d_ok = tf.not_equal(d, 0.)
        safe_d = tf.where(d_ok, d, 1.)
        d_ok_result = logx + softplus_inverse(safe_d)

        neg_inf = tf.constant(-np.inf, dtype=dtype)

        # When not(d_ok) and is_positive_and_largest then we manually compute the
        # log_loosum_x. (We can efficiently do this for any one point but not all,
        # hence we still need the above calculation.) This is good because when
        # this condition is met, we cannot use the above calculation; its -inf.
        # We now compute the log-leave-out-max-sum, replicate it to every
        # point and make sure to select it only when we need to.
        max_logx = tf.reduce_max(logx, axis=axis, keepdims=True)
        is_positive_and_largest = (logx > 0.) & tf.equal(logx, max_logx)
        log_lomsum_x = tf.reduce_logsumexp(tf.where(is_positive_and_largest,
                                                    neg_inf, logx),
                                           axis=axis,
                                           keepdims=True)
        d_not_ok_result = tf.where(is_positive_and_largest, log_lomsum_x,
                                   neg_inf)

        log_loosum_x = tf.where(d_ok, d_ok_result, d_not_ok_result)

        # We now squeeze log_sum_x so as if we used `keepdims=False`.
        # TODO(b/136176077): These mental gymnastics could all be replaced with
        # `tf.squeeze(log_sum_x, axis)` if tf.squeeze supported Tensor valued `axis`
        # arguments.
        if not keepdims:
            if axis is None:
                keepdims = np.array([], dtype=np.int32)
            else:
                rank = prefer_static.rank(logx)
                keepdims = prefer_static.setdiff1d(
                    prefer_static.range(rank),
                    prefer_static.non_negative_axis(axis, rank))
            squeeze_shape = tf.gather(prefer_static.shape(logx),
                                      indices=keepdims)
            log_sum_x = tf.reshape(log_sum_x, shape=squeeze_shape)
            if prefer_static.is_numpy(keepdims):
                tensorshape_util.set_shape(log_sum_x,
                                           np.array(logx.shape)[keepdims])

        # Set static shapes just in case we lost them.
        tensorshape_util.set_shape(n, [])
        tensorshape_util.set_shape(log_loosum_x, logx.shape)

        if not compute_mean:
            return log_loosum_x, log_sum_x, n

        log_nm1 = prefer_static.log(max(1., n - 1.))
        log_n = prefer_static.log(n)
        return log_loosum_x - log_nm1, log_sum_x - log_n, n
Example #11
0
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        resample_fn=weighted_resampling.resample_systematic,
        resample_criterion_fn=ess_below_threshold,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        num_transitions_per_observation=1,
        trace_fn=_default_trace_fn,
        step_indices_to_trace=None,
        seed=None,
        name=None):  # pylint: disable=g-doc-args
    """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
    trace_fn: Python `callable` defining the values to be traced at each step.
      It takes a `ParticleFilterStepResults` tuple and returns a structure of
      `Tensor`s. The default function returns
      `(particles, log_weights, parent_indices, step_log_likelihood)`.
    step_indices_to_trace: optional `int` `Tensor` listing, in increasing order,
      the indices of steps at which to record the values traced by `trace_fn`.
      If `None`, the default behavior is to trace at every timestep,
      equivalent to specifying `step_indices_to_trace=tf.range(num_timsteps)`.
    seed: Python `int` seed for random ops.
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'particle_filter'`).
  Returns:
    particles: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing (possibly weighted) samples from the series of filtering
      distributions `p(latent_states[t] | observations[:t])`.
    log_weights: `float` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`, such that
      `log_weights[t, :]` are the logarithms of normalized importance weights
      (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of
      the particles at time `t`. These may be used in conjunction with
      `particles` to compute expectations under the series of filtering
      distributions.
    parent_indices: `int` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`,
      such that `parent_indices[t, k]` gives the index of the particle at
      time `t - 1` that the `k`th particle at time `t` is immediately descended
      from. See also
      `tfp.experimental.mcmc.reconstruct_trajectories`.
    incremental_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each observed timestep `t`.
      Note that (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  """
    seed = SeedStream(seed, 'particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_observation_steps = ps.size0(tf.nest.flatten(observations)[0])
        num_timesteps = (1 + num_transitions_per_observation *
                         (num_observation_steps - 1))

        # If no criterion is specified, default is to resample at every step.
        if not resample_criterion_fn:
            resample_criterion_fn = lambda _: True

        # Canonicalize the list of steps to trace as a rank-1 tensor of (sorted)
        # positive integers. E.g., `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`.
        if step_indices_to_trace is not None:
            (step_indices_to_trace,
             traced_steps_have_rank_zero) = _canonicalize_steps_to_trace(
                 step_indices_to_trace, num_timesteps)

        # Dress up the prior and prior proposal as a fake `transition_fn` and
        # `proposal_fn` respectively.
        prior_fn = lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
            initial_state_prior, num_particles)
        prior_proposal_fn = (
            None if initial_state_proposal is None else
            lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
                initial_state_proposal, num_particles))

        # Initially the particles all have the same weight, `1. / num_particles`.
        broadcast_batch_shape = tf.convert_to_tensor(functools.reduce(
            ps.broadcast_shape,
            tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []),
                                                     dtype=tf.int32)
        log_uniform_weights = ps.zeros(
            ps.concat([[num_particles], broadcast_batch_shape], axis=0),
            dtype=tf.float32) - ps.log(num_particles)

        # Initialize from the prior and incorporate the first observation.
        dummy_previous_step = ParticleFilterStepResults(
            particles=prior_fn(0, []).sample(),
            log_weights=log_uniform_weights,
            parent_indices=None,
            incremental_log_marginal_likelihood=0.,
            accumulated_log_marginal_likelihood=0.)
        initial_step_results = _filter_one_step(
            step=0,
            # `previous_particles` at the first step is a dummy quantity, used only
            # to convey state structure and num_particles to an optional
            # proposal fn.
            previous_step_results=dummy_previous_step,
            observation=tf.nest.map_structure(lambda x: tf.gather(x, 0),
                                              observations),
            transition_fn=prior_fn,
            observation_fn=observation_fn,
            proposal_fn=prior_proposal_fn,
            resample_fn=resample_fn,
            resample_criterion_fn=resample_criterion_fn,
            seed=seed)

        def _loop_body(step, previous_step_results, accumulated_traced_results,
                       num_steps_traced):
            """Take one step in dynamics and accumulate marginal likelihood."""

            step_has_observation = (
                # The second of these conditions subsumes the first, but both are
                # useful because the first can often be evaluated statically.
                ps.equal(num_transitions_per_observation, 1)
                | ps.equal(step % num_transitions_per_observation, 0))
            observation_idx = step // num_transitions_per_observation
            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, observation_idx),
                observations)

            new_step_results = _filter_one_step(
                step=step,
                previous_step_results=previous_step_results,
                observation=current_observation,
                transition_fn=transition_fn,
                observation_fn=observation_fn,
                proposal_fn=proposal_fn,
                resample_criterion_fn=resample_criterion_fn,
                resample_fn=resample_fn,
                has_observation=step_has_observation,
                seed=seed)

            return _update_loop_variables(
                step=step,
                current_step_results=new_step_results,
                accumulated_traced_results=accumulated_traced_results,
                trace_fn=trace_fn,
                step_indices_to_trace=step_indices_to_trace,
                num_steps_traced=num_steps_traced)

        loop_results = tf.while_loop(
            cond=lambda step, *_: step < num_timesteps,
            body=_loop_body,
            loop_vars=_initialize_loop_variables(
                initial_step_results=initial_step_results,
                num_timesteps=num_timesteps,
                trace_fn=trace_fn,
                step_indices_to_trace=step_indices_to_trace))

        results = tf.nest.map_structure(
            lambda ta: ta.stack(), loop_results.accumulated_traced_results)
        if step_indices_to_trace is not None:
            # If we were passed a rank-0 (single scalar) step to trace, don't
            # return a time axis in the returned results.
            results = ps.cond(
                traced_steps_have_rank_zero,
                lambda: tf.nest.map_structure(lambda x: x[0, ...], results),
                lambda: results)

        return results