示例#1
0
 def test_renyi_nonreparameterized_vectorized(self):
     self.do_elbo_test(
         False, 5000,
         RenyiELBO(alpha=0.2,
                   num_particles=2,
                   vectorize_particles=True,
                   max_plate_nesting=1))
示例#2
0
def test_non_nested_plating_sum():
    """Example from https://github.com/pyro-ppl/pyro/issues/2361"""

    # Generative model: data = x @ weights + eps
    def model(data, weights):
        loc = torch.tensor(1.0)
        scale = torch.tensor(0.1)

        # Sample latents (shares no dimensions with data)
        with pyro.plate('x_plate', weights.shape[0]):
            x = pyro.sample('x', pyro.distributions.Normal(loc, scale))

        # Combine with weights and sample
        with pyro.plate('data_plate_1', data.shape[-1]):
            with pyro.plate('data_plate_2', data.shape[-2]):
                pyro.sample('data', pyro.distributions.Normal(x @ weights, scale), obs=data)

    def guide(data, weights):
        loc = pyro.param('x_loc', torch.tensor(0.5))
        scale = torch.tensor(0.1)

        with pyro.plate('x_plate', weights.shape[0]):
            pyro.sample('x', pyro.distributions.Normal(loc, scale))

    data = torch.randn([5, 3])
    weights = torch.randn([2, 3])
    adam = optim.Adam({"lr": 0.01})
    loss_fn = RenyiELBO(num_particles=30, vectorize_particles=True)
    svi = SVI(model, guide, adam, loss_fn)

    for step in range(1):
        loss = svi.step(data, weights)
        if step % 20 == 0:
            logger.info("step {} loss = {:0.4g}".format(step, loss))
示例#3
0
 def test_renyi_reparameterized_vectorized(self):
     self.do_elbo_test(
         True,
         5000,
         RenyiELBO(num_particles=2,
                   vectorize_particles=True,
                   max_plate_nesting=1),
     )
示例#4
0
def test_sequential_plating_sum():
    """Example from https://github.com/pyro-ppl/pyro/issues/2361"""
    def model(data):
        x = pyro.sample('x', dist.Bernoulli(torch.tensor(0.5)))
        for i in pyro.plate('data_plate', len(data)):
            pyro.sample('data_{:d}'.format(i),
                        dist.Normal(x, scale=torch.tensor(0.1)),
                        obs=data[i])

    def guide(data):
        p = pyro.param('p', torch.tensor(0.5))
        pyro.sample('x', pyro.distributions.Bernoulli(p))

    data = torch.cat([torch.randn([5]), 1. + torch.randn([5])])
    adam = optim.Adam({"lr": 0.01})
    loss_fn = RenyiELBO(alpha=0, num_particles=30, vectorize_particles=True)
    svi = SVI(model, guide, adam, loss_fn)

    for step in range(1):
        loss = svi.step(data)
        if step % 20 == 0:
            logger.info("step {} loss = {:0.4g}".format(step, loss))
示例#5
0
 def test_renyi_nonreparameterized(self):
     self.do_elbo_test(False, 7500, RenyiELBO(num_particles=3))
示例#6
0
 def test_renyi_reparameterized(self):
     self.do_elbo_test(True, 2500, RenyiELBO(num_particles=3))
示例#7
0
 def test_renyi_nonreparameterized(self):
     self.do_elbo_test(False, 5000, RenyiELBO(alpha=0.2, num_particles=2))
