def train_sleep(generative_model, inference_network, num_samples, num_iterations, log_interval): optimizer_phi = torch.optim.Adam(inference_network.parameters()) sleep_losses = [] device = next(generative_model.parameters()).device if device.type == "cuda": torch.cuda.reset_max_memory_allocated(device=device) util.logging.info("Pretraining with sleep") iteration = 0 while iteration < num_iterations: optimizer_phi.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples) sleep_phi_loss.backward() optimizer_phi.step() sleep_losses.append(sleep_phi_loss.item()) iteration += 1 # by this time, we have gone through `iteration` iterations if iteration % log_interval == 0: util.logging.info( "it. {}/{} | sleep loss = {:.2f} | " "GPU memory = {:.2f} MB".format( iteration, num_iterations, sleep_losses[-1], ( torch.cuda.max_memory_allocated(device=device) / 1e6 if device.type == "cuda" else 0 ), ) ) if iteration == num_iterations: break
def get_q_error(generative_model, inference_network, num_samples=100): """Expected KL(posterior || q) + const as a measure of q's quality. Returns: detached scalar E_p(x)[KL(p(z | x) || q(z | x))] + H(z | x) where the second term is constant wrt the inference network. """ return losses.get_sleep_loss(generative_model, inference_network, num_samples).detach()
def get_grads_correct_sleep(seed): util.set_seed(seed) theta_grads_correct = [] phi_grads_correct = [] log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obs, num_particles) optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight( log_weight) wake_theta_loss.backward(retain_graph=True) theta_grads_correct = [ parameter.grad.clone() for parameter in generative_model.parameters() ] # in rws, we step as we compute the grads # optimizer_theta.step() optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q( log_weight, log_q) wake_phi_loss.backward() wake_phi_grads_correct = [ parameter.grad.clone() for parameter in inference_network.parameters() ] # in rws, we step as we compute the grads # optimizer_phi.step() optimizer_phi.zero_grad() optimizer_theta.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples=num_particles) sleep_phi_loss.backward() sleep_phi_grads_correct = [ parameter.grad.clone() for parameter in inference_network.parameters() ] wake_factor = 0.755 phi_grads_correct = [ wake_factor * wake_phi_grad_correct + (1 - wake_factor) * sleep_phi_grad_correct for wake_phi_grad_correct, sleep_phi_grad_correct in zip( wake_phi_grads_correct, sleep_phi_grads_correct) ] return theta_grads_correct, phi_grads_correct
def train_sleep(generative_model, inference_network, num_samples, num_iterations, callback=None): optimizer = torch.optim.Adam(inference_network.parameters()) for iteration in range(num_iterations): optimizer.zero_grad() sleep_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples=num_samples) sleep_loss.backward() optimizer.step() if callback is not None: callback(iteration, sleep_loss.item(), generative_model, inference_network, optimizer) return optimizer
def train_wake_sleep(generative_model, inference_network, data_loader, num_iterations, num_particles, optim_kwargs, callback=None): optimizer_phi = torch.optim.Adam(inference_network.parameters(), **optim_kwargs) optimizer_theta = torch.optim.Adam(generative_model.parameters(), **optim_kwargs) iteration = 0 while iteration < num_iterations: for obs in iter(data_loader): # wake theta optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss( generative_model, inference_network, obs, num_particles) wake_theta_loss.backward() optimizer_theta.step() # sleep phi optimizer_phi.zero_grad() optimizer_theta.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples=obs.shape[0] * num_particles) sleep_phi_loss.backward() optimizer_phi.step() if callback is not None: callback(iteration, wake_theta_loss.item(), sleep_phi_loss.item(), elbo.item(), generative_model, inference_network, optimizer_theta, optimizer_phi) iteration += 1 # by this time, we have gone through `iteration` iterations if iteration == num_iterations: break
def get_grads_weird_detach_sleep(seed): util.set_seed(seed) theta_grads_in_one = [] phi_grads_in_one = [] log_weight, log_q = get_log_weight_and_log_q_weird_detach( generative_model, inference_network, obs, num_particles) optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight( log_weight) # optimizer_phi.zero_grad() -> don't zero phi grads # optimizer_theta.zero_grad() wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q( log_weight, log_q) sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples=num_particles) wake_factor = 0.755 phi_loss = wake_factor * wake_phi_loss + (1 - wake_factor) * sleep_phi_loss wake_theta_loss.backward(retain_graph=True) phi_loss.backward() # only get the grads in the end! theta_grads_in_one = [ parameter.grad.clone() for parameter in generative_model.parameters() ] phi_grads_in_one = [ parameter.grad.clone() for parameter in inference_network.parameters() ] # in pyro, we want step to be in a different stage # optimizer_theta.step() # optimizer_phi.step() return theta_grads_in_one, phi_grads_in_one
def train_wake_sleep(generative_model, inference_network, obss_data_loader, num_iterations, num_particles, callback=None): num_samples = obss_data_loader.batch_size * num_particles optimizer_phi = torch.optim.Adam(inference_network.parameters()) optimizer_theta = torch.optim.Adam(generative_model.parameters()) obss_iter = iter(obss_data_loader) for iteration in range(num_iterations): # get obss try: obss = next(obss_iter) except StopIteration: obss_iter = iter(obss_data_loader) obss = next(obss_iter) # wake theta optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss( generative_model, inference_network, obss, num_particles) wake_theta_loss.backward() optimizer_theta.step() # sleep phi optimizer_phi.zero_grad() optimizer_theta.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples) sleep_phi_loss.backward() optimizer_phi.step() if callback is not None: callback(iteration, wake_theta_loss.item(), sleep_phi_loss.item(), elbo.item(), generative_model, inference_network, optimizer_theta, optimizer_phi) return optimizer_theta, optimizer_phi
def __call__(self, iteration, wake_theta_loss, sleep_phi_loss, elbo, generative_model, inference_network, optimizer_theta, optimizer_phi): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = ' '{:.3f}'.format(iteration, wake_theta_loss, sleep_phi_loss, elbo)) self.wake_theta_loss_history.append(wake_theta_loss) self.sleep_phi_loss_history.append(sleep_phi_loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_path = util.get_stats_path(self.save_dir) util.save_object(self, stats_path) util.save_checkpoint(self.save_dir, iteration, generative_model=generative_model, inference_network=inference_network) if iteration % self.eval_interval == 0: log_p, kl = eval_gen_inf(generative_model, inference_network, self.test_data_loader, self.eval_num_particles) self.log_p_history.append(log_p) self.kl_history.append(kl) stats = util.OnlineMeanStd() for _ in range(10): inference_network.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, self.num_samples) sleep_phi_loss.backward() stats.update([p.grad for p in inference_network.parameters()]) self.grad_std_history.append(stats.avg_of_means_stds()[1].item()) util.print_with_time( 'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format( iteration, self.log_p_history[-1], self.kl_history[-1]))
def train_wake_sleep(generative_model, inference_network, true_generative_model, batch_size, num_iterations, num_particles, callback=None): num_samples = batch_size * num_particles optimizer_phi = torch.optim.Adam(inference_network.parameters()) optimizer_theta = torch.optim.Adam(generative_model.parameters()) for iteration in range(num_iterations): # generate synthetic data obss = [true_generative_model.sample_obs() for _ in range(batch_size)] # wake theta optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss( generative_model, inference_network, obss, num_particles) wake_theta_loss.backward() optimizer_theta.step() # sleep phi optimizer_phi.zero_grad() optimizer_theta.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples) sleep_phi_loss.backward() optimizer_phi.step() if callback is not None: callback(iteration, wake_theta_loss.item(), sleep_phi_loss.item(), elbo.item(), generative_model, inference_network, optimizer_theta, optimizer_phi) return optimizer_theta, optimizer_phi
def __call__(self, iteration, wake_theta_loss, sleep_phi_loss, elbo, generative_model, inference_network, optimizer_theta, optimizer_phi): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = ' '{:.3f}'.format(iteration, wake_theta_loss, sleep_phi_loss, elbo)) self.wake_theta_loss_history.append(wake_theta_loss) self.sleep_phi_loss_history.append(sleep_phi_loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_filename = util.get_stats_path(self.model_folder) util.save_object(self, stats_filename) util.save_models(generative_model, inference_network, self.model_folder, iteration) if iteration % self.eval_interval == 0: self.p_error_history.append( util.get_p_error(self.true_generative_model, generative_model)) self.q_error_history.append( util.get_q_error(self.true_generative_model, inference_network, self.test_obss)) stats = util.OnlineMeanStd() for _ in range(10): inference_network.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, self.num_samples) sleep_phi_loss.backward() stats.update([p.grad for p in inference_network.parameters()]) self.grad_std_history.append(stats.avg_of_means_stds()[1]) util.print_with_time( 'Iteration {} p_error = {:.3f}, q_error_to_true = ' '{:.3f}'.format(iteration, self.p_error_history[-1], self.q_error_history[-1]))
def get_mean_stds(generative_model, inference_network, num_mc_samples, obss, num_particles): vimco_grad = util.OnlineMeanStd() vimco_one_grad = util.OnlineMeanStd() reinforce_grad = util.OnlineMeanStd() reinforce_one_grad = util.OnlineMeanStd() two_grad = util.OnlineMeanStd() log_evidence_stats = util.OnlineMeanStd() log_evidence_grad = util.OnlineMeanStd() wake_phi_loss_grad = util.OnlineMeanStd() log_Q_grad = util.OnlineMeanStd() sleep_loss_grad = util.OnlineMeanStd() for mc_sample_idx in range(num_mc_samples): util.print_with_time('MC sample {}'.format(mc_sample_idx)) log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obss, num_particles) log_evidence = torch.logsumexp(log_weight, dim=1) - \ np.log(num_particles) avg_log_evidence = torch.mean(log_evidence) log_Q = torch.sum(log_q, dim=1) avg_log_Q = torch.mean(log_Q) reinforce_one = torch.mean(log_evidence.detach() * log_Q) reinforce = reinforce_one + avg_log_evidence vimco_one = 0 for i in range(num_particles): log_weight_ = log_weight[:, util.range_except(num_particles, i)] control_variate = torch.logsumexp( torch.cat([log_weight_, torch.mean(log_weight_, dim=1, keepdim=True)], dim=1), dim=1) vimco_one = vimco_one + (log_evidence.detach() - control_variate.detach()) * log_q[:, i] vimco_one = torch.mean(vimco_one) vimco = vimco_one + avg_log_evidence normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1) wake_phi_loss = torch.mean( -torch.sum(normalized_weight.detach() * log_q, dim=1)) inference_network.zero_grad() generative_model.zero_grad() vimco.backward(retain_graph=True) vimco_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() vimco_one.backward(retain_graph=True) vimco_one_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() reinforce.backward(retain_graph=True) reinforce_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() reinforce_one.backward(retain_graph=True) reinforce_one_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() avg_log_evidence.backward(retain_graph=True) two_grad.update([param.grad for param in inference_network.parameters()]) log_evidence_grad.update([param.grad for param in generative_model.parameters()]) inference_network.zero_grad() generative_model.zero_grad() wake_phi_loss.backward(retain_graph=True) wake_phi_loss_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() avg_log_Q.backward(retain_graph=True) log_Q_grad.update([param.grad for param in inference_network.parameters()]) log_evidence_stats.update([avg_log_evidence.unsqueeze(0)]) sleep_loss = losses.get_sleep_loss( generative_model, inference_network, num_particles * len(obss)) inference_network.zero_grad() generative_model.zero_grad() sleep_loss.backward() sleep_loss_grad.update([param.grad for param in inference_network.parameters()]) return list(map( lambda x: x.avg_of_means_stds(), [vimco_grad, vimco_one_grad, reinforce_grad, reinforce_one_grad, two_grad, log_evidence_stats, log_evidence_grad, wake_phi_loss_grad, log_Q_grad, sleep_loss_grad]))