Beispiel #1
0
def step_size_adaptation_step(
    state: 'StepSizeAdaptationState',
    log_accept_ratio: 'fun_mc.FloatTensor',
    num_adaptation_steps: 'Optional[fun_mc.IntTensor]',
    target_accept_prob: 'fun_mc.FloatTensor' = 0.8,
    adaptation_rate: 'fun_mc.FloatTensor' = 0.05,
    adaptation_rate_decay_power: 'fun_mc.FloatTensor' = 0.1,
    averaging_window_steps: 'fun_mc.IntTensor' = 100,
    min_log_accept_prob: 'fun_mc.FloatTensor' = np.log(1e-5),
    reduce_fn: 'Callable[[fun_mc.FloatTensor], fun_mc.FloatTensor]' = (
        tfp.math.reduce_logmeanexp),
) -> 'Tuple[StepSizeAdaptationState, StepSizeAdaptationExtra]':
  """Gradient based step size adaptation using ADAM.

  Given the `log_accept_ratio` statistic from an Metropolis-Hastings algorithm,
  this adapts the step size hyperparameter to make that statistic hit some
  `target_accept_prob`. The step size can be extracted using the `step_size`
  method on the state structure.

  Args:
    state: `StepSizeAdaptationState`
    log_accept_ratio: Float tensor. The logarithm of the accept ratio.
    num_adaptation_steps: Number of adaptation steps, can be `None`.
    target_accept_prob: Target acceptance probability.
    adaptation_rate: Step size adaptation rate.
    adaptation_rate_decay_power:  Power of the polynomial schedule for
      `trajectory_length_factor` warmup.
    averaging_window_steps: Number of steps to compute the averaged step size.
    min_log_accept_prob: Clamps acceptance probability to this value for
      numerical stability.
    reduce_fn: A function that reduces `log_accept_ratio` in log-space. By
      default, this computes the log-mean-exp.

  Returns:
    step_size_adaptation_state: `StepSizeAdaptationState`
    step_size_adaptation_extra: `StepSizeAdaptationExtra`
  """
  dtype = log_accept_ratio.dtype
  adaptation_rate = tf.convert_to_tensor(adaptation_rate, dtype=dtype)
  target_accept_prob = tf.convert_to_tensor(target_accept_prob, dtype=dtype)
  adaptation_rate_decay_power = tf.convert_to_tensor(
      adaptation_rate_decay_power, dtype=dtype)
  min_log_accept_prob = tf.fill(log_accept_ratio.shape,
                                tf.constant(min_log_accept_prob, dtype))

  log_accept_prob = tf.minimum(log_accept_ratio, tf.zeros([], dtype))
  log_accept_prob = tf.maximum(log_accept_prob, min_log_accept_prob)
  log_accept_prob = tf.where(
      tf.math.is_finite(log_accept_prob), log_accept_prob, min_log_accept_prob)
  accept_prob = tf.exp(reduce_fn(log_accept_prob))

  loss_fn = fun_mc.make_surrogate_loss_fn(lambda _:  # pylint: disable=g-long-lambda
                                          (target_accept_prob - accept_prob, ()
                                          ))

  if num_adaptation_steps is not None:
    adaptation_rate = _polynomial_decay(
        step=state.step,
        step_size=adaptation_rate,
        decay_steps=num_adaptation_steps,
        final_step_size=0.,
        power=adaptation_rate_decay_power,
    )

  # Optimize step size.
  opt_state, opt_extra = fun_mc.adam_step(state.opt_state, loss_fn,
                                          adaptation_rate)

  # Do iterate averaging.
  old_rms_state = state.rms_state
  rms_state, _ = fun_mc.running_mean_step(
      old_rms_state,
      tf.exp(opt_state.state),
      window_size=averaging_window_steps)

  if num_adaptation_steps is not None:
    rms_state = util.map_tree(
        lambda n, o: tf.where(state.step < num_adaptation_steps, n, o),
        rms_state, old_rms_state)

  state = state._replace(
      opt_state=opt_state, rms_state=rms_state, step=state.step + 1)
  extra = StepSizeAdaptationExtra(opt_extra=opt_extra, accept_prob=accept_prob)
  return state, extra
