Ejemplo n.º 1
0
Archivo: hmc.py Proyecto: gully/numpyro
    def sample_kernel(hmc_state, model_args=(), model_kwargs=None):
        """
        Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted)
        step size and return a new :data:`~numpyro.infer.mcmc.HMCState`.

        :param hmc_state: Current sample (and associated state).
        :param tuple model_args: Model arguments if `potential_fn_gen` is specified.
        :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
        :return: new proposed :data:`~numpyro.infer.mcmc.HMCState` from simulating
            Hamiltonian dynamics given existing state.

        """
        model_kwargs = {} if model_kwargs is None else model_kwargs
        rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3)
        r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum)
        vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
        vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size,
                                                                    hmc_state.adapt_state.inverse_mass_matrix,
                                                                    vv_state,
                                                                    model_args,
                                                                    model_kwargs,
                                                                    rng_key_transition)
        # not update adapt_state after warmup phase
        adapt_state = cond(hmc_state.i < wa_steps,
                           (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state),
                           lambda args: wa_update(*args),
                           hmc_state.adapt_state,
                           identity)

        itr = hmc_state.i + 1
        n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n

        return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
                        accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
Ejemplo n.º 2
0
 def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng_key):
     binary_tree = build_tree(vv_update,
                              kinetic_fn,
                              vv_state,
                              inverse_mass_matrix,
                              step_size,
                              rng_key,
                              max_delta_energy=max_delta_energy,
                              max_tree_depth=max_treedepth)
     accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
     num_steps = binary_tree.num_proposals
     vv_state = IntegratorState(z=binary_tree.z_proposal,
                                r=vv_state.r,
                                potential_energy=binary_tree.z_proposal_pe,
                                z_grad=binary_tree.z_proposal_grad)
     return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
Ejemplo n.º 3
0
Archivo: hmc.py Proyecto: gully/numpyro
    def _nuts_next(step_size, inverse_mass_matrix, vv_state,
                   model_args, model_kwargs, rng_key):
        if potential_fn_gen:
            nonlocal vv_update
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
            _, vv_update = velocity_verlet(pe_fn, kinetic_fn)

        binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
                                 inverse_mass_matrix, step_size, rng_key,
                                 max_delta_energy=max_delta_energy,
                                 max_tree_depth=max_treedepth)
        accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
        num_steps = binary_tree.num_proposals
        vv_state = IntegratorState(z=binary_tree.z_proposal,
                                   r=vv_state.r,
                                   potential_energy=binary_tree.z_proposal_pe,
                                   z_grad=binary_tree.z_proposal_grad)
        return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
