def test_build_tree(step_size): def kinetic_fn(m_inv, p): return 0.5 * np.sum(m_inv * p**2) def potential_fn(q): return 0.5 * q**2 vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn) vv_state = vv_init(0.0, 1.0) inverse_mass_matrix = np.array([1.]) rng_key = random.PRNGKey(0) @jit def fn(vv_state): tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng_key) return tree tree = fn(vv_state) assert tree.num_proposals >= 2**(tree.depth - 1) assert tree.sum_accept_probs <= tree.num_proposals if tree.depth < 10: assert tree.turning | tree.diverging # for large step_size, assert that diverging will happen in 1 step if step_size > 10: assert tree.diverging assert tree.num_proposals == 1 # for small step_size, assert that it should take a while to meet the terminate condition if step_size < 0.1: assert tree.num_proposals > 10
def get_final_state(model, step_size, num_steps, q_i, p_i): vv_init, vv_update = velocity_verlet(model.potential_fn, model.kinetic_fn) vv_state = vv_init(q_i, p_i) q_f, p_f, _, _ = fori_loop( 0, num_steps, lambda i, val: vv_update(step_size, args.m_inv, val), vv_state ) return (q_f, p_f)
def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key): if potential_fn_gen: nonlocal vv_update, forward_mode_ad pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad) num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop( 0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng_key, accept_prob) vv_state, energy = cond(transition, (vv_state_new, energy_new), identity, (vv_state, energy_old), identity) return vv_state, energy, num_steps, accept_prob, diverging
def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, trajectory_length): if potential_fn_gen: nonlocal vv_update, forward_mode_ad pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad) # no need to spend too many steps if the state z has 0 size (i.e. z is empty) if len(inverse_mass_matrix) == 0: num_steps = 1 else: num_steps = _get_num_steps(step_size, trajectory_length) # makes sure trajectory length is constant, rather than step_size * num_steps step_size = trajectory_length / num_steps vv_state_new = fori_loop( 0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng_key, accept_prob) vv_state, energy = cond(transition, (vv_state_new, energy_new), identity, (vv_state, energy_old), identity) return vv_state, energy, num_steps, accept_prob, diverging
def _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key): if potential_fn_gen: nonlocal vv_update pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn) binary_tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy=max_delta_energy, max_tree_depth=max_treedepth) accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals num_steps = binary_tree.num_proposals vv_state = IntegratorState(z=binary_tree.z_proposal, r=vv_state.r, potential_energy=binary_tree.z_proposal_pe, z_grad=binary_tree.z_proposal_grad) return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
def init_kernel(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, find_heuristic_step_size=False, model_args=(), model_kwargs=None, rng_key=random.PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. :param tuple model_args: Model arguments if `potential_fn_gen` is specified. :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64)) nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth if isinstance(init_params, ParamInfo): z, pe, z_grad = init_params else: z, pe, z_grad = init_params, None, None pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError( 'Only one of `potential_fn` or `potential_fn_gen` must be provided.' ) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) find_reasonable_ss = None if find_heuristic_step_size: find_reasonable_ss = partial(find_reasonable_step_size, pe_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3) z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad) wa_state = wa_init(z_info, rng_key_wa, step_size, inverse_mass_matrix=inverse_mass_matrix, mass_matrix_size=jnp.size(ravel_pytree(z)[0])) r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn) vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state, rng_key_hmc) return device_put(hmc_state)
def init_kernel( init_params, num_warmup, *, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True, model_args=(), model_kwargs=None, rng_key=random.PRNGKey(0), ): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param inverse_mass_matrix: Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. For a potential_fn with general JAX pytree parameters, the order of entries of the mass matrix is the order of the flattened version of pytree parameters obtained with `jax.tree_flatten`, which is a bit ambiguous (see more at https://jax.readthedocs.io/en/latest/pytrees.html). If `model` is not None, here we can specify a structured block mass matrix as a dictionary, where keys are tuple of site names and values are the corresponding block of the mass matrix. For more information about structured mass matrix, see `dense_mass` argument. :type inverse_mass_matrix: numpy.ndarray or dict :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param dense_mass: This flag controls whether mass matrix is dense (i.e. full-rank) or diagonal (defaults to ``dense_mass=False``). To specify a structured mass matrix, users can provide a list of tuples of site names. Each tuple represents a block in the joint mass matrix. For example, assuming that the model has latent variables "x", "y", "z" (where each variable can be multi-dimensional), possible specifications and corresponding mass matrix structures are as follows: + dense_mass=[("x", "y")]: use a dense mass matrix for the joint (x, y) and a diagonal mass matrix for z + dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass matrix for the joint (x, y, z) + dense_mass=[("x", "y", "z")] (equivalent to full_mass=True): use a dense mass matrix for the joint (x, y, z) + dense_mass=[("x",), ("y",), ("z")]: use dense mass matrices for each of x, y, and z (i.e. block-diagonal with 3 blocks) :type dense_mass: bool or list :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Defaults to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of integers `(d1, d2)`, where `d1` is the max tree depth during warmup phase and `d2` is the max tree depth during post warmup phase. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. :param bool regularize_mass_matrix: whether or not to regularize the estimated mass matrix for numerical stability during warmup phase. Defaults to True. This flag does not take effect if ``adapt_mass_matrix == False``. :param tuple model_args: Model arguments if `potential_fn_gen` is specified. :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type(step_size, jnp.result_type(float)) if trajectory_length is not None: trajectory_length = lax.convert_element_type( trajectory_length, jnp.result_type(float)) nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad forward_mode_ad = forward_mode_differentiation wa_steps = num_warmup max_treedepth = (max_tree_depth if isinstance(max_tree_depth, tuple) else (max_tree_depth, max_tree_depth)) if isinstance(init_params, ParamInfo): z, pe, z_grad = init_params else: z, pe, z_grad = init_params, None, None pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError( "Only one of `potential_fn` or `potential_fn_gen` must be provided." ) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) find_reasonable_ss = None if find_heuristic_step_size: find_reasonable_ss = partial(find_reasonable_step_size, pe_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss, regularize_mass_matrix=regularize_mass_matrix, ) rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3) z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad) wa_state = wa_init(z_info, rng_key_wa, step_size, inverse_mass_matrix=inverse_mass_matrix) r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad) vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = vv_state.potential_energy + kinetic_fn( wa_state.inverse_mass_matrix, vv_state.r) zero_int = jnp.array(0, dtype=jnp.result_type(int)) hmc_state = HMCState( zero_int, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, None, trajectory_length, zero_int, jnp.zeros(()), jnp.zeros(()), jnp.array(False), wa_state, rng_key_hmc, ) return device_put(hmc_state)
def hmc(potential_fn, kinetic_fn=None, algo='NUTS'): r""" Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length. **References:** 1. *MCMC Using Hamiltonian Dynamics*, Radford M. Neal 2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, and Andrew Gelman. 3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*, Michael Betancourt :param potential_fn: Python callable that computes the potential energy given input parameters. The input parameters to `potential_fn` can be any python collection type, provided that `init_params` argument to `init_kernel` has the same type. :param kinetic_fn: Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy. :param str algo: Whether to run ``HMC`` with fixed number of steps or ``NUTS`` with adaptive path length. Default is ``NUTS``. :return: a tuple of callables (`init_kernel`, `sample_kernel`), the first one to initialize the sampler, and the second one to generate samples given an existing one. .. warning:: Instead of using this interface directly, we would highly recommend you to use the higher level :class:`numpyro.infer.MCMC` API instead. **Example** .. testsetup:: import jax from jax import random import jax.numpy as np import numpyro import numpyro.distributions as dist from numpyro.infer.mcmc import hmc from numpyro.infer.util import initialize_model from numpyro.util import fori_collect .. doctest:: >>> true_coefs = np.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(2), (2000, 3)) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) >>> >>> def model(data, labels): ... coefs_mean = np.zeros(dim) ... coefs = numpyro.sample('beta', dist.Normal(coefs_mean, np.ones(3))) ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) >>> >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), ... model, data, labels) >>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') >>> hmc_state = init_kernel(init_params, ... trajectory_length=10, ... num_warmup=300) >>> samples = fori_collect(0, 500, sample_kernel, hmc_state, ... transform=lambda state: constrain_fn(state.z)) >>> print(np.mean(samples['beta'], axis=0)) # doctest: +SKIP [0.9153987 2.0754058 2.9621222] """ if kinetic_fn is None: kinetic_fn = euclidean_kinetic_energy vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn) trajectory_len = None max_treedepth = None momentum_generator = None wa_update = None wa_steps = None max_delta_energy = 1000. if algo not in {'HMC', 'NUTS'}: raise ValueError('`algo` must be one of `HMC` or `NUTS`.') def init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, run_warmup=True, progbar=True, rng_key=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, `init_kernel` returns an initial :data:`~numpyro.infer.mcmc.HMCState` that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation. :param bool progbar: Whether to enable progress bar updates. Defaults to ``True``. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type( step_size, xla_bridge.canonicalize_dtype(np.float64)) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_key_hmc, rng_key_wa = random.split(rng_key) wa_state = wa_init(z, rng_key_wa, step_size, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng_key) vv_state = vv_init(z, r) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state, rng_key_hmc) # TODO: Remove; this should be the responsibility of the MCMC class. if run_warmup and num_warmup > 0: # JIT if progress bar updates not required if not progbar: hmc_state = fori_loop(0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state) else: with tqdm.trange(num_warmup, desc='warmup') as t: for i in t: hmc_state = jit(sample_kernel)(hmc_state) t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False) return hmc_state def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng_key): num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop( 0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy) accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng_key, accept_prob) vv_state, energy = cond(transition, (vv_state_new, energy_new), lambda args: args, (vv_state, energy_old), lambda args: args) return vv_state, energy, num_steps, accept_prob, diverging def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng_key): binary_tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy=max_delta_energy, max_tree_depth=max_treedepth) accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals num_steps = binary_tree.num_proposals vv_state = IntegratorState(z=binary_tree.z_proposal, r=vv_state.r, potential_energy=binary_tree.z_proposal_pe, z_grad=binary_tree.z_proposal_grad) return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging _next = _nuts_next if algo == 'NUTS' else _hmc_next def sample_kernel(hmc_state): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. :param hmc_state: Current sample (and associated state). :return: new proposed :data:`~numpyro.infer.mcmc.HMCState` from simulating Hamiltonian dynamics given existing state. """ rng_key, rng_key_momentum, rng_key_transition = random.split( hmc_state.rng_key, 3) r = momentum_generator(hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) vv_state, energy, num_steps, accept_prob, diverging = _next( hmc_state.adapt_state.step_size, hmc_state.adapt_state.inverse_mass_matrix, vv_state, rng_key_transition) # not update adapt_state after warmup phase adapt_state = cond( hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state.z, hmc_state.adapt_state), lambda args: wa_update(*args), hmc_state.adapt_state, lambda x: x) itr = hmc_state.i + 1 n = np.where(hmc_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = hmc_state.mean_accept_prob + ( accept_prob - hmc_state.mean_accept_prob) / n return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key) # Make `init_kernel` and `sample_kernel` visible from the global scope once # `hmc` is called for sphinx doc generation. if 'SPHINX_BUILD' in os.environ: hmc.init_kernel = init_kernel hmc.sample_kernel = sample_kernel return init_kernel, sample_kernel