def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo, generative_model, inference_network, optimizer_theta, optimizer_phi): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = ' '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss, elbo)) self.wake_theta_loss_history.append(wake_theta_loss) self.wake_phi_loss_history.append(wake_phi_loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_path = util.get_stats_path(self.save_dir) util.save_object(self, stats_path) util.save_checkpoint(self.save_dir, iteration, generative_model=generative_model, inference_network=inference_network) if iteration % self.eval_interval == 0: log_p, kl = eval_gen_inf(generative_model, inference_network, self.test_data_loader, self.eval_num_particles) self.log_p_history.append(log_p) self.kl_history.append(kl) stats = util.OnlineMeanStd() for _ in range(10): generative_model.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss( generative_model, inference_network, self.test_obs, self.num_particles) wake_theta_loss.backward() theta_grads = [ p.grad.clone() for p in generative_model.parameters() ] inference_network.zero_grad() wake_phi_loss = losses.get_wake_phi_loss( generative_model, inference_network, self.test_obs, self.num_particles) wake_phi_loss.backward() phi_grads = [p.grad for p in inference_network.parameters()] stats.update(theta_grads + phi_grads) self.grad_std_history.append(stats.avg_of_means_stds()[1].item()) util.print_with_time( 'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format( iteration, self.log_p_history[-1], self.kl_history[-1]))
def train_wake_sleep(generative_model, inference_network, data_loader, num_iterations, num_particles, optim_kwargs, callback=None): optimizer_phi = torch.optim.Adam(inference_network.parameters(), **optim_kwargs) optimizer_theta = torch.optim.Adam(generative_model.parameters(), **optim_kwargs) iteration = 0 while iteration < num_iterations: for obs in iter(data_loader): # wake theta optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss( generative_model, inference_network, obs, num_particles) wake_theta_loss.backward() optimizer_theta.step() # sleep phi optimizer_phi.zero_grad() optimizer_theta.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples=obs.shape[0] * num_particles) sleep_phi_loss.backward() optimizer_phi.step() if callback is not None: callback(iteration, wake_theta_loss.item(), sleep_phi_loss.item(), elbo.item(), generative_model, inference_network, optimizer_theta, optimizer_phi) iteration += 1 # by this time, we have gone through `iteration` iterations if iteration == num_iterations: break
def train_wake_sleep(generative_model, inference_network, obss_data_loader, num_iterations, num_particles, callback=None): num_samples = obss_data_loader.batch_size * num_particles optimizer_phi = torch.optim.Adam(inference_network.parameters()) optimizer_theta = torch.optim.Adam(generative_model.parameters()) obss_iter = iter(obss_data_loader) for iteration in range(num_iterations): # get obss try: obss = next(obss_iter) except StopIteration: obss_iter = iter(obss_data_loader) obss = next(obss_iter) # wake theta optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss( generative_model, inference_network, obss, num_particles) wake_theta_loss.backward() optimizer_theta.step() # sleep phi optimizer_phi.zero_grad() optimizer_theta.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples) sleep_phi_loss.backward() optimizer_phi.step() if callback is not None: callback(iteration, wake_theta_loss.item(), sleep_phi_loss.item(), elbo.item(), generative_model, inference_network, optimizer_theta, optimizer_phi) return optimizer_theta, optimizer_phi
def train_wake_sleep(generative_model, inference_network, true_generative_model, batch_size, num_iterations, num_particles, callback=None): num_samples = batch_size * num_particles optimizer_phi = torch.optim.Adam(inference_network.parameters()) optimizer_theta = torch.optim.Adam(generative_model.parameters()) for iteration in range(num_iterations): # generate synthetic data obss = [true_generative_model.sample_obs() for _ in range(batch_size)] # wake theta optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss( generative_model, inference_network, obss, num_particles) wake_theta_loss.backward() optimizer_theta.step() # sleep phi optimizer_phi.zero_grad() optimizer_theta.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples) sleep_phi_loss.backward() optimizer_phi.step() if callback is not None: callback(iteration, wake_theta_loss.item(), sleep_phi_loss.item(), elbo.item(), generative_model, inference_network, optimizer_theta, optimizer_phi) return optimizer_theta, optimizer_phi