Ejemplo n.º 4
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    inverse_mass_matrix=None,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2 * math.pi,
                    max_tree_depth=10,
                    find_heuristic_step_size=False,
                    model_args=(),
                    model_kwargs=None,
                    rng_key=random.PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix.
            This may be adapted during warmup if adapt_mass_matrix = True.
            If no value is specified, then it is initialized to the identity matrix.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool find_heuristic_step_size: whether to a heuristic function to adjust the
            step size at the beginning of each adaptation window. Defaults to False.
        :param tuple model_args: Model arguments if `potential_fn_gen` is specified.
        :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.

        """
        step_size = lax.convert_element_type(step_size,
                                             canonicalize_dtype(jnp.float64))
        nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps
        wa_steps = num_warmup
        trajectory_len = trajectory_length
        max_treedepth = max_tree_depth
        if isinstance(init_params, ParamInfo):
            z, pe, z_grad = init_params
        else:
            z, pe, z_grad = init_params, None, None
        pe_fn = potential_fn
        if potential_fn_gen:
            if pe_fn is not None:
                raise ValueError(
                    'Only one of `potential_fn` or `potential_fn_gen` must be provided.'
                )
            else:
                kwargs = {} if model_kwargs is None else model_kwargs
                pe_fn = potential_fn_gen(*model_args, **kwargs)

        find_reasonable_ss = None
        if find_heuristic_step_size:
            find_reasonable_ss = partial(find_reasonable_step_size, pe_fn,
                                         kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss)

        rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3)
        z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad)
        wa_state = wa_init(z_info,
                           rng_key_wa,
                           step_size,
                           inverse_mass_matrix=inverse_mass_matrix,
                           mass_matrix_size=jnp.size(ravel_pytree(z)[0]))
        r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
        vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn)
        vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
        energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, energy, 0, 0., 0.,
                             False, wa_state, rng_key_hmc)
        return device_put(hmc_state)
Ejemplo n.º 5
0
    def init_kernel(
            init_params,
            num_warmup,
            *,
            step_size=1.0,
            inverse_mass_matrix=None,
            adapt_step_size=True,
            adapt_mass_matrix=True,
            dense_mass=False,
            target_accept_prob=0.8,
            trajectory_length=2 * math.pi,
            max_tree_depth=10,
            find_heuristic_step_size=False,
            forward_mode_differentiation=False,
            regularize_mass_matrix=True,
            model_args=(),
            model_kwargs=None,
            rng_key=random.PRNGKey(0),
    ):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param inverse_mass_matrix: Initial value for inverse mass matrix.
            This may be adapted during warmup if adapt_mass_matrix = True.
            If no value is specified, then it is initialized to the identity matrix.
            For a potential_fn with general JAX pytree parameters, the order of entries
            of the mass matrix is the order of the flattened version of pytree parameters
            obtained with `jax.tree_flatten`, which is a bit ambiguous (see more at
            https://jax.readthedocs.io/en/latest/pytrees.html). If `model` is not None,
            here we can specify a structured block mass matrix as a dictionary, where
            keys are tuple of site names and values are the corresponding block of the
            mass matrix.
            For more information about structured mass matrix, see `dense_mass` argument.
        :type inverse_mass_matrix: numpy.ndarray or dict
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param dense_mass:  This flag controls whether mass matrix is dense (i.e. full-rank) or
            diagonal (defaults to ``dense_mass=False``). To specify a structured mass matrix,
            users can provide a list of tuples of site names. Each tuple represents
            a block in the joint mass matrix. For example, assuming that the model
            has latent variables "x", "y", "z" (where each variable can be multi-dimensional),
            possible specifications and corresponding mass matrix structures are as follows:

                + dense_mass=[("x", "y")]: use a dense mass matrix for the joint
                  (x, y) and a diagonal mass matrix for z
                + dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass
                  matrix for the joint (x, y, z)
                + dense_mass=[("x", "y", "z")] (equivalent to full_mass=True):
                  use a dense mass matrix for the joint (x, y, z)
                + dense_mass=[("x",), ("y",), ("z")]: use dense mass matrices for
                  each of x, y, and z (i.e. block-diagonal with 3 blocks)

        :type dense_mass: bool or list
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Defaults to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of
            integers `(d1, d2)`, where `d1` is the max tree depth during warmup phase and
            `d2` is the max tree depth during post warmup phase.
        :param bool find_heuristic_step_size: whether to a heuristic function to adjust the
            step size at the beginning of each adaptation window. Defaults to False.
        :param bool regularize_mass_matrix: whether or not to regularize the estimated mass
            matrix for numerical stability during warmup phase. Defaults to True. This flag
            does not take effect if ``adapt_mass_matrix == False``.
        :param tuple model_args: Model arguments if `potential_fn_gen` is specified.
        :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.

        """
        step_size = lax.convert_element_type(step_size, jnp.result_type(float))
        if trajectory_length is not None:
            trajectory_length = lax.convert_element_type(
                trajectory_length, jnp.result_type(float))
        nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad
        forward_mode_ad = forward_mode_differentiation
        wa_steps = num_warmup
        max_treedepth = (max_tree_depth if isinstance(max_tree_depth, tuple)
                         else (max_tree_depth, max_tree_depth))
        if isinstance(init_params, ParamInfo):
            z, pe, z_grad = init_params
        else:
            z, pe, z_grad = init_params, None, None
        pe_fn = potential_fn
        if potential_fn_gen:
            if pe_fn is not None:
                raise ValueError(
                    "Only one of `potential_fn` or `potential_fn_gen` must be provided."
                )
            else:
                kwargs = {} if model_kwargs is None else model_kwargs
                pe_fn = potential_fn_gen(*model_args, **kwargs)

        find_reasonable_ss = None
        if find_heuristic_step_size:
            find_reasonable_ss = partial(find_reasonable_step_size, pe_fn,
                                         kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss,
            regularize_mass_matrix=regularize_mass_matrix,
        )

        rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3)
        z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad)
        wa_state = wa_init(z_info,
                           rng_key_wa,
                           step_size,
                           inverse_mass_matrix=inverse_mass_matrix)
        r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
        vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn,
                                             forward_mode_ad)
        vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
        energy = vv_state.potential_energy + kinetic_fn(
            wa_state.inverse_mass_matrix, vv_state.r)
        zero_int = jnp.array(0, dtype=jnp.result_type(int))
        hmc_state = HMCState(
            zero_int,
            vv_state.z,
            vv_state.z_grad,
            vv_state.potential_energy,
            energy,
            None,
            trajectory_length,
            zero_int,
            jnp.zeros(()),
            jnp.zeros(()),
            jnp.array(False),
            wa_state,
            rng_key_hmc,
        )
        return device_put(hmc_state)