def autoregressive_observation_fn(step, _, observation_history=None):
     loc = 0.
     if observation_history is not None:
         num_terms = prefer_static.minimum(step, len(weights))
         usable_weights = tf.convert_to_tensor(weights)[-num_terms:]
         loc = tf.reduce_sum(usable_weights * observation_history)
     return tfd.Normal(loc, 1.0)
Esempio n. 2
0
def _update_loop_variables(step, current_step_results,
                           accumulated_traced_results, trace_fn,
                           step_indices_to_trace, num_steps_traced):
    """Update the loop state to reflect a step of filtering."""

    # Write particles, indices, and likelihoods to their respective arrays.
    trace_this_step = True
    if step_indices_to_trace is not None:
        trace_this_step = ps.equal(
            step_indices_to_trace[ps.minimum(
                num_steps_traced,
                ps.cast(ps.size0(step_indices_to_trace) - 1, dtype=np.int32))],
            step)
    num_steps_traced, accumulated_traced_results = ps.cond(
        trace_this_step,
        lambda: (
            num_steps_traced + 1,  # pylint: disable=g-long-lambda
            tf.nest.map_structure(lambda x, y: x.write(num_steps_traced, y),
                                  accumulated_traced_results,
                                  trace_fn(current_step_results))),
        lambda: (num_steps_traced, accumulated_traced_results))

    return ParticleFilterLoopVariables(
        step=step + 1,
        previous_step_results=current_step_results,
        accumulated_traced_results=accumulated_traced_results,
        num_steps_traced=num_steps_traced)
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'dual_averaging_step_size_adaptation',
                                    'bootstrap_results')):
            inner_results = self.inner_kernel.bootstrap_results(init_state)
            step_size = self.step_size_getter_fn(inner_results)

            log_accept_prob = self.log_accept_prob_getter_fn(inner_results)

            state_parts = tf.nest.flatten(init_state)
            step_size_parts = tf.nest.flatten(step_size)
            dtype = dtype_util.common_dtype(step_size_parts, tf.float32)
            error_sum, log_averaging_step, log_shrinkage_target = [], [], []
            for state_part, step_size_part in zip(state_parts,
                                                  step_size_parts):
                num_reduce_dims = prefer_static.minimum(
                    prefer_static.rank(log_accept_prob),
                    prefer_static.rank(state_part) -
                    prefer_static.rank(step_size_part))
                reduced_log_accept_prob = reduce_logmeanexp(
                    log_accept_prob, axis=prefer_static.range(num_reduce_dims))
                reduce_indices = get_differing_dims(reduced_log_accept_prob,
                                                    step_size_part)
                reduced_log_accept_prob = reduce_logmeanexp(
                    reduced_log_accept_prob,
                    axis=reduce_indices,
                    keepdims=True)
                error_sum.append(
                    tf.zeros_like(reduced_log_accept_prob, dtype=dtype))
                log_averaging_step.append(
                    tf.zeros_like(step_size_part, dtype=dtype))

                if self._parameters['shrinkage_target'] is None:
                    log_shrinkage_target.append(
                        float(np.log(10.)) + tf.math.log(step_size_part))
                else:
                    log_shrinkage_target.append(
                        tf.math.log(
                            tf.cast(self._parameters['shrinkage_target'],
                                    dtype)))

            return DualAveragingStepSizeAdaptationResults(
                inner_results=inner_results,
                step=tf.constant(0, dtype=tf.int32),
                target_accept_prob=tf.cast(
                    self.parameters['target_accept_prob'],
                    log_accept_prob.dtype),
                log_shrinkage_target=log_shrinkage_target,
                exploration_shrinkage=tf.cast(
                    self.parameters['exploration_shrinkage'], dtype),
                step_count_smoothing=tf.cast(
                    self.parameters['step_count_smoothing'], dtype),
                decay_rate=tf.cast(self.parameters['decay_rate'], dtype),
                error_sum=error_sum,
                log_averaging_step=log_averaging_step,
                new_step_size=step_size)
def slice_batch_shape_tensor(base_shape, event_ndims):
  base_shape = ps.convert_to_shape_tensor(base_shape, dtype_hint=np.int32)
  event_ndims = ps.convert_to_shape_tensor(event_ndims, dtype_hint=np.int32)
  base_rank = ps.rank_from_shape(base_shape)
  return base_shape[:(base_rank -
                      # Don't try to slice away more ndims than the parameter
                      # actually has, if that's fewer than `event_ndims` (i.e.,
                      # if it relies on broadcasting).
                      ps.minimum(event_ndims, base_rank))]
