def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng_key): binary_tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy=max_delta_energy, max_tree_depth=max_treedepth) accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals num_steps = binary_tree.num_proposals vv_state = IntegratorState(z=binary_tree.z_proposal, r=vv_state.r, potential_energy=binary_tree.z_proposal_pe, z_grad=binary_tree.z_proposal_grad) return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
def _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key): if potential_fn_gen: nonlocal vv_update pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn) binary_tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy=max_delta_energy, max_tree_depth=max_treedepth) accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals num_steps = binary_tree.num_proposals vv_state = IntegratorState(z=binary_tree.z_proposal, r=vv_state.r, potential_energy=binary_tree.z_proposal_pe, z_grad=binary_tree.z_proposal_grad) return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
def fn(vv_state): tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng_key) return tree