Exemplo n.º 1
0
    def test_diagonal_mass_matrix_no_distribute(self):
        """Nothing distributed. Make sure EchoKernel works."""
        kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
            EchoKernel(),
            tfp.experimental.stats.RunningVariance.from_stats(
                num_samples=10., mean=tf.zeros(3), variance=tf.ones(3)))
        state = tf.zeros(3)
        pkr = kernel.bootstrap_results(state)
        draws = np.random.randn(10, 3).astype(np.float32)

        def body(pkr_seed, draw):
            pkr, seed = pkr_seed
            seed, kernel_seed = samplers.split_seed(seed)
            _, pkr = kernel.one_step(draw, pkr, seed=kernel_seed)
            return (pkr, seed)

        (pkr,
         _), _ = mcmc_util.trace_scan(body,
                                      (pkr, samplers.sanitize_seed(self.key)),
                                      draws, lambda _: ())

        running_variance = pkr.running_variance[0]
        emp_mean = draws.sum(axis=0) / 20.
        emp_squared_residuals = (np.sum(
            (draws - emp_mean)**2, axis=0) + 10 * emp_mean**2 + 10)
        self.assertAllClose(emp_mean, running_variance.mean)
        self.assertAllClose(emp_squared_residuals,
                            running_variance.sum_squared_residuals)
Exemplo n.º 2
0
    def testComposite(self):
        auto_normal = auto_composite_tensor.auto_composite_tensor(
            tfd.Normal, omit_kwargs=('name', ))

        def _loop_fn(state, element):
            return state + element

        def _trace_fn(state):
            return [state, 2 * state, auto_normal(state, 0.1)]

        final_state, trace = util.trace_scan(loop_fn=_loop_fn,
                                             initial_state=0.,
                                             elems=[1., 2.],
                                             trace_fn=_trace_fn)

        self.assertAllClose([], tensorshape_util.as_list(final_state.shape))
        self.assertAllClose([2], tensorshape_util.as_list(trace[0].shape))
        self.assertAllClose([2], tensorshape_util.as_list(trace[1].shape))

        self.assertAllClose(3, final_state)
        self.assertAllClose([1, 3], trace[0])
        self.assertAllClose([2, 6], trace[1])

        self.assertIsInstance(trace[2], tfd.Normal)
        self.assertAllClose([1., 3.], trace[2].loc)
        self.assertAllClose([0.1, 0.1], trace[2].scale)
Exemplo n.º 3
0
 def testTraceCriterion(self, static_length):
   final_state, trace = self.evaluate(
       util.trace_scan(
           loop_fn=lambda state, element: state + element,
           initial_state=0,
           elems=[1, 2, 3, 4, 5, 6, 7],
           trace_fn=lambda state: state / 2,
           trace_criterion_fn=lambda state: tf.equal(state % 2, 0),
           static_trace_allocation_size=3 if static_length else None))
   self.assertAllClose(7 + 6 + 5 + 4 + 3 + 2 + 1, final_state)
   self.assertAllClose([3, 5, 14], trace)
Exemplo n.º 4
0
 def testConditionFn(self, static_length):
     final_state, trace = self.evaluate(
         util.trace_scan(
             loop_fn=lambda state, element: state + element,
             initial_state=0,
             elems=[1, 2, 3, 4, 5, 6, 7],
             trace_fn=lambda state: state / 2,
             condition_fn=lambda step, state, num_traced, trace: state < 9,
             static_trace_allocation_size=4 if static_length else None))
     self.assertAllClose(10, final_state)
     self.assertAllClose([.5, 1.5, 3, 5], trace)