示例#8
0
def infer_CS(
    cond_model,
    guide_conf,
    guidefile,
    n_wake,
    n_sleep,
    gen_samples=True,
    n_rounds=1,
    n_simulate=1000,
    sleep_batch_size=3,
    wake_batch_size=1,
    alpha=1.,
    lr_wake=1e-3,
    lr_sleep=1e-3,
    n_write=300,
    device="cpu",
    verbose=True,
):
    """Contrastive inference.

    Regularly saves the parameter and loss values. Also saves the pyro
    parameter store so runs can be resumed.

    Parameters
    ----------
    args : dict
        Command line arguments
    cond_model : callable
        Lensing system model conditioned on an observed image.
    n_write : int
        Number of iterations between parameter store and loss and parameter
        saves.

    Returns
    -------
    loss : float
        Final value of loss.
    """

    # Initialize VI model and guide
    guide = init_guide(cond_model,
                       guide_conf,
                       guidefile=guidefile,
                       device=device,
                       default_observations=True)

    wake_optimizer = Adam({
        "lr": lr_wake,
        "amsgrad": False,
        "weight_decay": 0.0
    })
    sleep_optimizer = Adam({
        "lr": lr_sleep,
        "amsgrad": False,
        "weight_decay": 0.0
    })

    site_names = guide_conf['sleep_sites']

    conlearn = ConLearn(cond_model,
                        guide,
                        sleep_optimizer,
                        training_batch_size=sleep_batch_size,
                        site_names=site_names)

    # Default is standard ELBO loss
    if alpha == 1.:
        wake_loss = Trace_ELBO(num_particles=wake_batch_size)
    else:
        wake_loss = RenyiELBO(alpha=alpha, num_particles=wake_batch_size)

    svi = SVI(cond_model, guide, wake_optimizer, loss=wake_loss)

    if verbose:
        print()
        print("##################")
        print("# Initial values #")
        print("##################")
        print("Parameter store:")
        for name, value in pyro.get_param_store().items():
            print(name + ": " + str(value))
        print()
        print("Guide:")
        for name, value in guide()[1].items():
            print(name + ": " + str(value))
        print()

    print("###################")
    print("# Wake and Sleep. #")
    print("###################")

    sleep_losses = []
    wake_losses = []

    # Rounds
    for r in range(n_rounds):
        print("\nRound %i:" % r)

        # Wake phase
        with tqdm(total=n_wake, desc='Wake') as t:
            guide.wake()
            for i in range(n_wake):
                if (i + 1) % n_write == 0:
                    pyro.get_param_store().save(guidefile)

                loss = svi.step()

                wake_losses.append(loss)
                minloss = min(wake_losses)
                t.postfix = "loss=%.3f (%.3f)" % (loss, minloss)
                t.update()

        # Sleep phase
        if n_sleep > 0:
            conlearn.simulate(n_simulate,
                              replace=True,
                              gen_samples=gen_samples)
            gen_samples = False  # Only in first round

        with tqdm(total=n_sleep, desc='Sleep') as t:
            guide.sleep()
            for i in range(n_sleep):
                if (i + 1) % n_write == 0:
                    pyro.get_param_store().save(guidefile)

                loss = conlearn.step()
                sleep_losses.append(loss)
                minloss = min(sleep_losses)
                t.postfix = "loss=%.3f (%.3f)" % (loss, minloss)
                t.update()

    if verbose:
        print()
        print("################")
        print("# Final values #")
        print("################")
        print("Parameter store:")
        for name, value in pyro.get_param_store().items():
            print(name + ": " + str(value))
        print()
        print("Guide:")
        for name, value in guide()[1].items():
            print(name + ": " + str(value))
        print()

    save_guide(guidefile)
# Before starting the training, let us import the MNIST dataset.

# train = MNIST("MNIST", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),]), )
# test = MNIST("MNIST", train=False, download=True, transform=transforms.Compose([transforms.ToTensor(),]), )
# dataloader_args = dict(shuffle=True, batch_size=batch_size, num_workers=1, pin_memory=False)
# train_loader = dataloader.DataLoader(train, **dataloader_args)
# test_loader = dataloader.DataLoader(test, **dataloader_args)
num_epochs = 500
num_samples = 100


## Inference

# We can now launch the inference.

inference = SVI(model, guide, Adam({"lr": 0.001}), loss=RenyiELBO(alpha=.5))
num_schedules = 50
data, labels = create_simple_classification_dataset(num_schedules)
schedule_starts = np.linspace(0, 20 * (num_schedules-1), num=num_schedules)
not_first_time = False
distributions = [np.array([.5, .1], dtype=float) for _ in range(num_schedules)]  # each one is mean, sigma

print('Inference')
for epoch in range(num_epochs):
    # for j, (imgs, lbls) in enumerate(train_loader, 0):
    #     loss = inference.step(imgs.to(device), lbls.to(device))
    for _ in range(num_schedules):
        x_data = []
        y_data = []
        chosen_schedule_start = int(np.random.choice(schedule_starts))
        schedule_num = int(chosen_schedule_start / 20)