Example #1
0
def resample(particles, log_weights, resample_fn, seed=None):
    """Resamples the current particles according to provided weights.

  Args:
    particles: Nested structure of `Tensor`s each of shape
      `[num_particles, b1, ..., bN, ...]`, where
      `b1, ..., bN` are optional batch dimensions.
    log_weights: float `Tensor` of shape `[num_particles, b1, ..., bN]`, where
      `b1, ..., bN` are optional batch dimensions.
    resample_fn: choose the function used for resampling.
      Use 'resample_independent' for independent resamples.
      Use 'resample_stratified' for stratified resampling.
      Use 'resample_systematic' for systematic resampling.
    seed: Python `int` random seed.

  Returns:
    resampled_particles: Nested structure of `Tensor`s, matching `particles`.
    resample_indices: int `Tensor` of shape `[num_particles, b1, ..., bN]`.
  """
    with tf.name_scope('resample'):
        num_particles = ps.size0(log_weights)
        log_probs = tf.math.log_softmax(log_weights, axis=0)
        resampled_indices = resample_fn(log_probs,
                                        num_particles, (),
                                        seed=seed)
        resampled_particles = tf.nest.map_structure(
            lambda x: mcmc_util.index_remapping_gather(  # pylint: disable=g-long-lambda
                x, resampled_indices, axis=0),
            particles)
    return resampled_particles, resampled_indices
def resample_particle_and_info(particles, log_weights, seed=None):
    """Resamples the current particles according to provided weights.

  Args:
    particles: Nested structure of `Tensor`s each of shape
      `[num_particles, b1, ..., bN, ...]`, where
      `b1, ..., bN` are optional batch dimensions.
    log_weights: float `Tensor` of shape `[num_particles, b1, ..., bN]`, where
      `b1, ..., bN` are optional batch dimensions.
    seed: Python `int` random seed.
  Returns:
    resampled_particles: Nested structure of `Tensor`s, matching `particles`.
    resample_indices: int `Tensor` of shape `[num_particles, b1, ..., bN]`.
  """
    with tf.name_scope('resample'):
        num_particles = ps.size0(log_weights)

        log_probs = tf.math.log_softmax(log_weights, axis=0)
        # TODO(junpenglao): use an `axis` specifiable categorical sampler to avoid
        # transpose below.
        resample_indices = categorical.Categorical(
            logits=dist_util.move_dimension(log_probs, 0, -1)).sample(
                num_particles, seed=seed)
        resampled_particles = tf.nest.map_structure(
            lambda x: mcmc_util.index_remapping_gather(x, resample_indices),
            particles)
    return resampled_particles, resample_indices
Example #3
0
def reconstruct_trajectories(particles, parent_indices, name=None):
  """Reconstructs the ancestor trajectory that generated each final particle."""
  with tf.name_scope(name or 'reconstruct_trajectories'):
    # Walk backwards to compute the ancestor of each final particle at time t.
    final_indices = smc_kernel._dummy_indices_like(parent_indices[-1])  # pylint: disable=protected-access
    ancestor_indices = tf.scan(
        fn=lambda ancestor, parent: mcmc_util.index_remapping_gather(  # pylint: disable=g-long-lambda
            parent, ancestor, axis=0),
        elems=parent_indices[1:],
        initializer=final_indices,
        reverse=True)
    ancestor_indices = tf.concat([ancestor_indices, [final_indices]], axis=0)

  return tf.nest.map_structure(
      lambda part: mcmc_util.index_remapping_gather(  # pylint: disable=g-long-lambda
          part, ancestor_indices, axis=1, indices_axis=1),
      particles)
Example #4
0
  def test_rank_1_same_as_gather(self):
    params = [10, 11, 12, 13]
    indices = [3, 2, 0]

    expected = [13, 12, 10]
    result = util.index_remapping_gather(params, indices)
    self.assertAllEqual(np.asarray(indices).shape, result.shape)

    self.assertAllEqual(expected, self.evaluate(result))
