def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo, generative_model, inference_network, optimizer_theta, optimizer_phi): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = ' '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss, elbo)) self.wake_theta_loss_history.append(wake_theta_loss) self.wake_phi_loss_history.append(wake_phi_loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_path = util.get_stats_path(self.save_dir) util.save_object(self, stats_path) util.save_checkpoint(self.save_dir, iteration, generative_model=generative_model, inference_network=inference_network) if iteration % self.eval_interval == 0: log_p, kl = eval_gen_inf(generative_model, inference_network, self.test_data_loader, self.eval_num_particles) self.log_p_history.append(log_p) self.kl_history.append(kl) stats = util.OnlineMeanStd() for _ in range(10): generative_model.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss( generative_model, inference_network, self.test_obs, self.num_particles) wake_theta_loss.backward() theta_grads = [ p.grad.clone() for p in generative_model.parameters() ] inference_network.zero_grad() wake_phi_loss = losses.get_wake_phi_loss( generative_model, inference_network, self.test_obs, self.num_particles) wake_phi_loss.backward() phi_grads = [p.grad for p in inference_network.parameters()] stats.update(theta_grads + phi_grads) self.grad_std_history.append(stats.avg_of_means_stds()[1].item()) util.print_with_time( 'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format( iteration, self.log_p_history[-1], self.kl_history[-1]))
def __call__(self, iteration, loss, elbo, generative_model, inference_network, optimizer): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format( iteration, loss, elbo)) self.loss_history.append(loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_path = util.get_stats_path(self.save_dir) util.save_object(self, stats_path) util.save_checkpoint(self.save_dir, iteration, generative_model=generative_model, inference_network=inference_network) if iteration % self.eval_interval == 0: log_p, kl = eval_gen_inf(generative_model, inference_network, self.test_data_loader, self.eval_num_particles) _, renyi = eval_gen_inf_alpha(generative_model, inference_network, self.test_data_loader, self.eval_num_particles, self.alpha) self.log_p_history.append(log_p) self.kl_history.append(kl) self.renyi_history.append(renyi) stats = util.OnlineMeanStd() for _ in range(10): generative_model.zero_grad() inference_network.zero_grad() loss, elbo = losses.get_thermo_alpha_loss( generative_model, inference_network, self.test_obs, self.partition, self.num_particles, self.alpha, self.integration) loss.backward() stats.update([p.grad for p in generative_model.parameters()] + [p.grad for p in inference_network.parameters()]) self.grad_std_history.append(stats.avg_of_means_stds()[1].item()) util.print_with_time( 'Iteration {} log_p = {:.3f}, kl = {:.3f}, renyi = {:.3f}'. format(iteration, self.log_p_history[-1], self.kl_history[-1], self.renyi_history[-1]))
def __call__(self, iteration, loss, elbo, generative_model, inference_network, optimizer): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format( iteration, loss, elbo)) self.loss_history.append(loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_filename = util.get_stats_path(self.model_folder) util.save_object(self, stats_filename) util.save_models(generative_model, inference_network, self.model_folder, iteration) if iteration % self.eval_interval == 0: self.p_error_history.append( util.get_p_error(self.true_generative_model, generative_model)) self.q_error_history.append( util.get_q_error(self.true_generative_model, inference_network, self.test_obss)) stats = util.OnlineMeanStd() for _ in range(10): inference_network.zero_grad() if self.train_mode == 'vimco': loss, elbo = losses.get_vimco_loss(generative_model, inference_network, self.test_obss, self.num_particles) elif self.train_mode == 'reinforce': loss, elbo = losses.get_reinforce_loss( generative_model, inference_network, self.test_obss, self.num_particles) loss.backward() stats.update([p.grad for p in inference_network.parameters()]) self.grad_std_history.append(stats.avg_of_means_stds()[1]) util.print_with_time( 'Iteration {} p_error = {:.3f}, q_error_to_true = ' '{:.3f}'.format(iteration, self.p_error_history[-1], self.q_error_history[-1]))
def get_mean_stds(generative_model, inference_network, num_mc_samples, obss, num_particles): vimco_grad = util.OnlineMeanStd() vimco_one_grad = util.OnlineMeanStd() reinforce_grad = util.OnlineMeanStd() reinforce_one_grad = util.OnlineMeanStd() two_grad = util.OnlineMeanStd() log_evidence_stats = util.OnlineMeanStd() log_evidence_grad = util.OnlineMeanStd() wake_phi_loss_grad = util.OnlineMeanStd() log_Q_grad = util.OnlineMeanStd() sleep_loss_grad = util.OnlineMeanStd() for mc_sample_idx in range(num_mc_samples): util.print_with_time('MC sample {}'.format(mc_sample_idx)) log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obss, num_particles) log_evidence = torch.logsumexp(log_weight, dim=1) - \ np.log(num_particles) avg_log_evidence = torch.mean(log_evidence) log_Q = torch.sum(log_q, dim=1) avg_log_Q = torch.mean(log_Q) reinforce_one = torch.mean(log_evidence.detach() * log_Q) reinforce = reinforce_one + avg_log_evidence vimco_one = 0 for i in range(num_particles): log_weight_ = log_weight[:, util.range_except(num_particles, i)] control_variate = torch.logsumexp( torch.cat([log_weight_, torch.mean(log_weight_, dim=1, keepdim=True)], dim=1), dim=1) vimco_one = vimco_one + (log_evidence.detach() - control_variate.detach()) * log_q[:, i] vimco_one = torch.mean(vimco_one) vimco = vimco_one + avg_log_evidence normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1) wake_phi_loss = torch.mean( -torch.sum(normalized_weight.detach() * log_q, dim=1)) inference_network.zero_grad() generative_model.zero_grad() vimco.backward(retain_graph=True) vimco_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() vimco_one.backward(retain_graph=True) vimco_one_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() reinforce.backward(retain_graph=True) reinforce_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() reinforce_one.backward(retain_graph=True) reinforce_one_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() avg_log_evidence.backward(retain_graph=True) two_grad.update([param.grad for param in inference_network.parameters()]) log_evidence_grad.update([param.grad for param in generative_model.parameters()]) inference_network.zero_grad() generative_model.zero_grad() wake_phi_loss.backward(retain_graph=True) wake_phi_loss_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() avg_log_Q.backward(retain_graph=True) log_Q_grad.update([param.grad for param in inference_network.parameters()]) log_evidence_stats.update([avg_log_evidence.unsqueeze(0)]) sleep_loss = losses.get_sleep_loss( generative_model, inference_network, num_particles * len(obss)) inference_network.zero_grad() generative_model.zero_grad() sleep_loss.backward() sleep_loss_grad.update([param.grad for param in inference_network.parameters()]) return list(map( lambda x: x.avg_of_means_stds(), [vimco_grad, vimco_one_grad, reinforce_grad, reinforce_one_grad, two_grad, log_evidence_stats, log_evidence_grad, wake_phi_loss_grad, log_Q_grad, sleep_loss_grad]))