Esempio n. 5
0
def _truncate_shape_tensor(shape, ndims_to_truncate):
    shape = ps.convert_to_shape_tensor(shape, dtype_hint=np.int32)
    ndims_to_truncate = ps.convert_to_shape_tensor(ndims_to_truncate,
                                                   dtype_hint=np.int32)
    base_rank = ps.rank_from_shape(shape)
    return shape[:(
        base_rank -
        # Don't try to slice away more ndims than the parameter
        # actually has, if that's fewer than `event_ndims` (i.e.,
        # if it relies on broadcasting).
        ps.minimum(ndims_to_truncate, base_rank))]
 def initialize_state_history(state):
   """Build an initial state history by replicating the initial state."""
   with tf.name_scope('initialize_state_history'):
     initial_state_histories = tf.nest.map_structure(
         lambda x: tf.broadcast_to(  # pylint: disable=g-long-lambda
             tf.expand_dims(x, ps.minimum(ps.rank(x), 1)),
             ps.concat([ps.shape(x)[:1],
                        [history_size],
                        ps.shape(x)[1:]], axis=0)),
         state)
     return (joint_distribution_util
             .independent_joint_distribution_from_structure(
                 _wrap_as_distributions(initial_state_histories)))
    def preprocess_state(init_state):
      """Initial preprocessing at Stage 0."""
      dimension = ps.reduce_sum([
          ps.reduce_prod(ps.shape(x)[1:]) for x in init_state])
      likelihood_log_prob = likelihood_log_prob_fn(*init_state)

      # Default to the optimal for normal distributed targets.
      # TODO(b/152412213): Revisit this default parameter.
      scale_start = (
          tf.constant(2.38 ** 2, dtype=likelihood_log_prob.dtype) /
          tf.constant(dimension, dtype=likelihood_log_prob.dtype))
      # TODO(b/152412213): Enable batch of batches style by using non-scalar
      # inverse_temperature
      inverse_temperature = tf.zeros([], dtype=likelihood_log_prob.dtype)
      scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(scale_start, 1.)
      kernel = make_kernel_fn(
          _make_tempered_target_log_prob_fn(
              prior_log_prob_fn,
              likelihood_log_prob_fn,
              inverse_temperature),
          init_state,
          scalings,
          seed=seed_stream())
      pkr = kernel.bootstrap_results(current_state)
      _, kernel_target_log_prob = gather_mh_like_result(pkr)

      particle_info = ParticleInfo(
          log_accept_prob=ps.zeros_like(likelihood_log_prob),
          log_scalings=tf.math.log(scalings),
          tempered_log_prob=kernel_target_log_prob,
          likelihood_log_prob=likelihood_log_prob,
      )

      return SMCResults(
          num_steps=tf.convert_to_tensor(
              max_num_steps, dtype=tf.int32, name='num_steps'),
          inverse_temperature=inverse_temperature,
          log_marginal_likelihood=tf.constant(
              0., dtype=likelihood_log_prob.dtype),
          particle_info=particle_info
      )
def initial_value_of_masked_time_series(time_series_tensor, broadcast_mask):
    """Get the first unmasked entry of each time series in the batch.

  If a batch element has no unmasked entries, the corresponding return value
  for that element is undefined.

  Args:
    time_series_tensor: float `Tensor` of shape `batch_shape + [num_timesteps]`.
    broadcast_mask: bool `Tensor` of same shape as `time_series`.
  Returns:
    initial_values: float `Tensor` of shape `batch_shape`.
  """

    num_timesteps = ps.shape(time_series_tensor)[-1]

    # Compute the index of the first unmasked entry for each series in the batch.
    unmasked_negindices = (ps.cast(~broadcast_mask, np.int32) *
                           ps.range(num_timesteps, 0, -1))
    first_unmasked_indices = num_timesteps - ps.reduce_max(unmasked_negindices,
                                                           axis=-1)
    # Avoid out-of-bounds errors if all indices are masked.
    safe_unmasked_indices = ps.minimum(first_unmasked_indices,
                                       num_timesteps - 1)

    batch_dims = tensorshape_util.rank(safe_unmasked_indices.shape)
    if batch_dims is None:
        raise NotImplementedError(
            'Cannot compute initial values of a masked time series with'
            'dynamic rank.')  # `batch_gather` requires static rank

    # Extract the initial value for each series in the batch.
    return tf.squeeze(
        tf.gather(params=time_series_tensor,
                  indices=safe_unmasked_indices[..., np.newaxis],
                  batch_dims=batch_dims,
                  axis=-1),
        # Since we've gathered exactly one step from the
        # `num_timesteps` axis, we can remove that axis entirely.
        axis=-1)
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'simple_step_size_adaptation',
                                    'one_step')):
            # Set the step_size.
            inner_results = self.step_size_setter_fn(
                previous_kernel_results.inner_results,
                previous_kernel_results.new_step_size)

            # Step the inner kernel.
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results)

            # Get the new step size.
            log_accept_prob = self.log_accept_prob_getter_fn(new_inner_results)
            log_target_accept_prob = tf.math.log(
                tf.cast(previous_kernel_results.target_accept_prob,
                        dtype=log_accept_prob.dtype))

            state_parts = tf.nest.flatten(current_state)
            step_size = self.step_size_getter_fn(new_inner_results)
            step_size_parts = tf.nest.flatten(step_size)
            log_accept_prob_rank = prefer_static.rank(log_accept_prob)

            new_step_size_parts = []
            for step_size_part, state_part in zip(step_size_parts,
                                                  state_parts):
                # Compute new step sizes for each step size part. If step size part has
                # smaller rank than the corresponding state part, then the difference is
                # averaged away in the log accept prob.
                #
                # Example:
                #
                # state_part has shape      [2, 3, 4, 5]
                # step_size_part has shape     [1, 4, 1]
                # log_accept_prob has shape [2, 3, 4]
                #
                # Since step size has 1 rank fewer than the state, we reduce away the
                # leading dimension of log_accept_prob to get a Tensor with shape [3,
                # 4]. Next, since log_accept_prob must broadcast into step_size_part on
                # the left, we reduce the dimensions where their shapes differ, to get a
                # Tensor with shape [1, 4], which now is compatible with the leading
                # dimensions of step_size_part.
                #
                # There is a subtlety here in that step_size_parts might be a length-1
                # list, which means that we'll be "structure-broadcasting" it for all
                # the state parts (see logic in, e.g., hmc.py). In this case we must
                # assume that that the lone step size provided broadcasts with the event
                # dims of each state part. This means that either step size has no
                # dimensions corresponding to chain dimensions, or all states are of the
                # same shape. For the former, we want to reduce over all chain
                # dimensions. For the later, we want to use the same logic as in the
                # non-structure-broadcasted case.
                #
                # It turns out we can compute the reduction dimensions for both cases
                # uniformly by taking the rank of any state part. This obviously works
                # in the second case (where all state ranks are the same). In the first
                # case, all state parts have the rank L + D_i + B, where L is the rank
                # of log_accept_prob, D_i is the non-shared dimensions amongst all
                # states, and B are the shared dimensions of all the states, which are
                # equal to the step size. When we subtract B, we will always get a
                # number >= L, which means we'll get the full reduction we want.
                num_reduce_dims = prefer_static.minimum(
                    log_accept_prob_rank,
                    prefer_static.rank(state_part) -
                    prefer_static.rank(step_size_part))
                reduced_log_accept_prob = reduce_logmeanexp(
                    log_accept_prob, axis=prefer_static.range(num_reduce_dims))
                # reduced_log_accept_prob must broadcast into step_size_part on the
                # left, so we do an additional reduction over dimensions where their
                # shapes differ.
                reduce_indices = get_differing_dims(reduced_log_accept_prob,
                                                    step_size_part)
                reduced_log_accept_prob = reduce_logmeanexp(
                    reduced_log_accept_prob,
                    axis=reduce_indices,
                    keepdims=True)

                one_plus_adaptation_rate = 1. + tf.cast(
                    previous_kernel_results.adaptation_rate,
                    dtype=step_size_part.dtype)
                new_step_size_part = mcmc_util.choose(
                    reduced_log_accept_prob > log_target_accept_prob,
                    step_size_part * one_plus_adaptation_rate,
                    step_size_part / one_plus_adaptation_rate)

                new_step_size_parts.append(
                    tf.where(
                        previous_kernel_results.step <
                        self.num_adaptation_steps, new_step_size_part,
                        step_size_part))
            new_step_size = tf.nest.pack_sequence_as(step_size,
                                                     new_step_size_parts)

            return new_state, previous_kernel_results._replace(
                inner_results=new_inner_results,
                step=1 + previous_kernel_results.step,
                new_step_size=new_step_size)
