Esempio n. 1
0
def train_sleep(generative_model, inference_network, num_samples, num_iterations, log_interval):
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    sleep_losses = []
    device = next(generative_model.parameters()).device
    if device.type == "cuda":
        torch.cuda.reset_max_memory_allocated(device=device)

    util.logging.info("Pretraining with sleep")
    iteration = 0
    while iteration < num_iterations:
        optimizer_phi.zero_grad()
        sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples)
        sleep_phi_loss.backward()
        optimizer_phi.step()

        sleep_losses.append(sleep_phi_loss.item())
        iteration += 1
        # by this time, we have gone through `iteration` iterations
        if iteration % log_interval == 0:
            util.logging.info(
                "it. {}/{} | sleep loss = {:.2f} | "
                "GPU memory = {:.2f} MB".format(
                    iteration,
                    num_iterations,
                    sleep_losses[-1],
                    (
                        torch.cuda.max_memory_allocated(device=device) / 1e6
                        if device.type == "cuda"
                        else 0
                    ),
                )
            )
        if iteration == num_iterations:
            break
Esempio n. 2
0
File: util.py Progetto: yyht/rrws
def get_q_error(generative_model, inference_network, num_samples=100):
    """Expected KL(posterior || q) + const as a measure of q's quality.

    Returns: detached scalar E_p(x)[KL(p(z | x) || q(z | x))] + H(z | x) where
        the second term is constant wrt the inference network.
    """

    return losses.get_sleep_loss(generative_model, inference_network,
                                 num_samples).detach()
Esempio n. 3
0
def get_grads_correct_sleep(seed):
    util.set_seed(seed)

    theta_grads_correct = []
    phi_grads_correct = []

    log_weight, log_q = losses.get_log_weight_and_log_q(
        generative_model, inference_network, obs, num_particles)

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
        log_weight)
    wake_theta_loss.backward(retain_graph=True)
    theta_grads_correct = [
        parameter.grad.clone() for parameter in generative_model.parameters()
    ]
    # in rws, we step as we compute the grads
    # optimizer_theta.step()

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
        log_weight, log_q)
    wake_phi_loss.backward()
    wake_phi_grads_correct = [
        parameter.grad.clone() for parameter in inference_network.parameters()
    ]
    # in rws, we step as we compute the grads
    # optimizer_phi.step()

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                           inference_network,
                                           num_samples=num_particles)
    sleep_phi_loss.backward()
    sleep_phi_grads_correct = [
        parameter.grad.clone() for parameter in inference_network.parameters()
    ]

    wake_factor = 0.755
    phi_grads_correct = [
        wake_factor * wake_phi_grad_correct +
        (1 - wake_factor) * sleep_phi_grad_correct
        for wake_phi_grad_correct, sleep_phi_grad_correct in zip(
            wake_phi_grads_correct, sleep_phi_grads_correct)
    ]

    return theta_grads_correct, phi_grads_correct
Esempio n. 4
0
File: train.py Progetto: yyht/rrws
def train_sleep(generative_model,
                inference_network,
                num_samples,
                num_iterations,
                callback=None):
    optimizer = torch.optim.Adam(inference_network.parameters())
    for iteration in range(num_iterations):
        optimizer.zero_grad()
        sleep_loss = losses.get_sleep_loss(generative_model,
                                           inference_network,
                                           num_samples=num_samples)
        sleep_loss.backward()
        optimizer.step()
        if callback is not None:
            callback(iteration, sleep_loss.item(), generative_model,
                     inference_network, optimizer)

    return optimizer
Esempio n. 5
0
def train_wake_sleep(generative_model,
                     inference_network,
                     data_loader,
                     num_iterations,
                     num_particles,
                     optim_kwargs,
                     callback=None):
    optimizer_phi = torch.optim.Adam(inference_network.parameters(),
                                     **optim_kwargs)
    optimizer_theta = torch.optim.Adam(generative_model.parameters(),
                                       **optim_kwargs)

    iteration = 0
    while iteration < num_iterations:
        for obs in iter(data_loader):
            # wake theta
            optimizer_phi.zero_grad()
            optimizer_theta.zero_grad()
            wake_theta_loss, elbo = losses.get_wake_theta_loss(
                generative_model, inference_network, obs, num_particles)
            wake_theta_loss.backward()
            optimizer_theta.step()

            # sleep phi
            optimizer_phi.zero_grad()
            optimizer_theta.zero_grad()
            sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                                   inference_network,
                                                   num_samples=obs.shape[0] *
                                                   num_particles)
            sleep_phi_loss.backward()
            optimizer_phi.step()

            if callback is not None:
                callback(iteration, wake_theta_loss.item(),
                         sleep_phi_loss.item(), elbo.item(), generative_model,
                         inference_network, optimizer_theta, optimizer_phi)

            iteration += 1
            # by this time, we have gone through `iteration` iterations
            if iteration == num_iterations:
                break
