Esempio n. 1
0
    def smc_body_fn(stage, state, smc_kernel_result):
      """Run one stage of SMC with constant temperature."""
      (
          new_marginal,
          new_inv_temperature,
          log_weights
      ) = update_weights_temperature(
          smc_kernel_result.inverse_temperature,
          smc_kernel_result.particle_info.likelihood_log_prob)
      # TODO(b/152412213) Use a tf.scan to better collect debug info.
      if PRINT_DEBUG:
        tf.print(
            'Stage:', stage,
            'Beta:', new_inv_temperature,
            'n_steps:', smc_kernel_result.num_steps,
            'accept:', tf.exp(reduce_logmeanexp(
                smc_kernel_result.particle_info.log_accept_prob, axis=0)),
            'scaling:', tf.exp(reduce_logmeanexp(
                smc_kernel_result.particle_info.log_scalings, axis=0))
            )
      (resampled_state,
       resampled_particle_info), _ = weighted_resampling.resample(
           particles=(state, smc_kernel_result.particle_info),
           log_weights=log_weights,
           resample_fn=resample_fn,
           seed=seed_stream)
      next_num_steps, next_log_scalings = tuning_fn(
          smc_kernel_result.num_steps,
          resampled_particle_info.log_scalings,
          resampled_particle_info.log_accept_prob)
      # Skip tuning at stage 0.
      next_num_steps = tf.where(stage == 0,
                                smc_kernel_result.num_steps,
                                next_num_steps)
      next_log_scalings = tf.where(stage == 0,
                                   resampled_particle_info.log_scalings,
                                   next_log_scalings)
      next_num_steps = tf.clip_by_value(
          next_num_steps, min_num_steps, max_num_steps)

      next_state, log_accept_prob, tempered_log_prob = mutate(
          resampled_state,
          next_log_scalings,
          next_num_steps,
          new_inv_temperature)
      next_pkr = SMCResults(
          num_steps=next_num_steps,
          inverse_temperature=new_inv_temperature,
          log_marginal_likelihood=(new_marginal +
                                   smc_kernel_result.log_marginal_likelihood),
          particle_info=ParticleInfo(
              log_accept_prob=log_accept_prob,
              log_scalings=next_log_scalings,
              tempered_log_prob=tempered_log_prob,
              likelihood_log_prob=likelihood_log_prob_fn(*next_state),
          ))
      return stage + 1, next_state, next_pkr
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'dual_averaging_step_size_adaptation',
                                    'bootstrap_results')):
            inner_results = self.inner_kernel.bootstrap_results(init_state)
            step_size = self.step_size_getter_fn(inner_results)

            log_accept_prob = self.log_accept_prob_getter_fn(inner_results)

            state_parts = tf.nest.flatten(init_state)
            step_size_parts = tf.nest.flatten(step_size)
            dtype = dtype_util.common_dtype(step_size_parts, tf.float32)
            error_sum, log_averaging_step, log_shrinkage_target = [], [], []
            for state_part, step_size_part in zip(state_parts,
                                                  step_size_parts):
                num_reduce_dims = prefer_static.minimum(
                    prefer_static.rank(log_accept_prob),
                    prefer_static.rank(state_part) -
                    prefer_static.rank(step_size_part))
                reduced_log_accept_prob = reduce_logmeanexp(
                    log_accept_prob, axis=prefer_static.range(num_reduce_dims))
                reduce_indices = get_differing_dims(reduced_log_accept_prob,
                                                    step_size_part)
                reduced_log_accept_prob = reduce_logmeanexp(
                    reduced_log_accept_prob,
                    axis=reduce_indices,
                    keepdims=True)
                error_sum.append(
                    tf.zeros_like(reduced_log_accept_prob, dtype=dtype))
                log_averaging_step.append(
                    tf.zeros_like(step_size_part, dtype=dtype))

                if self._parameters['shrinkage_target'] is None:
                    log_shrinkage_target.append(
                        float(np.log(10.)) + tf.math.log(step_size_part))
                else:
                    log_shrinkage_target.append(
                        tf.math.log(
                            tf.cast(self._parameters['shrinkage_target'],
                                    dtype)))

            return DualAveragingStepSizeAdaptationResults(
                inner_results=inner_results,
                step=tf.constant(0, dtype=tf.int32),
                target_accept_prob=tf.cast(
                    self.parameters['target_accept_prob'],
                    log_accept_prob.dtype),
                log_shrinkage_target=log_shrinkage_target,
                exploration_shrinkage=tf.cast(
                    self.parameters['exploration_shrinkage'], dtype),
                step_count_smoothing=tf.cast(
                    self.parameters['step_count_smoothing'], dtype),
                decay_rate=tf.cast(self.parameters['decay_rate'], dtype),
                error_sum=error_sum,
                log_averaging_step=log_averaging_step,
                new_step_size=step_size)
