def run_hmc(): return mcmc.sample_chain( num_results=num_results, current_state=initial_state, num_burnin_steps=num_warmup_steps, kernel=mcmc.DualAveragingStepSizeAdaptation( inner_kernel=mcmc.TransformedTransitionKernel( inner_kernel=mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, step_size=initial_step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=True), bijector=[param.bijector for param in model.parameters]), num_adaptation_steps=int(num_warmup_steps * 0.8)), seed=seed())
def hmc(target, model_config, step_size_init, initial_states, reparam): """Runs HMC to sample from the given target distribution.""" if reparam == 'CP': to_centered = lambda x: x elif reparam == 'NCP': to_centered = model_config.to_centered else: to_centered = model_config.make_to_centered(**reparam) model_config = model_config._replace(to_centered=to_centered) initial_states = list(initial_states) # Variational samples. vectorized_target = vectorize_log_joint_fn(target) per_chain_initial_step_sizes = [ np.array(step_size_init[i] * np.ones(initial_states[i].shape) / (float(FLAGS.num_leapfrog_steps) / 4.)**2).astype(np.float32) for i in range(len(step_size_init)) ] inner_kernel = mcmc.HamiltonianMonteCarlo( target_log_prob_fn=vectorized_target, step_size=per_chain_initial_step_sizes, state_gradients_are_stopped=True, num_leapfrog_steps=FLAGS.num_leapfrog_steps) kernel = mcmc.DualAveragingStepSizeAdaptation( inner_kernel=inner_kernel, num_adaptation_steps=FLAGS.num_adaptation_steps) def do_sampling(): return mcmc.sample_chain(num_results=FLAGS.num_samples, num_burnin_steps=FLAGS.num_burnin_steps, current_state=initial_states, kernel=kernel, num_steps_between_results=1) states_orig, kernel_results = tf.xla.experimental.compile(do_sampling) states_transformed = tf.xla.experimental.compile( lambda states: transform_mcmc_states(states, to_centered), [states_orig]) ess = tfp.mcmc.effective_sample_size(states_transformed) return states_orig, kernel_results, states_transformed, ess