def sample_kernel(hmc_state, model_args=(), model_kwargs=None): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. :param hmc_state: Current sample (and associated state). :param tuple model_args: Model arguments if `potential_fn_gen` is specified. :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. :return: new proposed :data:`~numpyro.infer.mcmc.HMCState` from simulating Hamiltonian dynamics given existing state. """ model_kwargs = {} if model_kwargs is None else model_kwargs rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size, hmc_state.adapt_state.inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key_transition) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state), lambda args: wa_update(*args), hmc_state.adapt_state, identity) itr = hmc_state.i + 1 n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng_key): binary_tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy=max_delta_energy, max_tree_depth=max_treedepth) accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals num_steps = binary_tree.num_proposals vv_state = IntegratorState(z=binary_tree.z_proposal, r=vv_state.r, potential_energy=binary_tree.z_proposal_pe, z_grad=binary_tree.z_proposal_grad) return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
def _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key): if potential_fn_gen: nonlocal vv_update pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn) binary_tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy=max_delta_energy, max_tree_depth=max_treedepth) accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals num_steps = binary_tree.num_proposals vv_state = IntegratorState(z=binary_tree.z_proposal, r=vv_state.r, potential_energy=binary_tree.z_proposal_pe, z_grad=binary_tree.z_proposal_grad) return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
def init_kernel(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, find_heuristic_step_size=False, model_args=(), model_kwargs=None, rng_key=random.PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. :param tuple model_args: Model arguments if `potential_fn_gen` is specified. :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64)) nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth if isinstance(init_params, ParamInfo): z, pe, z_grad = init_params else: z, pe, z_grad = init_params, None, None pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError( 'Only one of `potential_fn` or `potential_fn_gen` must be provided.' ) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) find_reasonable_ss = None if find_heuristic_step_size: find_reasonable_ss = partial(find_reasonable_step_size, pe_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3) z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad) wa_state = wa_init(z_info, rng_key_wa, step_size, inverse_mass_matrix=inverse_mass_matrix, mass_matrix_size=jnp.size(ravel_pytree(z)[0])) r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn) vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state, rng_key_hmc) return device_put(hmc_state)
def init_kernel( init_params, num_warmup, *, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True, model_args=(), model_kwargs=None, rng_key=random.PRNGKey(0), ): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param inverse_mass_matrix: Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. For a potential_fn with general JAX pytree parameters, the order of entries of the mass matrix is the order of the flattened version of pytree parameters obtained with `jax.tree_flatten`, which is a bit ambiguous (see more at https://jax.readthedocs.io/en/latest/pytrees.html). If `model` is not None, here we can specify a structured block mass matrix as a dictionary, where keys are tuple of site names and values are the corresponding block of the mass matrix. For more information about structured mass matrix, see `dense_mass` argument. :type inverse_mass_matrix: numpy.ndarray or dict :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param dense_mass: This flag controls whether mass matrix is dense (i.e. full-rank) or diagonal (defaults to ``dense_mass=False``). To specify a structured mass matrix, users can provide a list of tuples of site names. Each tuple represents a block in the joint mass matrix. For example, assuming that the model has latent variables "x", "y", "z" (where each variable can be multi-dimensional), possible specifications and corresponding mass matrix structures are as follows: + dense_mass=[("x", "y")]: use a dense mass matrix for the joint (x, y) and a diagonal mass matrix for z + dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass matrix for the joint (x, y, z) + dense_mass=[("x", "y", "z")] (equivalent to full_mass=True): use a dense mass matrix for the joint (x, y, z) + dense_mass=[("x",), ("y",), ("z")]: use dense mass matrices for each of x, y, and z (i.e. block-diagonal with 3 blocks) :type dense_mass: bool or list :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Defaults to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of integers `(d1, d2)`, where `d1` is the max tree depth during warmup phase and `d2` is the max tree depth during post warmup phase. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. :param bool regularize_mass_matrix: whether or not to regularize the estimated mass matrix for numerical stability during warmup phase. Defaults to True. This flag does not take effect if ``adapt_mass_matrix == False``. :param tuple model_args: Model arguments if `potential_fn_gen` is specified. :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type(step_size, jnp.result_type(float)) if trajectory_length is not None: trajectory_length = lax.convert_element_type( trajectory_length, jnp.result_type(float)) nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad forward_mode_ad = forward_mode_differentiation wa_steps = num_warmup max_treedepth = (max_tree_depth if isinstance(max_tree_depth, tuple) else (max_tree_depth, max_tree_depth)) if isinstance(init_params, ParamInfo): z, pe, z_grad = init_params else: z, pe, z_grad = init_params, None, None pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError( "Only one of `potential_fn` or `potential_fn_gen` must be provided." ) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) find_reasonable_ss = None if find_heuristic_step_size: find_reasonable_ss = partial(find_reasonable_step_size, pe_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss, regularize_mass_matrix=regularize_mass_matrix, ) rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3) z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad) wa_state = wa_init(z_info, rng_key_wa, step_size, inverse_mass_matrix=inverse_mass_matrix) r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad) vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = vv_state.potential_energy + kinetic_fn( wa_state.inverse_mass_matrix, vv_state.r) zero_int = jnp.array(0, dtype=jnp.result_type(int)) hmc_state = HMCState( zero_int, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, None, trajectory_length, zero_int, jnp.zeros(()), jnp.zeros(()), jnp.array(False), wa_state, rng_key_hmc, ) return device_put(hmc_state)