parameters_list=ld_parameters, parameters_times=ld_time, metric_functions = metric_functions, sample_functions = sample_functions, ) ld_evaluator.evaluate(40, tqdm=tqdm) # Plot Traces, Metrics, and Compare from sgmcmc_ssm.plotting_utils import ( plot_trace_plot, plot_metrics, compare_metrics, ) plot_trace_plot(sgld_evaluator) plot_metrics(sgld_evaluator) plot_trace_plot(ld_evaluator) plot_metrics(ld_evaluator) compare_metrics(dict( SGLD=sgld_evaluator, LD=ld_evaluator, ), x='time', ) #
sampler, parameters_list=half_average_parameters_list(sgrld_parameters), parameters_times=sgrld_time, metric_functions=metric_functions, sample_functions=sample_functions, ) sgrld_evaluator.evaluate(16, tqdm=tqdm) # Plot Traces, Metrics, and Compare from sgmcmc_ssm.plotting_utils import ( plot_trace_plot, plot_metrics, compare_metrics, ) plot_trace_plot(gibbs_evaluator) plot_metrics(gibbs_evaluator) plot_trace_plot(sgrld_evaluator) plot_metrics(sgrld_evaluator) compare_metrics( dict( Gibbs=gibbs_evaluator, SGRLD=sgrld_evaluator, ), x='time', ) # EOF
keys = my_evaluators.keys() for step in tqdm(range(1000)): for ii, key in enumerate(keys): my_evaluators[key].evaluate_sampler_step( *sampler_steps[key.split("_")[0]]) if (step % 25) == 0: logging.info("============= CHECKPOINT ================") if not os.path.isdir(path_to_save): os.makedirs(path_to_save) joblib.dump( { key: evaluator.get_state() for key, evaluator in my_evaluators.items() }, os.path.join(path_to_save, "slds_demo.p")) g = compare_metrics(my_evaluators) g.savefig(os.path.join(path_to_save, "metrics_compare.png")) if step > 50: g = compare_metrics(my_evaluators, full_trace=False) g.savefig(os.path.join(path_to_save, "metrics_compare_zoom.png")) for key in my_evaluators.keys(): sampler = my_evaluators[key].sampler fig, axes = plot_trace_plot( my_evaluators[key], single_variables=['C', 'LRinv', 'R', 'Rinv']) fig.suptitle(key) fig.savefig(os.path.join(path_to_save, "{0}_trace.png".format(key))) plt.close('all')
evaluators = {} from sgmcmc_ssm.evaluator import half_average_parameters_list evaluators['Gibbs'] = OfflineEvaluator( sampler, parameters_list=half_average_parameters_list(gibbs_parameters), parameters_times=gibbs_time, metric_functions=metric_functions, sample_functions=sample_functions, ) evaluators['SGRLD No Buffer'] = OfflineEvaluator( sampler, parameters_list=half_average_parameters_list(nobuffer_parameters), parameters_times=nobuffer_time, metric_functions=metric_functions, sample_functions=sample_functions, ) evaluators['SGRLD Buffer'] = OfflineEvaluator( sampler, parameters_list=half_average_parameters_list(buffer_parameters), parameters_times=buffer_time, metric_functions=metric_functions, sample_functions=sample_functions, ) for evaluator in tqdm(evaluators.values()): evaluator.evaluate(40, tqdm=tqdm) # Plot Results from sgmcmc_ssm.plotting_utils import compare_metrics compare_metrics(evaluators, x='time')