Example #5
0
    def test_rank_2_and_axis_0(self):
        params = [[95, 46, 17], [46, 29, 55]]
        indices = [[0, 0, 1], [1, 0, 1]]

        expected = [[95, 46, 55], [46, 46, 55]]
        result = util.index_remapping_gather(params, indices)

        self.assertAllEqual(np.asarray(params).shape, result.shape)

        self.assertAllEqual(expected, self.evaluate(result))
 def _resample(self, particles, log_weights, seed=None):
     """Chooses one out of `importance_sample_size` many weighted proposals."""
     sampled_indices = categorical.Categorical(
         logits=distribution_util.move_dimension(
             log_weights, 0, -1)).sample(sample_shape=[1], seed=seed)
     return tf.nest.map_structure(
         lambda x: (  # pylint: disable=g-long-lambda
             mcmc_util.index_remapping_gather(x, sampled_indices, axis=0)[
                 0, ...]),
         particles)
def resample(particles, log_weights, resample_fn, target_log_weights=None,
             seed=None):
  """Resamples the current particles according to provided weights.

  Args:
    particles: Nested structure of `Tensor`s each of shape
      `[num_particles, b1, ..., bN, ...]`, where
      `b1, ..., bN` are optional batch dimensions.
    log_weights: float `Tensor` of shape `[num_particles, b1, ..., bN]`, where
      `b1, ..., bN` are optional batch dimensions.
    resample_fn: choose the function used for resampling.
      Use 'resample_independent' for independent resamples.
      Use 'resample_stratified' for stratified resampling.
      Use 'resample_systematic' for systematic resampling.
    target_log_weights: optional float `Tensor` of the same shape and dtype as
      `log_weights`, specifying the target measure on `particles` if this is
      different from that implied by normalizing `log_weights`. The
      returned `log_weights_after_resampling` will represent this measure. If
      `None`, the target measure is implicitly taken to be the normalized
      log weights (`log_weights - tf.reduce_logsumexp(log_weights, axis=0)`).
      Default value: `None`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

  Returns:
    resampled_particles: Nested structure of `Tensor`s, matching `particles`.
    resample_indices: int `Tensor` of shape `[num_particles, b1, ..., bN]`.
    log_weights_after_resampling: float `Tensor` of same shape and dtype as
      `log_weights`, such that weighted sums of the resampled particles are
      equal (in expectation over the resampling step) to weighted sums of
      the original particles:
      `E [ exp(log_weights_after_resampling) * some_fn(resampled_particles) ]
      == exp(target_log_weights) * some_fn(particles)`.
      If no `target_log_weights` was specified, the log weights after
      resampling are uniformly equal to `-log(num_particles)`.
  """
  with tf.name_scope('resample'):
    num_particles = ps.size0(log_weights)
    log_num_particles = tf.math.log(tf.cast(num_particles, log_weights.dtype))

    # Normalize the weights and sample the ancestral indices.
    log_probs = tf.math.log_softmax(log_weights, axis=0)
    resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed)

    gather_ancestors = lambda x: (  # pylint: disable=g-long-lambda
        mcmc_util.index_remapping_gather(x, resampled_indices, axis=0))
    resampled_particles = tf.nest.map_structure(gather_ancestors, particles)
    if target_log_weights is None:
      log_weights_after_resampling = tf.fill(ps.shape(log_weights),
                                             -log_num_particles)
    else:
      importance_weights = target_log_weights - log_probs - log_num_particles
      log_weights_after_resampling = tf.nest.map_structure(
          gather_ancestors, importance_weights)
  return resampled_particles, resampled_indices, log_weights_after_resampling
