def train_iwae(algorithm, generative_model, inference_network, data_loader, num_iterations, num_particles, optim_kwargs, callback=None): parameters = itertools.chain.from_iterable( [x.parameters() for x in [generative_model, inference_network]]) optimizer = torch.optim.Adam(parameters, **optim_kwargs) iteration = 0 while iteration < num_iterations: for obs in iter(data_loader): optimizer.zero_grad() if algorithm == 'vimco': loss, elbo = losses.get_vimco_loss(generative_model, inference_network, obs, num_particles) elif algorithm == 'reinforce': loss, elbo = losses.get_reinforce_loss(generative_model, inference_network, obs, num_particles) loss.backward() optimizer.step() if callback is not None: callback(iteration, loss.item(), elbo.item(), generative_model, inference_network, optimizer) iteration += 1 # by this time, we have gone through `iteration` iterations if iteration == num_iterations: break
def train_iwae(algorithm, generative_model, inference_network, obss_data_loader, num_iterations, num_particles, callback=None): """Train using IWAE objective. Args: algorithm: reinforce, vimco or concrete """ parameters = itertools.chain.from_iterable( [x.parameters() for x in [generative_model, inference_network]]) optimizer = torch.optim.Adam(parameters) obss_iter = iter(obss_data_loader) for iteration in range(num_iterations): # get obss try: obss = next(obss_iter) except StopIteration: obss_iter = iter(obss_data_loader) obss = next(obss_iter) # wake theta optimizer.zero_grad() if algorithm == 'vimco': loss, elbo = losses.get_vimco_loss(generative_model, inference_network, obss, num_particles) elif algorithm == 'reinforce': loss, elbo = losses.get_reinforce_loss(generative_model, inference_network, obss, num_particles) elif algorithm == 'concrete': loss, elbo = losses.get_concrete_loss(generative_model, inference_network, obss, num_particles) loss.backward() optimizer.step() if callback is not None: callback(iteration, loss.item(), elbo.item(), generative_model, inference_network, optimizer) return optimizer
def __call__(self, iteration, loss, elbo, generative_model, inference_network, optimizer): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format( iteration, loss, elbo)) self.loss_history.append(loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_path = util.get_stats_path(self.save_dir) util.save_object(self, stats_path) util.save_checkpoint(self.save_dir, iteration, generative_model=generative_model, inference_network=inference_network) if iteration % self.eval_interval == 0: log_p, kl = eval_gen_inf(generative_model, inference_network, self.test_data_loader, self.eval_num_particles) self.log_p_history.append(log_p) self.kl_history.append(kl) stats = util.OnlineMeanStd() for _ in range(10): generative_model.zero_grad() inference_network.zero_grad() if self.train_mode == 'vimco': loss, elbo = losses.get_vimco_loss(generative_model, inference_network, self.test_obs, self.num_particles) elif self.train_mode == 'reinforce': loss, elbo = losses.get_reinforce_loss( generative_model, inference_network, self.test_obs, self.num_particles) loss.backward() stats.update([p.grad for p in generative_model.parameters()] + [p.grad for p in inference_network.parameters()]) self.grad_std_history.append(stats.avg_of_means_stds()[1].item()) util.print_with_time( 'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format( iteration, self.log_p_history[-1], self.kl_history[-1]))
def train_iwae(algorithm, generative_model, inference_network, true_generative_model, batch_size, num_iterations, num_particles, callback=None): """Train using IWAE objective. Args: algorithm: reinforce or vimco """ parameters = itertools.chain.from_iterable( [x.parameters() for x in [generative_model, inference_network]]) optimizer = torch.optim.Adam(parameters) for iteration in range(num_iterations): # generate synthetic data obss = [true_generative_model.sample_obs() for _ in range(batch_size)] # wake theta optimizer.zero_grad() if algorithm == 'vimco': loss, elbo = losses.get_vimco_loss(generative_model, inference_network, obss, num_particles) elif algorithm == 'reinforce': loss, elbo = losses.get_reinforce_loss(generative_model, inference_network, obss, num_particles) loss.backward() optimizer.step() if callback is not None: callback(iteration, loss.item(), elbo.item(), generative_model, inference_network, optimizer) return optimizer
def __call__(self, iteration, loss, elbo, generative_model, inference_network, optimizer): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format( iteration, loss, elbo)) self.loss_history.append(loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_filename = util.get_stats_path(self.model_folder) util.save_object(self, stats_filename) util.save_models(generative_model, inference_network, self.model_folder, iteration) if iteration % self.eval_interval == 0: self.p_error_history.append( util.get_p_error(self.true_generative_model, generative_model)) self.q_error_history.append( util.get_q_error(self.true_generative_model, inference_network, self.test_obss)) stats = util.OnlineMeanStd() for _ in range(10): inference_network.zero_grad() if self.train_mode == 'vimco': loss, elbo = losses.get_vimco_loss(generative_model, inference_network, self.test_obss, self.num_particles) elif self.train_mode == 'reinforce': loss, elbo = losses.get_reinforce_loss( generative_model, inference_network, self.test_obss, self.num_particles) loss.backward() stats.update([p.grad for p in inference_network.parameters()]) self.grad_std_history.append(stats.avg_of_means_stds()[1]) util.print_with_time( 'Iteration {} p_error = {:.3f}, q_error_to_true = ' '{:.3f}'.format(iteration, self.p_error_history[-1], self.q_error_history[-1]))
def train_vimco( generative_model, inference_network, data_loader, num_iterations, num_particles, true_cluster_cov, test_data_loader, test_num_particles, true_generative_model, checkpoint_path, ): optimizer = torch.optim.Adam( itertools.chain(generative_model.parameters(), inference_network.parameters())) ( theta_losses, phi_losses, cluster_cov_distances, test_log_ps, test_log_ps_true, test_kl_qps, test_kl_pqs, test_kl_qps_true, test_kl_pqs_true, train_log_ps, train_log_ps_true, train_kl_qps, train_kl_pqs, train_kl_qps_true, train_kl_pqs_true, reweighted_train_kl_qps, reweighted_train_kl_qps_true, ) = ([], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []) data_loader_iter = iter(data_loader) for iteration in range(num_iterations): # get obs try: obs = next(data_loader_iter) except StopIteration: data_loader_iter = iter(data_loader) obs = next(data_loader_iter) # loss optimizer.zero_grad() loss, elbo = losses.get_vimco_loss(generative_model, inference_network, obs, num_particles) loss.backward(retain_graph=True) optimizer.step() theta_losses.append(loss.item()) phi_losses.append(loss.item()) cluster_cov_distances.append( torch.norm(true_cluster_cov - generative_model.get_cluster_cov()).item()) if iteration % 100 == 0: # test every 100 iterations ( test_log_p, test_log_p_true, test_kl_qp, test_kl_pq, test_kl_qp_true, test_kl_pq_true, _, _, ) = models.eval_gen_inf(true_generative_model, generative_model, inference_network, None, test_data_loader) test_log_ps.append(test_log_p) test_log_ps_true.append(test_log_p_true) test_kl_qps.append(test_kl_qp) test_kl_pqs.append(test_kl_pq) test_kl_qps_true.append(test_kl_qp_true) test_kl_pqs_true.append(test_kl_pq_true) ( train_log_p, train_log_p_true, train_kl_qp, train_kl_pq, train_kl_qp_true, train_kl_pq_true, _, _, reweighted_train_kl_qp, reweighted_train_kl_qp_true, ) = models.eval_gen_inf( true_generative_model, generative_model, inference_network, None, data_loader, num_particles=num_particles, reweighted_kl=True, ) train_log_ps.append(train_log_p) train_log_ps_true.append(train_log_p_true) train_kl_qps.append(train_kl_qp) train_kl_pqs.append(train_kl_pq) train_kl_qps_true.append(train_kl_qp_true) train_kl_pqs_true.append(train_kl_pq_true) reweighted_train_kl_qps.append(reweighted_train_kl_qp) reweighted_train_kl_qps_true.append(reweighted_train_kl_qp_true) util.save_checkpoint( checkpoint_path, generative_model, inference_network, theta_losses, phi_losses, cluster_cov_distances, test_log_ps, test_log_ps_true, test_kl_qps, test_kl_pqs, test_kl_qps_true, test_kl_pqs_true, train_log_ps, train_log_ps_true, train_kl_qps, train_kl_pqs, train_kl_qps_true, train_kl_pqs_true, None, None, None, reweighted_train_kl_qps, reweighted_train_kl_qps_true, ) util.print_with_time("it. {} | theta loss = {:.2f}".format( iteration, loss)) # if iteration % 200 == 0: # z = inference_network.get_latent_dist(obs).sample() # util.save_plot("images/rws/iteration_{}.png".format(iteration), # obs[:3], z[:3]) return ( theta_losses, phi_losses, cluster_cov_distances, test_log_ps, test_log_ps_true, test_kl_qps, test_kl_pqs, test_kl_qps_true, test_kl_pqs_true, train_log_ps, train_log_ps_true, train_kl_qps, train_kl_pqs, train_kl_qps_true, train_kl_pqs_true, None, None, None, reweighted_train_kl_qps, reweighted_train_kl_qps_true, )