Ejemplo n.º 1
0
def test_ctrvae_decode(invariances):
    data_dim = (8, 8)
    model = models.iVAE(data_dim, c_dim=3, invariances=invariances)
    z_coord = torch.tensor([0.0, 0.0]).unsqueeze(0)
    y = utils.to_onehot(torch.tensor(0).unsqueeze(0), 3)
    decoded = model.decode(z_coord, y)
    assert_equal(decoded.squeeze().shape, data_dim)
Ejemplo n.º 2
0
def test_trvae_decoder_sampler(sampler, expected_dist):
    data_dim = (2, 8, 8)
    x = torch.randn(*data_dim)
    model = models.iVAE(data_dim[1:], coord=1, sampler_d=sampler)
    _, model_trace = get_traces(model, x)
    assert_(isinstance(model_trace.nodes["obs"]['fn'].base_dist,
                       expected_dist))
Ejemplo n.º 3
0
def test_trvae_manifold2d(invariances, num_classes):
    data_dim = (8, 8)
    model = models.iVAE(data_dim, c_dim=num_classes, invariances=invariances)
    y = None
    if num_classes > 0:
        y = utils.to_onehot(torch.tensor(0).unsqueeze(0), num_classes)
    decoded_grid = model.manifold2d(4, y, plot=True)
    assert_equal(decoded_grid.squeeze().shape, (16, *data_dim))
Ejemplo n.º 4
0
def test_trvae_encode_1d(invariances):
    data_dim = (3, 8)
    x = torch.randn(*data_dim)
    coord = 0 if invariances is None else len(invariances)
    model = models.iVAE(data_dim[1:], 2, invariances=invariances)
    encoded = model.encode(x)
    assert_equal(encoded[0].shape, (data_dim[0], coord + 2))
    assert_equal(encoded[0].shape, encoded[1].shape)
Ejemplo n.º 5
0
def test_trvae_sites_fn(data_dim, invariances):
    x = torch.randn(*data_dim)
    model = models.iVAE(data_dim[1:], invariances=invariances)
    guide_trace, model_trace = get_traces(model, x)
    assert_(
        isinstance(model_trace.nodes["latent"]['fn'].base_dist, dist.Normal))
    assert_(
        isinstance(guide_trace.nodes["latent"]['fn'].base_dist, dist.Normal))
    assert_(
        isinstance(model_trace.nodes["obs"]['fn'].base_dist, dist.Bernoulli))
Ejemplo n.º 6
0
def test_trvae_encode_2d(invariances):
    data_dim = (3, 8, 8)
    x = torch.randn(*data_dim)
    coord = 0
    if invariances is not None:
        coord = len(invariances)
        if 't' in invariances and len(data_dim[1:]) == 2:
            coord = coord + 1
    model = models.iVAE(data_dim[1:], 2, invariances=invariances)
    encoded = model.encode(x)
    assert_equal(encoded[0].shape, (data_dim[0], coord + 2))
    assert_equal(encoded[0].shape, encoded[1].shape)
Ejemplo n.º 7
0
def test_trvae_sites_dims_1d(invariances):
    data_dim = (3, 8)
    x = torch.randn(*data_dim)
    coord = 0 if invariances is None else len(invariances)
    model = models.iVAE(data_dim[1:], invariances=invariances)
    guide_trace, model_trace = get_traces(model, x)
    assert_equal(model_trace.nodes["latent"]['value'].shape,
                 (data_dim[0], coord + 2))
    assert_equal(guide_trace.nodes["latent"]['value'].shape,
                 (data_dim[0], coord + 2))
    assert_equal(model_trace.nodes["obs"]['value'].shape,
                 (data_dim[0], torch.prod(tt(data_dim[1:])).item()))
Ejemplo n.º 8
0
def test_svi_trainer_trvae(invariances):
    data_dim = (5, 8, 8)
    train_data = torch.randn(*data_dim)
    test_data = torch.randn(*data_dim)
    train_loader = utils.init_dataloader(train_data, batch_size=2)
    test_loader = utils.init_dataloader(test_data, batch_size=2)
    vae = models.iVAE(data_dim[1:], 2, invariances)
    trainer = trainers.SVItrainer(vae)
    weights_before = dc(vae.state_dict())
    for _ in range(2):
        trainer.step(train_loader, test_loader)
    weights_after = vae.state_dict()
    assert_(not torch.isnan(tt(trainer.loss_history["training_loss"])).any())
    assert_(not assert_weights_equal(weights_before, weights_after))
Ejemplo n.º 9
0
def test_trvae_decode_1d(invariances):
    data_dim = (8, )
    model = models.iVAE(data_dim, invariances=invariances)
    z_coord = torch.tensor([0.0, 0.0]).unsqueeze(0)
    decoded = model.decode(z_coord)
    assert_equal(decoded.squeeze().shape, data_dim)