Exemple #1
0
    def fn(self, *args, **kwargs):
        if safe_value == 'safe_sample' or make_arg0_safe:  # Only if needed.
            safe_val = tf.stop_gradient(self.safe_sample_fn(self.distribution))

        validity_mask = tf.convert_to_tensor(self.validity_mask)
        if make_arg0_safe:
            x = args[0]
            safe_x = tf.where(
                _add_event_dims_to_mask(validity_mask, dist=self), x, safe_val)
            args = (safe_x, ) + args[1:]

        val = getattr(self.distribution, fn_name)(*args, **kwargs)
        if n_event_shapes:
            validity_mask = tf.reshape(
                validity_mask,
                ps.concat(
                    [ps.shape(validity_mask)] +
                    [ps.ones_like(self.event_shape_tensor())] * n_event_shapes,
                    axis=0))
        if safe_value == 'safe_sample':
            sentinel = tf.cast(safe_val, val.dtype)
        else:
            sentinel = tf.cast(safe_value, val.dtype)
        return tf.where(validity_mask, val, sentinel)
  def _sample_n(self, n, seed=None):
    init_seed, scan_seed, observation_seed = samplers.split_seed(
        seed, n=3, salt='HiddenMarkovModel')

    transition_batch_shape = self.transition_distribution.batch_shape_tensor()
    num_states = transition_batch_shape[-1]

    batch_shape = self.batch_shape_tensor()
    batch_size = ps.reduce_prod(batch_shape)
    # The batch sizes of the underlying initial distributions and
    # transition distributions might not match the batch size of
    # the HMM distribution.
    # As a result we need to ask for more samples from the
    # underlying distributions and then reshape the results into
    # the correct batch size for the HMM.
    init_repeat = (
        ps.reduce_prod(batch_shape) //
        ps.reduce_prod(self._initial_distribution.batch_shape_tensor()))
    init_state = self._initial_distribution.sample(n * init_repeat,
                                                   seed=init_seed)
    init_state = tf.reshape(init_state, [n, batch_size])
    # init_state :: n batch_size

    transition_repeat = (
        ps.reduce_prod(batch_shape) // ps.reduce_prod(
            transition_batch_shape[:-1]))

    init_shape = init_state.shape

    def generate_step(state_and_seed, _):
      """Take a single step in Markov chain."""
      state, seed = state_and_seed
      sample_seed, next_seed = samplers.split_seed(seed)

      gen = self._transition_distribution.sample(n * transition_repeat,
                                                 seed=sample_seed)
      # gen :: (n * transition_repeat) transition_batch

      new_states = tf.reshape(gen,
                              [n, batch_size, num_states])

      # new_states :: n batch_size num_states

      old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32)

      # old_states :: n batch_size num_states

      result = tf.reduce_sum(old_states_one_hot * new_states, axis=-1)
      # We know that `generate_step` must preserve the shape of the
      # tensor of states of each state. This is because
      # the transition matrix must be square. But TensorFlow might
      # not know this so we explicitly tell it that the result has the
      # same shape.
      tensorshape_util.set_shape(result, init_shape)
      return result, next_seed

    def _scan_multiple_steps():
      """Take multiple steps with tf.scan."""
      dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
      hidden_states, _ = tf.scan(generate_step, dummy_index,
                                 initializer=(init_state, scan_seed))

      # TODO(b/115618503): add/use prepend_initializer to tf.scan
      return tf.concat([[init_state],
                        hidden_states], axis=0)
    hidden_states = ps.cond(
        self._num_steps > 1,
        _scan_multiple_steps,
        lambda: init_state[tf.newaxis, ...])

    hidden_one_hot = tf.one_hot(hidden_states, num_states,
                                dtype=self._observation_distribution.dtype)
    # hidden_one_hot :: num_steps n batch_size num_states

    # The observation distribution batch size might not match
    # the required batch size so as with the initial and
    # transition distributions we generate more samples and
    # reshape.
    observation_repeat = tf.maximum(
        batch_size // ps.reduce_prod(
            self._observation_distribution.batch_shape_tensor()[:-1]),
        1)

    if self._time_varying_observation_distribution:
      possible_observations = self._observation_distribution.sample(
          [observation_repeat * n], seed=observation_seed)
      # possible observations needs to have num_steps moved to the beginning.
      possible_observations = distribution_util.move_dimension(
          possible_observations,
          -(tf.size(self._observation_distribution.event_shape_tensor()) + 2),
          0)
    else:
      possible_observations = self._observation_distribution.sample(
          [self._num_steps, observation_repeat * n], seed=observation_seed)

    inner_shape = self._observation_distribution.event_shape_tensor()

    # possible_observations :: num_steps (observation_repeat * n)
    #                          observation_batch[:-1] num_states inner_shape

    possible_observations = tf.reshape(
        possible_observations,
        ps.concat([[self._num_steps, n],
                   batch_shape,
                   [num_states],
                   inner_shape], axis=0))

    # possible_observations :: steps n batch_size num_states inner_shape

    hidden_one_hot = tf.reshape(hidden_one_hot,
                                ps.concat([[self._num_steps, n],
                                           batch_shape,
                                           [num_states],
                                           ps.ones_like(inner_shape)],
                                          axis=0))

    # hidden_one_hot :: steps n batch_size num_states "inner_shape"

    observations = tf.reduce_sum(
        hidden_one_hot * possible_observations,
        axis=-1 - ps.size(inner_shape))
    # observations :: steps n batch_size inner_shape

    observations = distribution_util.move_dimension(observations, 0,
                                                    1 + ps.size(batch_shape))
    # returned :: n batch_shape steps inner_shape

    return observations