def sample_sequential_monte_carlo(
        prior_log_prob_fn,
        likelihood_log_prob_fn,
        current_state,
        max_num_steps=25,
        max_stage=100,
        make_kernel_fn=make_rwmh_kernel_fn,
        tuning_fn=simple_heuristic_tuning,
        make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn,
        ess_threshold_ratio=0.5,
        parallel_iterations=10,
        seed=None,
        name=None):
    """Runs Sequential Monte Carlo to sample from the posterior distribution.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'prior' distribution:

    `exp(prior_log_prob_fn(x))`

  and the target 'posterior' distribution:

    `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`,

  by mutating a collection of MC samples (i.e., particles). The approach is also
  known as Particle Filter in some literature. The current implemenetation is
  largely based on  Del Moral et al [1], which adapts the tempering sequence
  adaptively (base on the effective sample size) and the scaling of the mutation
  kernel (base on the sample covariance of the particles) at each stage.

  Args:
    prior_log_prob_fn: Python callable that returns the log density of the
      prior distribution.
    likelihood_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the likelihood distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    max_num_steps: The maximum number of kernel transition steps in one mutation
      of the MC samples. Note that the actual number of steps in one mutation is
      tuned during sampling and likely lower than the max_num_step.
    max_stage: Integer number of the stage for increasing the temperature
      from 0 to 1.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_sequential_monte_carlo` creates a new `target_log_prob_fn`
      which is an interpolation between the supplied `target_log_prob_fn` and
      `proposal_log_prob_fn`; it is this interpolated function which is used as
      an argument to `make_kernel_fn`.
    tuning_fn: Python `callable` which takes the number of steps, the log
      scaling, and the log acceptance ratio from the last mutation and output
      the number of steps and log scaling for the next mutation.
    make_tempered_target_log_prob_fn: Python `callable` that takes the
      `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures`
      and creates a `target_log_prob_fn` `callable` that pass to
      `make_kernel_fn`.
    ess_threshold_ratio: Target ratio for effective sample size.
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    seed: Python integer or TFP seedstream to seed the random number generator.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_sequential_monte_carlo').

  Returns:
    n_stage: Number of the mutation stage SMC ran.
    final_state: `Tensor` or Python `list` of `Tensor`s representing the
      final state(s) of the Markov chain(s). The output are the posterior
      samples.
    final_kernel_results: `collections.namedtuple` of internal calculations used
      to advance the chain.

  #### References

  [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential
      Monte Carlo method for approximate Bayesian computation.
      _Statistics and Computing_, 22.5(1009-1020), 2012.

  """

    with tf.name_scope(name or 'sample_sequential_monte_carlo'):
        seed_stream = SeedStream(seed, salt='smc_seed')

        unwrap_state_list = not tf.nest.is_nested(current_state)
        if unwrap_state_list:
            current_state = [current_state]
        current_state = [
            tf.convert_to_tensor(s, dtype_hint=tf.float32)
            for s in current_state
        ]

        # Initial preprocessing at Stage 0
        likelihood_log_prob = likelihood_log_prob_fn(*current_state)

        likelihood_rank = ps.rank(likelihood_log_prob)
        dimension = ps.reduce_sum([
            ps.reduce_prod(ps.shape(x)[likelihood_rank:])
            for x in current_state
        ])

        # We infer the particle shapes from the resulting likelihood:
        # [num_particles, b1, ..., bN]
        particle_shape = ps.shape(likelihood_log_prob)
        num_particles, batch_shape = particle_shape[0], particle_shape[1:]
        effective_sample_size_threshold = tf.cast(
            num_particles * ess_threshold_ratio, tf.int32)

        # TODO(b/152412213): Revisit this default parameter.
        # Default to the optimal scaling of a random walk kernel for a d-dimensional
        # normal distributed targets: 2.38 ** 2 / d.
        # For more detail see:
        # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of
        # random walk Metropolis algorithms. _The annals of applied probability_.
        # 1997;7(1):110-20.
        scale_start = (tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) /
                       tf.constant(dimension, dtype=likelihood_log_prob.dtype))

        inverse_temperature = tf.zeros(batch_shape,
                                       dtype=likelihood_log_prob.dtype)
        scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(
            scale_start, 1.)
        kernel = make_kernel_fn(make_tempered_target_log_prob_fn(
            prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature),
                                current_state,
                                scalings,
                                seed=seed_stream)
        pkr = kernel.bootstrap_results(current_state)
        _, kernel_target_log_prob = gather_mh_like_result(pkr)

        particle_info = ParticleInfo(
            log_accept_prob=ps.zeros_like(likelihood_log_prob),
            log_scalings=tf.math.log(scalings),
            tempered_log_prob=kernel_target_log_prob,
            likelihood_log_prob=likelihood_log_prob,
        )

        current_pkr = SMCResults(
            num_steps=tf.convert_to_tensor(max_num_steps,
                                           dtype=tf.int32,
                                           name='num_steps'),
            inverse_temperature=inverse_temperature,
            log_marginal_likelihood=tf.zeros_like(inverse_temperature),
            particle_info=particle_info)

        def update_weights_temperature(inverse_temperature,
                                       likelihood_log_prob):
            """Calculate the next inverse temperature and update weights."""
            likelihood_diff = likelihood_log_prob - tf.reduce_max(
                likelihood_log_prob, axis=0)

            def _body_fn(new_beta, upper_beta, lower_beta, eff_size,
                         log_weights):
                """One iteration of the temperature and weight update."""
                new_beta = (lower_beta + upper_beta) / 2.0
                log_weights = (new_beta -
                               inverse_temperature) * likelihood_diff
                log_weights_norm = tf.math.log_softmax(log_weights, axis=0)
                eff_size = tf.cast(
                    tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm,
                                                     axis=0)), tf.int32)
                upper_beta = tf.where(
                    eff_size < effective_sample_size_threshold, new_beta,
                    upper_beta)
                lower_beta = tf.where(
                    eff_size < effective_sample_size_threshold, lower_beta,
                    new_beta)
                return new_beta, upper_beta, lower_beta, eff_size, log_weights

            def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_):  # pylint: disable=unused-argument
                # TODO(junpenglao): revisit threshold below to be dtype specific.
                threshold = 1e-6
                return (tf.math.reduce_any(upper_beta - lower_beta > threshold)
                        & tf.math.reduce_any(
                            eff_size != effective_sample_size_threshold))

            (new_beta, upper_beta, lower_beta, eff_size,
             log_weights) = tf.while_loop(  # pylint: disable=unused-variable
                 cond=_cond_fn,
                 body=_body_fn,
                 loop_vars=(tf.zeros_like(inverse_temperature),
                            tf.fill(ps.shape(inverse_temperature),
                                    tf.constant(2, inverse_temperature.dtype)),
                            inverse_temperature,
                            tf.zeros_like(inverse_temperature, dtype=tf.int32),
                            tf.zeros_like(likelihood_diff)),
                 parallel_iterations=parallel_iterations)

            log_weights = tf.where(new_beta < 1., log_weights,
                                   (1. - inverse_temperature) *
                                   likelihood_diff)
            marginal_loglike_ = reduce_logmeanexp(
                (new_beta - inverse_temperature) * likelihood_log_prob, axis=0)
            new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.)

            return marginal_loglike_, new_inverse_temperature, log_weights

        def mutate(current_state, log_scalings, num_steps,
                   inverse_temperature):
            """Mutate the state using a Transition kernel."""
            with tf.name_scope('mutate_states'):
                scalings = tf.exp(log_scalings)
                kernel = make_kernel_fn(make_tempered_target_log_prob_fn(
                    prior_log_prob_fn, likelihood_log_prob_fn,
                    inverse_temperature),
                                        current_state,
                                        scalings,
                                        seed=seed_stream)
                pkr = kernel.bootstrap_results(current_state)
                kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)

                def mutate_onestep(i, state, pkr, log_accept_prob_sum):
                    next_state, next_kernel_results = kernel.one_step(
                        state, pkr)
                    kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)
                    log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.)
                    log_accept_prob_sum = log_add_exp(log_accept_prob_sum,
                                                      log_accept_prob)
                    return i + 1, next_state, next_kernel_results, log_accept_prob_sum

                (
                    _, next_state, next_kernel_results, log_accept_prob_sum
                ) = tf.while_loop(
                    cond=lambda i, *args: i < num_steps,
                    body=mutate_onestep,
                    loop_vars=(
                        tf.zeros([], dtype=tf.int32),
                        current_state,
                        pkr,
                        # we accumulate the acceptance probability in log space.
                        tf.fill(
                            ps.shape(kernel_log_accept_ratio),
                            tf.constant(-np.inf,
                                        kernel_log_accept_ratio.dtype))),
                    parallel_iterations=parallel_iterations)
                _, kernel_target_log_prob = gather_mh_like_result(
                    next_kernel_results)
                avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log(
                    tf.cast(num_steps + 1, log_accept_prob_sum.dtype))
                return (next_state, avg_log_accept_prob_per_particle,
                        kernel_target_log_prob)

        # One SMC steps.
        def smc_body_fn(stage, state, smc_kernel_result):
            """Run one stage of SMC with constant temperature."""
            (new_marginal, new_inv_temperature,
             log_weights) = update_weights_temperature(
                 smc_kernel_result.inverse_temperature,
                 smc_kernel_result.particle_info.likelihood_log_prob)
            # TODO(b/152412213) Use a tf.scan to better collect debug info.
            if PRINT_DEBUG:
                tf.print(
                    'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:',
                    smc_kernel_result.num_steps, 'accept:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_accept_prob,
                            axis=0)), 'scaling:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_scalings,
                            axis=0)))
            (resampled_state,
             resampled_particle_info), _ = resample_particle_and_info(
                 (state, smc_kernel_result.particle_info),
                 log_weights,
                 seed=seed_stream)
            next_num_steps, next_log_scalings = tuning_fn(
                smc_kernel_result.num_steps,
                resampled_particle_info.log_scalings,
                resampled_particle_info.log_accept_prob)
            # Skip tuning at stage 0.
            next_num_steps = tf.where(stage == 0, smc_kernel_result.num_steps,
                                      next_num_steps)
            next_log_scalings = tf.where(stage == 0,
                                         resampled_particle_info.log_scalings,
                                         next_log_scalings)
            next_num_steps = tf.clip_by_value(next_num_steps, 2, max_num_steps)

            next_state, log_accept_prob, tempered_log_prob = mutate(
                resampled_state, next_log_scalings, next_num_steps,
                new_inv_temperature)
            next_pkr = SMCResults(
                num_steps=next_num_steps,
                inverse_temperature=new_inv_temperature,
                log_marginal_likelihood=(
                    new_marginal + smc_kernel_result.log_marginal_likelihood),
                particle_info=ParticleInfo(
                    log_accept_prob=log_accept_prob,
                    log_scalings=next_log_scalings,
                    tempered_log_prob=tempered_log_prob,
                    likelihood_log_prob=likelihood_log_prob_fn(*next_state),
                ))
            return stage + 1, next_state, next_pkr

        (n_stage, final_state, final_kernel_results) = tf.while_loop(
            cond=lambda i, state, pkr: (  # pylint: disable=g-long-lambda
                (i < max_stage) & tf.reduce_any(pkr.inverse_temperature < 1.)),
            body=smc_body_fn,
            loop_vars=(tf.zeros([],
                                dtype=tf.int32), current_state, current_pkr),
            parallel_iterations=parallel_iterations)
        if unwrap_state_list:
            final_state = final_state[0]
        return n_stage, final_state, final_kernel_results
