def freeze_A_to_solution_and_fit(): # used to debug infs # from tests.test_vi import test_elbo_components, test_q_E_logstick SCALE = 1. N = 500 X = generate_gg_blocks_dataset(N, 0.05) model = InfiniteIBP(1.5, 6, 0.1, 0.05, 36) model.phi.data[:4] = SCALE * gg_blocks() model.init_z(N) model.train() visualize_A_save(model.phi.detach().numpy(), 0) visualize_nu_save(model.nu.detach().numpy(), 0) optimizer = torch.optim.Adam(model.parameters(), 0.1) for i in range(20): model.cavi(X) print("[Epoch {:<3}] ELBO = {:.3f}".format(i + 1, model.elbo(X).item())) print("CHANGE OF REGIME") visualize_A_save(model.phi.detach().numpy(), 20) visualize_nu_save(model.nu.detach().numpy(), 20) import ipdb ipdb.set_trace()
def check_that_naive_doesnt_work(): N = 500 X = generate_gg_blocks_dataset(N, 0.05) model = InfiniteIBP(1.5, 6, 0.1, 0.05, 36) model.init_z(N) model.eval() for i in range(100): if i % 5 == 0: visualize_A_save(model.phi.detach().numpy(), i) if (i + 1) % 10 == 0: model._nu.data = torch.randn(model._nu.shape) model.cavi(X) print("[Epoch {:<3}] ELBO = {:.3f}".format(i + 1, model.elbo(X).item()))