def test_sstrvae_disc_sites_fn(invariances): data_dim = (3, 8, 8) x = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item()) coord = 0 if invariances is not None: coord = len(invariances) if 't' in invariances and len(data_dim[1:]) == 2: coord = coord + 1 model = models.ssiVAE(data_dim[1:], 2, 3, invariances=invariances) guide_trace, model_trace = get_enum_traces(model, x) assert_(isinstance(model_trace.nodes["y"]['fn'], dist.OneHotCategorical)) assert_(isinstance(guide_trace.nodes["y"]['fn'], dist.OneHotCategorical))
def test_sstrvae_disc_sites_dims(invariances): data_dim = (3, 8, 8) x = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item()) coord = 0 if invariances is not None: coord = len(invariances) if 't' in invariances and len(data_dim[1:]) == 2: coord = coord + 1 model = models.ssiVAE(data_dim[1:], 2, 3, invariances=invariances) guide_trace, model_trace = get_enum_traces(model, x) assert_equal(model_trace.nodes["y"]['value'].shape, (3, data_dim[0], 3)) assert_equal(guide_trace.nodes["y"]['value'].shape, (3, data_dim[0], 3))
def test_sstrvae_encode(invariances): data_dim = (3, 8, 8) x = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item()) coord = 0 if invariances is not None: coord = len(invariances) if 't' in invariances: coord = coord + 1 model = models.ssiVAE(data_dim[1:], 2, 5, invariances=invariances) encoded = model.encode(x) assert_equal(encoded[0].shape, encoded[1].shape) assert_equal(encoded[0].shape, (data_dim[0], coord + 2)) assert_equal(encoded[2].shape, (data_dim[0], ))
def test_auxsvi_trainer_cls(invariances): data_dim = (5, 8, 8) train_unsup = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item()) train_sup = train_unsup + .1 * torch.randn_like(train_unsup) labels = dist.OneHotCategorical(torch.ones(data_dim[0], 3)).sample() loader_unsup, loader_sup, loader_val = utils.init_ssvae_dataloaders( train_unsup, (train_sup, labels), (train_sup, labels), batch_size=2) vae = models.ssiVAE(data_dim[1:], 2, 3, invariances) trainer = trainers.auxSVItrainer(vae) weights_before = dc(vae.state_dict()) for _ in range(2): trainer.step(loader_unsup, loader_sup, loader_val) weights_after = vae.state_dict() assert_(not torch.isnan(tt(trainer.history["training_loss"])).any()) assert_(not assert_weights_equal(weights_before, weights_after))
def test_auxsvi_trainer_swa(invariances): data_dim = (5, 8, 8) train_unsup = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item()) train_sup = train_unsup + .1 * torch.randn_like(train_unsup) labels = dist.OneHotCategorical(torch.ones(data_dim[0], 3)).sample() loader_unsup, loader_sup, _ = utils.init_ssvae_dataloaders( train_unsup, (train_sup, labels), (train_sup, labels), batch_size=2) vae = models.ssiVAE(data_dim[1:], 2, 3, invariances) trainer = trainers.auxSVItrainer(vae) for _ in range(3): trainer.step(loader_unsup, loader_sup) trainer.save_running_weights("encoder_y") weights_final = dc(vae.encoder_y.state_dict()) trainer.average_weights("encoder_y") weights_aver = vae.encoder_y.state_dict() assert_(not assert_weights_equal(weights_final, weights_aver))
def test_sstrvae_decoder_sampler(sampler, expected_dist): data_dim = (2, 64) x = torch.randn(*data_dim) model = models.ssiVAE(data_dim[1:], 2, 3, coord=1, sampler_d=sampler) _, model_trace = get_enum_traces(model, x) assert_(isinstance(model_trace.nodes["x"]['fn'].base_dist, expected_dist))