Esempio n. 11
0
    def _bootstrap_from_inner_results(self, init_state, inner_results):
        step_size = self.step_size_getter_fn(inner_results)

        log_accept_prob = self.log_accept_prob_getter_fn(inner_results)

        state_parts = tf.nest.flatten(init_state)
        step_size_parts = tf.nest.flatten(step_size)

        if self._parameters['shrinkage_target'] is None:
            shrinkage_target_parts = [None] * len(step_size_parts)
        else:
            shrinkage_target_parts = tf.nest.flatten(
                self._parameters['shrinkage_target'])
            if len(shrinkage_target_parts) not in [1, len(step_size_parts)]:
                raise ValueError(
                    '`shrinkage_target` should be a Tensor or list of tensors of '
                    'same length as `step_size`. Found len(`step_size`) = {} and '
                    'len(shrinkage_target) = {}'.format(
                        len(step_size_parts), len(shrinkage_target_parts)))
            if len(shrinkage_target_parts) < len(step_size_parts):
                shrinkage_target_parts *= len(step_size_parts)

        dtype = dtype_util.common_dtype(step_size_parts, tf.float32)
        error_sum, log_averaging_step, log_shrinkage_target = [], [], []
        for state_part, step_size_part, shrinkage_target_part in zip(
                state_parts, step_size_parts, shrinkage_target_parts):
            num_reduce_dims = ps.minimum(
                ps.rank(log_accept_prob),
                ps.rank(state_part) - ps.rank(step_size_part))
            reduced_log_accept_prob = reduce_logmeanexp(
                log_accept_prob,
                axis=ps.range(num_reduce_dims),
                experimental_named_axis=self.
                experimental_reduce_chain_axis_names)
            reduce_indices = get_differing_dims(reduced_log_accept_prob,
                                                step_size_part)
            reduced_log_accept_prob = reduce_logmeanexp(
                reduced_log_accept_prob, axis=reduce_indices, keepdims=True)
            error_sum.append(
                tf.zeros_like(reduced_log_accept_prob, dtype=dtype))
            log_averaging_step.append(
                tf.zeros_like(step_size_part, dtype=dtype))

            if shrinkage_target_part is None:
                log_shrinkage_target.append(
                    float(np.log(10.)) + tf.math.log(step_size_part))
            else:
                log_shrinkage_target.append(
                    tf.math.log(tf.cast(shrinkage_target_part, dtype)))

        return DualAveragingStepSizeAdaptationResults(
            inner_results=inner_results,
            step=tf.constant(0, dtype=tf.int32),
            target_accept_prob=tf.cast(self.parameters['target_accept_prob'],
                                       log_accept_prob.dtype),
            log_shrinkage_target=log_shrinkage_target,
            exploration_shrinkage=tf.cast(
                self.parameters['exploration_shrinkage'], dtype),
            step_count_smoothing=tf.cast(
                self.parameters['step_count_smoothing'], dtype),
            decay_rate=tf.cast(self.parameters['decay_rate'], dtype),
            error_sum=error_sum,
            log_averaging_step=log_averaging_step,
            new_step_size=step_size,
            num_adaptation_steps=tf.cast(self.num_adaptation_steps,
                                         dtype=tf.int32))
