def test_ssregvae_disc_sites_dims(invariances): data_dim = (3, 8, 8) x = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item()) coord = 0 if invariances is not None: coord = len(invariances) if 't' in invariances and len(data_dim[1:]) == 2: coord = coord + 1 model = models.ss_reg_iVAE(data_dim[1:], 2, 3, invariances=invariances) guide_trace, model_trace = get_traces(model, x) assert_equal(model_trace.nodes["y"]['value'].shape, (data_dim[0], 3)) assert_equal(guide_trace.nodes["y"]['value'].shape, (data_dim[0], 3))
def test_ssregvae_reg_sites_fn(invariances): data_dim = (3, 8, 8) x = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item()) coord = 0 if invariances is not None: coord = len(invariances) if 't' in invariances and len(data_dim[1:]) == 2: coord = coord + 1 model = models.ss_reg_iVAE(data_dim[1:], 2, 3, invariances=invariances) guide_trace, model_trace = get_traces(model, x) assert_(isinstance(model_trace.nodes["y"]['fn'].base_dist, dist.Normal)) assert_(isinstance(guide_trace.nodes["y"]['fn'].base_dist, dist.Normal))
def test_auxsvi_trainer_reg(c_dim, invariances): data_dim = (5, 8, 8) train_unsup = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item()) train_sup = train_unsup + .1 * torch.randn_like(train_unsup) gt = torch.randn(data_dim[0], c_dim) loader_unsup, loader_sup, loader_val = utils.init_ssvae_dataloaders( train_unsup, (train_sup, gt), (train_sup, gt), batch_size=2) vae = models.ss_reg_iVAE(data_dim[1:], 2, c_dim, invariances) trainer = trainers.auxSVItrainer(vae, task="regression") weights_before = dc(vae.state_dict()) for _ in range(2): trainer.step(loader_unsup, loader_sup, loader_val) weights_after = vae.state_dict() assert_(not torch.isnan(tt(trainer.history["training_loss"])).any()) assert_(not assert_weights_equal(weights_before, weights_after))