Example #1
0
 def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng_key):
     binary_tree = build_tree(vv_update,
                              kinetic_fn,
                              vv_state,
                              inverse_mass_matrix,
                              step_size,
                              rng_key,
                              max_delta_energy=max_delta_energy,
                              max_tree_depth=max_treedepth)
     accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
     num_steps = binary_tree.num_proposals
     vv_state = IntegratorState(z=binary_tree.z_proposal,
                                r=vv_state.r,
                                potential_energy=binary_tree.z_proposal_pe,
                                z_grad=binary_tree.z_proposal_grad)
     return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
Example #2
0
File: hmc.py Project: gully/numpyro
    def _nuts_next(step_size, inverse_mass_matrix, vv_state,
                   model_args, model_kwargs, rng_key):
        if potential_fn_gen:
            nonlocal vv_update
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
            _, vv_update = velocity_verlet(pe_fn, kinetic_fn)

        binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
                                 inverse_mass_matrix, step_size, rng_key,
                                 max_delta_energy=max_delta_energy,
                                 max_tree_depth=max_treedepth)
        accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
        num_steps = binary_tree.num_proposals
        vv_state = IntegratorState(z=binary_tree.z_proposal,
                                   r=vv_state.r,
                                   potential_energy=binary_tree.z_proposal_pe,
                                   z_grad=binary_tree.z_proposal_grad)
        return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
Example #3
0
 def fn(vv_state):
     tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix,
                       step_size, rng_key)
     return tree