Exemplo n.º 1
0
    layer_units=layer_units,
    layer_activations=layer_activations,
    initial_unconstrained_scale=initial_unconstrained_scale,
    transform_unconstrained_scale_factor=transform_unconstrained_scale_factor,
    weight_prior=weight_prior,
    bias_prior=bias_prior,
    noise_scale_prior=noise_scale_prior,
    n_train=n_train,
    learning_rate=learning_rate,
    names=layer_names,
    seed=train_seed,
)

print("Done initializing")
assert check_posterior_equivalence(
    large_ensemble.networks[0], hmc_net, x_train, y_train, n_train
)

large_ensemble.fit(
    x_train=x_train,
    y_train=y_train,
    batch_size=batch_size,
    epochs=epochs,
    early_stop_callback=early_stop_callback,
    verbose=0,
)


# %%
large_ensemble.save(save_dir.joinpath(f"large_map_ensemble"))
plot_moment_matched_predictive_normal_distribution(
    x_plot=_x_plot,
    predictive_distribution=mog_prediction,
    x_train=_x_train,
    y_train=y_train,
    y_ground_truth=y_ground_truth,
    y_lim=y_lim,
)


# %%
save_dir = "._toy_network_saving/"
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
save_path = save_dir.joinpath("toy_map_ensemble")
ensemble.save(save_path)

# %%
loaded_ensemble = map_density_ensemble_from_save_path(save_path)
ensemble = loaded_ensemble

# %% codecell
gaussian_predictions = ensemble.predict_list_of_gaussians(x_plot, n_predictions=3)
plot_distribution_samples(
    x_plot=_x_plot,
    distribution_samples=gaussian_predictions,
    x_train=_x_train,
    y_train=y_train,
    y_ground_truth=y_ground_truth,
    y_lim=y_lim,
)