def simple_heuristic_tuning(num_steps,
                            log_scalings,
                            log_accept_prob,
                            optimal_accept=0.234,
                            target_accept_prob=0.99,
                            name=None):
  """Tune the number of steps and scaling of one mutation.

  # TODO(b/152412213): Better explanation of the heuristic used here.

  This is a simple heuristic for tuning the number of steps of the next
  mutation, as well as the scaling of a transition kernel (e.g., step size in
  HMC, scale of a Normal proposal in RWMH) using the acceptance probability from
  the previous mutation stage in SMC.

  Args:
    num_steps: The initial number of steps for the next mutation, to be tune.
    log_scalings: The log of the scale of the proposal kernel
    log_accept_prob: The log of the acceptance ratio from the last mutation.
    optimal_accept: Optimal acceptance ratio for a Transitional Kernel. Default
      value is 0.234 (Optimal for Random Walk Metropolis kernel).
    target_accept_prob: Target acceptance probability at the end of one mutation
      step. Default value: 0.99
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None`.

  Returns:
    num_steps: The number of steps for the next mutation.
    new_log_scalings: The log of the scale of the proposal kernel for the next
      mutation.

  """
  with tf.name_scope(name or 'simple_heuristic_tuning'):
    optimal_accept = tf.constant(optimal_accept, dtype=log_accept_prob.dtype)
    target_accept_prob = tf.constant(
        target_accept_prob, dtype=log_accept_prob.dtype)
    log_half_constant = tf.constant(np.log(.5), dtype=log_scalings.dtype)

    avg_log_scalings = reduce_logmeanexp(log_scalings)
    avg_log_accept_prob = reduce_logmeanexp(log_accept_prob)

    avg_log_scaling_target = avg_log_scalings + (
        tf.exp(avg_log_accept_prob) - optimal_accept)
    new_log_scalings = log_half_constant + log_add_exp(
        avg_log_scaling_target,
        log_scalings + (tf.exp(log_accept_prob) - optimal_accept)
        )

    num_replica = ps.size0(log_accept_prob)
    num_proposed = tf.cast(
        num_replica * num_steps, dtype=avg_log_accept_prob.dtype)
    log_avg_accept = tf.math.maximum(-tf.math.log(num_proposed),
                                     avg_log_accept_prob)
    num_steps = tf.cast(
        tf.math.log1p(-target_accept_prob) / log1mexp(log_avg_accept),
        dtype=num_steps.dtype)
    return num_steps, new_log_scalings
        def update_weights_temperature(inverse_temperature,
                                       likelihood_log_prob):
            """Calculate the next inverse temperature and update weights."""
            likelihood_diff = likelihood_log_prob - tf.reduce_max(
                likelihood_log_prob, axis=0)

            def _body_fn(new_beta, upper_beta, lower_beta, eff_size,
                         log_weights):
                """One iteration of the temperature and weight update."""
                new_beta = (lower_beta + upper_beta) / 2.0
                log_weights = (new_beta -
                               inverse_temperature) * likelihood_diff
                log_weights_norm = tf.math.log_softmax(log_weights, axis=0)
                eff_size = tf.cast(
                    tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm,
                                                     axis=0)), tf.int32)
                upper_beta = tf.where(
                    eff_size < effective_sample_size_threshold, new_beta,
                    upper_beta)
                lower_beta = tf.where(
                    eff_size < effective_sample_size_threshold, lower_beta,
                    new_beta)
                return new_beta, upper_beta, lower_beta, eff_size, log_weights

            def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_):  # pylint: disable=unused-argument
                # TODO(junpenglao): revisit threshold below to be dtype specific.
                threshold = 1e-6
                return (tf.math.reduce_any(upper_beta - lower_beta > threshold)
                        & tf.math.reduce_any(
                            eff_size != effective_sample_size_threshold))

            (new_beta, upper_beta, lower_beta, eff_size,
             log_weights) = tf.while_loop(  # pylint: disable=unused-variable
                 cond=_cond_fn,
                 body=_body_fn,
                 loop_vars=(tf.zeros_like(inverse_temperature),
                            tf.fill(ps.shape(inverse_temperature),
                                    tf.constant(2, inverse_temperature.dtype)),
                            inverse_temperature,
                            tf.zeros_like(inverse_temperature, dtype=tf.int32),
                            tf.zeros_like(likelihood_diff)),
                 parallel_iterations=parallel_iterations)

            log_weights = tf.where(new_beta < 1., log_weights,
                                   (1. - inverse_temperature) *
                                   likelihood_diff)
            marginal_loglike_ = reduce_logmeanexp(
                (new_beta - inverse_temperature) * likelihood_log_prob, axis=0)
            new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.)

            return marginal_loglike_, new_inverse_temperature, log_weights