Esempio n. 12
0
    def _one_step_part(self,
                       step_size,
                       state,
                       error_sum,
                       log_averaging_step,
                       log_shrinkage_target,
                       log_accept_prob_rank=None,
                       log_accept_prob=None,
                       target_accept_prob=None,
                       previous_kernel_results=None):
        """Compute new step sizes for each step size part.

    If step size part has smaller rank than the corresponding state part, then
    the difference is averaged away in the log accept prob.

    Example:

      state_part has shape      [2, 3, 4, 5]
      step_size_part has shape     [1, 4, 1]
      log_accept_prob has shape [2, 3, 4]

    Since step size has 1 rank fewer than the state, we reduce away the leading
    dimension of `log_accept_prob` to get a Tensor with shape [3, 4]. Next,
    since `log_accept_prob` must broadcast into step_size_part on the left, we
    reduce the dimensions where their shapes differ, to get a Tensor with shape
    [1, 4], which now is compatible with the leading dimensions of
    step_size_part.

    There is a subtlety here in that `step_size_parts` might be a length-1 list,
    which means that we'll be "structure-broadcasting" it for all the state
    parts (see logic in, e.g., hmc.py). In this case we must assume that that
    the lone step size provided broadcasts with the event dims of each state
    part. This means that either step size has no dimensions corresponding to
    chain dimensions, or all states are of the same shape. For the former, we
    want to reduce over all chain dimensions. For the later, we want to use
    the same logic as in the non-structure-broadcasted case.

    It turns out we can compute the reduction dimensions for both cases
    uniformly by taking the rank of any state part. This obviously works in
    the second case (where all state ranks are the same). In the first case,
    all state parts have the rank L + D_i + B, where L is the rank of
    log_accept_prob, D_i is the non-shared dimensions amongst all states, and
    B are the shared dimensions of all the states, which are equal to the step
    size. When we subtract B, we will always get a number >= L, which means
    we'll get the full reduction we want.

    Args:
      step_size: Previous step's step_size.
      state: Previous step's state value.
      error_sum: Previous step's error accumulator.
      log_averaging_step: Previous step's log_averaging_step.
      log_shrinkage_target: Floating point scalar `Tensor`. Logarithm of value
        the exploration step size is biased towards.
      log_accept_prob_rank: Rank of log_accept_prob.
      log_accept_prob: Floating point scalar `Tensor`. Target accept
        probability.
      target_accept_prob: A floating point `Tensor` representing desired
        acceptance probability. Must be a positive number less than 1.
      previous_kernel_results: Results struct from previous step.

    Returns:
      new_step_size: Updated `step_size`.
      new_log_averaging_step: Updated `log_averaging_step`.
      new_error_sum: Updated `error_sum`.
    """
        num_reduce_dims = ps.minimum(log_accept_prob_rank,
                                     (ps.rank(state) - ps.rank(step_size)))
        reduced_log_accept_prob = self.reduce_fn(
            log_accept_prob,
            axis=ps.range(num_reduce_dims),
            keepdims=False,
            experimental_named_axis=self.experimental_reduce_chain_axis_names)

        # reduced_log_accept_prob must broadcast into step_size on the
        # left, so we do an additional reduction over dimensions where their
        # shapes differ.
        reduce_indices = get_differing_dims(reduced_log_accept_prob, step_size)
        reduced_log_accept_prob = self.reduce_fn(
            reduced_log_accept_prob,
            axis=reduce_indices,
            keepdims=True,
            experimental_named_axis=self.experimental_reduce_chain_axis_names)
        new_error_sum = (error_sum + target_accept_prob -
                         tf.math.exp(reduced_log_accept_prob))
        num_ones_to_pad = ps.maximum(
            ps.rank(log_shrinkage_target) - ps.rank(new_error_sum), 0)
        new_error_sum_extend = tf.reshape(new_error_sum,
                                          shape=ps.pad(
                                              ps.shape(new_error_sum),
                                              paddings=[[0, num_ones_to_pad]],
                                              constant_values=1))

        step_count_smoothing = previous_kernel_results.step_count_smoothing
        step = tf.cast(previous_kernel_results.step,
                       step_count_smoothing.dtype) + 1.
        soft_t = step_count_smoothing + step

        new_log_step = (log_shrinkage_target - (
            (tf.cast(new_error_sum_extend, step.dtype) * tf.math.sqrt(step)) /
            (soft_t * previous_kernel_results.exploration_shrinkage)))

        eta = step**(-previous_kernel_results.decay_rate)
        new_log_averaging_step = (eta * new_log_step +
                                  (1. - eta) * log_averaging_step)

        # - If still adapting, return an exploring step size,
        # - If just finished, return the averaging step size
        # - Otherwise, do not update
        num_adaptation_steps = previous_kernel_results.num_adaptation_steps
        step = previous_kernel_results.step + 1
        new_step_size = tf.where(
            step < num_adaptation_steps, tf.math.exp(new_log_step),
            tf.where(step > num_adaptation_steps, step_size,
                     tf.math.exp(new_log_averaging_step)))
        new_log_averaging_step = tf.where(step > num_adaptation_steps,
                                          log_averaging_step,
                                          new_log_averaging_step)
        return new_step_size, new_log_averaging_step, new_error_sum