Example #8
0
  def test_params_rank3_indices_rank2_axis_0(self):
    axis = 0
    params = np.random.randint(10, 100, size=(4, 5, 2))
    indices = np.random.randint(0, params.shape[axis], size=(6, 5))

    result = util.index_remapping_gather(params, indices)
    self.assertAllEqual(indices.shape[:axis + 1] + params.shape[axis + 1:],
                        result.shape)
    result_ = self.evaluate(result)

    for i in range(indices.shape[0]):
      for j in range(params.shape[1]):
        for k in range(params.shape[2]):
          self.assertEqual(params[indices[i, j], j, k], result_[i, j, k])
Example #9
0
def _maybe_embed_swaps_validation(swaps, null_swaps, validate_args):
    """Return `swaps`, possibly with embedded "once only" assertion."""
    if not validate_args:
        return swaps

    assertions = [
        assert_util.assert_equal(
            null_swaps,
            mcmc_util.index_remapping_gather(swaps, swaps),
            message=('Proposed replica swaps must be consist of "once only '
                     'swaps," i.e., be a self-inverse permutation, '
                     '`range(swaps.shape[0]) == gather(swaps, swaps).')),
    ]
    with tf.control_dependencies(assertions):
        return tf.identity(swaps)
Example #10
0
    def test_params_rank3_indices_rank1_axis_1(self):
        axis = 1
        params = np.random.randint(10, 100, size=[4, 5, 2])
        indices = np.random.randint(0, params.shape[axis], size=[6])

        result = util.index_remapping_gather(params, indices, axis=axis)
        self.assertAllEqual(
            params.shape[:axis] + indices.shape[:1] + params.shape[axis + 1:],
            result.shape)
        result_ = self.evaluate(result)

        for i in range(params.shape[0]):
            for j in range(indices.shape[0]):
                for k in range(params.shape[2]):
                    self.assertEqual(params[i, indices[j], k], result_[i, j,
                                                                       k])
Example #11
0
  def test_params_rank5_indices_rank3_axis_2_iaxis_1(self):
    axis = 2
    indices_axis = 1
    params = np.random.randint(10, 100, size=[4, 5, 2, 3, 4])
    indices = np.random.randint(0, params.shape[axis], size=[5, 6, 3])

    result = util.index_remapping_gather(
        params, indices, axis=axis, indices_axis=indices_axis)
    self.assertAllEqual(
        params.shape[:axis] +
        indices.shape[indices_axis:indices_axis + 1] +
        params.shape[axis + 1:],
        result.shape)
    result_ = self.evaluate(result)

    for i in range(params.shape[0]):
      for j in range(params.shape[1]):
        for k in range(indices.shape[1]):
          for l in range(params.shape[3]):
            for m in range(params.shape[4]):
              self.assertEqual(params[i, j, indices[j, k, l], l, m],
                               result_[i, j, k, l, m])
Example #12
0
 def _swap_tensor(x):
     return mcmc_util.choose(
         is_swap_accepted_mask,
         mcmc_util.index_remapping_gather(x, swaps), x)