Esempio n. 6
0
def get_grads_weird_detach_sleep(seed):
    util.set_seed(seed)

    theta_grads_in_one = []
    phi_grads_in_one = []

    log_weight, log_q = get_log_weight_and_log_q_weird_detach(
        generative_model, inference_network, obs, num_particles)

    optimizer_phi.zero_grad()
    optimizer_theta.zero_grad()
    wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight(
        log_weight)

    # optimizer_phi.zero_grad() -> don't zero phi grads
    # optimizer_theta.zero_grad()
    wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q(
        log_weight, log_q)

    sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                           inference_network,
                                           num_samples=num_particles)

    wake_factor = 0.755
    phi_loss = wake_factor * wake_phi_loss + (1 - wake_factor) * sleep_phi_loss

    wake_theta_loss.backward(retain_graph=True)
    phi_loss.backward()

    # only get the grads in the end!
    theta_grads_in_one = [
        parameter.grad.clone() for parameter in generative_model.parameters()
    ]
    phi_grads_in_one = [
        parameter.grad.clone() for parameter in inference_network.parameters()
    ]

    # in pyro, we want step to be in a different stage
    # optimizer_theta.step()
    # optimizer_phi.step()
    return theta_grads_in_one, phi_grads_in_one
Esempio n. 7
0
def train_wake_sleep(generative_model,
                     inference_network,
                     obss_data_loader,
                     num_iterations,
                     num_particles,
                     callback=None):
    num_samples = obss_data_loader.batch_size * num_particles
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    optimizer_theta = torch.optim.Adam(generative_model.parameters())
    obss_iter = iter(obss_data_loader)

    for iteration in range(num_iterations):
        # get obss
        try:
            obss = next(obss_iter)
        except StopIteration:
            obss_iter = iter(obss_data_loader)
            obss = next(obss_iter)

        # wake theta
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_theta_loss, elbo = losses.get_wake_theta_loss(
            generative_model, inference_network, obss, num_particles)
        wake_theta_loss.backward()
        optimizer_theta.step()

        # sleep phi
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                               inference_network, num_samples)
        sleep_phi_loss.backward()
        optimizer_phi.step()

        if callback is not None:
            callback(iteration, wake_theta_loss.item(), sleep_phi_loss.item(),
                     elbo.item(), generative_model, inference_network,
                     optimizer_theta, optimizer_phi)

    return optimizer_theta, optimizer_phi
Esempio n. 8
0
    def __call__(self, iteration, wake_theta_loss, sleep_phi_loss, elbo,
                 generative_model, inference_network, optimizer_theta,
                 optimizer_phi):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = '
                '{:.3f}'.format(iteration, wake_theta_loss, sleep_phi_loss,
                                elbo))
            self.wake_theta_loss_history.append(wake_theta_loss)
            self.sleep_phi_loss_history.append(sleep_phi_loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_path = util.get_stats_path(self.save_dir)
            util.save_object(self, stats_path)
            util.save_checkpoint(self.save_dir,
                                 iteration,
                                 generative_model=generative_model,
                                 inference_network=inference_network)

        if iteration % self.eval_interval == 0:
            log_p, kl = eval_gen_inf(generative_model, inference_network,
                                     self.test_data_loader,
                                     self.eval_num_particles)
            self.log_p_history.append(log_p)
            self.kl_history.append(kl)

            stats = util.OnlineMeanStd()
            for _ in range(10):
                inference_network.zero_grad()
                sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                                       inference_network,
                                                       self.num_samples)
                sleep_phi_loss.backward()
                stats.update([p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1].item())
            util.print_with_time(
                'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format(
                    iteration, self.log_p_history[-1], self.kl_history[-1]))
Esempio n. 9
0
File: train.py Progetto: yyht/rrws
def train_wake_sleep(generative_model,
                     inference_network,
                     true_generative_model,
                     batch_size,
                     num_iterations,
                     num_particles,
                     callback=None):
    num_samples = batch_size * num_particles
    optimizer_phi = torch.optim.Adam(inference_network.parameters())
    optimizer_theta = torch.optim.Adam(generative_model.parameters())

    for iteration in range(num_iterations):
        # generate synthetic data
        obss = [true_generative_model.sample_obs() for _ in range(batch_size)]

        # wake theta
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        wake_theta_loss, elbo = losses.get_wake_theta_loss(
            generative_model, inference_network, obss, num_particles)
        wake_theta_loss.backward()
        optimizer_theta.step()

        # sleep phi
        optimizer_phi.zero_grad()
        optimizer_theta.zero_grad()
        sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                               inference_network, num_samples)
        sleep_phi_loss.backward()
        optimizer_phi.step()

        if callback is not None:
            callback(iteration, wake_theta_loss.item(), sleep_phi_loss.item(),
                     elbo.item(), generative_model, inference_network,
                     optimizer_theta, optimizer_phi)

    return optimizer_theta, optimizer_phi