Exemplo n.º 5
0
  def testBasic(self):

    def _loop_fn(state, element):
      return state + element

    def _trace_fn(state):
      return [state, state * 2]

    final_state, trace = util.trace_scan(
        loop_fn=_loop_fn, initial_state=0, elems=[1, 2], trace_fn=_trace_fn)

    self.assertAllClose([], tensorshape_util.as_list(final_state.shape))
    self.assertAllClose([2], tensorshape_util.as_list(trace[0].shape))
    self.assertAllClose([2], tensorshape_util.as_list(trace[1].shape))

    final_state, trace = self.evaluate([final_state, trace])

    self.assertAllClose(3, final_state)
    self.assertAllClose([1, 3], trace[0])
    self.assertAllClose([2, 6], trace[1])
Exemplo n.º 6
0
        def run(seed):
            dist_seed, *seeds = samplers.split_seed(seed, 11)
            dist = tfp_dist.Sharded(tfd.Sample(tfd.Normal(0., 1.), 3),
                                    shard_axis_name=self.axis_name)
            state = dist.sample(seed=dist_seed)
            kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
                EchoKernel(),
                tfp.experimental.stats.RunningVariance.from_stats(
                    num_samples=10., mean=tf.zeros(3), variance=tf.ones(3)))
            pkr = kernel.bootstrap_results(state)

            def body(draw_pkr, seed):
                _, pkr = draw_pkr
                draw_seed, step_seed = samplers.split_seed(seed)
                draw = dist.sample(seed=draw_seed)
                _, pkr = kernel.one_step(draw, pkr, seed=step_seed)
                return draw, pkr

            (_, pkr), draws = mcmc_util.trace_scan(
                body, (tf.zeros(dist.event_shape), pkr), seeds, lambda v: v[0])
            return draws, pkr