Esempio n. 5
0
def log_average_probs(logits, sample_axis=0, event_axis=None, keepdims=False,
                      validate_args=False, name=None):
  """Computes `log(average(to_probs(logits)))` in a numerically stable manner.

  The meaning of `to_probs` is controlled by the `event_axis` argument. When
  `event_axis` is `None`, `to_probs = tf.math.sigmoid` and otherwise
  `to_probs = lambda x: tf.math.log_softmax(x, axis=event_axis)`.

  `sample_axis` and `event_axis` should have a null intersection. This
  requirement is always verified when `validate_args` is `True`.

  Args:
    logits: A `float` `Tensor` representing logits.
    sample_axis: Scalar or vector `Tensor` designating axis holding samples, or
      `None` (meaning all axis hold samples).
      Default value: `0` (leftmost dimension).
    event_axis: Scalar or vector `Tensor` designating the axis representing
      categorical logits.
      Default value: `None` (i.e., Bernoulli logits).
    keepdims:  Boolean.  Whether to keep the sample axis as singletons.
      Default value: `False` (i.e., squeeze the reduced dimensions).
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
      Default value: `False` (i.e., do not validate args).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., `'log_average_probs'`).

  Returns:
    log_avg_probs: The natural log of the average of probs computed from logits.
  """
  with tf.name_scope(name or 'average_sigmoid'):
    logits = tf.convert_to_tensor(logits, dtype_hint=tf.float32, name='logits')
    if sample_axis is not None:
      sample_axis = tf.convert_to_tensor(
          sample_axis, dtype_hint=tf.int32, name='sample_axis')
    if event_axis is not None:
      event_axis = tf.convert_to_tensor(
          event_axis, dtype_hint=tf.int32, name='event_axis')
    if event_axis is None:
      # log(sigmoid(x)) = log(1 / (1 + exp(-x))) = -log1p(exp(-x)) = -sp(-x)
      log_probs = -tf.math.softplus(-logits)
    else:
      sample_axis, event_axis = _log_average_probs_process_args(
          logits, validate_args, sample_axis, event_axis)
      with tf.control_dependencies(_log_average_probs_maybe_check_args(
          sample_axis, event_axis, validate_args)):
        log_probs = _log_softmax(logits, axis=event_axis)
    return reduce_logmeanexp(log_probs, axis=sample_axis, keepdims=keepdims)
        def update_weights_temperature(inverse_temperature,
                                       likelihood_log_prob):
            """Calculate the next inverse temperature and update weights."""

            likelihood_diff = likelihood_log_prob - tf.reduce_max(
                likelihood_log_prob)

            def _body_fn(new_beta, upper_beta, lower_beta, eff_size,
                         log_weights):
                """One iteration of the temperature and weight update."""
                new_beta = (lower_beta + upper_beta) / 2.0
                log_weights = (new_beta -
                               inverse_temperature) * likelihood_diff
                log_weights_norm = (log_weights -
                                    tf.math.reduce_logsumexp(log_weights))
                eff_size = tf.cast(
                    tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm)),
                    tf.int32)
                upper_beta = tf.where(
                    eff_size < effective_sample_size_threshold, new_beta,
                    upper_beta)
                lower_beta = tf.where(
                    eff_size < effective_sample_size_threshold, lower_beta,
                    new_beta)
                return new_beta, upper_beta, lower_beta, eff_size, log_weights

            (new_beta, upper_beta, lower_beta, eff_size,
             log_weights) = tf.while_loop(  # pylint: disable=unused-variable
                 cond=lambda new_beta, upper_beta, lower_beta, eff_size, *_:  # pylint: disable=g-long-lambda
                 (upper_beta - lower_beta > 1e-6) &
                 (eff_size != effective_sample_size_threshold),
                 body=_body_fn,
                 loop_vars=(tf.zeros_like(inverse_temperature),
                            tf.cast(2.0, inverse_temperature.dtype),
                            inverse_temperature, tf.cast(0, tf.int32),
                            tf.zeros_like(likelihood_diff)),
                 parallel_iterations=parallel_iterations)

            log_weights = tf.where(new_beta < 1., log_weights,
                                   (1. - inverse_temperature) *
                                   likelihood_diff)
            marginal_loglike_ = reduce_logmeanexp(
                (new_beta - inverse_temperature) * likelihood_log_prob)

            return marginal_loglike_, tf.clip_by_value(new_beta, 0.,
                                                       1.), log_weights
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'simple_step_size_adaptation',
                                    'one_step')):
            # Set the step_size.
            inner_results = self.step_size_setter_fn(
                previous_kernel_results.inner_results,
                previous_kernel_results.new_step_size)

            # Step the inner kernel.
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results)

            # Get the new step size.
            log_accept_prob = self.log_accept_prob_getter_fn(new_inner_results)
            log_target_accept_prob = tf.math.log(
                tf.cast(previous_kernel_results.target_accept_prob,
                        dtype=log_accept_prob.dtype))

            state_parts = tf.nest.flatten(current_state)
            step_size = self.step_size_getter_fn(new_inner_results)
            step_size_parts = tf.nest.flatten(step_size)
            log_accept_prob_rank = prefer_static.rank(log_accept_prob)

            new_step_size_parts = []
            for step_size_part, state_part in zip(step_size_parts,
                                                  state_parts):
                # Compute new step sizes for each step size part. If step size part has
                # smaller rank than the corresponding state part, then the difference is
                # averaged away in the log accept prob.
                #
                # Example:
                #
                # state_part has shape      [2, 3, 4, 5]
                # step_size_part has shape     [1, 4, 1]
                # log_accept_prob has shape [2, 3, 4]
                #
                # Since step size has 1 rank fewer than the state, we reduce away the
                # leading dimension of log_accept_prob to get a Tensor with shape [3,
                # 4]. Next, since log_accept_prob must broadcast into step_size_part on
                # the left, we reduce the dimensions where their shapes differ, to get a
                # Tensor with shape [1, 4], which now is compatible with the leading
                # dimensions of step_size_part.
                #
                # There is a subtlety here in that step_size_parts might be a length-1
                # list, which means that we'll be "structure-broadcasting" it for all
                # the state parts (see logic in, e.g., hmc.py). In this case we must
                # assume that that the lone step size provided broadcasts with the event
                # dims of each state part. This means that either step size has no
                # dimensions corresponding to chain dimensions, or all states are of the
                # same shape. For the former, we want to reduce over all chain
                # dimensions. For the later, we want to use the same logic as in the
                # non-structure-broadcasted case.
                #
                # It turns out we can compute the reduction dimensions for both cases
                # uniformly by taking the rank of any state part. This obviously works
                # in the second case (where all state ranks are the same). In the first
                # case, all state parts have the rank L + D_i + B, where L is the rank
                # of log_accept_prob, D_i is the non-shared dimensions amongst all
                # states, and B are the shared dimensions of all the states, which are
                # equal to the step size. When we subtract B, we will always get a
                # number >= L, which means we'll get the full reduction we want.
                num_reduce_dims = prefer_static.minimum(
                    log_accept_prob_rank,
                    prefer_static.rank(state_part) -
                    prefer_static.rank(step_size_part))
                reduced_log_accept_prob = reduce_logmeanexp(
                    log_accept_prob, axis=prefer_static.range(num_reduce_dims))
                # reduced_log_accept_prob must broadcast into step_size_part on the
                # left, so we do an additional reduction over dimensions where their
                # shapes differ.
                reduce_indices = get_differing_dims(reduced_log_accept_prob,
                                                    step_size_part)
                reduced_log_accept_prob = reduce_logmeanexp(
                    reduced_log_accept_prob,
                    axis=reduce_indices,
                    keepdims=True)

                one_plus_adaptation_rate = 1. + tf.cast(
                    previous_kernel_results.adaptation_rate,
                    dtype=step_size_part.dtype)
                new_step_size_part = mcmc_util.choose(
                    reduced_log_accept_prob > log_target_accept_prob,
                    step_size_part * one_plus_adaptation_rate,
                    step_size_part / one_plus_adaptation_rate)

                new_step_size_parts.append(
                    tf.where(
                        previous_kernel_results.step <
                        self.num_adaptation_steps, new_step_size_part,
                        step_size_part))
            new_step_size = tf.nest.pack_sequence_as(step_size,
                                                     new_step_size_parts)

            return new_state, previous_kernel_results._replace(
                inner_results=new_inner_results,
                step=1 + previous_kernel_results.step,
                new_step_size=new_step_size)
