def test_build_tree(step_size): def kinetic_fn(m_inv, p): return 0.5 * np.sum(m_inv * p ** 2) def potential_fn(q): return 0.5 * q ** 2 vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn) vv_state = vv_init(0.0, 1.0) inverse_mass_matrix = np.array([1.]) rng = random.PRNGKey(0) @jit def fn(vv_state): tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng) return tree tree = fn(vv_state) assert tree.num_proposals >= 2 ** (tree.depth - 1) assert tree.sum_accept_probs <= tree.num_proposals if tree.depth < 10: assert tree.turning | tree.diverging # for large step_size, assert that diverging will happen in 1 step if step_size > 10: assert tree.diverging assert tree.num_proposals == 1 # for small step_size, assert that it should take a while to meet the terminate condition if step_size < 0.1: assert tree.num_proposals > 10
def get_final_state(model, step_size, num_steps, q_i, p_i): vv_init, vv_update = velocity_verlet(model.potential_fn, model.kinetic_fn) vv_state = vv_init(q_i, p_i) q_f, p_f, _, _ = lax.fori_loop(0, num_steps, lambda i, val: vv_update(step_size, args.m_inv, val), vv_state) return (q_f, p_f)
def hmc(potential_fn, kinetic_fn=None, algo='NUTS'): r""" Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length. **References:** 1. *MCMC Using Hamiltonian Dynamics*, Radford M. Neal 2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, and Andrew Gelman. 3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*, Michael Betancourt :param potential_fn: Python callable that computes the potential energy given input parameters. The input parameters to `potential_fn` can be any python collection type, provided that `init_params` argument to `init_kernel` has the same type. :param kinetic_fn: Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy. :param str algo: Whether to run ``HMC`` with fixed number of steps or ``NUTS`` with adaptive path length. Default is ``NUTS``. :return: a tuple of callables (`init_kernel`, `sample_kernel`), the first one to initialize the sampler, and the second one to generate samples given an existing one. **Example** .. testsetup:: import jax from jax import random import jax.numpy as np import numpyro.distributions as dist from numpyro.handlers import sample from numpyro.hmc_util import initialize_model from numpyro.mcmc import hmc from numpyro.util import fori_collect .. doctest:: >>> true_coefs = np.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(2), (2000, 3)) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) >>> >>> def model(data, labels): ... coefs_mean = np.zeros(dim) ... coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3))) ... intercept = sample('intercept', dist.Normal(0., 10.)) ... return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) >>> >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), ... model, data, labels) >>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') >>> hmc_state = init_kernel(init_params, ... trajectory_length=10, ... num_warmup=300) >>> samples = fori_collect(0, 500, sample_kernel, hmc_state, ... transform=lambda state: constrain_fn(state.z)) >>> print(np.mean(samples['beta'], axis=0)) # doctest: +SKIP [0.9153987 2.0754058 2.9621222] """ if kinetic_fn is None: kinetic_fn = euclidean_kinetic_energy vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn) trajectory_len = None max_treedepth = None momentum_generator = None wa_update = None wa_steps = None if algo not in {'HMC', 'NUTS'}: raise ValueError('`algo` must be one of `HMC` or `NUTS`.') def init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2*math.pi, max_tree_depth=10, run_warmup=True, progbar=True, rng=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup_steps: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, `init_kernel` returns an initial :data:`HMCState` that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation. :param bool progbar: Whether to enable progress bar updates. Defaults to ``True``. :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of step size is done at the beginning of each adaptation window to achieve `target_acceptance_prob`. :param jax.random.PRNGKey rng: random key to be used as the source of randomness. """ step_size = float(step_size) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps wa_steps = num_warmup trajectory_len = float(trajectory_length) max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter(num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_hmc, rng_wa = random.split(rng) wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng) vv_state = vv_init(z, r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0., wa_state, rng_hmc) if run_warmup and num_warmup > 0: # JIT if progress bar updates not required if not progbar: hmc_state = fori_loop(0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state) else: with tqdm.trange(num_warmup, desc='warmup') as t: for i in t: hmc_state = sample_kernel(hmc_state) t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False) return hmc_state def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng): num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop(0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy) accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) transition = random.bernoulli(rng, accept_prob) vv_state = cond(transition, vv_state_new, lambda state: state, vv_state, lambda state: state) return vv_state, num_steps, accept_prob def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng): binary_tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng, max_tree_depth=max_treedepth) accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals num_steps = binary_tree.num_proposals vv_state = IntegratorState(z=binary_tree.z_proposal, r=vv_state.r, potential_energy=binary_tree.z_proposal_pe, z_grad=binary_tree.z_proposal_grad) return vv_state, num_steps, accept_prob _next = _nuts_next if algo == 'NUTS' else _hmc_next @jit def sample_kernel(hmc_state): """ Given an existing :data:`HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`HMCState`. :param hmc_state: Current sample (and associated state). :return: new proposed :data:`HMCState` from simulating Hamiltonian dynamics given existing state. """ rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3) r = momentum_generator(hmc_state.adapt_state.mass_matrix_sqrt, rng_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) vv_state, num_steps, accept_prob = _next(hmc_state.adapt_state.step_size, hmc_state.adapt_state.inverse_mass_matrix, vv_state, rng_transition) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state.z, hmc_state.adapt_state), lambda args: wa_update(*args), hmc_state.adapt_state, lambda x: x) itr = hmc_state.i + 1 n = np.where(hmc_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps, accept_prob, mean_accept_prob, adapt_state, rng) # Make `init_kernel` and `sample_kernel` visible from the global scope once # `hmc` is called for sphinx doc generation. if 'SPHINX_BUILD' in os.environ: hmc.init_kernel = init_kernel hmc.sample_kernel = sample_kernel return init_kernel, sample_kernel
def hmc_kernel(potential_fn, kinetic_fn=None, algo='NUTS'): if kinetic_fn is None: kinetic_fn = _euclidean_ke vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn) trajectory_length = None momentum_generator = None wa_update = None def init_kernel(init_samples, num_warmup_steps, step_size=1.0, num_steps=None, adapt_step_size=True, adapt_mass_matrix=True, diag_mass=True, target_accept_prob=0.8, run_warmup=True, rng=PRNGKey(0)): step_size = float(step_size) nonlocal trajectory_length, momentum_generator, wa_update if num_steps is None: trajectory_length = 2 * math.pi else: trajectory_length = num_steps * step_size z = init_samples z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup_steps, find_reasonable_step_size=find_reasonable_ss, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, diag_mass=diag_mass, target_accept_prob=target_accept_prob) rng_hmc, rng_wa = random.split(rng) wa_state = wa_init(z, rng_wa, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.inverse_mass_matrix, rng) vv_state = vv_init(z, r) hmc_state = HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., wa_state.step_size, wa_state.inverse_mass_matrix, rng_hmc) if run_warmup: hmc_state, _ = fori_loop(0, num_warmup_steps, warmup_update, (hmc_state, wa_state)) return hmc_state else: return hmc_state, wa_state, warmup_update def warmup_update(t, states): hmc_state, wa_state = states hmc_state = sample_kernel(hmc_state) wa_state = wa_update(t, hmc_state.accept_prob, hmc_state.z, wa_state) hmc_state = hmc_state.update( step_size=wa_state.step_size, inverse_mass_matrix=wa_state.inverse_mass_matrix) return hmc_state, wa_state def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng): num_steps = _get_num_steps(step_size, trajectory_length) vv_state_new = fori_loop( 0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy) accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) transition = random.bernoulli(rng, accept_prob) vv_state = cond(transition, vv_state_new, lambda state: state, vv_state, lambda state: state) return vv_state, num_steps, accept_prob def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng): binary_tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng) accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals num_steps = binary_tree.num_proposals vv_state = vv_state.update(z=binary_tree.z_proposal, potential_energy=binary_tree.z_proposal_pe, z_grad=binary_tree.z_proposal_grad) return vv_state, num_steps, accept_prob _next = _nuts_next if algo == 'NUTS' else _hmc_next def sample_kernel(hmc_state): rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3) r = momentum_generator(hmc_state.inverse_mass_matrix, rng_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) vv_state, num_steps, accept_prob = _next(hmc_state.step_size, hmc_state.inverse_mass_matrix, vv_state, rng_transition) return HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps, accept_prob, hmc_state.step_size, hmc_state.inverse_mass_matrix, rng) return init_kernel, sample_kernel