Ejemplo n.º 1
0
def test_ssregvae_disc_sites_dims(invariances):
    data_dim = (3, 8, 8)
    x = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item())
    coord = 0
    if invariances is not None:
        coord = len(invariances)
        if 't' in invariances and len(data_dim[1:]) == 2:
            coord = coord + 1
    model = models.ss_reg_iVAE(data_dim[1:], 2, 3, invariances=invariances)
    guide_trace, model_trace = get_traces(model, x)
    assert_equal(model_trace.nodes["y"]['value'].shape, (data_dim[0], 3))
    assert_equal(guide_trace.nodes["y"]['value'].shape, (data_dim[0], 3))
Ejemplo n.º 2
0
def test_ssregvae_reg_sites_fn(invariances):
    data_dim = (3, 8, 8)
    x = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item())
    coord = 0
    if invariances is not None:
        coord = len(invariances)
        if 't' in invariances and len(data_dim[1:]) == 2:
            coord = coord + 1
    model = models.ss_reg_iVAE(data_dim[1:], 2, 3, invariances=invariances)
    guide_trace, model_trace = get_traces(model, x)
    assert_(isinstance(model_trace.nodes["y"]['fn'].base_dist, dist.Normal))
    assert_(isinstance(guide_trace.nodes["y"]['fn'].base_dist, dist.Normal))
Ejemplo n.º 3
0
def test_auxsvi_trainer_reg(c_dim, invariances):
    data_dim = (5, 8, 8)
    train_unsup = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item())
    train_sup = train_unsup + .1 * torch.randn_like(train_unsup)
    gt = torch.randn(data_dim[0], c_dim)

    loader_unsup, loader_sup, loader_val = utils.init_ssvae_dataloaders(
        train_unsup, (train_sup, gt), (train_sup, gt), batch_size=2)
    vae = models.ss_reg_iVAE(data_dim[1:], 2, c_dim, invariances)
    trainer = trainers.auxSVItrainer(vae, task="regression")
    weights_before = dc(vae.state_dict())
    for _ in range(2):
        trainer.step(loader_unsup, loader_sup, loader_val)
    weights_after = vae.state_dict()
    assert_(not torch.isnan(tt(trainer.history["training_loss"])).any())
    assert_(not assert_weights_equal(weights_before, weights_after))