Example #1
0
def freeze_A_to_solution_and_fit():
    # used to debug infs
    # from tests.test_vi import test_elbo_components, test_q_E_logstick

    SCALE = 1.

    N = 500
    X = generate_gg_blocks_dataset(N, 0.05)

    model = InfiniteIBP(1.5, 6, 0.1, 0.05, 36)
    model.phi.data[:4] = SCALE * gg_blocks()
    model.init_z(N)
    model.train()

    visualize_A_save(model.phi.detach().numpy(), 0)
    visualize_nu_save(model.nu.detach().numpy(), 0)

    optimizer = torch.optim.Adam(model.parameters(), 0.1)

    for i in range(20):
        model.cavi(X)
        print("[Epoch {:<3}] ELBO = {:.3f}".format(i + 1,
                                                   model.elbo(X).item()))

    print("CHANGE OF REGIME")

    visualize_A_save(model.phi.detach().numpy(), 20)
    visualize_nu_save(model.nu.detach().numpy(), 20)
    import ipdb
    ipdb.set_trace()
Example #2
0
def check_that_naive_doesnt_work():
    N = 500
    X = generate_gg_blocks_dataset(N, 0.05)

    model = InfiniteIBP(1.5, 6, 0.1, 0.05, 36)
    model.init_z(N)
    model.eval()

    for i in range(100):
        if i % 5 == 0:
            visualize_A_save(model.phi.detach().numpy(), i)
        if (i + 1) % 10 == 0:
            model._nu.data = torch.randn(model._nu.shape)

        model.cavi(X)
        print("[Epoch {:<3}] ELBO = {:.3f}".format(i + 1,
                                                   model.elbo(X).item()))