Example #13
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                target_log_prob_fn=self.target_log_prob_fn,
                inverse_temperatures=inverse_temperatures,
                untempered_log_prob_fn=self.untempered_log_prob_fn,
                tempered_log_prob_fn=self.tempered_log_prob_fn,
            )
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` '
                    'argument. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3)
            # Step the inner TransitionKernel.
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                seed=inner_seed)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = bu.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = bu.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs for use in the swap acceptance ratio.
            if self.tempered_log_prob_fn is None:
                # Efficient way of re-evaluating target_log_prob_fn on the
                # pre_swap_replica_states.
                untempered_negative_energy_ignoring_ulp = (
                    # Since untempered_log_prob_fn is None, we may assume
                    # inverse_temperatures > 0 (else the target is improper).
                    pre_swap_replica_target_log_prob / inverse_temperatures)
            else:
                # The untempered_log_prob_fn does not factor into the acceptance ratio.
                # Proof: Suppose the tempered target is
                #   p_k(x) = f(x)^{beta_k} g(x),
                # So f(x) is tempered, and g(x) is not.  Then, the acceptance ratio for
                # a 1 <--> 2 swap is...
                #   (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2))
                # which depends only on f(x), since terms involving g(x) cancel.
                untempered_negative_energy_ignoring_ulp = self.tempered_log_prob_fn(
                    *pre_swap_replica_states)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            # Note: The untempered_log_prob_fn (if provided) is not included in
            # untempered_pre_swap_replica_target_log_prob, and hence does not factor
            # into energy_diff. Why? Because, it cancels out in the acceptance ratio.
            energy_diff = (untempered_negative_energy_ignoring_ulp -
                           mcmc_util.index_remapping_gather(
                               untempered_negative_energy_ignoring_ulp,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff * bu.left_justified_expand_dims_to(
                inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = bu.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                bu.left_justified_broadcast_to(swaps, replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _set_swapped_fields_to_nan(
                _swap_log_prob_and_maybe_grads(pre_swap_replica_results,
                                               post_swap_replica_states,
                                               inner_kernel))

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=seed,
                potential_energy=-untempered_negative_energy_ignoring_ulp,
            )

            return states, post_swap_kernel_results
Example #14
0
def infer_trajectories(observations,
                       initial_state_prior,
                       transition_fn,
                       observation_fn,
                       num_particles,
                       initial_state_proposal=None,
                       proposal_fn=None,
                       resample_fn=weighted_resampling.resample_systematic,
                       resample_criterion_fn=ess_below_threshold,
                       rejuvenation_kernel_fn=None,
                       num_transitions_per_observation=1,
                       seed=None,
                       name=None):  # pylint: disable=g-doc-args
    """Use particle filtering to sample from the posterior over trajectories.

  ${particle_filter_arg_str}
    seed: Python `int` seed for random ops.
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'infer_trajectories'`).
  Returns:
    trajectories: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing unbiased samples from the posterior distribution
      `p(latent_states | observations)`.
    incremental_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each timestep `t`. Note that
      (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  #### Examples

  **Tracking unknown position and velocity**: Let's consider tracking an object
  moving in a one-dimensional space. We'll define a dynamical system
  by specifying an `initial_state_prior`, a `transition_fn`,
  and `observation_fn`.

  The structure of the latent state space is determined by the prior
  distribution. Here, we'll define a state space that includes the object's
  current position and velocity:

  ```python
  initial_state_prior = tfd.JointDistributionNamed({
      'position': tfd.Normal(loc=0., scale=1.),
      'velocity': tfd.Normal(loc=0., scale=0.1)})
  ```

  The `transition_fn` specifies the evolution of the system. It should
  return a distribution over latent states of the same structure as the prior.
  Here, we'll assume that the position evolves according to the velocity,
  with a small random drift, and the velocity also changes slowly, following
  a random drift:

  ```python
  def transition_fn(_, previous_state):
    return tfd.JointDistributionNamed({
        'position': tfd.Normal(
            loc=previous_state['position'] + previous_state['velocity'],
            scale=0.1),
        'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)})
  ```

  The `observation_fn` specifies the process by which the system is observed
  at each time step. Let's suppose we observe only a noisy version of the =
  current position.

  ```python
    def observation_fn(_, state):
      return tfd.Normal(loc=state['position'], scale=0.1)
  ```

  Now let's track our object. Suppose we've been given observations
  corresponding to an initial position of `0.4` and constant velocity of `0.01`:

  ```python
  # Generate simulated observations.
  observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01),
                                  scale=0.1).sample()

  # Run particle filtering to sample plausible trajectories.
  (trajectories,  # {'position': [40, 1000], 'velocity': [40, 1000]}
   lps) = tfp.experimental.mcmc.infer_trajectories(
            observations=observed_positions,
            initial_state_prior=initial_state_prior,
            transition_fn=transition_fn,
            observation_fn=observation_fn,
            num_particles=1000)
  ```

  For all `i`, `trajectories['position'][:, i]` is a sample from the
  posterior over position sequences, given the observations:
  `p(state[0:T] | observations[0:T])`. Often, the sampled trajectories
  will be highly redundant in their earlier timesteps, because most
  of the initial particles have been discarded through resampling
  (this problem is known as 'particle degeneracy'; see section 3.5 of
  [Doucet and Johansen][1]).
  In such cases it may be useful to also consider the series of *filtering*
  distributions `p(state[t] | observations[:t])`, in which each latent state
  is inferred conditioned only on observations up to that point in time; these
  may be computed using `tfp.mcmc.experimental.particle_filter`.

  #### References

  [1] Arnaud Doucet and Adam M. Johansen. A tutorial on particle
      filtering and smoothing: Fifteen years later.
      _Handbook of nonlinear filtering_, 12(656-704), 2009.
      https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf

  """
    with tf.name_scope(name or 'infer_trajectories') as name:
        seed = SeedStream(seed, 'infer_trajectories')
        (particles, log_weights, parent_indices,
         incremental_log_marginal_likelihoods) = particle_filter(
             observations=observations,
             initial_state_prior=initial_state_prior,
             transition_fn=transition_fn,
             observation_fn=observation_fn,
             num_particles=num_particles,
             initial_state_proposal=initial_state_proposal,
             proposal_fn=proposal_fn,
             resample_fn=resample_fn,
             resample_criterion_fn=resample_criterion_fn,
             rejuvenation_kernel_fn=rejuvenation_kernel_fn,
             num_transitions_per_observation=num_transitions_per_observation,
             trace_fn=_default_trace_fn,
             seed=seed,
             name=name)
        weighted_trajectories = reconstruct_trajectories(
            particles, parent_indices)

        # Resample all steps of the trajectories using the final weights.
        resample_indices = resample_fn(log_probs=log_weights[-1],
                                       event_size=num_particles,
                                       sample_shape=(),
                                       seed=seed)
        trajectories = tf.nest.map_structure(
            lambda x: mcmc_util.index_remapping_gather(
                x,  # pylint: disable=g-long-lambda
                resample_indices,
                axis=1),
            weighted_trajectories)

        return trajectories, incremental_log_marginal_likelihoods
Example #15
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                self.target_log_prob_fn, inverse_temperatures)
            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`, and no seed
            # expected by `kernel.one_step`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we warn and fall back to the previous behavior.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is '
                    'deprecated. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel, self._seed_stream())

            # Now that we've constructed the TransitionKernel instance:
            # - If we were given a seed, we sanitize it to stateless and pass along
            #   to `kernel.one_step`. If it doesn't like that, we crash and propagate
            #   the error.  Rationale: The contract is stateless sampling given
            #   seed, and doing otherwise would not meet it.
            # - If not given a seed, we don't pass one along. This avoids breaking
            #   underlying kernels lacking a `seed` arg on `one_step`.
            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                seed = samplers.sanitize_seed(seed)
                inner_seed, swap_seed, logu_seed = samplers.split_seed(
                    seed, n=3, salt='remc_one_step')
                inner_kwargs = dict(seed=inner_seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                inner_kwargs = {}
                swap_seed, logu_seed = samplers.split_seed(self._seed_stream())
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                **inner_kwargs)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            return states, post_swap_kernel_results
Example #16
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                _make_replica_target_log_prob_fn(self.target_log_prob_fn,
                                                 inverse_temperatures),
                self._seed_stream())

            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = prefer_static.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = prefer_static.rank(
                pre_swap_replica_target_log_prob)
            num_replica = prefer_static.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            swaps = tf.cast(
                self.swap_proposal_fn(  # pylint: disable=not-callable
                    num_replica,
                    batch_shape=batch_shape,
                    seed=self._seed_stream()),
                dtype=tf.int32)
            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                tf.random.uniform(shape=replica_and_batch_shape,
                                  dtype=dtype,
                                  seed=self._seed_stream()))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
            )

            return states, post_swap_kernel_results