Exemplo n.º 7
0
def sample_chain(
    num_results,
    current_state,
    previous_kernel_results=None,
    kernel=None,
    num_burnin_steps=0,
    num_steps_between_results=0,
    trace_fn=lambda current_state, kernel_results: kernel_results,
    return_final_kernel_results=False,
    parallel_iterations=10,
    name=None,
):
    """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps.

  This function samples from an Markov chain at `current_state` and whose
  stationary distribution is governed by the supplied `TransitionKernel`
  instance (`kernel`).

  This function can sample from multiple chains, in parallel. (Whether or not
  there are multiple chains is dictated by the `kernel`.)

  The `current_state` can be represented as a single `Tensor` or a `list` of
  `Tensors` which collectively represent the current state.

  Since MCMC states are correlated, it is sometimes desirable to produce
  additional intermediate states, and then discard them, ending up with a set of
  states with decreased autocorrelation.  See [Owen (2017)][1]. Such "thinning"
  is made possible by setting `num_steps_between_results > 0`. The chain then
  takes `num_steps_between_results` extra steps between the steps that make it
  into the results. The extra steps are never materialized (in calls to
  `sess.run`), and thus do not increase memory requirements.

  Warning: when setting a `seed` in the `kernel`, ensure that `sample_chain`'s
  `parallel_iterations=1`, otherwise results will not be reproducible.

  In addition to returning the chain state, this function supports tracing of
  auxiliary variables used by the kernel. The traced values are selected by
  specifying `trace_fn`. By default, all kernel results are traced but in the
  future the default will be changed to no results being traced, so plan
  accordingly. See below for some examples of this feature.

  Args:
    num_results: Integer number of Markov chain draws.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s).
    previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s
      representing internal calculations made within the previous call to this
      function (or as returned by `bootstrap_results`).
    kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step
      of the Markov chain.
    num_burnin_steps: Integer number of chain steps to take before starting to
      collect results.
      Default value: 0 (i.e., no burn-in).
    num_steps_between_results: Integer number of chain steps between collecting
      a result. Only one out of every `num_steps_between_samples + 1` steps is
      included in the returned results.  The number of returned chain states is
      still equal to `num_results`.  Default value: 0 (i.e., no thinning).
    trace_fn: A callable that takes in the current chain state and the previous
      kernel results and return a `Tensor` or a nested collection of `Tensor`s
      that is then traced along with the chain state.
    return_final_kernel_results: If `True`, then the final kernel results are
      returned alongside the chain state and the trace specified by the
      `trace_fn`.
    parallel_iterations: The number of iterations allowed to run in parallel. It
      must be a positive integer. See `tf.while_loop` for more details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "mcmc_sample_chain").

  Returns:
    checkpointable_states_and_trace: if `return_final_kernel_results` is
      `True`. The return value is an instance of
      `CheckpointableStatesAndTrace`.
    all_states: if `return_final_kernel_results` is `False` and `trace_fn` is
      `None`. The return value is a `Tensor` or Python list of `Tensor`s
      representing the state(s) of the Markov chain(s) at each result step. Has
      same shape as input `current_state` but with a prepended
      `num_results`-size dimension.
    states_and_trace: if `return_final_kernel_results` is `False` and
      `trace_fn` is not `None`. The return value is an instance of
      `StatesAndTrace`.

  #### Examples

  ##### Sample from a diagonal-variance Gaussian.

  I.e.,

  ```none
  for i=1..n:
    x[i] ~ MultivariateNormal(loc=0, scale=diag(true_stddev))  # likelihood
  ```

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  dims = 10
  true_stddev = np.sqrt(np.linspace(1., 3., dims))
  likelihood = tfd.MultivariateNormalDiag(loc=0., scale_diag=true_stddev)

  states = tfp.mcmc.sample_chain(
      num_results=1000,
      num_burnin_steps=500,
      current_state=tf.zeros(dims),
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=likelihood.log_prob,
        step_size=0.5,
        num_leapfrog_steps=2),
      trace_fn=None)

  sample_mean = tf.reduce_mean(states, axis=0)
  # ==> approx all zeros

  sample_stddev = tf.sqrt(tf.reduce_mean(
      tf.squared_difference(states, sample_mean),
      axis=0))
  # ==> approx equal true_stddev
  ```

  ##### Sampling from factor-analysis posteriors with known factors.

  I.e.,

  ```none
  # prior
  w ~ MultivariateNormal(loc=0, scale=eye(d))
  for i=1..n:
    # likelihood
    x[i] ~ Normal(loc=w^T F[i], scale=1)
  ```

  where `F` denotes factors.

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  # Specify model.
  def make_prior(dims):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims))

  def make_likelihood(weights, factors):
    return tfd.MultivariateNormalDiag(
        loc=tf.matmul(weights, factors, adjoint_b=True))

  def joint_log_prob(num_weights, factors, x, w):
    return (make_prior(num_weights).log_prob(w) +
            make_likelihood(w, factors).log_prob(x))

  def unnormalized_log_posterior(w):
    # Posterior is proportional to: `p(W, X=x | factors)`.
    return joint_log_prob(num_weights, factors, x, w)

  # Setup data.
  num_weights = 10 # == d
  num_factors = 40 # == n
  num_chains = 100

  weights = make_prior(num_weights).sample(1)
  factors = tf.random_normal([num_factors, num_weights])
  x = make_likelihood(weights, factors).sample()

  # Sample from Hamiltonian Monte Carlo Markov Chain.

  # Get `num_results` samples from `num_chains` independent chains.
  chains_states, kernels_results = tfp.mcmc.sample_chain(
      num_results=1000,
      num_burnin_steps=500,
      current_state=tf.zeros([num_chains, num_weights], name='init_weights'),
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=unnormalized_log_posterior,
        step_size=0.1,
        num_leapfrog_steps=2))

  # Compute sample stats.
  sample_mean = tf.reduce_mean(chains_states, axis=[0, 1])
  # ==> approx equal to weights

  sample_var = tf.reduce_mean(
      tf.squared_difference(chains_states, sample_mean),
      axis=[0, 1])
  # ==> less than 1
  ```

  ##### Custom tracing functions.

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  likelihood = tfd.Normal(loc=0., scale=1.)

  def sample_chain(trace_fn):
    return tfp.mcmc.sample_chain(
      num_results=1000,
      num_burnin_steps=500,
      current_state=0.,
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=likelihood.log_prob,
        step_size=0.5,
        num_leapfrog_steps=2),
      trace_fn=trace_fn)

  def trace_log_accept_ratio(states, previous_kernel_results):
    return previous_kernel_results.log_accept_ratio

  def trace_everything(states, previous_kernel_results):
    return previous_kernel_results

  _, log_accept_ratio = sample_chain(trace_fn=trace_log_accept_ratio)
  _, kernel_results = sample_chain(trace_fn=trace_everything)

  acceptance_prob = tf.math.exp(tf.minimum(log_accept_ratio_, 0.))
  # Equivalent to, but more efficient than:
  acceptance_prob = tf.math.exp(tf.minimum(
      kernel_results.log_accept_ratio_, 0.))
  ```

  #### References

  [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler.
       _Technical Report_, 2017.
       http://statweb.stanford.edu/~owen/reports/bestthinning.pdf
  """
    if not kernel.is_calibrated:
        warnings.warn(
            "supplied `TransitionKernel` is not calibrated. Markov "
            "chain may not converge to intended target distribution.")
    with tf.name_scope(name or "mcmc_sample_chain"):
        num_results = tf.convert_to_tensor(num_results,
                                           dtype=tf.int32,
                                           name="num_results")
        num_burnin_steps = tf.convert_to_tensor(num_burnin_steps,
                                                dtype=tf.int32,
                                                name="num_burnin_steps")
        num_steps_between_results = tf.convert_to_tensor(
            num_steps_between_results,
            dtype=tf.int32,
            name="num_steps_between_results")
        current_state = tf.nest.map_structure(
            lambda x: tf.convert_to_tensor(x, name="current_state"),
            current_state)
        if previous_kernel_results is None:
            previous_kernel_results = kernel.bootstrap_results(current_state)

        if trace_fn is None:
            # It simplifies the logic to use a dummy function here.
            trace_fn = lambda *args: ()
            no_trace = True
        else:
            no_trace = False
        if trace_fn is sample_chain.__defaults__[4]:
            warnings.warn(
                "Tracing all kernel results by default is deprecated. Set "
                "the `trace_fn` argument to None (the future default "
                "value) or an explicit callback that traces the values "
                "you are interested in.")

        def _trace_scan_fn(state_and_results, num_steps):
            next_state, current_kernel_results = mcmc_util.smart_for_loop(
                loop_num_iter=num_steps,
                body_fn=kernel.one_step,
                initial_loop_vars=list(state_and_results),
                parallel_iterations=parallel_iterations)
            return next_state, current_kernel_results

        (_, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
            loop_fn=_trace_scan_fn,
            initial_state=(current_state, previous_kernel_results),
            elems=tf.one_hot(indices=0,
                             depth=num_results,
                             on_value=1 + num_burnin_steps,
                             off_value=1 + num_steps_between_results,
                             dtype=tf.int32),
            # pylint: disable=g-long-lambda
            trace_fn=lambda state_and_results:
            (state_and_results[0], trace_fn(*state_and_results)),
            # pylint: enable=g-long-lambda
            parallel_iterations=parallel_iterations)

        if return_final_kernel_results:
            return CheckpointableStatesAndTrace(
                all_states=all_states,
                trace=trace,
                final_kernel_results=final_kernel_results)
        else:
            if no_trace:
                return all_states
            else:
                return StatesAndTrace(all_states=all_states, trace=trace)
