parameters_list=ld_parameters,
        parameters_times=ld_time,
        metric_functions = metric_functions,
        sample_functions = sample_functions,
        )
ld_evaluator.evaluate(40, tqdm=tqdm)


# Plot Traces, Metrics, and Compare
from sgmcmc_ssm.plotting_utils import (
        plot_trace_plot,
        plot_metrics,
        compare_metrics,
        )

plot_trace_plot(sgld_evaluator)
plot_metrics(sgld_evaluator)

plot_trace_plot(ld_evaluator)
plot_metrics(ld_evaluator)

compare_metrics(dict(
    SGLD=sgld_evaluator,
    LD=ld_evaluator,
    ),
    x='time',
    )
#


print(evaluator.samples)

## Run a few ADA_GRAD sampler steps
for _ in range(10):
    evaluator.evaluate_sampler_step(
            ['step_adagrad', 'project_parameters'],
            [dict(epsilon=0.1, subsequence_length=10, buffer_length=5), {}],
            )
print(evaluator.metrics)
print(evaluator.samples)


## Run a few SGRLD Steps
for _ in range(10):
    evaluator.evaluate_sampler_step(
            ['sample_sgrld', 'project_parameters'],
            [dict(preconditioner=preconditioner,
                epsilon=0.1, subsequence_length=10, buffer_length=5), {}],
            )
print(evaluator.metrics)
print(evaluator.samples)

from sgmcmc_ssm.plotting_utils import plot_metrics, plot_trace_plot
plot_metrics(evaluator)
plot_trace_plot(evaluator, single_variables=['C', 'LRinv', 'R', 'Rinv'])





# Offline Evaluation
from sgmcmc_ssm.evaluator import OfflineEvaluator
evaluator = OfflineEvaluator(
    sampler=sampler,
    parameters_list=parameters_list,
    metric_functions=metric_functions,
    sample_functions=sample_functions,
)
evaluator.evaluate(num_to_eval=40, tqdm=tqdm)
print(evaluator.get_metrics())
print(evaluator.get_samples())

# Plot Results
from sgmcmc_ssm.plotting_utils import plot_metrics, plot_trace_plot
plot_metrics(evaluator, burnin=10)
plot_trace_plot(evaluator, burnin=10)

###############################################################################
# Compare Multiple Inference Methods
###############################################################################
init = sampler.prior_init()
sampler = LGSSMSampler(n=2, m=2, observations=data['observations'])

max_time = 60
## Fit Gibbs saving sample every second
gibbs_parameters, gibbs_time = sampler.fit_timed(
    iter_type='Gibbs',
    init_parameters=init,
    max_time=max_time,
    min_save_time=1,
    sampler,
    parameters_list=half_average_parameters_list(sgrld_parameters),
    parameters_times=sgrld_time,
    metric_functions=metric_functions,
    sample_functions=sample_functions,
)
sgrld_evaluator.evaluate(16, tqdm=tqdm)

# Plot Traces, Metrics, and Compare
from sgmcmc_ssm.plotting_utils import (
    plot_trace_plot,
    plot_metrics,
    compare_metrics,
)

plot_trace_plot(gibbs_evaluator)
plot_metrics(gibbs_evaluator)

plot_trace_plot(sgrld_evaluator)
plot_metrics(sgrld_evaluator)

compare_metrics(
    dict(
        Gibbs=gibbs_evaluator,
        SGRLD=sgrld_evaluator,
    ),
    x='time',
)

# EOF
Example #5
0
metric_functions = [
    noisy_predictive_logjoint_loglike_metric(num_steps_ahead=3,
                                             observations=Y_test,
                                             tqdm=tqdm),
]
sample_functions = sample_function_parameters(['pi', 'logit_pi', 'mu', 'R'])

# Evaluate SGRLD samples
sgrld_evaluator = OfflineEvaluator(
    sampler,
    parameters_list=half_average_parameters_list(sgrld_parameters),
    parameters_times=sgrld_time,
    metric_functions=metric_functions,
    sample_functions=sample_functions,
)
sgrld_evaluator.evaluate(16, tqdm=tqdm)

# Plot Traces, Metrics, and Compare
from sgmcmc_ssm.plotting_utils import (
    plot_trace_plot,
    plot_metrics,
    compare_metrics,
)

plot_trace_plot(sgrld_evaluator)
plot_metrics(sgrld_evaluator)

plot_metrics(sgrld_evaluator, x='time')

# EOF