Esempio n. 13
0
def _hmc_like_log_accept_prob_getter_fn(kernel_results):
    return prefer_static.minimum(kernel_results.log_accept_ratio, 0.)
Esempio n. 14
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)
            kernel_shape = ps.shape(kernel)
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  kernel_shape[-1],
                                                  batch_shape, event_shape)

                idx, shape = im2row_index((xh * sh + sum(pad_values[0]),
                                           xw * sw + sum(pad_values[1]), c_in),
                                          block_shape=filter_shape,
                                          slice_step=(1, 1),
                                          dilations=dilations,
                                          dtype=dtype,
                                          transpose=True)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                # Interleave the rows and columns of the input with rows and columns of
                # zeros equal to the number of strides.
                x_half_dilated = tf.concat([
                    tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)],
                                       axis=0),
                             dtype=input_dtype),
                    tf.reshape(x,
                               shape=ps.concat(
                                   [batch_shape, (xh * xw, 1, c_in)], axis=0))
                ],
                                           axis=-2)
                y = tf.reshape(x_half_dilated,
                               shape=ps.concat(
                                   [batch_shape, (xh, 1, xw * sw, c_in)],
                                   axis=0))

                x = tf.reshape(tf.concat([
                    tf.zeros(ps.concat(
                        [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0),
                             dtype=input_dtype), y
                ],
                                         axis=-3),
                               shape=ps.concat(
                                   [batch_shape, (xh * sh, xw * sw, c_in)],
                                   axis=0))

                truncations = -ps.minimum(ps.cast(paddings, dtype=tf.int32), 0)
                truncate_start, truncate_end = ps.unstack(truncations, axis=1)
                x_truncate = tf.slice(x,
                                      begin=truncate_start,
                                      size=ps.shape(x) -
                                      (truncate_start + truncate_end))

                x_pad = tf.pad(x_truncate,
                               paddings=ps.maximum(paddings, 0),
                               constant_values=0)

                flat_shape = ps.pad(batch_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape),
                                   indices=idx,
                                   axis=-1)
                im_x = tf.reshape(flat_x,
                                  shape=ps.concat([batch_shape, shape],
                                                  axis=0))
                return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
