示例#1
0
 def run_hmc():
   return mcmc.sample_chain(
       num_results=num_results,
       current_state=initial_state,
       num_burnin_steps=num_warmup_steps,
       kernel=mcmc.DualAveragingStepSizeAdaptation(
           inner_kernel=mcmc.TransformedTransitionKernel(
               inner_kernel=mcmc.HamiltonianMonteCarlo(
                   target_log_prob_fn=target_log_prob_fn,
                   step_size=initial_step_size,
                   num_leapfrog_steps=num_leapfrog_steps,
                   state_gradients_are_stopped=True),
               bijector=[param.bijector for param in model.parameters]),
           num_adaptation_steps=int(num_warmup_steps * 0.8)),
       seed=seed())
示例#2
0
def hmc(target, model_config, step_size_init, initial_states, reparam):
    """Runs HMC to sample from the given target distribution."""
    if reparam == 'CP':
        to_centered = lambda x: x
    elif reparam == 'NCP':
        to_centered = model_config.to_centered
    else:
        to_centered = model_config.make_to_centered(**reparam)

    model_config = model_config._replace(to_centered=to_centered)

    initial_states = list(initial_states)  # Variational samples.
    vectorized_target = vectorize_log_joint_fn(target)

    per_chain_initial_step_sizes = [
        np.array(step_size_init[i] * np.ones(initial_states[i].shape) /
                 (float(FLAGS.num_leapfrog_steps) / 4.)**2).astype(np.float32)
        for i in range(len(step_size_init))
    ]

    inner_kernel = mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=vectorized_target,
        step_size=per_chain_initial_step_sizes,
        state_gradients_are_stopped=True,
        num_leapfrog_steps=FLAGS.num_leapfrog_steps)

    kernel = mcmc.DualAveragingStepSizeAdaptation(
        inner_kernel=inner_kernel,
        num_adaptation_steps=FLAGS.num_adaptation_steps)

    def do_sampling():
        return mcmc.sample_chain(num_results=FLAGS.num_samples,
                                 num_burnin_steps=FLAGS.num_burnin_steps,
                                 current_state=initial_states,
                                 kernel=kernel,
                                 num_steps_between_results=1)

    states_orig, kernel_results = tf.xla.experimental.compile(do_sampling)

    states_transformed = tf.xla.experimental.compile(
        lambda states: transform_mcmc_states(states, to_centered),
        [states_orig])
    ess = tfp.mcmc.effective_sample_size(states_transformed)

    return states_orig, kernel_results, states_transformed, ess