Beispiel #1
0
    def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo,
                 generative_model, inference_network, optimizer_theta,
                 optimizer_phi):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = '
                '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss,
                                elbo))
            self.wake_theta_loss_history.append(wake_theta_loss)
            self.wake_phi_loss_history.append(wake_phi_loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_path = util.get_stats_path(self.save_dir)
            util.save_object(self, stats_path)
            util.save_checkpoint(self.save_dir,
                                 iteration,
                                 generative_model=generative_model,
                                 inference_network=inference_network)

        if iteration % self.eval_interval == 0:
            log_p, kl = eval_gen_inf(generative_model, inference_network,
                                     self.test_data_loader,
                                     self.eval_num_particles)
            self.log_p_history.append(log_p)
            self.kl_history.append(kl)

            stats = util.OnlineMeanStd()
            for _ in range(10):
                generative_model.zero_grad()
                wake_theta_loss, elbo = losses.get_wake_theta_loss(
                    generative_model, inference_network, self.test_obs,
                    self.num_particles)
                wake_theta_loss.backward()
                theta_grads = [
                    p.grad.clone() for p in generative_model.parameters()
                ]

                inference_network.zero_grad()
                wake_phi_loss = losses.get_wake_phi_loss(
                    generative_model, inference_network, self.test_obs,
                    self.num_particles)
                wake_phi_loss.backward()
                phi_grads = [p.grad for p in inference_network.parameters()]

                stats.update(theta_grads + phi_grads)
            self.grad_std_history.append(stats.avg_of_means_stds()[1].item())
            util.print_with_time(
                'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format(
                    iteration, self.log_p_history[-1], self.kl_history[-1]))
Beispiel #2
0
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_path = util.get_stats_path(self.save_dir)
            util.save_object(self, stats_path)
            util.save_checkpoint(self.save_dir,
                                 iteration,
                                 generative_model=generative_model,
                                 inference_network=inference_network)

        if iteration % self.eval_interval == 0:
            log_p, kl = eval_gen_inf(generative_model, inference_network,
                                     self.test_data_loader,
                                     self.eval_num_particles)
            _, renyi = eval_gen_inf_alpha(generative_model, inference_network,
                                          self.test_data_loader,
                                          self.eval_num_particles, self.alpha)
            self.log_p_history.append(log_p)
            self.kl_history.append(kl)
            self.renyi_history.append(renyi)

            stats = util.OnlineMeanStd()
            for _ in range(10):
                generative_model.zero_grad()
                inference_network.zero_grad()
                loss, elbo = losses.get_thermo_alpha_loss(
                    generative_model, inference_network, self.test_obs,
                    self.partition, self.num_particles, self.alpha,
                    self.integration)
                loss.backward()
                stats.update([p.grad for p in generative_model.parameters()] +
                             [p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1].item())
            util.print_with_time(
                'Iteration {} log_p = {:.3f}, kl = {:.3f}, renyi = {:.3f}'.
                format(iteration, self.log_p_history[-1], self.kl_history[-1],
                       self.renyi_history[-1]))
Beispiel #3
0
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_path(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.model_folder, iteration)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_history.append(
                util.get_q_error(self.true_generative_model, inference_network,
                                 self.test_obss))
            stats = util.OnlineMeanStd()
            for _ in range(10):
                inference_network.zero_grad()
                if self.train_mode == 'vimco':
                    loss, elbo = losses.get_vimco_loss(generative_model,
                                                       inference_network,
                                                       self.test_obss,
                                                       self.num_particles)
                elif self.train_mode == 'reinforce':
                    loss, elbo = losses.get_reinforce_loss(
                        generative_model, inference_network, self.test_obss,
                        self.num_particles)
                loss.backward()
                stats.update([p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1])
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = '
                '{:.3f}'.format(iteration, self.p_error_history[-1],
                                self.q_error_history[-1]))
Beispiel #4
0
def get_mean_stds(generative_model, inference_network, num_mc_samples, obss,
                  num_particles):
    vimco_grad = util.OnlineMeanStd()
    vimco_one_grad = util.OnlineMeanStd()
    reinforce_grad = util.OnlineMeanStd()
    reinforce_one_grad = util.OnlineMeanStd()
    two_grad = util.OnlineMeanStd()
    log_evidence_stats = util.OnlineMeanStd()
    log_evidence_grad = util.OnlineMeanStd()
    wake_phi_loss_grad = util.OnlineMeanStd()
    log_Q_grad = util.OnlineMeanStd()
    sleep_loss_grad = util.OnlineMeanStd()

    for mc_sample_idx in range(num_mc_samples):
        util.print_with_time('MC sample {}'.format(mc_sample_idx))
        log_weight, log_q = losses.get_log_weight_and_log_q(
            generative_model, inference_network, obss, num_particles)
        log_evidence = torch.logsumexp(log_weight, dim=1) - \
            np.log(num_particles)
        avg_log_evidence = torch.mean(log_evidence)
        log_Q = torch.sum(log_q, dim=1)
        avg_log_Q = torch.mean(log_Q)
        reinforce_one = torch.mean(log_evidence.detach() * log_Q)
        reinforce = reinforce_one + avg_log_evidence
        vimco_one = 0
        for i in range(num_particles):
            log_weight_ = log_weight[:, util.range_except(num_particles, i)]
            control_variate = torch.logsumexp(
                torch.cat([log_weight_, torch.mean(log_weight_, dim=1,
                                                   keepdim=True)], dim=1),
                dim=1)
            vimco_one = vimco_one + (log_evidence.detach() -
                                     control_variate.detach()) * log_q[:, i]
        vimco_one = torch.mean(vimco_one)
        vimco = vimco_one + avg_log_evidence
        normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1)
        wake_phi_loss = torch.mean(
            -torch.sum(normalized_weight.detach() * log_q, dim=1))

        inference_network.zero_grad()
        generative_model.zero_grad()
        vimco.backward(retain_graph=True)
        vimco_grad.update([param.grad for param in
                           inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        vimco_one.backward(retain_graph=True)
        vimco_one_grad.update([param.grad for param in
                               inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        reinforce.backward(retain_graph=True)
        reinforce_grad.update([param.grad for param in
                               inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        reinforce_one.backward(retain_graph=True)
        reinforce_one_grad.update([param.grad for param in
                                   inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        avg_log_evidence.backward(retain_graph=True)
        two_grad.update([param.grad for param in
                         inference_network.parameters()])
        log_evidence_grad.update([param.grad for param in
                                  generative_model.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        wake_phi_loss.backward(retain_graph=True)
        wake_phi_loss_grad.update([param.grad for param in
                                   inference_network.parameters()])

        inference_network.zero_grad()
        generative_model.zero_grad()
        avg_log_Q.backward(retain_graph=True)
        log_Q_grad.update([param.grad for param in
                           inference_network.parameters()])

        log_evidence_stats.update([avg_log_evidence.unsqueeze(0)])

        sleep_loss = losses.get_sleep_loss(
            generative_model, inference_network, num_particles * len(obss))
        inference_network.zero_grad()
        generative_model.zero_grad()
        sleep_loss.backward()
        sleep_loss_grad.update([param.grad for param in
                                inference_network.parameters()])

    return list(map(
        lambda x: x.avg_of_means_stds(),
        [vimco_grad, vimco_one_grad, reinforce_grad, reinforce_one_grad,
         two_grad, log_evidence_stats, log_evidence_grad, wake_phi_loss_grad,
         log_Q_grad, sleep_loss_grad]))