Esempio n. 10
0
    def __call__(self, iteration, wake_theta_loss, sleep_phi_loss, elbo,
                 generative_model, inference_network, optimizer_theta,
                 optimizer_phi):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = '
                '{:.3f}'.format(iteration, wake_theta_loss, sleep_phi_loss,
                                elbo))
            self.wake_theta_loss_history.append(wake_theta_loss)
            self.sleep_phi_loss_history.append(sleep_phi_loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_path(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.model_folder, iteration)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_history.append(
                util.get_q_error(self.true_generative_model, inference_network,
                                 self.test_obss))
            stats = util.OnlineMeanStd()
            for _ in range(10):
                inference_network.zero_grad()
                sleep_phi_loss = losses.get_sleep_loss(generative_model,
                                                       inference_network,
                                                       self.num_samples)
                sleep_phi_loss.backward()
                stats.update([p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1])
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = '
                '{:.3f}'.format(iteration, self.p_error_history[-1],
                                self.q_error_history[-1]))
Esempio n. 11
0
def get_mean_stds(generative_model, inference_network, num_mc_samples, obss,
                  num_particles):
    vimco_grad = util.OnlineMeanStd()
    vimco_one_grad = util.OnlineMeanStd()
    reinforce_grad = util.OnlineMeanStd()
    reinforce_one_grad = util.OnlineMeanStd()
    two_grad = util.OnlineMeanStd()
    log_evidence_stats = util.OnlineMeanStd()
    log_evidence_grad = util.OnlineMeanStd()
    wake_phi_loss_grad = util.OnlineMeanStd()
    log_Q_grad = util.OnlineMeanStd()
    sleep_loss_grad = util.OnlineMeanStd()

    for mc_sample_idx in range(num_mc_samples):
        util.print_with_time('MC sample {}'.format(mc_sample_idx))
        log_weight, log_q = losses.get_log_weight_and_log_q(
            generative_model, inference_network, obss, num_particles)
        log_evidence = torch.logsumexp(log_weight, dim=1) - \
            np.log(num_particles)
        avg_log_evidence = torch.mean(log_evidence)
        log_Q = torch.sum(log_q, dim=1)
        avg_log_Q = torch.mean(log_Q)
        reinforce_one = torch.mean(log_evidence.detach() * log_Q)
        reinforce = reinforce_one + avg_log_evidence
        vimco_one = 0
        for i in range(num_particles):
            log_weight_ = log_weight[:, util.range_except(num_particles, i)]
            control_variate = torch.logsumexp(
                torch.cat([log_weight_, torch.mean(log_weight_, dim=1,
                                                   keepdim=True)], dim=1),
                dim=1)
            vimco_one = vimco_one + (log_evidence.detach() -
                                     control_variate.detach()) * log_q[:, i]
        vimco_one = torch.mean(vimco_one)
        vimco = vimco_one + avg_log_evidence
        normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1)
        wake_phi_loss = torch.mean(
            -torch.sum(normalized_weight.detach() * log_q, dim=1))

        inference_network.zero_grad()
        generative_model.zero_grad()
        vimco.backward(retain_graph=True)
        vimco_grad.update([param.grad for param in
                           inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        vimco_one.backward(retain_graph=True)
        vimco_one_grad.update([param.grad for param in
                               inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        reinforce.backward(retain_graph=True)
        reinforce_grad.update([param.grad for param in
                               inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        reinforce_one.backward(retain_graph=True)
        reinforce_one_grad.update([param.grad for param in
                                   inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        avg_log_evidence.backward(retain_graph=True)
        two_grad.update([param.grad for param in
                         inference_network.parameters()])
        log_evidence_grad.update([param.grad for param in
                                  generative_model.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        wake_phi_loss.backward(retain_graph=True)
        wake_phi_loss_grad.update([param.grad for param in
                                   inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        avg_log_Q.backward(retain_graph=True)
        log_Q_grad.update([param.grad for param in
                           inference_network.parameters()])

        log_evidence_stats.update([avg_log_evidence.unsqueeze(0)])

        sleep_loss = losses.get_sleep_loss(
            generative_model, inference_network, num_particles * len(obss))
        inference_network.zero_grad()
        generative_model.zero_grad()
        sleep_loss.backward()
        sleep_loss_grad.update([param.grad for param in
                                inference_network.parameters()])

    return list(map(
        lambda x: x.avg_of_means_stds(),
        [vimco_grad, vimco_one_grad, reinforce_grad, reinforce_one_grad,
         two_grad, log_evidence_stats, log_evidence_grad, wake_phi_loss_grad,
         log_Q_grad, sleep_loss_grad]))