Exemple #1
0
def test_ved_decoder_sampler(sampler, expected_dist):
    input_dim = (8, 8)
    output_dim = (8, )
    x = torch.randn(2, 1, *input_dim)
    y = torch.randn(2, 1, *output_dim)
    model = models.VED(input_dim, output_dim, sampler_d=sampler)
    _, model_trace = get_traces(model, x, y)
    assert_(isinstance(model_trace.nodes["obs"]['fn'].base_dist,
                       expected_dist))
Exemple #2
0
def test_ved_sites_fn(input_dim, output_dim):
    x = torch.randn(2, 1, *input_dim)
    y = torch.randn(2, 1, *output_dim)
    model = models.VED(input_dim, output_dim)
    guide_trace, model_trace = get_traces(model, x, y)
    assert_(isinstance(model_trace.nodes["z"]['fn'].base_dist, dist.Normal))
    assert_(isinstance(guide_trace.nodes["z"]['fn'].base_dist, dist.Normal))
    assert_(
        isinstance(model_trace.nodes["obs"]['fn'].base_dist, dist.Bernoulli))
Exemple #3
0
def test_ved_sites_dims(input_dim, output_dim):
    x = torch.randn(2, 1, *input_dim)
    y = torch.randn(2, 1, *output_dim)
    model = models.VED(input_dim, output_dim)
    guide_trace, model_trace = get_traces(model, x, y)
    assert_equal(model_trace.nodes["z"]['value'].shape, (x.shape[0], 2))
    assert_equal(guide_trace.nodes["z"]['value'].shape, (x.shape[0], 2))
    assert_equal(model_trace.nodes["obs"]['value'].shape,
                 (y.shape[0], torch.prod(tt(output_dim)).item()))
def test_svi_trainer_ved(input_dim, output_dim):
    train_data_x = torch.randn(5, 1, *input_dim)
    train_data_y = torch.randn(5, 1, *output_dim)
    train_loader = utils.init_dataloader(train_data_x,
                                         train_data_y,
                                         batch_size=2)
    vae = models.VED(input_dim, output_dim)
    trainer = trainers.SVItrainer(vae)
    weights_before = dc(vae.state_dict())
    for _ in range(2):
        trainer.step(train_loader)
    weights_after = vae.state_dict()
    assert_(not torch.isnan(tt(trainer.loss_history["training_loss"])).any())
    assert_(not assert_weights_equal(weights_before, weights_after))
Exemple #5
0
def test_ved_manifold2d(input_dim, output_dim):
    model = models.VED(input_dim, output_dim)
    decoded_grid = model.manifold2d(4, plot=True)
    assert_equal(decoded_grid.squeeze().shape, (16, *output_dim))
Exemple #6
0
def test_ved_encode(input_dim, output_dim):
    x = torch.randn(2, 1, *input_dim)
    model = models.VED(input_dim, output_dim)
    encoded = model.encode(x)
    assert_equal(encoded[0].shape, (x.shape[0], 2))
    assert_equal(encoded[0].shape, encoded[1].shape)
Exemple #7
0
def test_ved_predict(input_dim, output_dim):
    x = torch.randn(2, 1, *input_dim)
    model = models.VED(input_dim, output_dim)
    prediction, _ = model.predict(x)
    assert_equal(prediction.squeeze().shape, (2, *output_dim))
Exemple #8
0
def test_ved_decode(input_dim, output_dim):
    z_coord = torch.tensor([0.0, 0.0]).unsqueeze(0)
    model = models.VED(input_dim, output_dim)
    decoded = model.decode(z_coord)
    assert_equal(decoded.squeeze().shape, output_dim)