parameters_list=ld_parameters,
        parameters_times=ld_time,
        metric_functions = metric_functions,
        sample_functions = sample_functions,
        )
ld_evaluator.evaluate(40, tqdm=tqdm)


# Plot Traces, Metrics, and Compare
from sgmcmc_ssm.plotting_utils import (
        plot_trace_plot,
        plot_metrics,
        compare_metrics,
        )

plot_trace_plot(sgld_evaluator)
plot_metrics(sgld_evaluator)

plot_trace_plot(ld_evaluator)
plot_metrics(ld_evaluator)

compare_metrics(dict(
    SGLD=sgld_evaluator,
    LD=ld_evaluator,
    ),
    x='time',
    )
#


    sampler,
    parameters_list=half_average_parameters_list(sgrld_parameters),
    parameters_times=sgrld_time,
    metric_functions=metric_functions,
    sample_functions=sample_functions,
)
sgrld_evaluator.evaluate(16, tqdm=tqdm)

# Plot Traces, Metrics, and Compare
from sgmcmc_ssm.plotting_utils import (
    plot_trace_plot,
    plot_metrics,
    compare_metrics,
)

plot_trace_plot(gibbs_evaluator)
plot_metrics(gibbs_evaluator)

plot_trace_plot(sgrld_evaluator)
plot_metrics(sgrld_evaluator)

compare_metrics(
    dict(
        Gibbs=gibbs_evaluator,
        SGRLD=sgrld_evaluator,
    ),
    x='time',
)

# EOF
Beispiel #3
0
keys = my_evaluators.keys()
for step in tqdm(range(1000)):
    for ii, key in enumerate(keys):
        my_evaluators[key].evaluate_sampler_step(
            *sampler_steps[key.split("_")[0]])

    if (step % 25) == 0:
        logging.info("============= CHECKPOINT ================")
        if not os.path.isdir(path_to_save):
            os.makedirs(path_to_save)
        joblib.dump(
            {
                key: evaluator.get_state()
                for key, evaluator in my_evaluators.items()
            }, os.path.join(path_to_save, "slds_demo.p"))
        g = compare_metrics(my_evaluators)
        g.savefig(os.path.join(path_to_save, "metrics_compare.png"))
        if step > 50:
            g = compare_metrics(my_evaluators, full_trace=False)
            g.savefig(os.path.join(path_to_save, "metrics_compare_zoom.png"))

        for key in my_evaluators.keys():
            sampler = my_evaluators[key].sampler
            fig, axes = plot_trace_plot(
                my_evaluators[key],
                single_variables=['C', 'LRinv', 'R', 'Rinv'])
            fig.suptitle(key)
            fig.savefig(os.path.join(path_to_save,
                                     "{0}_trace.png".format(key)))
        plt.close('all')
evaluators = {}
from sgmcmc_ssm.evaluator import half_average_parameters_list
evaluators['Gibbs'] = OfflineEvaluator(
    sampler,
    parameters_list=half_average_parameters_list(gibbs_parameters),
    parameters_times=gibbs_time,
    metric_functions=metric_functions,
    sample_functions=sample_functions,
)
evaluators['SGRLD No Buffer'] = OfflineEvaluator(
    sampler,
    parameters_list=half_average_parameters_list(nobuffer_parameters),
    parameters_times=nobuffer_time,
    metric_functions=metric_functions,
    sample_functions=sample_functions,
)
evaluators['SGRLD Buffer'] = OfflineEvaluator(
    sampler,
    parameters_list=half_average_parameters_list(buffer_parameters),
    parameters_times=buffer_time,
    metric_functions=metric_functions,
    sample_functions=sample_functions,
)

for evaluator in tqdm(evaluators.values()):
    evaluator.evaluate(40, tqdm=tqdm)

# Plot Results
from sgmcmc_ssm.plotting_utils import compare_metrics
compare_metrics(evaluators, x='time')