def sample_sequential_monte_carlo(
        prior_log_prob_fn,
        likelihood_log_prob_fn,
        current_state,
        max_num_steps=25,
        max_stage=100,
        make_kernel_fn=make_rwmh_kernel_fn,
        tuning_fn=simple_heuristic_tuning,
        make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn,
        ess_threshold_ratio=0.5,
        parallel_iterations=10,
        seed=None,
        name=None):
    """Runs Sequential Monte Carlo to sample from the posterior distribution.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'prior' distribution:

    `exp(prior_log_prob_fn(x))`

  and the target 'posterior' distribution:

    `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`,

  by mutating a collection of MC samples (i.e., particles). The approach is also
  known as Particle Filter in some literature. The current implemenetation is
  largely based on  Del Moral et al [1], which adapts the tempering sequence
  adaptively (base on the effective sample size) and the scaling of the mutation
  kernel (base on the sample covariance of the particles) at each stage.

  Args:
    prior_log_prob_fn: Python callable that returns the log density of the
      prior distribution.
    likelihood_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the likelihood distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    max_num_steps: The maximum number of kernel transition steps in one mutation
      of the MC samples. Note that the actual number of steps in one mutation is
      tuned during sampling and likely lower than the max_num_step.
    max_stage: Integer number of the stage for increasing the temperature
      from 0 to 1.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_sequential_monte_carlo` creates a new `target_log_prob_fn`
      which is an interpolation between the supplied `target_log_prob_fn` and
      `proposal_log_prob_fn`; it is this interpolated function which is used as
      an argument to `make_kernel_fn`.
    tuning_fn: Python `callable` which takes the number of steps, the log
      scaling, and the log acceptance ratio from the last mutation and output
      the number of steps and log scaling for the next mutation.
    make_tempered_target_log_prob_fn: Python `callable` that takes the
      `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures`
      and creates a `target_log_prob_fn` `callable` that pass to
      `make_kernel_fn`.
    ess_threshold_ratio: Target ratio for effective sample size.
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    seed: Python integer or TFP seedstream to seed the random number generator.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_sequential_monte_carlo').

  Returns:
    n_stage: Number of the mutation stage SMC ran.
    final_state: `Tensor` or Python `list` of `Tensor`s representing the
      final state(s) of the Markov chain(s). The output are the posterior
      samples.
    final_kernel_results: `collections.namedtuple` of internal calculations used
      to advance the chain.

  #### References

  [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential
      Monte Carlo method for approximate Bayesian computation.
      _Statistics and Computing_, 22.5(1009-1020), 2012.

  """

    with tf.name_scope(name or 'sample_sequential_monte_carlo'):
        seed_stream = SeedStream(seed, salt='smc_seed')

        unwrap_state_list = not tf.nest.is_nested(current_state)
        if unwrap_state_list:
            current_state = [current_state]
        current_state = [
            tf.convert_to_tensor(s, dtype_hint=tf.float32)
            for s in current_state
        ]

        # Initial preprocessing at Stage 0
        likelihood_log_prob = likelihood_log_prob_fn(*current_state)

        likelihood_rank = ps.rank(likelihood_log_prob)
        dimension = ps.reduce_sum([
            ps.reduce_prod(ps.shape(x)[likelihood_rank:])
            for x in current_state
        ])

        # We infer the particle shapes from the resulting likelihood:
        # [num_particles, b1, ..., bN]
        particle_shape = ps.shape(likelihood_log_prob)
        num_particles, batch_shape = particle_shape[0], particle_shape[1:]
        effective_sample_size_threshold = tf.cast(
            num_particles * ess_threshold_ratio, tf.int32)

        # TODO(b/152412213): Revisit this default parameter.
        # Default to the optimal scaling of a random walk kernel for a d-dimensional
        # normal distributed targets: 2.38 ** 2 / d.
        # For more detail see:
        # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of
        # random walk Metropolis algorithms. _The annals of applied probability_.
        # 1997;7(1):110-20.
        scale_start = (tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) /
                       tf.constant(dimension, dtype=likelihood_log_prob.dtype))

        inverse_temperature = tf.zeros(batch_shape,
                                       dtype=likelihood_log_prob.dtype)
        scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(
            scale_start, 1.)
        kernel = make_kernel_fn(make_tempered_target_log_prob_fn(
            prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature),
                                current_state,
                                scalings,
                                seed=seed_stream)
        pkr = kernel.bootstrap_results(current_state)
        _, kernel_target_log_prob = gather_mh_like_result(pkr)

        particle_info = ParticleInfo(
            log_accept_prob=ps.zeros_like(likelihood_log_prob),
            log_scalings=tf.math.log(scalings),
            tempered_log_prob=kernel_target_log_prob,
            likelihood_log_prob=likelihood_log_prob,
        )

        current_pkr = SMCResults(
            num_steps=tf.convert_to_tensor(max_num_steps,
                                           dtype=tf.int32,
                                           name='num_steps'),
            inverse_temperature=inverse_temperature,
            log_marginal_likelihood=tf.zeros_like(inverse_temperature),
            particle_info=particle_info)

        def update_weights_temperature(inverse_temperature,
                                       likelihood_log_prob):
            """Calculate the next inverse temperature and update weights."""
            likelihood_diff = likelihood_log_prob - tf.reduce_max(
                likelihood_log_prob, axis=0)

            def _body_fn(new_beta, upper_beta, lower_beta, eff_size,
                         log_weights):
                """One iteration of the temperature and weight update."""
                new_beta = (lower_beta + upper_beta) / 2.0
                log_weights = (new_beta -
                               inverse_temperature) * likelihood_diff
                log_weights_norm = tf.math.log_softmax(log_weights, axis=0)
                eff_size = tf.cast(
                    tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm,
                                                     axis=0)), tf.int32)
                upper_beta = tf.where(
                    eff_size < effective_sample_size_threshold, new_beta,
                    upper_beta)
                lower_beta = tf.where(
                    eff_size < effective_sample_size_threshold, lower_beta,
                    new_beta)
                return new_beta, upper_beta, lower_beta, eff_size, log_weights

            def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_):  # pylint: disable=unused-argument
                # TODO(junpenglao): revisit threshold below to be dtype specific.
                threshold = 1e-6
                return (tf.math.reduce_any(upper_beta - lower_beta > threshold)
                        & tf.math.reduce_any(
                            eff_size != effective_sample_size_threshold))

            (new_beta, upper_beta, lower_beta, eff_size,
             log_weights) = tf.while_loop(  # pylint: disable=unused-variable
                 cond=_cond_fn,
                 body=_body_fn,
                 loop_vars=(tf.zeros_like(inverse_temperature),
                            tf.fill(ps.shape(inverse_temperature),
                                    tf.constant(2, inverse_temperature.dtype)),
                            inverse_temperature,
                            tf.zeros_like(inverse_temperature, dtype=tf.int32),
                            tf.zeros_like(likelihood_diff)),
                 parallel_iterations=parallel_iterations)

            log_weights = tf.where(new_beta < 1., log_weights,
                                   (1. - inverse_temperature) *
                                   likelihood_diff)
            marginal_loglike_ = reduce_logmeanexp(
                (new_beta - inverse_temperature) * likelihood_log_prob, axis=0)
            new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.)

            return marginal_loglike_, new_inverse_temperature, log_weights

        def mutate(current_state, log_scalings, num_steps,
                   inverse_temperature):
            """Mutate the state using a Transition kernel."""
            with tf.name_scope('mutate_states'):
                scalings = tf.exp(log_scalings)
                kernel = make_kernel_fn(make_tempered_target_log_prob_fn(
                    prior_log_prob_fn, likelihood_log_prob_fn,
                    inverse_temperature),
                                        current_state,
                                        scalings,
                                        seed=seed_stream)
                pkr = kernel.bootstrap_results(current_state)
                kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)

                def mutate_onestep(i, state, pkr, log_accept_prob_sum):
                    next_state, next_kernel_results = kernel.one_step(
                        state, pkr)
                    kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)
                    log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.)
                    log_accept_prob_sum = log_add_exp(log_accept_prob_sum,
                                                      log_accept_prob)
                    return i + 1, next_state, next_kernel_results, log_accept_prob_sum

                (
                    _, next_state, next_kernel_results, log_accept_prob_sum
                ) = tf.while_loop(
                    cond=lambda i, *args: i < num_steps,
                    body=mutate_onestep,
                    loop_vars=(
                        tf.zeros([], dtype=tf.int32),
                        current_state,
                        pkr,
                        # we accumulate the acceptance probability in log space.
                        tf.fill(
                            ps.shape(kernel_log_accept_ratio),
                            tf.constant(-np.inf,
                                        kernel_log_accept_ratio.dtype))),
                    parallel_iterations=parallel_iterations)
                _, kernel_target_log_prob = gather_mh_like_result(
                    next_kernel_results)
                avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log(
                    tf.cast(num_steps + 1, log_accept_prob_sum.dtype))
                return (next_state, avg_log_accept_prob_per_particle,
                        kernel_target_log_prob)

        # One SMC steps.
        def smc_body_fn(stage, state, smc_kernel_result):
            """Run one stage of SMC with constant temperature."""
            (new_marginal, new_inv_temperature,
             log_weights) = update_weights_temperature(
                 smc_kernel_result.inverse_temperature,
                 smc_kernel_result.particle_info.likelihood_log_prob)
            # TODO(b/152412213) Use a tf.scan to better collect debug info.
            if PRINT_DEBUG:
                tf.print(
                    'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:',
                    smc_kernel_result.num_steps, 'accept:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_accept_prob,
                            axis=0)), 'scaling:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_scalings,
                            axis=0)))
            (resampled_state,
             resampled_particle_info), _ = resample_particle_and_info(
                 (state, smc_kernel_result.particle_info),
                 log_weights,
                 seed=seed_stream)
            next_num_steps, next_log_scalings = tuning_fn(
                smc_kernel_result.num_steps,
                resampled_particle_info.log_scalings,
                resampled_particle_info.log_accept_prob)
            # Skip tuning at stage 0.
            next_num_steps = tf.where(stage == 0, smc_kernel_result.num_steps,
                                      next_num_steps)
            next_log_scalings = tf.where(stage == 0,
                                         resampled_particle_info.log_scalings,
                                         next_log_scalings)
            next_num_steps = tf.clip_by_value(next_num_steps, 2, max_num_steps)

            next_state, log_accept_prob, tempered_log_prob = mutate(
                resampled_state, next_log_scalings, next_num_steps,
                new_inv_temperature)
            next_pkr = SMCResults(
                num_steps=next_num_steps,
                inverse_temperature=new_inv_temperature,
                log_marginal_likelihood=(
                    new_marginal + smc_kernel_result.log_marginal_likelihood),
                particle_info=ParticleInfo(
                    log_accept_prob=log_accept_prob,
                    log_scalings=next_log_scalings,
                    tempered_log_prob=tempered_log_prob,
                    likelihood_log_prob=likelihood_log_prob_fn(*next_state),
                ))
            return stage + 1, next_state, next_pkr

        (n_stage, final_state, final_kernel_results) = tf.while_loop(
            cond=lambda i, state, pkr: (  # pylint: disable=g-long-lambda
                (i < max_stage) & tf.reduce_any(pkr.inverse_temperature < 1.)),
            body=smc_body_fn,
            loop_vars=(tf.zeros([],
                                dtype=tf.int32), current_state, current_pkr),
            parallel_iterations=parallel_iterations)
        if unwrap_state_list:
            final_state = final_state[0]
        return n_stage, final_state, final_kernel_results
