def step_size_adaptation_step( state: 'StepSizeAdaptationState', log_accept_ratio: 'fun_mc.FloatTensor', num_adaptation_steps: 'Optional[fun_mc.IntTensor]', target_accept_prob: 'fun_mc.FloatTensor' = 0.8, adaptation_rate: 'fun_mc.FloatTensor' = 0.05, adaptation_rate_decay_power: 'fun_mc.FloatTensor' = 0.1, averaging_window_steps: 'fun_mc.IntTensor' = 100, min_log_accept_prob: 'fun_mc.FloatTensor' = np.log(1e-5), reduce_fn: 'Callable[[fun_mc.FloatTensor], fun_mc.FloatTensor]' = ( tfp.math.reduce_logmeanexp), ) -> 'Tuple[StepSizeAdaptationState, StepSizeAdaptationExtra]': """Gradient based step size adaptation using ADAM. Given the `log_accept_ratio` statistic from an Metropolis-Hastings algorithm, this adapts the step size hyperparameter to make that statistic hit some `target_accept_prob`. The step size can be extracted using the `step_size` method on the state structure. Args: state: `StepSizeAdaptationState` log_accept_ratio: Float tensor. The logarithm of the accept ratio. num_adaptation_steps: Number of adaptation steps, can be `None`. target_accept_prob: Target acceptance probability. adaptation_rate: Step size adaptation rate. adaptation_rate_decay_power: Power of the polynomial schedule for `trajectory_length_factor` warmup. averaging_window_steps: Number of steps to compute the averaged step size. min_log_accept_prob: Clamps acceptance probability to this value for numerical stability. reduce_fn: A function that reduces `log_accept_ratio` in log-space. By default, this computes the log-mean-exp. Returns: step_size_adaptation_state: `StepSizeAdaptationState` step_size_adaptation_extra: `StepSizeAdaptationExtra` """ dtype = log_accept_ratio.dtype adaptation_rate = tf.convert_to_tensor(adaptation_rate, dtype=dtype) target_accept_prob = tf.convert_to_tensor(target_accept_prob, dtype=dtype) adaptation_rate_decay_power = tf.convert_to_tensor( adaptation_rate_decay_power, dtype=dtype) min_log_accept_prob = tf.fill(log_accept_ratio.shape, tf.constant(min_log_accept_prob, dtype)) log_accept_prob = tf.minimum(log_accept_ratio, tf.zeros([], dtype)) log_accept_prob = tf.maximum(log_accept_prob, min_log_accept_prob) log_accept_prob = tf.where( tf.math.is_finite(log_accept_prob), log_accept_prob, min_log_accept_prob) accept_prob = tf.exp(reduce_fn(log_accept_prob)) loss_fn = fun_mc.make_surrogate_loss_fn(lambda _: # pylint: disable=g-long-lambda (target_accept_prob - accept_prob, () )) if num_adaptation_steps is not None: adaptation_rate = _polynomial_decay( step=state.step, step_size=adaptation_rate, decay_steps=num_adaptation_steps, final_step_size=0., power=adaptation_rate_decay_power, ) # Optimize step size. opt_state, opt_extra = fun_mc.adam_step(state.opt_state, loss_fn, adaptation_rate) # Do iterate averaging. old_rms_state = state.rms_state rms_state, _ = fun_mc.running_mean_step( old_rms_state, tf.exp(opt_state.state), window_size=averaging_window_steps) if num_adaptation_steps is not None: rms_state = util.map_tree( lambda n, o: tf.where(state.step < num_adaptation_steps, n, o), rms_state, old_rms_state) state = state._replace( opt_state=opt_state, rms_state=rms_state, step=state.step + 1) extra = StepSizeAdaptationExtra(opt_extra=opt_extra, accept_prob=accept_prob) return state, extra
def stochastic_gradient_ascent_hmc_step( sga_hmc_state: 'StochasticGradientAscentHMCState', scalar_step_size: 'fun_mc.FloatNest', criterion_fn: 'Callable[[fun_mc.State, fun_mc.State, fun_mc.FloatTensor, ' 'fun_mc.FloatTensor], Tuple[fun_mc.FloatTensor, Any]]', trajectory_length_adaptation_rate: 'fun_mc.FloatTensor' = 0.05, trajectory_length_sample_fn: 'Callable[[Any, fun_mc.IntTensor, Any], ' 'fun_mc.FloatTensor]' = (default_trajectory_length_sample), trajectory_length_constrain_fn: 'Callable[[Any], Any]' = ( default_trajectory_length_constrain), adam_kwargs: 'Mapping[str, Any]' = immutabledict({ 'beta_1': 0., 'beta_2': 0.5 }), averaging_window_steps: 'fun_mc.IntTensor' = 100, adapt: 'fun_mc.BooleanTensor' = True, seed: 'Any' = None, **hmc_kwargs: 'Mapping[str, Any]', ): """Stochastic gradient ascent Hamiltonian Monte Carlo step. SGA-HMC posits an existence of a parameteric distribution over trajectory lengths. It then uses stochastic gradients to adapt those parameters by maximizing the expected value of some criterion. The gradients are computed by the use of `hamiltonian_monte_carlo_with_state_grads_step` and then using Monte-Carlo averages across separate Markov Chains and Markov Chain iterates. ChEES [1] criterion is the prototypical example. The trajectory distribution is parameterized via `trajectory_length_sample_fn` and `trajectory_length_params_constrain_fn`. While `adapt` is `False`, the parameters are adapted using Adam (controlled using `trajectory_length_adaptation_rate` and `adam_kwargs`). When `adapt` is `False`, averaged parameters are used, which have been computed via an exponential moving while `adapt` was `True`. The degree of averaging is controlled via `averaging_window_steps`. The parameters that were actually used for this step are returned in the `trajectory_length_params` field in the `sga_hmc_extra` return. Args: sga_hmc_state: `StochasticGradientAscentHMCState` scalar_step_size: Scalar step size (see `hamiltonian_monte_carlo_with_state_grads_step` for details). criterion_fn: Callable with signature `(previous_state, proposed_state, accept_prob, trajectory_length) -> (criterion, criterion_extra)`. The criterion to maximize. trajectory_length_adaptation_rate: Adaption rate for the trajectory length parameters. trajectory_length_sample_fn: Callable with signature `(trajectory_length_params, step, seed) -> trajectory_length`. Used to sample a new trajectory length. trajectory_length_constrain_fn: Used to constrain the trajectory length parameters by projecting them into the allowed set. adam_kwargs: Additional keyword arguments for Adam optimizer used to adapt the trajectory length parameters. averaging_window_steps: Window size for averaging the trajectory parameters. See `fun_mc.running_mean_step` for the meaning of this argument. adapt: Whether to adapt the trajectory parameters and whether to use the adapted parameters or the averaged parameters when sampling the trajectory for this step. seed: PRNG seed. **hmc_kwargs: Passed to `hamiltonian_monte_carlo_with_state_grads_step`. Returns: sga_hmc_state: `StochasticGradientAscentHMCState`. sga_hmc_extra: `StochasticGradientAscentHMCExtra`. #### References [1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme for Setting Trajectory Lengths in Hamiltonian Monte Carlo. In preparation. """ seed, sample_seed, hmc_seed = util.split_seed(seed, 3) @util.named_call def loss_fn(*args, **kwargs): rmean_params = sga_hmc_state.trajectory_length_params_rmean_state.mean adapting_params = fun_mc.recover_state_from_args(args, kwargs, rmean_params) params = fun_mc.choose(adapt, adapting_params, rmean_params) trajectory_length = trajectory_length_sample_fn(params, sga_hmc_state.step, sample_seed) hmc_state, hmc_extra = hamiltonian_monte_carlo_with_state_grads_step( sga_hmc_state.hmc_state, trajectory_length=trajectory_length, scalar_step_size=scalar_step_size, seed=hmc_seed, **hmc_kwargs) accept_prob = tf.exp( tf.minimum( tf.zeros_like(hmc_extra.hmc_extra.log_accept_ratio), hmc_extra.hmc_extra.log_accept_ratio)) accept_prob = tf.where( tf.math.is_finite(accept_prob), accept_prob, tf.zeros_like(accept_prob)) criterion, criterion_extra = criterion_fn( sga_hmc_state.hmc_state.state, hmc_extra.proposed_state, accept_prob, # + step_size because we're effectively doing floor(traj / step_size) # when computing the number of leapfrog steps. trajectory_length + scalar_step_size, ) return -criterion, (hmc_state, hmc_extra, criterion, criterion_extra, params) # Adapt trajectory. trajectory_length_params_opt_state, trajectory_length_params_opt_extra = fun_mc.adam_step( sga_hmc_state.trajectory_length_params_opt_state, loss_fn, learning_rate=trajectory_length_adaptation_rate, **adam_kwargs, ) (hmc_state, hmc_extra, criterion, criterion_extra, trajectory_length_params) = trajectory_length_params_opt_extra.loss_extra # Constrain trajectory params. trajectory_length_params_opt_state = fun_mc.choose( adapt, trajectory_length_params_opt_state, sga_hmc_state.trajectory_length_params_opt_state) constrained_trajectory_length_params = trajectory_length_constrain_fn( trajectory_length_params_opt_state.state) trajectory_length_params_opt_state = trajectory_length_params_opt_state._replace( state=constrained_trajectory_length_params) # Update the running mean for trajectory params. trajectory_length_params_rmean_state, _ = fun_mc.running_mean_step( sga_hmc_state.trajectory_length_params_rmean_state, trajectory_length_params_opt_state.state, window_size=averaging_window_steps) trajectory_length_params_rmean_state = fun_mc.choose( adapt, trajectory_length_params_rmean_state, sga_hmc_state.trajectory_length_params_rmean_state) sga_hmc_state = sga_hmc_state._replace( hmc_state=hmc_state, step=sga_hmc_state.step + 1, trajectory_length_params_rmean_state=trajectory_length_params_rmean_state, trajectory_length_params_opt_state=trajectory_length_params_opt_state, ) sga_hmc_extra = StochasticGradientAscentHMCExtra( hmc_extra=hmc_extra.hmc_extra, num_integrator_steps=hmc_extra.num_integrator_steps, trajectory_length_params_opt_extra=trajectory_length_params_opt_extra, criterion=criterion, criterion_extra=criterion_extra, trajectory_length_params=trajectory_length_params, ) return sga_hmc_state, sga_hmc_extra