def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): rng_key, rng_r = random.split(rng_key) state = super().init(rng_key, num_warmup, init_params, model_args, model_kwargs) self._support_sizes_flat, _ = ravel_pytree( {k: self._support_sizes[k] for k in self._gibbs_sites}) if self._num_discrete_updates is None: self._num_discrete_updates = self._support_sizes_flat.shape[0] self._num_warmup = num_warmup # NB: the warmup adaptation can not be performed in sub-trajectories (i.e. the hmc trajectory # between two discrete updates), so we will do it here, at the end of each MixedHMC step. _, self._wa_update = warmup_adapter( num_warmup, adapt_step_size=self.inner_kernel._adapt_step_size, adapt_mass_matrix=self.inner_kernel._adapt_mass_matrix, dense_mass=self.inner_kernel._dense_mass, target_accept_prob=self.inner_kernel._target_accept_prob, find_reasonable_step_size=None, ) # In HMC, when `hmc_state.r` is not None, we will skip drawing a random momemtum at the # beginning of an HMC step. The reason is we need to maintain `r` between each sub-trajectories. r = momentum_generator(state.hmc_state.z, state.hmc_state.adapt_state.mass_matrix_sqrt, rng_r) return MixedHMCState(state.z, state.hmc_state._replace(r=r), state.rng_key, jnp.zeros(()))
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): self._num_warmup = num_warmup # TODO (low-priority): support chain_method="vectorized", i.e. rng_key is a batch of keys assert rng_key.shape == (2,), ("BarkerMH only supports chain_method='parallel' or chain_method='sequential'." " Please put in a feature request if you think it would be useful to be able " "to use BarkerMH in vectorized mode.") rng_key, rng_key_init_model, rng_key_wa = random.split(rng_key, 3) init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) if self._potential_fn and init_params is None: raise ValueError('Valid value of `init_params` must be provided with' ' `potential_fn`.') pe, grad = jax.value_and_grad(self._potential_fn)(init_params) wa_init, self._wa_update = warmup_adapter( num_warmup, adapt_step_size=self._adapt_step_size, adapt_mass_matrix=self._adapt_mass_matrix, dense_mass=self._dense_mass, target_accept_prob=self._target_accept_prob) size = len(ravel_pytree(init_params)[0]) wa_state = wa_init(None, rng_key_wa, self._step_size, mass_matrix_size=size) wa_state = wa_state._replace(rng_key=None) init_state = BarkerMHState(jnp.array(0), init_params, pe, grad, jnp.array(0.), jnp.array(0.), wa_state, rng_key) return jax.device_put(init_state)
def test_warmup_adapter(jitted): def find_reasonable_step_size(step_size, m_inv, z, rng_key): return jnp.where(step_size < 1, step_size * 4, step_size / 4) num_steps = 150 adaptation_schedule = build_adaptation_schedule(num_steps) init_step_size = 1. mass_matrix_size = 3 wa_init, wa_update = warmup_adapter(num_steps, find_reasonable_step_size) wa_update = jit(wa_update) if jitted else wa_update rng_key = random.PRNGKey(0) z = jnp.ones(3) wa_state = wa_init((z, None, None, None), rng_key, init_step_size, mass_matrix_size=mass_matrix_size) step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert step_size == find_reasonable_step_size(init_step_size, inverse_mass_matrix, z, rng_key) assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size)) assert window_idx == 0 window = adaptation_schedule[0] for t in range(window.start, window.end + 1): wa_state = wa_update(t, 0.7 + 0.1 * t / (window.end - window.start), z, wa_state) last_step_size = step_size step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert window_idx == 1 # step_size is decreased because accept_prob < target_accept_prob assert step_size < last_step_size # inverse_mass_matrix does not change at the end of the first window assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size)) window = adaptation_schedule[1] window_len = window.end - window.start for t in range(window.start, window.end + 1): wa_state = wa_update(t, 0.8 + 0.1 * (t - window.start) / window_len, 2 * z, wa_state) last_step_size = step_size step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert window_idx == 2 # step_size is increased because accept_prob > target_accept_prob assert step_size > last_step_size # Verifies that inverse_mass_matrix changes at the end of the second window. # Because z_flat is constant during the second window, covariance will be 0 # and only regularize_term of welford scheme is involved. # This also verifies that z_flat terms in the first window does not affect # the second window. welford_regularize_term = 1e-3 * (5 / (window.end + 1 - window.start + 5)) assert_allclose(inverse_mass_matrix, jnp.full((mass_matrix_size, ), welford_regularize_term), atol=1e-7) window = adaptation_schedule[2] for t in range(window.start, window.end + 1): wa_state = wa_update(t, 0.8, t * z, wa_state) last_step_size = step_size step_size, final_inverse_mass_matrix, _, _, _, window_idx, _ = wa_state assert window_idx == 3 # during the last window, because target_accept_prob=0.8, # log_step_size will be equal to the constant prox_center=log(10*last_step_size) assert_allclose(step_size, last_step_size * 10, atol=1e-6) # Verifies that inverse_mass_matrix does not change during the last window # despite z_flat changes w.r.t time t, assert_allclose(final_inverse_mass_matrix, inverse_mass_matrix)
def init_kernel(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, find_heuristic_step_size=False, model_args=(), model_kwargs=None, rng_key=random.PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. :param tuple model_args: Model arguments if `potential_fn_gen` is specified. :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64)) nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth if isinstance(init_params, ParamInfo): z, pe, z_grad = init_params else: z, pe, z_grad = init_params, None, None pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError( 'Only one of `potential_fn` or `potential_fn_gen` must be provided.' ) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) find_reasonable_ss = None if find_heuristic_step_size: find_reasonable_ss = partial(find_reasonable_step_size, pe_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3) z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad) wa_state = wa_init(z_info, rng_key_wa, step_size, inverse_mass_matrix=inverse_mass_matrix, mass_matrix_size=jnp.size(ravel_pytree(z)[0])) r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn) vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state, rng_key_hmc) return device_put(hmc_state)
def init_kernel( init_params, num_warmup, *, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True, model_args=(), model_kwargs=None, rng_key=random.PRNGKey(0), ): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param inverse_mass_matrix: Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. For a potential_fn with general JAX pytree parameters, the order of entries of the mass matrix is the order of the flattened version of pytree parameters obtained with `jax.tree_flatten`, which is a bit ambiguous (see more at https://jax.readthedocs.io/en/latest/pytrees.html). If `model` is not None, here we can specify a structured block mass matrix as a dictionary, where keys are tuple of site names and values are the corresponding block of the mass matrix. For more information about structured mass matrix, see `dense_mass` argument. :type inverse_mass_matrix: numpy.ndarray or dict :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param dense_mass: This flag controls whether mass matrix is dense (i.e. full-rank) or diagonal (defaults to ``dense_mass=False``). To specify a structured mass matrix, users can provide a list of tuples of site names. Each tuple represents a block in the joint mass matrix. For example, assuming that the model has latent variables "x", "y", "z" (where each variable can be multi-dimensional), possible specifications and corresponding mass matrix structures are as follows: + dense_mass=[("x", "y")]: use a dense mass matrix for the joint (x, y) and a diagonal mass matrix for z + dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass matrix for the joint (x, y, z) + dense_mass=[("x", "y", "z")] (equivalent to full_mass=True): use a dense mass matrix for the joint (x, y, z) + dense_mass=[("x",), ("y",), ("z")]: use dense mass matrices for each of x, y, and z (i.e. block-diagonal with 3 blocks) :type dense_mass: bool or list :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Defaults to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of integers `(d1, d2)`, where `d1` is the max tree depth during warmup phase and `d2` is the max tree depth during post warmup phase. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. :param bool regularize_mass_matrix: whether or not to regularize the estimated mass matrix for numerical stability during warmup phase. Defaults to True. This flag does not take effect if ``adapt_mass_matrix == False``. :param tuple model_args: Model arguments if `potential_fn_gen` is specified. :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type(step_size, jnp.result_type(float)) if trajectory_length is not None: trajectory_length = lax.convert_element_type( trajectory_length, jnp.result_type(float)) nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad forward_mode_ad = forward_mode_differentiation wa_steps = num_warmup max_treedepth = (max_tree_depth if isinstance(max_tree_depth, tuple) else (max_tree_depth, max_tree_depth)) if isinstance(init_params, ParamInfo): z, pe, z_grad = init_params else: z, pe, z_grad = init_params, None, None pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError( "Only one of `potential_fn` or `potential_fn_gen` must be provided." ) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) find_reasonable_ss = None if find_heuristic_step_size: find_reasonable_ss = partial(find_reasonable_step_size, pe_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss, regularize_mass_matrix=regularize_mass_matrix, ) rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3) z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad) wa_state = wa_init(z_info, rng_key_wa, step_size, inverse_mass_matrix=inverse_mass_matrix) r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad) vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = vv_state.potential_energy + kinetic_fn( wa_state.inverse_mass_matrix, vv_state.r) zero_int = jnp.array(0, dtype=jnp.result_type(int)) hmc_state = HMCState( zero_int, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, None, trajectory_length, zero_int, jnp.zeros(()), jnp.zeros(()), jnp.array(False), wa_state, rng_key_hmc, ) return device_put(hmc_state)
def init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, run_warmup=True, progbar=True, rng_key=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, `init_kernel` returns an initial :data:`~numpyro.infer.mcmc.HMCState` that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation. :param bool progbar: Whether to enable progress bar updates. Defaults to ``True``. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type( step_size, xla_bridge.canonicalize_dtype(np.float64)) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_key_hmc, rng_key_wa = random.split(rng_key) wa_state = wa_init(z, rng_key_wa, step_size, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng_key) vv_state = vv_init(z, r) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state, rng_key_hmc) # TODO: Remove; this should be the responsibility of the MCMC class. if run_warmup and num_warmup > 0: # JIT if progress bar updates not required if not progbar: hmc_state = fori_loop(0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state) else: with tqdm.trange(num_warmup, desc='warmup') as t: for i in t: hmc_state = jit(sample_kernel)(hmc_state) t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False) return hmc_state
def init_kernel(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, rng_key=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type( step_size, xla_bridge.canonicalize_dtype(np.float64)) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_key_hmc, rng_key_wa = random.split(rng_key) wa_state = wa_init(z, rng_key_wa, step_size, inverse_mass_matrix=inverse_mass_matrix, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng_key) vv_state = vv_init(z, r) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state, rng_key_hmc) return hmc_state