예제 #1
0
    def sample(n_samples, init_state, thin=0, previous_results=None):
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

            gibbs_schema = GibbsKernel(
                target_log_prob_fn=joint_log_prob,
                kernel_list=[
                    (0, make_blk0_kernel(init_state[0].shape, "block0")),
                    (1, make_blk1_kernel(init_state[1].shape, "block1")),
                    (2, make_event_multiscan_kernel),
                ],
                name="gibbs0",
            )

            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
                num_steps_between_results=thin,
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
            )

            return samples, results, final_results
예제 #2
0
def _fast_adapt_window(
    num_draws,
    joint_log_prob_fn,
    initial_position,
    hmc_kernel_kwargs,
    dual_averaging_kwargs,
    event_kernel_kwargs,
    trace_fn=None,
    seed=None,
):
    """
    In the fast adaptation window, we use the
    `DualAveragingStepSizeAdaptation` kernel
    to wrap an HMC kernel.

    :param num_draws: Number of MCMC draws in window
    :param joint_log_prob_fn: joint log posterior function
    :param initial_position: initial state of the Markov chain
    :param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kernel keywords args
    :param dual_averaging_kwargs: `DualAveragingStepSizeAdaptation` keyword args
    :param event_kernel_kwargs: EventTimesMH and Occult kernel args
    :param trace_fn: function to trace kernel results
    :param seed: optional random seed.
    :returns: draws, kernel results, the adapted HMC step size, and variance
              accumulator
    """
    kernel_list = [
        (
            0,
            make_hmc_fast_adapt_kernel(
                hmc_kernel_kwargs=hmc_kernel_kwargs,
                dual_averaging_kwargs=dual_averaging_kwargs,
            ),
        ),
        (1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)),
    ]

    kernel = GibbsKernel(
        target_log_prob_fn=joint_log_prob_fn,
        kernel_list=kernel_list,
        name="fast_adapt",
    )

    draws, trace, fkr = tfp.mcmc.sample_chain(
        num_draws,
        initial_position,
        kernel=kernel,
        return_final_kernel_results=True,
        trace_fn=trace_fn,
        seed=seed,
    )

    weighted_running_variance = get_weighted_running_variance(draws[0])
    step_size = unnest.get_outermost(fkr.inner_results[0], "step_size")
    return draws, trace, step_size, weighted_running_variance
예제 #3
0
 def make_event_multiscan_kernel(target_log_prob_fn, _):
     return MultiScanKernel(
         config["num_event_time_updates"],
         GibbsKernel(
             target_log_prob_fn=target_log_prob_fn,
             kernel_list=[
                 (0, make_partially_observed_step(0, None, 1, "se_events")),
                 (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                 (0, make_occults_step(None, 0, 1, "se_occults")),
                 (0, make_occults_step(0, 1, 2, "ei_occults")),
             ],
             name="gibbs1",
         ),
     )
예제 #4
0
 def make_kernel_fn(target_log_prob_fn, _):
     return MultiScanKernel(
         config["num_event_time_updates"],
         GibbsKernel(
             target_log_prob_fn=target_log_prob_fn,
             kernel_list=[
                 (
                     0,
                     make_partially_observed_step(initial_state, 0, None, 1,
                                                  config, "se_events"),
                 ),
                 (
                     0,
                     make_partially_observed_step(initial_state, 1, 0, 2,
                                                  config, "ei_events"),
                 ),
                 (
                     0,
                     make_occults_step(
                         initial_state,
                         t_range,
                         None,
                         0,
                         1,
                         config,
                         "se_occults",
                     ),
                 ),
                 (
                     0,
                     make_occults_step(
                         initial_state,
                         t_range,
                         0,
                         1,
                         2,
                         config,
                         "ei_occults",
                     ),
                 ),
             ],
             name="gibbs1",
         ),
     )
예제 #5
0
def _fixed_window(
    num_draws,
    joint_log_prob_fn,
    initial_position,
    hmc_kernel_kwargs,
    event_kernel_kwargs,
    trace_fn=None,
    seed=None,
):
    """Fixed step size and mass matrix HMC.

    :param num_draws: number of MCMC iterations
    :param joint_log_prob_fn: joint log posterior density function
    :param initial_position: initial Markov chain state
    :param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kwargs
    :param event_kernel_kwargs: Event and Occults kwargs
    :param trace_fn: results trace function
    :param seed: optional random seed
    :returns: (draws, trace, final_kernel_results)
    """
    kernel_list = [
        (0, make_hmc_base_kernel(**hmc_kernel_kwargs)),
        (1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)),
    ]

    kernel = GibbsKernel(
        target_log_prob_fn=joint_log_prob_fn,
        kernel_list=kernel_list,
        name="fixed",
    )

    return tfp.mcmc.sample_chain(
        num_draws,
        current_state=initial_position,
        kernel=kernel,
        return_final_kernel_results=True,
        trace_fn=trace_fn,
        seed=seed,
    )
예제 #6
0
def _slow_adapt_window(
    num_draws,
    joint_log_prob_fn,
    initial_position,
    initial_running_variance,
    hmc_kernel_kwargs,
    dual_averaging_kwargs,
    event_kernel_kwargs,
    trace_fn=None,
    seed=None,
):
    """In the slow adaptation phase, we adapt the HMC
    step size and mass matrix together.

    :param num_draws: number of MCMC iterations
    :param joint_log_prob_fn: the joint posterior density function
    :param initial_position: initial Markov chain state
    :param initial_running_variance: initial variance accumulator
    :param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kernel kwargs
    :param dual_averaging_kwargs: `DualAveragingStepSizeAdaptation` kwargs
    :param event_kernel_kwargs: EventTimesMH and Occults kwargs
    :param trace_fn: result trace function
    :param seed: optional random seed
    :returns: draws, kernel results, adapted step size, the variance accumulator,
              and "learned" momentum distribution for the HMC.
    """
    kernel_list = [
        (
            0,
            make_hmc_slow_adapt_kernel(
                initial_running_variance,
                hmc_kernel_kwargs,
                dual_averaging_kwargs,
            ),
        ),
        (1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)),
    ]

    kernel = GibbsKernel(
        target_log_prob_fn=joint_log_prob_fn,
        kernel_list=kernel_list,
        name="slow_adapt",
    )

    draws, trace, fkr = tfp.mcmc.sample_chain(
        num_draws,
        current_state=initial_position,
        kernel=kernel,
        return_final_kernel_results=True,
        trace_fn=trace_fn,
    )

    step_size = unnest.get_outermost(fkr.inner_results[0], "step_size")
    momentum_distribution = unnest.get_outermost(
        fkr.inner_results[0], "momentum_distribution"
    )

    weighted_running_variance = get_weighted_running_variance(draws[0])

    return (
        draws,
        trace,
        step_size,
        weighted_running_variance,
        momentum_distribution,
    )