Beispiel #2
0
def stochastic_gradient_ascent_hmc_step(
    sga_hmc_state: 'StochasticGradientAscentHMCState',
    scalar_step_size: 'fun_mc.FloatNest',
    criterion_fn: 'Callable[[fun_mc.State, fun_mc.State, fun_mc.FloatTensor, '
    'fun_mc.FloatTensor], Tuple[fun_mc.FloatTensor, Any]]',
    trajectory_length_adaptation_rate: 'fun_mc.FloatTensor' = 0.05,
    trajectory_length_sample_fn: 'Callable[[Any, fun_mc.IntTensor, Any], '
    'fun_mc.FloatTensor]' = (default_trajectory_length_sample),
    trajectory_length_constrain_fn: 'Callable[[Any], Any]' = (
        default_trajectory_length_constrain),
    adam_kwargs: 'Mapping[str, Any]' = immutabledict({
        'beta_1': 0.,
        'beta_2': 0.5
    }),
    averaging_window_steps: 'fun_mc.IntTensor' = 100,
    adapt: 'fun_mc.BooleanTensor' = True,
    seed: 'Any' = None,
    **hmc_kwargs: 'Mapping[str, Any]',
):
  """Stochastic gradient ascent Hamiltonian Monte Carlo step.

  SGA-HMC posits an existence of a parameteric distribution over trajectory
  lengths. It then uses stochastic gradients to adapt those parameters by
  maximizing the expected value of some criterion.

  The gradients are computed by the use of
  `hamiltonian_monte_carlo_with_state_grads_step` and then using Monte-Carlo
  averages across separate Markov Chains and Markov Chain iterates.

  ChEES [1] criterion is the prototypical example.

  The trajectory distribution is parameterized via `trajectory_length_sample_fn`
  and `trajectory_length_params_constrain_fn`. While `adapt` is `False`, the
  parameters are adapted using Adam (controlled using
  `trajectory_length_adaptation_rate` and `adam_kwargs`). When `adapt` is
  `False`, averaged parameters are used, which have been computed via an
  exponential moving while `adapt` was `True`. The degree of averaging is
  controlled via `averaging_window_steps`. The parameters that were actually
  used for this step are returned in the `trajectory_length_params` field in the
  `sga_hmc_extra` return.

  Args:
    sga_hmc_state: `StochasticGradientAscentHMCState`
    scalar_step_size: Scalar step size (see
      `hamiltonian_monte_carlo_with_state_grads_step` for details).
    criterion_fn: Callable with signature `(previous_state, proposed_state,
      accept_prob, trajectory_length) -> (criterion, criterion_extra)`. The
      criterion to maximize.
    trajectory_length_adaptation_rate: Adaption rate for the trajectory length
      parameters.
    trajectory_length_sample_fn: Callable with signature
      `(trajectory_length_params, step, seed) -> trajectory_length`. Used to
      sample a new trajectory length.
    trajectory_length_constrain_fn: Used to constrain the trajectory length
      parameters by projecting them into the allowed set.
    adam_kwargs: Additional keyword arguments for Adam optimizer used to adapt
      the trajectory length parameters.
    averaging_window_steps: Window size for averaging the trajectory parameters.
      See `fun_mc.running_mean_step` for the meaning of this argument.
    adapt: Whether to adapt the trajectory parameters and whether to use the
      adapted parameters or the averaged parameters when sampling the trajectory
      for this step.
    seed: PRNG seed.
    **hmc_kwargs: Passed to `hamiltonian_monte_carlo_with_state_grads_step`.

  Returns:
    sga_hmc_state: `StochasticGradientAscentHMCState`.
    sga_hmc_extra: `StochasticGradientAscentHMCExtra`.

  #### References

  [1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme
       for Setting Trajectory Lengths in Hamiltonian Monte Carlo. In
       preparation.
  """

  seed, sample_seed, hmc_seed = util.split_seed(seed, 3)

  @util.named_call
  def loss_fn(*args, **kwargs):
    rmean_params = sga_hmc_state.trajectory_length_params_rmean_state.mean
    adapting_params = fun_mc.recover_state_from_args(args, kwargs, rmean_params)
    params = fun_mc.choose(adapt, adapting_params, rmean_params)
    trajectory_length = trajectory_length_sample_fn(params, sga_hmc_state.step,
                                                    sample_seed)

    hmc_state, hmc_extra = hamiltonian_monte_carlo_with_state_grads_step(
        sga_hmc_state.hmc_state,
        trajectory_length=trajectory_length,
        scalar_step_size=scalar_step_size,
        seed=hmc_seed,
        **hmc_kwargs)

    accept_prob = tf.exp(
        tf.minimum(
            tf.zeros_like(hmc_extra.hmc_extra.log_accept_ratio),
            hmc_extra.hmc_extra.log_accept_ratio))
    accept_prob = tf.where(
        tf.math.is_finite(accept_prob), accept_prob, tf.zeros_like(accept_prob))

    criterion, criterion_extra = criterion_fn(
        sga_hmc_state.hmc_state.state,
        hmc_extra.proposed_state,
        accept_prob,
        # + step_size because we're effectively doing floor(traj / step_size)
        # when computing the number of leapfrog steps.
        trajectory_length + scalar_step_size,
    )

    return -criterion, (hmc_state, hmc_extra, criterion, criterion_extra,
                        params)

  # Adapt trajectory.
  trajectory_length_params_opt_state, trajectory_length_params_opt_extra = fun_mc.adam_step(
      sga_hmc_state.trajectory_length_params_opt_state,
      loss_fn,
      learning_rate=trajectory_length_adaptation_rate,
      **adam_kwargs,
  )

  (hmc_state, hmc_extra, criterion, criterion_extra,
   trajectory_length_params) = trajectory_length_params_opt_extra.loss_extra

  # Constrain trajectory params.
  trajectory_length_params_opt_state = fun_mc.choose(
      adapt, trajectory_length_params_opt_state,
      sga_hmc_state.trajectory_length_params_opt_state)
  constrained_trajectory_length_params = trajectory_length_constrain_fn(
      trajectory_length_params_opt_state.state)
  trajectory_length_params_opt_state = trajectory_length_params_opt_state._replace(
      state=constrained_trajectory_length_params)

  # Update the running mean for trajectory params.
  trajectory_length_params_rmean_state, _ = fun_mc.running_mean_step(
      sga_hmc_state.trajectory_length_params_rmean_state,
      trajectory_length_params_opt_state.state,
      window_size=averaging_window_steps)
  trajectory_length_params_rmean_state = fun_mc.choose(
      adapt, trajectory_length_params_rmean_state,
      sga_hmc_state.trajectory_length_params_rmean_state)

  sga_hmc_state = sga_hmc_state._replace(
      hmc_state=hmc_state,
      step=sga_hmc_state.step + 1,
      trajectory_length_params_rmean_state=trajectory_length_params_rmean_state,
      trajectory_length_params_opt_state=trajectory_length_params_opt_state,
  )
  sga_hmc_extra = StochasticGradientAscentHMCExtra(
      hmc_extra=hmc_extra.hmc_extra,
      num_integrator_steps=hmc_extra.num_integrator_steps,
      trajectory_length_params_opt_extra=trajectory_length_params_opt_extra,
      criterion=criterion,
      criterion_extra=criterion_extra,
      trajectory_length_params=trajectory_length_params,
  )
  return sga_hmc_state, sga_hmc_extra