def kernel(rvs, idx): rvs, _ = fun_mcmc.running_variance_step( rvs, data[idx], window_size=window_size) return (rvs, idx + 1), (rvs.mean, rvs.variance)
def adaptive_hamiltonian_monte_carlo_step( adaptive_hmc_state: 'AdaptiveHamiltonianMonteCarloState', target_log_prob_fn: 'fun_mc.PotentialFn', num_adaptation_steps: 'Optional[int]', variance_window_steps: 'int' = 100, trajectory_length_factor: 'fun_mc.FloatTensor' = 1., num_trajectory_ramp_steps: 'Optional[int]' = None, trajectory_warmup_power=1., step_size_adaptation_rate: 'fun_mc.FloatTensor' = 1e-2, step_size_adaptation_rate_decay_power=0.1, target_accept_prob: 'float' = 0.8, seed: 'Any' = None, ) -> ('Tuple[AdaptiveHamiltonianMonteCarloState, ' 'AdaptiveHamiltonianMonteCarloExtra]'): """Adaptive Hamiltonian Monte Carlo `TransitionOperator`. This implements a relatively straighforward adaptive HMC algorithm with diagonal mass-matrix adaptation and step size adaptation. The algorithm also estimates the trajectory length based on the variance of the chain. All adaptation stops after `num_adaptation_steps`, after which this algorithm becomes regular HMC with fixed number of leapfrog steps, and fixed per-component step size. Typically, chain samples after `num_adaptation_steps + num_warmup_steps` are discarded, where `num_warmup_steps` is typically heuristically chosen to be `0.25 * num_adaptation_steps`. For maximum efficiency, however, it's recommended to actually use `fun_mc.hamiltonian_monte_carlo` initialized with the relevant hyperparameters, as that `TransitionOperator` won't have the overhead of the adaptation logic. This can be set to `None`, in which case adaptation never stops (and the algorihthm ceases to be calibrated). The mass matrix is adapted by computing an exponential moving variance of the chain. The averaging window is controlled by the `variance_window_steps`, with larger values leading to a smoother estimate. Trajectory length is computed as `max(sqrt(chain_variance)) * trajectory_length_factor`. To be resilient to poor initialization, `trajectory_length_factor` can be increased from 0 based on a polynomial schedule, controlled by `num_trajectory_ramp_steps` and `trajectory_warmup_power`. The step size is adapted to make the acceptance probability close to `target_accept_prob`, using the `Adam` optimizer. This is controlled by the `step_size_adaptation_rate`. If `num_adaptation_steps` is not `None`, this rate is decayed using a polynomial schedule controlled by `step_size_adaptation_rate`. Args: adaptive_hmc_state: `AdaptiveHamiltonianMonteCarloState` target_log_prob_fn: Target log prob fn. num_adaptation_steps: Number of adaptation steps, can be `None`. variance_window_steps: Window to compute the chain variance over. trajectory_length_factor: Trajectory length factor. num_trajectory_ramp_steps: Number of steps to warmup the `trajectory_length_factor`. trajectory_warmup_power: Power of the polynomial schedule for `trajectory_length_factor` warmup. step_size_adaptation_rate: Step size adaptation rate., step_size_adaptation_rate_decay_power: Power of the polynomial schedule for `trajectory_length_factor` warmup. target_accept_prob: Target acceptance probability. seed: Random seed to use. Returns: adaptive_hmc_state: `AdaptiveHamiltonianMonteCarloState`. adaptive_hmc_extra: `AdaptiveHamiltonianMonteCarloExtra`. #### Examples Here's an example using using Adaptive HMC and TensorFlow Probability to sample from a simple model. ```python num_chains = 16 num_steps = 2000 num_warmup_steps = num_steps // 2 num_adapt_steps = int(0.8 * num_warmup_steps) # Setup the model and state constraints. model = tfp.distributions.JointDistributionSequential([ tfp.distributions.Normal(loc=0., scale=1.), tfp.distributions.Independent( tfp.distributions.LogNormal(loc=[1., 1.], scale=0.5), 1), ]) bijector = [tfp.bijectors.Identity(), tfp.bijectors.Exp()] transform_fn = fun_mcmc.util_tfp.bijector_to_transform_fn( bijector, model.dtype, batch_ndims=1) def target_log_prob_fn(*x): return model.log_prob(x), () # Start out at zeros (in the unconstrained space). state, _ = transform_fn( *map(lambda e: tf.zeros([num_chains] + list(e)), model.event_shape)) reparam_log_prob_fn, reparam_state = fun_mcmc.reparameterize_potential_fn( target_log_prob_fn, transform_fn, state) # Define the kernel. def kernel(adaptive_hmc_state): adaptive_hmc_state, adaptive_hmc_extra = ( fun_mcmc.prefab.adaptive_hamiltonian_monte_carlo_step( adaptive_hmc_state, target_log_prob_fn=reparam_log_prob_fn, num_adaptation_steps=num_adapt_steps)) return adaptive_hmc_state, (adaptive_hmc_extra.state, adaptive_hmc_extra.is_accepted, adaptive_hmc_extra.step_size) _, (state_chain, is_accepted_chain, step_size_chain) = tf.function( lambda: fun_mcmc.trace( state=fun_mcmc.prefab.adaptive_hamiltonian_monte_carlo_init( reparam_state, reparam_log_prob_fn), fn=kernel, num_steps=num_steps), autograph=False)() # Discard the warmup samples. state_chain = [s[num_warmup_steps:] for s in state_chain] is_accepted_chain = is_accepted_chain[num_warmup_steps:] # Compute diagnostics. accept_rate = tf.reduce_mean(tf.cast(is_accepted_chain, tf.float32)) ess = tfp.mcmc.effective_sample_size( state_chain, filter_beyond_positive_pairs=True, cross_chain_dims=[1, 1]) rhat = tfp.mcmc.potential_scale_reduction(state_chain) # Compute relevant quantities. sample_mean = [tf.reduce_mean(s, axis=[0, 1]) for s in state_chain] sample_var = [tf.math.reduce_variance(s, axis=[0, 1]) for s in state_chain] # It's also important to look at the `step_size_chain` (e.g. via a plot), to # verify that adaptation succeeded. ``` """ dtype = util.flatten_tree(adaptive_hmc_state.hmc_state.state)[0].dtype step_size_adaptation_rate = tf.convert_to_tensor(step_size_adaptation_rate, dtype=dtype) trajectory_length_factor = tf.convert_to_tensor(trajectory_length_factor, dtype=dtype) target_accept_prob = tf.convert_to_tensor(target_accept_prob, dtype=dtype) step_size_adaptation_rate_decay_power = tf.convert_to_tensor( step_size_adaptation_rate_decay_power, dtype=dtype) trajectory_warmup_power = tf.convert_to_tensor(trajectory_warmup_power, dtype=dtype) hmc_state = adaptive_hmc_state.hmc_state running_var_state = adaptive_hmc_state.running_var_state log_step_size_opt_state = adaptive_hmc_state.log_step_size_opt_state step = adaptive_hmc_state.step # Warmup the trajectory length, if requested. if num_trajectory_ramp_steps is not None: trajectory_length_factor = _polynomial_decay( step=step, step_size=tf.constant(0., dtype), decay_steps=num_trajectory_ramp_steps, final_step_size=trajectory_length_factor, power=trajectory_warmup_power, ) # Compute the per-component step_size and num_leapfrog_steps from the variance # estimate. scale = util.map_tree(tf.math.sqrt, running_var_state.variance) max_scale = functools.reduce( tf.maximum, util.flatten_tree(util.map_tree(tf.reduce_max, scale))) step_size = tf.exp(log_step_size_opt_state.state) num_leapfrog_steps = tf.cast( tf.math.ceil(max_scale * trajectory_length_factor / step_size), tf.int32) # We implement mass-matrix adaptation via step size rescaling, as this is a # little bit simpler to code up. step_size = util.map_tree(lambda scale: scale / max_scale * step_size, scale) # Run a step of HMC. hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo( hmc_state, target_log_prob_fn=target_log_prob_fn, step_size=step_size, num_integrator_steps=num_leapfrog_steps, seed=seed, ) # Update the running variance estimate. chain_ndims = len(hmc_state.target_log_prob.shape) old_running_var_state = running_var_state running_var_state, _ = fun_mc.running_variance_step( running_var_state, hmc_state.state, axis=tuple(range(chain_ndims)), window_size=int(np.prod(hmc_state.target_log_prob.shape)) * variance_window_steps) if num_adaptation_steps is not None: # Take care of adaptation for variance and step size. running_var_state = util.map_tree( lambda n, o: tf.where(step < num_adaptation_steps, n, o), # pylint: disable=g-long-lambda running_var_state, old_running_var_state) step_size_adaptation_rate = _polynomial_decay( step=step, step_size=step_size_adaptation_rate, decay_steps=num_adaptation_steps, final_step_size=0., power=step_size_adaptation_rate_decay_power, ) # Update the scalar step size as a function of acceptance rate. p_accept = tf.reduce_mean( tf.exp(tf.minimum(hmc_extra.log_accept_ratio, 0.))) p_accept = tf.where(tf.math.is_finite(p_accept), p_accept, tf.zeros_like(p_accept)) loss_fn = fun_mc.make_surrogate_loss_fn(lambda _: # pylint: disable=g-long-lambda (target_accept_prob - p_accept, () )) log_step_size_opt_state, _ = fun_mc.adam_step(log_step_size_opt_state, loss_fn, step_size_adaptation_rate) adaptive_hmc_state = AdaptiveHamiltonianMonteCarloState( hmc_state=hmc_state, running_var_state=running_var_state, log_step_size_opt_state=log_step_size_opt_state, step=step + 1, ) extra = AdaptiveHamiltonianMonteCarloExtra( hmc_state=hmc_state, hmc_extra=hmc_extra, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps) return adaptive_hmc_state, extra
def kernel(rvs, idx): rvs, _ = fun_mcmc.running_variance_step(rvs, data[idx], axis=aggregation) return (rvs, idx + 1), ()