Exemplo n.º 8
0
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        resample_fn=weighted_resampling.resample_systematic,
        resample_criterion_fn=smc_kernel.ess_below_threshold,
        unbiased_gradients=True,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        num_transitions_per_observation=1,
        trace_fn=_default_trace_fn,
        trace_criterion_fn=_always_trace,
        static_trace_allocation_size=None,
        parallel_iterations=1,
        seed=None,
        name=None):  # pylint: disable=g-doc-args
    """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
    trace_fn: Python `callable` defining the values to be traced at each step,
      with signature `traced_values = trace_fn(weighted_particles, results)`
      in which the first argument is an instance of
      `tfp.experimental.mcmc.WeightedParticles` and the second an instance of
      `SequentialMonteCarloResults` tuple, and the return value is a structure
      of `Tensor`s.
      Default value: `lambda s, r: (s.particles, s.log_weights,
      r.parent_indices, r.incremental_log_marginal_likelihood)`
    trace_criterion_fn: optional Python `callable` with signature
      `trace_this_step = trace_criterion_fn(weighted_particles, results)` taking
      the same arguments as `trace_fn` and returning a boolean `Tensor`. If
      `None`, only values from the final step are returned.
      Default value: `lambda *_: True` (trace every step).
    static_trace_allocation_size: Optional Python `int` size of trace to
      allocate statically. This should be an upper bound on the number of steps
      traced and is used only when the length cannot be
      statically inferred (for example, if a `trace_criterion_fn` is specified).
      It is primarily intended for contexts where static shapes are required,
      such as in XLA-compiled code.
      Default value: `None`.
    parallel_iterations: Passed to the internal `tf.while_loop`.
      Default value: `1`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'particle_filter'`).
  Returns:
    traced_results: A structure of Tensors as returned by `trace_fn`. If
      `trace_criterion_fn==None`, this is computed from the final step;
      otherwise, each Tensor will have initial dimension `num_steps_traced`
      and stacks the traced results across all steps.

  #### References

  [1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle
      Filtering without Modifying the Forward Pass. _arXiv preprint
      arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314
  """

    init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_observation_steps = ps.size0(tf.nest.flatten(observations)[0])
        num_timesteps = (1 + num_transitions_per_observation *
                         (num_observation_steps - 1))

        # If trace criterion is `None`, we'll return only the final results.
        never_trace = lambda *_: False
        if trace_criterion_fn is None:
            static_trace_allocation_size = 0
            trace_criterion_fn = never_trace

        initial_weighted_particles = _particle_filter_initial_weighted_particles(
            observations=observations,
            observation_fn=observation_fn,
            initial_state_prior=initial_state_prior,
            initial_state_proposal=initial_state_proposal,
            num_particles=num_particles,
            seed=init_seed)
        propose_and_update_log_weights_fn = (
            _particle_filter_propose_and_update_log_weights_fn(
                observations=observations,
                transition_fn=transition_fn,
                proposal_fn=proposal_fn,
                observation_fn=observation_fn,
                num_transitions_per_observation=num_transitions_per_observation
            ))

        kernel = smc_kernel.SequentialMonteCarlo(
            propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
            resample_fn=resample_fn,
            resample_criterion_fn=resample_criterion_fn,
            unbiased_gradients=unbiased_gradients)

        # Use `trace_scan` rather than `sample_chain` directly because the latter
        # would force us to trace the state history (with or without thinning),
        # which is not always appropriate.
        def seeded_one_step(seed_state_results, _):
            seed, state, results = seed_state_results
            one_step_seed, next_seed = samplers.split_seed(seed)
            next_state, next_results = kernel.one_step(state,
                                                       results,
                                                       seed=one_step_seed)
            return next_seed, next_state, next_results

        final_seed_state_result, traced_results = mcmc_util.trace_scan(
            loop_fn=seeded_one_step,
            initial_state=(
                loop_seed, initial_weighted_particles,
                kernel.bootstrap_results(initial_weighted_particles)),
            elems=tf.ones([num_timesteps]),
            trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[
                1:]),
            trace_criterion_fn=(
                lambda seed_state_results: trace_criterion_fn(  # pylint: disable=g-long-lambda
                    *seed_state_results[1:])),
            static_trace_allocation_size=static_trace_allocation_size,
            parallel_iterations=parallel_iterations)

        if trace_criterion_fn is never_trace:
            # Return results from just the final step.
            traced_results = trace_fn(*final_seed_state_result[1:])

        return traced_results