Esempio n. 15
0
def _get_search_direction(state):
  """Computes the search direction to follow at the current state.

  On the `k`-th iteration of the main L-BFGS algorithm, the state has collected
  the most recent `m` correction pairs in position_deltas and gradient_deltas,
  where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`.

  Assuming these, the code below is an implementation of the L-BFGS two-loop
  recursion algorithm given by [Nocedal and Wright(2006)][1]:

  ```None
    q_direction = objective_gradient
    for i in reversed(range(m)):  # First loop.
      inv_rho[i] = gradient_deltas[i]^T * position_deltas[i]
      alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i]
      q_direction = q_direction - alpha[i] * gradient_deltas[i]

    kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] /
                              gradient_deltas[-1]^T * gradient_deltas[-1])
    r_direction = kth_inv_hessian_factor * I * q_direction

    for i in range(m):  # Second loop.
      beta = gradient_deltas[i]^T * r_direction / inv_rho[i]
      r_direction = r_direction + position_deltas[i] * (alpha[i] - beta)

    return -r_direction  # Approximates - H_k * objective_gradient.
  ```

  Args:
    state: A `LBfgsOptimizerResults` tuple with the current state of the
      search procedure.

  Returns:
    A real `Tensor` of the same shape as the `state.position`. The direction
    along which to perform line search.
  """
  # The number of correction pairs that have been collected so far.
  num_elements = ps.minimum(
      state.num_iterations,  # TODO(b/162733947): Change loop state -> closure.
      ps.shape(state.position_deltas)[0])

  def _two_loop_algorithm():
    """L-BFGS two-loop algorithm."""
    # Correction pairs are always appended to the end, so only the latest
    # `num_elements` vectors have valid position/gradient deltas. Vectors
    # that haven't been computed yet are zero.
    position_deltas = state.position_deltas
    gradient_deltas = state.gradient_deltas

    # Pre-compute all `inv_rho[i]`s.
    inv_rhos = tf.reduce_sum(
        gradient_deltas * position_deltas, axis=-1)

    def first_loop(acc, args):
      _, q_direction = acc
      position_delta, gradient_delta, inv_rho = args
      alpha = tf.math.divide_no_nan(
          tf.reduce_sum(position_delta * q_direction, axis=-1), inv_rho)
      direction_delta = alpha[..., tf.newaxis] * gradient_delta
      return (alpha, q_direction - direction_delta)

    # Run first loop body computing and collecting `alpha[i]`s, while also
    # computing the updated `q_direction` at each step.
    zero = tf.zeros_like(inv_rhos[-num_elements])
    alphas, q_directions = tf.scan(
        first_loop, [position_deltas, gradient_deltas, inv_rhos],
        initializer=(zero, state.objective_gradient), reverse=True)

    # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse
    # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`.
    gamma_k = inv_rhos[-1] / tf.reduce_sum(
        gradient_deltas[-1] * gradient_deltas[-1], axis=-1)
    r_direction = gamma_k[..., tf.newaxis] * q_directions[-num_elements]

    def second_loop(r_direction, args):
      alpha, position_delta, gradient_delta, inv_rho = args
      beta = tf.math.divide_no_nan(
          tf.reduce_sum(gradient_delta * r_direction, axis=-1), inv_rho)
      direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta
      return r_direction + direction_delta

    # Finally, run second loop body computing the updated `r_direction` at each
    # step.
    r_directions = tf.scan(
        second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos],
        initializer=r_direction)
    return -r_directions[-1]

  return ps.cond(ps.equal(num_elements, 0),
                 lambda: -state.objective_gradient,
                 _two_loop_algorithm)