Esempio n. 8
0
    def _bootstrap_from_inner_results(self, init_state, inner_results):
        step_size = self.step_size_getter_fn(inner_results)

        log_accept_prob = self.log_accept_prob_getter_fn(inner_results)

        state_parts = tf.nest.flatten(init_state)
        step_size_parts = tf.nest.flatten(step_size)

        if self._parameters['shrinkage_target'] is None:
            shrinkage_target_parts = [None] * len(step_size_parts)
        else:
            shrinkage_target_parts = tf.nest.flatten(
                self._parameters['shrinkage_target'])
            if len(shrinkage_target_parts) not in [1, len(step_size_parts)]:
                raise ValueError(
                    '`shrinkage_target` should be a Tensor or list of tensors of '
                    'same length as `step_size`. Found len(`step_size`) = {} and '
                    'len(shrinkage_target) = {}'.format(
                        len(step_size_parts), len(shrinkage_target_parts)))
            if len(shrinkage_target_parts) < len(step_size_parts):
                shrinkage_target_parts *= len(step_size_parts)

        dtype = dtype_util.common_dtype(step_size_parts, tf.float32)
        error_sum, log_averaging_step, log_shrinkage_target = [], [], []
        for state_part, step_size_part, shrinkage_target_part in zip(
                state_parts, step_size_parts, shrinkage_target_parts):
            num_reduce_dims = ps.minimum(
                ps.rank(log_accept_prob),
                ps.rank(state_part) - ps.rank(step_size_part))
            reduced_log_accept_prob = reduce_logmeanexp(
                log_accept_prob,
                axis=ps.range(num_reduce_dims),
                experimental_named_axis=self.
                experimental_reduce_chain_axis_names)
            reduce_indices = get_differing_dims(reduced_log_accept_prob,
                                                step_size_part)
            reduced_log_accept_prob = reduce_logmeanexp(
                reduced_log_accept_prob, axis=reduce_indices, keepdims=True)
            error_sum.append(
                tf.zeros_like(reduced_log_accept_prob, dtype=dtype))
            log_averaging_step.append(
                tf.zeros_like(step_size_part, dtype=dtype))

            if shrinkage_target_part is None:
                log_shrinkage_target.append(
                    float(np.log(10.)) + tf.math.log(step_size_part))
            else:
                log_shrinkage_target.append(
                    tf.math.log(tf.cast(shrinkage_target_part, dtype)))

        return DualAveragingStepSizeAdaptationResults(
            inner_results=inner_results,
            step=tf.constant(0, dtype=tf.int32),
            target_accept_prob=tf.cast(self.parameters['target_accept_prob'],
                                       log_accept_prob.dtype),
            log_shrinkage_target=log_shrinkage_target,
            exploration_shrinkage=tf.cast(
                self.parameters['exploration_shrinkage'], dtype),
            step_count_smoothing=tf.cast(
                self.parameters['step_count_smoothing'], dtype),
            decay_rate=tf.cast(self.parameters['decay_rate'], dtype),
            error_sum=error_sum,
            log_averaging_step=log_averaging_step,
            new_step_size=step_size,
            num_adaptation_steps=tf.cast(self.num_adaptation_steps,
                                         dtype=tf.int32))
