Esempio n. 1
0
    def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo,
                 generative_model, inference_network, optimizer_theta,
                 optimizer_phi):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = '
                '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss,
                                elbo))
            self.wake_theta_loss_history.append(wake_theta_loss)
            self.wake_phi_loss_history.append(wake_phi_loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_path = util.get_stats_path(self.save_dir)
            util.save_object(self, stats_path)
            util.save_checkpoint(self.save_dir,
                                 iteration,
                                 generative_model=generative_model,
                                 inference_network=inference_network)

        if iteration % self.eval_interval == 0:
            log_p, kl = eval_gen_inf(generative_model, inference_network,
                                     self.test_data_loader,
                                     self.eval_num_particles)
            self.log_p_history.append(log_p)
            self.kl_history.append(kl)

            stats = util.OnlineMeanStd()
            for _ in range(10):
                generative_model.zero_grad()
                wake_theta_loss, elbo = losses.get_wake_theta_loss(
                    generative_model, inference_network, self.test_obs,
                    self.num_particles)
                wake_theta_loss.backward()
                theta_grads = [
                    p.grad.clone() for p in generative_model.parameters()
                ]

                inference_network.zero_grad()
                wake_phi_loss = losses.get_wake_phi_loss(
                    generative_model, inference_network, self.test_obs,
                    self.num_particles)
                wake_phi_loss.backward()
                phi_grads = [p.grad for p in inference_network.parameters()]

                stats.update(theta_grads + phi_grads)
            self.grad_std_history.append(stats.avg_of_means_stds()[1].item())
            util.print_with_time(
                'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format(
                    iteration, self.log_p_history[-1], self.kl_history[-1]))
Esempio n. 2
0
def train_wake_sleep(generative_model,
                     inference_network,
                     data_loader,
                     num_iterations,
                     num_particles,
                     optim_kwargs,
                     callback=None):
    optimizer_phi = torch.optim.Adam(inference_network.parameters(),
                                     **optim_kwargs)
    optimizer_theta = torch.optim.Adam(generative_model.parameters(),
                                       **optim_kwargs)

    iteration = 0
    while iteration < num_iterations:
        for obs in iter(data_loader):
            # wake theta
            optimizer_phi.zero_grad()
            optimizer_theta.zero_grad()
            wake_theta_loss, elbo = losses.get_wake_theta_loss(
                generative_model, inference_network, obs, num_particles)
            wake_theta_loss.backward()
            optimizer_theta.step()

            # sleep phi
            optimizer_phi.zero_grad()
            optimizer_theta.zero_grad()
            sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                                   inference_network,
                                                   num_samples=obs.shape[0] *
                                                   num_particles)
            sleep_phi_loss.backward()
            optimizer_phi.step()

            if callback is not None:
                callback(iteration, wake_theta_loss.item(),
                         sleep_phi_loss.item(), elbo.item(), generative_model,
                         inference_network, optimizer_theta, optimizer_phi)

            iteration += 1
            # by this time, we have gone through `iteration` iterations
            if iteration == num_iterations:
                break
Esempio n. 3
0
def train_wake_sleep(generative_model,
                     inference_network,
                     obss_data_loader,
                     num_iterations,
                     num_particles,
                     callback=None):
    num_samples = obss_data_loader.batch_size * num_particles
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    optimizer_theta = torch.optim.Adam(generative_model.parameters())
    obss_iter = iter(obss_data_loader)

    for iteration in range(num_iterations):
        # get obss
        try:
            obss = next(obss_iter)
        except StopIteration:
            obss_iter = iter(obss_data_loader)
            obss = next(obss_iter)

        # wake theta
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_theta_loss, elbo = losses.get_wake_theta_loss(
            generative_model, inference_network, obss, num_particles)
        wake_theta_loss.backward()
        optimizer_theta.step()

        # sleep phi
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                               inference_network, num_samples)
        sleep_phi_loss.backward()
        optimizer_phi.step()

        if callback is not None:
            callback(iteration, wake_theta_loss.item(), sleep_phi_loss.item(),
                     elbo.item(), generative_model, inference_network,
                     optimizer_theta, optimizer_phi)

    return optimizer_theta, optimizer_phi
Esempio n. 4
0
File: train.py Progetto: yyht/rrws
def train_wake_sleep(generative_model,
                     inference_network,
                     true_generative_model,
                     batch_size,
                     num_iterations,
                     num_particles,
                     callback=None):
    num_samples = batch_size * num_particles
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    optimizer_theta = torch.optim.Adam(generative_model.parameters())

    for iteration in range(num_iterations):
        # generate synthetic data
        obss = [true_generative_model.sample_obs() for _ in range(batch_size)]

        # wake theta
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_theta_loss, elbo = losses.get_wake_theta_loss(
            generative_model, inference_network, obss, num_particles)
        wake_theta_loss.backward()
        optimizer_theta.step()

        # sleep phi
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                               inference_network, num_samples)
        sleep_phi_loss.backward()
        optimizer_phi.step()

        if callback is not None:
            callback(iteration, wake_theta_loss.item(), sleep_phi_loss.item(),
                     elbo.item(), generative_model, inference_network,
                     optimizer_theta, optimizer_phi)

    return optimizer_theta, optimizer_phi