def __call__(self, iteration, loss, elbo, generative_model, inference_network, control_variate): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format( iteration, loss, elbo)) self.loss_history.append(loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_filename = util.get_stats_filename(self.model_folder) util.save_object(self, stats_filename) util.save_models(generative_model, inference_network, self.pcfg_path, self.model_folder) util.save_control_variate(control_variate, self.model_folder) if iteration % self.eval_interval == 0: self.p_error_history.append( util.get_p_error(self.true_generative_model, generative_model)) self.q_error_to_true_history.append( util.get_q_error(self.true_generative_model, inference_network)) self.q_error_to_model_history.append( util.get_q_error(generative_model, inference_network)) util.print_with_time( 'Iteration {} p_error = {:.3f}, q_error_to_true = {:.3f}, ' 'q_error_to_model = {:.3f}'.format( iteration, self.p_error_history[-1], self.q_error_to_true_history[-1], self.q_error_to_model_history[-1]))
def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo, generative_model, inference_network, optimizer_theta, optimizer_phi): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = ' '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss, elbo)) self.wake_theta_loss_history.append(wake_theta_loss) self.wake_phi_loss_history.append(wake_phi_loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_filename = util.get_stats_filename(self.model_folder) util.save_object(self, stats_filename) util.save_models(generative_model, inference_network, self.pcfg_path, self.model_folder) if iteration % self.eval_interval == 0: self.p_error_history.append( util.get_p_error(self.true_generative_model, generative_model)) self.q_error_to_true_history.append( util.get_q_error(self.true_generative_model, inference_network)) self.q_error_to_model_history.append( util.get_q_error(generative_model, inference_network)) util.print_with_time( 'Iteration {} p_error = {:.3f}, q_error_to_true = {:.3f}, ' 'q_error_to_model = {:.3f}'.format( iteration, self.p_error_history[-1], self.q_error_to_true_history[-1], self.q_error_to_model_history[-1]))
def __call__(self, iteration, theta_loss, phi_loss, generative_model, inference_network, memory, optimizer): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} losses: theta = {:.3f}, phi = {:.3f}'.format( iteration, theta_loss, phi_loss)) self.theta_loss_history.append(theta_loss) self.phi_loss_history.append(phi_loss) if iteration % self.checkpoint_interval == 0: stats_filename = util.get_stats_path(self.model_folder) util.save_object(self, stats_filename) util.save_models(generative_model, inference_network, self.model_folder, iteration, memory) if iteration % self.eval_interval == 0: self.p_error_history.append( util.get_p_error(self.true_generative_model, generative_model)) self.q_error_history.append( util.get_q_error(self.true_generative_model, inference_network, self.test_obss)) # TODO # self.memory_error_history.append(util.get_memory_error( # self.true_generative_model, memory, generative_model, # self.test_obss)) util.print_with_time( 'Iteration {} p_error = {:.3f}, q_error_to_true = ' '{:.3f}'.format(iteration, self.p_error_history[-1], self.q_error_history[-1]))
def __call__(self, iteration, loss, elbo, generative_model, inference_network, optimizer): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format( iteration, loss, elbo)) self.loss_history.append(loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_filename = util.get_stats_path(self.model_folder) util.save_object(self, stats_filename) util.save_models(generative_model, inference_network, self.model_folder, iteration) if iteration % self.eval_interval == 0: self.p_error_history.append( util.get_p_error(self.true_generative_model, generative_model)) self.q_error_history.append( util.get_q_error(self.true_generative_model, inference_network, self.test_obss)) stats = util.OnlineMeanStd() for _ in range(10): inference_network.zero_grad() if self.train_mode == 'vimco': loss, elbo = losses.get_vimco_loss(generative_model, inference_network, self.test_obss, self.num_particles) elif self.train_mode == 'reinforce': loss, elbo = losses.get_reinforce_loss( generative_model, inference_network, self.test_obss, self.num_particles) loss.backward() stats.update([p.grad for p in inference_network.parameters()]) self.grad_std_history.append(stats.avg_of_means_stds()[1]) util.print_with_time( 'Iteration {} p_error = {:.3f}, q_error_to_true = ' '{:.3f}'.format(iteration, self.p_error_history[-1], self.q_error_history[-1]))