Exemple #4
0
 def test_ones_like(self):
     x = tf1.placeholder_with_default(tf.ones([2], dtype=tf.float32),
                                      shape=None)
     self.assertEqual(dtype_util.convert_to_dtype(ps.ones_like(x)),
                      tf.float32)
Exemple #5
0
def _left_doubling_increments(batch_shape,
                              max_doublings,
                              step_size,
                              seed=None,
                              name=None):
    """Computes the doubling increments for the left end point.

  The doubling procedure expands an initial interval to find a superset of the
  true slice. At each doubling iteration, the interval width is doubled to
  either the left or the right hand side with equal probability.
  If, initially, the left end point is at `L(0)` and the width of the
  interval is `w(0)`, then the left end point and the width at the
  k-th iteration (denoted L(k) and w(k) respectively) are given by the following
  recursions:

  ```none
  w(k) = 2 * w(k-1)
  L(k) = L(k-1) - w(k-1) * X_k, X_k ~ Bernoulli(0.5)
  or, L(0) - L(k) = w(0) Sum(2^i * X(i+1), 0 <= i < k)
  ```

  This function computes the sequence of `L(0)-L(k)` and `w(k)` for k between 0
  and `max_doublings` independently for each chain.

  Args:
    batch_shape: Positive int32 `tf.Tensor`. The batch shape.
    max_doublings: Scalar positive int32 `tf.Tensor`. The maximum number of
      doublings to consider.
    step_size: A real `tf.Tensor` with shape compatible with [num_chains].
      The size of the initial interval.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    left_increments: A tensor of shape (max_doublings+1, batch_shape). The
      relative position of the left end point after the doublings.
    widths: A tensor of shape (max_doublings+1, ones_like(batch_shape)). The
      widths of the intervals at each stage of the doubling.
  """
    with tf.name_scope(name or 'left_doubling_increments'):
        step_size = tf.convert_to_tensor(value=step_size)
        dtype = dtype_util.base_dtype(step_size.dtype)
        # Output shape of the left increments tensor.
        output_shape = ps.concat(([max_doublings + 1], batch_shape), axis=0)
        # A sample realization of X_k.
        expand_left = bernoulli_lib.Bernoulli(0.5, dtype=dtype).sample(
            sample_shape=output_shape, seed=seed)

        # The widths of the successive intervals. Starts with 1.0 and ends with
        # 2^max_doublings.
        width_multipliers = tf.cast(2**tf.range(0, max_doublings + 1),
                                    dtype=dtype)
        # Output shape of the `widths` tensor.
        widths_shape = ps.concat(
            ([max_doublings + 1], ps.ones_like(batch_shape)), axis=0)
        width_multipliers = tf.reshape(width_multipliers, shape=widths_shape)
        # Widths shape is [max_doublings + 1, 1, 1, 1...].
        widths = width_multipliers * step_size

        # Take the cumulative sum of the left side increments in slice width to give
        # the resulting distance from the initial lower bound.
        left_increments = tf.cumsum(widths * expand_left,
                                    exclusive=True,
                                    axis=0)
        return left_increments, widths
 def make_reference_values(event_shape):
   dist_shape = ps.concat([batch_shape, event_shape], axis=0)
   x = tf.reshape([-4., -2., 0., 2., 4.],
                  ps.concat([[5], ps.ones_like(dist_shape)], axis=0))
   return tf.broadcast_to(x, ps.concat([[5], dist_shape], axis=0))