Esempio n. 16
0
    def vectorized_fn(*args):
        """Vectorized version of `fn` that accepts arguments of any rank."""
        with tf.name_scope(name or 'make_rank_polymorphic'):
            # If we got a single value for core_ndims, tile it across all args.
            core_ndims_structure = (core_ndims if tf.nest.is_nested(core_ndims)
                                    else tf.nest.map_structure(
                                        lambda _: core_ndims, args))

            # Build flat lists of all argument parts and their corresponding core
            # ndims.
            flat_core_ndims = tf.nest.flatten(core_ndims_structure)
            flat_args = nest.flatten_up_to(core_ndims_structure,
                                           args,
                                           check_types=False)

            # Filter to only the `Tensor`-valued args (taken to be those with `None`
            # values for `core_ndims`). Other args will be passed through to `fn`
            # unmodified.
            (vectorized_arg_core_ndims, vectorized_args,
             fn_of_vectorized_args) = _lock_in_non_vectorized_args(
                 fn,
                 arg_structure=core_ndims_structure,
                 flat_core_ndims=flat_core_ndims,
                 flat_args=flat_args)

            # `vectorized_map` requires all inputs to have a single, common batch
            # dimension `[n]`. So we broadcast all input parts to a common
            # batch shape, then flatten it down to a single dimension.
            vectorized_arg_shapes = [ps.shape(arg) for arg in vectorized_args]

            vectorized_arg_actual_core_ndims = []
            batch_shapes, core_shapes = [], []
            for (arg_shape, core_nd) in zip(vectorized_arg_shapes,
                                            vectorized_arg_core_ndims):
                arg_nd = ps.rank_from_shape(arg_shape)
                # Shrink 'core' ndims of rank-deficient args. This guarantees that
                # `batch_ndims` is always nonnegative.
                actual_core_nd = ps.minimum(arg_nd, core_nd)
                vectorized_arg_actual_core_ndims.append(actual_core_nd)

                batch_ndims = arg_nd - actual_core_nd
                batch_shapes.append(arg_shape[:batch_ndims])
                core_shapes.append(arg_shape[batch_ndims:])

            # Flatten all of the batch dimensions into one.
            broadcast_batch_shape = (functools.reduce(ps.broadcast_shape,
                                                      batch_shapes, []))
            n = ps.cast(ps.reduce_prod(broadcast_batch_shape), tf.int32)

            static_n = tf.get_static_value(n)
            if static_n == 1:
                # We can bypass `vectorized_map` if the batch shape is `[]`, `[1]`,
                # `[1, 1]`, etc., just by flattening to batch shape `[]`.
                result_batch_dims = 0
                batched_result = fn_of_vectorized_args(
                    tf.nest.map_structure(lambda x, nd: tf.reshape(
                        x,
                        ps.shape(x)[ps.rank(x) - nd:]),
                                          vectorized_args,
                                          vectorized_arg_actual_core_ndims,
                                          check_types=False))
            else:
                # Pad all input parts to the common shape, then flatten
                # into the single leading dimension `[n]`.
                # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap
                # when batch rank is static so that we can exploit broadcasting.
                broadcast_vectorized_args = [
                    tf.broadcast_to(
                        part,
                        ps.concat([broadcast_batch_shape, core_shape], axis=0))
                    for (part, core_shape) in zip(vectorized_args, core_shapes)
                ]
                vectorized_args_with_flattened_batch_dim = [
                    tf.reshape(part, ps.concat([[n], core_shape], axis=0))
                    for (part, core_shape
                         ) in zip(broadcast_vectorized_args, core_shapes)
                ]
                result_batch_dims = 1
                batched_result = tf.vectorized_map(
                    fn_of_vectorized_args,
                    vectorized_args_with_flattened_batch_dim)

            # Unflatten any `Tensor`s in the result.
            unflatten = lambda x: tf.reshape(
                x,
                ps.concat(
                    [  # pylint: disable=g-long-lambda
                        broadcast_batch_shape,
                        ps.shape(x)[result_batch_dims:]
                    ],
                    axis=0))
            result = tf.nest.map_structure(lambda x: unflatten(x)
                                           if tf.is_tensor(x) else x,
                                           batched_result,
                                           expand_composites=True)
        return result