def sample(n_samples, init_state, thin=0, previous_results=None): with tf.name_scope("main_mcmc_sample_loop"): init_state = init_state.copy() gibbs_schema = GibbsKernel( target_log_prob_fn=joint_log_prob, kernel_list=[ (0, make_blk0_kernel(init_state[0].shape, "block0")), (1, make_blk1_kernel(init_state[1].shape, "block1")), (2, make_event_multiscan_kernel), ], name="gibbs0", ) samples, results, final_results = tfp.mcmc.sample_chain( n_samples, init_state, kernel=gibbs_schema, num_steps_between_results=thin, previous_kernel_results=previous_results, return_final_kernel_results=True, trace_fn=trace_results_fn, ) return samples, results, final_results
def _fast_adapt_window( num_draws, joint_log_prob_fn, initial_position, hmc_kernel_kwargs, dual_averaging_kwargs, event_kernel_kwargs, trace_fn=None, seed=None, ): """ In the fast adaptation window, we use the `DualAveragingStepSizeAdaptation` kernel to wrap an HMC kernel. :param num_draws: Number of MCMC draws in window :param joint_log_prob_fn: joint log posterior function :param initial_position: initial state of the Markov chain :param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kernel keywords args :param dual_averaging_kwargs: `DualAveragingStepSizeAdaptation` keyword args :param event_kernel_kwargs: EventTimesMH and Occult kernel args :param trace_fn: function to trace kernel results :param seed: optional random seed. :returns: draws, kernel results, the adapted HMC step size, and variance accumulator """ kernel_list = [ ( 0, make_hmc_fast_adapt_kernel( hmc_kernel_kwargs=hmc_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, ), ), (1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)), ] kernel = GibbsKernel( target_log_prob_fn=joint_log_prob_fn, kernel_list=kernel_list, name="fast_adapt", ) draws, trace, fkr = tfp.mcmc.sample_chain( num_draws, initial_position, kernel=kernel, return_final_kernel_results=True, trace_fn=trace_fn, seed=seed, ) weighted_running_variance = get_weighted_running_variance(draws[0]) step_size = unnest.get_outermost(fkr.inner_results[0], "step_size") return draws, trace, step_size, weighted_running_variance
def make_event_multiscan_kernel(target_log_prob_fn, _): return MultiScanKernel( config["num_event_time_updates"], GibbsKernel( target_log_prob_fn=target_log_prob_fn, kernel_list=[ (0, make_partially_observed_step(0, None, 1, "se_events")), (0, make_partially_observed_step(1, 0, 2, "ei_events")), (0, make_occults_step(None, 0, 1, "se_occults")), (0, make_occults_step(0, 1, 2, "ei_occults")), ], name="gibbs1", ), )
def make_kernel_fn(target_log_prob_fn, _): return MultiScanKernel( config["num_event_time_updates"], GibbsKernel( target_log_prob_fn=target_log_prob_fn, kernel_list=[ ( 0, make_partially_observed_step(initial_state, 0, None, 1, config, "se_events"), ), ( 0, make_partially_observed_step(initial_state, 1, 0, 2, config, "ei_events"), ), ( 0, make_occults_step( initial_state, t_range, None, 0, 1, config, "se_occults", ), ), ( 0, make_occults_step( initial_state, t_range, 0, 1, 2, config, "ei_occults", ), ), ], name="gibbs1", ), )
def _fixed_window( num_draws, joint_log_prob_fn, initial_position, hmc_kernel_kwargs, event_kernel_kwargs, trace_fn=None, seed=None, ): """Fixed step size and mass matrix HMC. :param num_draws: number of MCMC iterations :param joint_log_prob_fn: joint log posterior density function :param initial_position: initial Markov chain state :param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kwargs :param event_kernel_kwargs: Event and Occults kwargs :param trace_fn: results trace function :param seed: optional random seed :returns: (draws, trace, final_kernel_results) """ kernel_list = [ (0, make_hmc_base_kernel(**hmc_kernel_kwargs)), (1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)), ] kernel = GibbsKernel( target_log_prob_fn=joint_log_prob_fn, kernel_list=kernel_list, name="fixed", ) return tfp.mcmc.sample_chain( num_draws, current_state=initial_position, kernel=kernel, return_final_kernel_results=True, trace_fn=trace_fn, seed=seed, )
def _slow_adapt_window( num_draws, joint_log_prob_fn, initial_position, initial_running_variance, hmc_kernel_kwargs, dual_averaging_kwargs, event_kernel_kwargs, trace_fn=None, seed=None, ): """In the slow adaptation phase, we adapt the HMC step size and mass matrix together. :param num_draws: number of MCMC iterations :param joint_log_prob_fn: the joint posterior density function :param initial_position: initial Markov chain state :param initial_running_variance: initial variance accumulator :param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kernel kwargs :param dual_averaging_kwargs: `DualAveragingStepSizeAdaptation` kwargs :param event_kernel_kwargs: EventTimesMH and Occults kwargs :param trace_fn: result trace function :param seed: optional random seed :returns: draws, kernel results, adapted step size, the variance accumulator, and "learned" momentum distribution for the HMC. """ kernel_list = [ ( 0, make_hmc_slow_adapt_kernel( initial_running_variance, hmc_kernel_kwargs, dual_averaging_kwargs, ), ), (1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)), ] kernel = GibbsKernel( target_log_prob_fn=joint_log_prob_fn, kernel_list=kernel_list, name="slow_adapt", ) draws, trace, fkr = tfp.mcmc.sample_chain( num_draws, current_state=initial_position, kernel=kernel, return_final_kernel_results=True, trace_fn=trace_fn, ) step_size = unnest.get_outermost(fkr.inner_results[0], "step_size") momentum_distribution = unnest.get_outermost( fkr.inner_results[0], "momentum_distribution" ) weighted_running_variance = get_weighted_running_variance(draws[0]) return ( draws, trace, step_size, weighted_running_variance, momentum_distribution, )