Esempio n. 9
0
    def _one_step_part(self,
                       step_size,
                       state,
                       error_sum,
                       log_averaging_step,
                       shrinkage_target,
                       log_accept_prob_rank=None,
                       log_accept_prob=None,
                       target_accept_prob=None,
                       previous_kernel_results=None):
        """Compute new step sizes for each step size part.

    If step size part has smaller rank than the corresponding state part, then
    the difference is averaged away in the log accept prob.

    Example:

      state_part has shape      [2, 3, 4, 5]
      step_size_part has shape     [1, 4, 1]
      log_accept_prob has shape [2, 3, 4]

    Since step size has 1 rank fewer than the state, we reduce away the leading
    dimension of `log_accept_prob` to get a Tensor with shape [3, 4]. Next,
    since `log_accept_prob` must broadcast into step_size_part on the left, we
    reduce the dimensions where their shapes differ, to get a Tensor with shape
    [1, 4], which now is compatible with the leading dimensions of
    step_size_part.

    There is a subtlety here in that `step_size_parts` might be a length-1 list,
    which means that we'll be "structure-broadcasting" it for all the state
    parts (see logic in, e.g., hmc.py). In this case we must assume that that
    the lone step size provided broadcasts with the event dims of each state
    part. This means that either step size has no dimensions corresponding to
    chain dimensions, or all states are of the same shape. For the former, we
    want to reduce over all chain dimensions. For the later, we want to use
    the same logic as in the non-structure-broadcasted case.

    It turns out we can compute the reduction dimensions for both cases
    uniformly by taking the rank of any state part. This obviously works in
    the second case (where all state ranks are the same). In the first case,
    all state parts have the rank L + D_i + B, where L is the rank of
    log_accept_prob, D_i is the non-shared dimensions amongst all states, and
    B are the shared dimensions of all the states, which are equal to the step
    size. When we subtract B, we will always get a number >= L, which means
    we'll get the full reduction we want.

    Args:
      step_size: Previous step's step_size.
      state: Previous step's state value.
      error_sum: Previous step's error accumulator.
      log_averaging_step: Previous step's log_averaging_step.
      shrinkage_target: Floating point scalar `Tensor`. Arbitrary value the
        exploration step size is biased towards.
      log_accept_prob_rank: Rank of log_accept_prob.
      log_accept_prob: Floating point scalar `Tensor`. Target accept
        probability.
      target_accept_prob: A floating point `Tensor` representing desired
        acceptance probability. Must be a positive number less than 1.
      previous_kernel_results: Results struct from previous step.

    Returns:
      new_step_size: Updated `step_size`.
      new_log_averaging_step: Updated `log_averaging_step`.
      new_error_sum: Updated `error_sum`.
    """
        num_reduce_dims = prefer_static.minimum(
            log_accept_prob_rank,
            (prefer_static.rank(state) - prefer_static.rank(step_size)))
        reduced_log_accept_prob = reduce_logmeanexp(
            log_accept_prob, axis=prefer_static.range(num_reduce_dims))

        # reduced_log_accept_prob must broadcast into step_size on the
        # left, so we do an additional reduction over dimensions where their
        # shapes differ.
        reduce_indices = _get_differing_dims(reduced_log_accept_prob,
                                             step_size)
        reduced_log_accept_prob = reduce_logmeanexp(reduced_log_accept_prob,
                                                    axis=reduce_indices,
                                                    keepdims=True)
        new_error_sum = (error_sum + target_accept_prob -
                         tf.math.exp(reduced_log_accept_prob))
        num_ones_to_pad = prefer_static.maximum(
            prefer_static.rank(shrinkage_target) -
            prefer_static.rank(new_error_sum), 0)
        new_error_sum_extend = tf.reshape(
            new_error_sum,
            shape=tf.pad(prefer_static.shape(new_error_sum),
                         paddings=[[0, num_ones_to_pad]],
                         constant_values=1))

        step_count_smoothing = previous_kernel_results.step_count_smoothing
        step = tf.cast(previous_kernel_results.step,
                       step_count_smoothing.dtype) + 1.
        soft_t = step_count_smoothing + step

        new_log_step = (shrinkage_target - (
            (tf.cast(new_error_sum_extend, step.dtype) * tf.math.sqrt(step)) /
            (soft_t * previous_kernel_results.exploration_shrinkage)))

        eta = step**(-previous_kernel_results.decay_rate)
        new_log_averaging_step = (eta * new_log_step +
                                  (1. - eta) * log_averaging_step)

        # - If still adapting, return an exploring step size,
        # - If just finished, return the averaging step size
        # - Otherwise, do not update
        new_step_size = tf.where(
            previous_kernel_results.step < self.num_adaptation_steps,
            tf.math.exp(new_log_step),
            tf.where(previous_kernel_results.step > self.num_adaptation_steps,
                     step_size, tf.math.exp(new_log_averaging_step)))
        new_log_averaging_step = tf.where(
            previous_kernel_results.step > self.num_adaptation_steps,
            log_averaging_step, new_log_averaging_step)
        return new_step_size, new_log_averaging_step, new_error_sum