def test_ctrvae_decode(invariances): data_dim = (8, 8) model = models.iVAE(data_dim, c_dim=3, invariances=invariances) z_coord = torch.tensor([0.0, 0.0]).unsqueeze(0) y = utils.to_onehot(torch.tensor(0).unsqueeze(0), 3) decoded = model.decode(z_coord, y) assert_equal(decoded.squeeze().shape, data_dim)
def test_trvae_decoder_sampler(sampler, expected_dist): data_dim = (2, 8, 8) x = torch.randn(*data_dim) model = models.iVAE(data_dim[1:], coord=1, sampler_d=sampler) _, model_trace = get_traces(model, x) assert_(isinstance(model_trace.nodes["obs"]['fn'].base_dist, expected_dist))
def test_trvae_manifold2d(invariances, num_classes): data_dim = (8, 8) model = models.iVAE(data_dim, c_dim=num_classes, invariances=invariances) y = None if num_classes > 0: y = utils.to_onehot(torch.tensor(0).unsqueeze(0), num_classes) decoded_grid = model.manifold2d(4, y, plot=True) assert_equal(decoded_grid.squeeze().shape, (16, *data_dim))
def test_trvae_encode_1d(invariances): data_dim = (3, 8) x = torch.randn(*data_dim) coord = 0 if invariances is None else len(invariances) model = models.iVAE(data_dim[1:], 2, invariances=invariances) encoded = model.encode(x) assert_equal(encoded[0].shape, (data_dim[0], coord + 2)) assert_equal(encoded[0].shape, encoded[1].shape)
def test_trvae_sites_fn(data_dim, invariances): x = torch.randn(*data_dim) model = models.iVAE(data_dim[1:], invariances=invariances) guide_trace, model_trace = get_traces(model, x) assert_( isinstance(model_trace.nodes["latent"]['fn'].base_dist, dist.Normal)) assert_( isinstance(guide_trace.nodes["latent"]['fn'].base_dist, dist.Normal)) assert_( isinstance(model_trace.nodes["obs"]['fn'].base_dist, dist.Bernoulli))
def test_trvae_encode_2d(invariances): data_dim = (3, 8, 8) x = torch.randn(*data_dim) coord = 0 if invariances is not None: coord = len(invariances) if 't' in invariances and len(data_dim[1:]) == 2: coord = coord + 1 model = models.iVAE(data_dim[1:], 2, invariances=invariances) encoded = model.encode(x) assert_equal(encoded[0].shape, (data_dim[0], coord + 2)) assert_equal(encoded[0].shape, encoded[1].shape)
def test_trvae_sites_dims_1d(invariances): data_dim = (3, 8) x = torch.randn(*data_dim) coord = 0 if invariances is None else len(invariances) model = models.iVAE(data_dim[1:], invariances=invariances) guide_trace, model_trace = get_traces(model, x) assert_equal(model_trace.nodes["latent"]['value'].shape, (data_dim[0], coord + 2)) assert_equal(guide_trace.nodes["latent"]['value'].shape, (data_dim[0], coord + 2)) assert_equal(model_trace.nodes["obs"]['value'].shape, (data_dim[0], torch.prod(tt(data_dim[1:])).item()))
def test_svi_trainer_trvae(invariances): data_dim = (5, 8, 8) train_data = torch.randn(*data_dim) test_data = torch.randn(*data_dim) train_loader = utils.init_dataloader(train_data, batch_size=2) test_loader = utils.init_dataloader(test_data, batch_size=2) vae = models.iVAE(data_dim[1:], 2, invariances) trainer = trainers.SVItrainer(vae) weights_before = dc(vae.state_dict()) for _ in range(2): trainer.step(train_loader, test_loader) weights_after = vae.state_dict() assert_(not torch.isnan(tt(trainer.loss_history["training_loss"])).any()) assert_(not assert_weights_equal(weights_before, weights_after))
def test_trvae_decode_1d(invariances): data_dim = (8, ) model = models.iVAE(data_dim, invariances=invariances) z_coord = torch.tensor([0.0, 0.0]).unsqueeze(0) decoded = model.decode(z_coord) assert_equal(decoded.squeeze().shape, data_dim)