def build_kernel(logpdf: Callable, parameters: HMCParameters) -> Callable: """Builds the kernel that moves the chain from one point to the next. """ potential = logpdf try: inverse_mass_matrix = parameters.inverse_mass_matrix num_integration_steps = parameters.num_integration_steps step_size = parameters.step_size except AttributeError: AttributeError( "The Hamiltonian Monte Carlo algorithm requires the following parameters: mass matrix, inverse mass matrix and step size." ) momentum_generator, kinetic_energy = gaussian_euclidean_metric( inverse_mass_matrix, ) integrator_step = integrator(potential, kinetic_energy) proposal = hmc_proposal(integrator_step, step_size, num_integration_steps) kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, potential) return kernel
def build_kernel(num_integration_steps, step_size, inverse_mass_matrix): momentum_generator, kinetic_energy = gaussian_euclidean_metric( inverse_mass_matrix, ) integrator_step = self.integrator(potential, kinetic_energy) proposal = hmc_proposal(integrator_step, step_size, num_integration_steps) kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, potential) return kernel
def kernel_generator(step_size, inv_mass_matrix): momentum_generator, kinetic_energy = gaussian_euclidean_metric( inv_mass_matrix) integrator_step = velocity_verlet(potential_fn, kinetic_energy) proposal = hmc_proposal(integrator_step, step_size, 1) kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, potential_fn) return kernel
def test_find_reasonable_step_size(): def potential_fn(x): return np.sum(0.5 * np.square(x)) rng_key = jax.random.PRNGKey(0) init_position = np.array([3.0]) inv_mass_matrix = np.array([1.0]) init_state = hmc_init(init_position, potential_fn) momentum_generator, kinetic_energy = gaussian_euclidean_metric(inv_mass_matrix) integrator_step = velocity_verlet(potential_fn, kinetic_energy) # Test that the algorithm actually does something epsilon_1 = find_reasonable_step_size( rng_key, momentum_generator, kinetic_energy, potential_fn, integrator_step, init_state, 1.0, 0.95, ) assert epsilon_1 != 1.0 # Different target acceptance rate epsilon_3 = find_reasonable_step_size( rng_key, momentum_generator, kinetic_energy, potential_fn, integrator_step, init_state, 1.0, 0.05, ) assert epsilon_3 != epsilon_1
def stan_hmc_warmup( rng_key: jax.random.PRNGKey, logpdf: Callable, initial_state: HMCState, euclidean_metric: Callable, integrator_step: Callable, inital_step_size: float, path_length: float, num_steps: int, is_mass_matrix_diagonal=True, ) -> Tuple[HMCState, DualAveragingState, MassMatrixAdaptationState]: """ Warmup scheme for sampling procedures based on euclidean manifold HMC. The schedule and algorithms used match Stan's [1]_ as closely as possible. Unlike several other libraries, we separate the warmup and sampling phases explicitly. This ensure a better modularity; a change in the warmup does not affect the sampling. It also allows users to run their own warmup should they want to. Stan's warmup consists in the three following phases: 1. A fast adaptation window where only the step size is adapted using Nesterov's dual averaging scheme to match a target acceptance rate. 2. A succession of slow adapation windows (where the size of a window is double that of the previous window) where both the mass matrix and the step size are adapted. The mass matrix is recomputed at the end of each window; the step size is re-initialized to a "reasonable" value. 3. A last fast adaptation window where only the step size is adapted. Arguments --------- Returns ------- Tuple The current state of the chain, of the dual averaging scheme and mass matrix adaptation scheme. """ n_dims = np.shape(initial_state.position)[-1] # `position` is a 1D array # Initialize the mass matrix adaptation mm_init, mm_update, mm_final = mass_matrix_adaptation( is_mass_matrix_diagonal) mm_state = mm_init(n_dims) # Initialize the HMC transition kernel momentum_generator, kinetic_energy = euclidean_metric( mm_state.inverse_mass_matrix) # Find a first reasonable step size and initialize dual averaging step_size = find_reasonable_step_size( rng_key, momentum_generator, kinetic_energy, integrator_step, initial_state, inital_step_size, ) da_init, da_update = dual_averaging() da_state = da_init(step_size) # initial kernel proposal = hmc_proposal(integrator_step, path_length, step_size) kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, logpdf) # Get warmup schedule schedule = warmup_schedule(num_steps) state = initial_state for i, window in enumerate(schedule): is_middle_window = (0 < i) & (i < (len(schedule) - 1)) for step in range(window): _, rng_key = jax.random.split(rng_key) state, info = kernel(rng_key, state) da_state = da_update(info.acceptance_probability, da_state) step_size = np.exp(da_state.log_step_size) proposal = hmc_proposal(integrator_step, path_length, step_size) if is_middle_window: mm_state = mm_update(mm_state, state.position) kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, logpdf) if is_middle_window: inverse_mass_matrix = mm_final(mm_state) momentum_generator, kinetic_energy = gaussian_euclidean_metric( inverse_mass_matrix) mm_state = mm_init(n_dims) step_size = find_reasonable_step_size( rng_key, momentum_generator, kinetic_energy, integrator_step, state, step_size, ) da_state = da_init(step_size) proposal = hmc_proposal(integrator_step, path_length, step_size) kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, logpdf) return state, da_state, mm_state