示例#1
0
文件: ac_fig.py 项目: rueberger/MJHMC
def plot_ac(distribution,
            control_params,
            mjhmc_params,
            lahmc_params,
            max_steps=3000,
            sample_steps=1,
            truncate=False,
            truncate_at=0.0,
            nuts=False,
            truncate_idx=None):
    """
    distribution is an instantiated distribution object
    runs the sampler for max steps and then truncates the output to autocorrelation 0.5
    throws an error if ac 0.5 is not reached
    """
    from mjhmc.samplers.markov_jump_hmc import MarkovJumpHMC, ControlHMC
    from mjhmc.misc.autocor import calculate_autocorrelation, autocorrelation
    from mjhmc.misc.nutshell import sample_nuts_to_df
    # TODO: bring up to speed of DF-less calc autocor
    plt.clf()
    print('Calculating AutoCorrelation for ControlHMC')
    control_ac = calculate_autocorrelation(ControlHMC,
                                           distribution,
                                           num_steps=max_steps,
                                           sample_steps=sample_steps,
                                           half_window=True,
                                           use_cached_var=True,
                                           **control_params)

    print('Calculating AutoCorrelation for MJHMC')
    mjhmc_ac = calculate_autocorrelation(MarkovJumpHMC,
                                         distribution,
                                         num_steps=max_steps,
                                         sample_steps=sample_steps,
                                         half_window=True,
                                         resample=False,
                                         use_cached_var=True,
                                         **mjhmc_params)

    # lahmc_ac = calculate_autocorrelation(LAHMC, distribution,
    #                                      num_steps=max_steps,
    #                                      sample_steps=sample_steps,
    #                                      half_window=True,
    #                                      **lahmc_params)

    if nuts:
        print('Calculating AutoCorrelation for NUTS')
        nuts_df = sample_nuts_to_df(distribution, 100000, n_burn_in=10000)
        nuts_ac = autocorrelation(nuts_df, half_window=True)

# find idx with autocorrelation closest to truncate_at
    if truncate:
        control_trunc = control_ac.loc[:, 'autocorrelation'] < truncate_at
        mjhmc_trunc = mjhmc_ac.loc[:, 'autocorrelation'] < truncate_at
        if truncate_idx is None:
            if nuts:
                nuts_trunc = nuts_ac.loc[:, 'autocorrelation'] < truncate_at
                trunc_idx = max(control_trunc[control_trunc].index[0],
                                mjhmc_trunc[mjhmc_trunc].index[0])
                # nuts_trunc[nuts_trunc].index[0])
                nuts_ac = nuts_ac.loc[:trunc_idx]
            else:
                trunc_idx = max(control_trunc[control_trunc].index[0],
                                mjhmc_trunc[mjhmc_trunc].index[0])
        else:
            trunc_idx = truncate_idx
        control_ac = control_ac.loc[:trunc_idx]
        mjhmc_ac = mjhmc_ac.loc[:trunc_idx]

    control_ac.index = control_ac['num grad']
    mjhmc_ac.index = control_ac['num grad']
    if nuts:
        nuts_ac.index = nuts_ac['num grad']
        nuts_ac['autocorrelation'].plot(label='NUTS')

    # control_ac['autocorrelation'].plot(label='Control HMC;\n {}'.format(str(control_params)))
    # mjhmc_ac['autocorrelation'].plot(label='Markov Jump HMC;\n {}'.format(str(mjhmc_params)))

    control_ac['autocorrelation'].plot(label='Control HMC')
    mjhmc_ac['autocorrelation'].plot(label='Markov Jump HMC')

    plt.xlabel("Gradient Evaluations")
    plt.ylabel("Autocorrelation")
    plt.title("{}D {}".format(distribution.ndims, type(distribution).__name__))
    plt.legend()
    plt.show()
    plt.savefig("{}_{}_dim_ac_{}_steps.pdf".format(
        type(distribution).__name__, distribution.ndims, max_steps))
    plt.show()
示例#2
0
文件: ac_fig.py 项目: lingerlyn/MJHMC
def plot_ac(distribution, control_params, mjhmc_params, lahmc_params, max_steps=3000,
            sample_steps=1, truncate=False, truncate_at=0.0, nuts=False, truncate_idx=None):
    """
    distribution is an instantiated distribution object
    runs the sampler for max steps and then truncates the output to autocorrelation 0.5
    throws an error if ac 0.5 is not reached
    """
    plt.clf()
    print('Calculating AutoCorrelation for ControlHMC')
    control_ac = calculate_autocorrelation(ControlHMC, distribution,
                                           num_steps=max_steps,
                                           sample_steps=sample_steps,
                                           half_window=True,
                                           use_cached_var=True,
                                           **control_params)

    print('Calculating AutoCorrelation for MJHMC')
    mjhmc_ac = calculate_autocorrelation(MarkovJumpHMC, distribution,
                                         num_steps=max_steps,
                                         sample_steps=sample_steps,
                                         half_window=True,
                                         resample=False,
                                         use_cached_var=True,
                                         **mjhmc_params)

    # lahmc_ac = calculate_autocorrelation(LAHMC, distribution,
    #                                      num_steps=max_steps,
    #                                      sample_steps=sample_steps,
    #                                      half_window=True,
    #                                      **lahmc_params)

    if nuts:
        print('Calculating AutoCorrelation for NUTS')
        nuts_df = sample_nuts_to_df(distribution, 100000, n_burn_in=10000)
        nuts_ac = autocorrelation(nuts_df, half_window=True)

   # find idx with autocorrelation closest to truncate_at
    if truncate:
        control_trunc = control_ac.loc[:, 'autocorrelation'] < truncate_at
        mjhmc_trunc = mjhmc_ac.loc[:, 'autocorrelation'] < truncate_at
        if truncate_idx is None:
            if nuts:
                nuts_trunc = nuts_ac.loc[:, 'autocorrelation'] < truncate_at
                trunc_idx = max(control_trunc[control_trunc].index[0],
                                mjhmc_trunc[mjhmc_trunc].index[0])
                                # nuts_trunc[nuts_trunc].index[0])
                nuts_ac = nuts_ac.loc[:trunc_idx]
            else:
                trunc_idx = max(control_trunc[control_trunc].index[0],
                                mjhmc_trunc[mjhmc_trunc].index[0])
        else:
            trunc_idx = truncate_idx
        control_ac = control_ac.loc[:trunc_idx]
        mjhmc_ac = mjhmc_ac.loc[:trunc_idx]





    control_ac.index = control_ac['num grad']
    mjhmc_ac.index = control_ac['num grad']
    if nuts:
        nuts_ac.index = nuts_ac['num grad']
        nuts_ac['autocorrelation'].plot(label='NUTS')


    # control_ac['autocorrelation'].plot(label='Control HMC;\n {}'.format(str(control_params)))
    # mjhmc_ac['autocorrelation'].plot(label='Markov Jump HMC;\n {}'.format(str(mjhmc_params)))

    control_ac['autocorrelation'].plot(label='Control HMC')
    mjhmc_ac['autocorrelation'].plot(label='Markov Jump HMC')

    plt.xlabel("Gradient Evaluations")
    plt.ylabel("Autocorrelation")
    plt.title("{}D {}".format(distribution.ndims, type(distribution).__name__))
    plt.legend()
    plt.show()
    plt.savefig("{}_{}_dim_ac_{}_steps.pdf".format(type(distribution).__name__, distribution.ndims, max_steps))
    plt.show()