def get_grads_in_one_no_zeroing(seed): util.set_seed(seed) theta_grads_in_one = [] phi_grads_in_one = [] log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obs, num_particles) optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight( log_weight) wake_theta_loss.backward(retain_graph=True) # optimizer_phi.zero_grad() -> don't zero phi grads # optimizer_theta.zero_grad() wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q( log_weight, log_q) wake_phi_loss.backward() # only get the grads in the end! theta_grads_in_one = [parameter.grad.clone() for parameter in generative_model.parameters()] phi_grads_in_one = [parameter.grad.clone() for parameter in inference_network.parameters()] # in pyro, we want step to be in a different stage # optimizer_theta.step() # optimizer_phi.step() return theta_grads_in_one, phi_grads_in_one
def get_grads_correct(seed): util.set_seed(seed) theta_grads_correct = [] phi_grads_correct = [] log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obs, num_particles) optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight( log_weight) wake_theta_loss.backward(retain_graph=True) theta_grads_correct = [parameter.grad.clone() for parameter in generative_model.parameters()] # in rws, we step as we compute the grads # optimizer_theta.step() optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q( log_weight, log_q) wake_phi_loss.backward() phi_grads_correct = [parameter.grad.clone() for parameter in inference_network.parameters()] # in rws, we step as we compute the grads # optimizer_phi.step() return theta_grads_correct, phi_grads_correct
def get_grads_correct_sleep(seed): util.set_seed(seed) theta_grads_correct = [] phi_grads_correct = [] log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obs, num_particles) optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight( log_weight) wake_theta_loss.backward(retain_graph=True) theta_grads_correct = [ parameter.grad.clone() for parameter in generative_model.parameters() ] # in rws, we step as we compute the grads # optimizer_theta.step() optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q( log_weight, log_q) wake_phi_loss.backward() wake_phi_grads_correct = [ parameter.grad.clone() for parameter in inference_network.parameters() ] # in rws, we step as we compute the grads # optimizer_phi.step() optimizer_phi.zero_grad() optimizer_theta.zero_grad() sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples=num_particles) sleep_phi_loss.backward() sleep_phi_grads_correct = [ parameter.grad.clone() for parameter in inference_network.parameters() ] wake_factor = 0.755 phi_grads_correct = [ wake_factor * wake_phi_grad_correct + (1 - wake_factor) * sleep_phi_grad_correct for wake_phi_grad_correct, sleep_phi_grad_correct in zip( wake_phi_grads_correct, sleep_phi_grads_correct) ] return theta_grads_correct, phi_grads_correct
def train_wake_wake(generative_model, inference_network, data_loader, num_iterations, num_particles, optim_kwargs, callback=None): optimizer_phi = torch.optim.Adam(inference_network.parameters(), **optim_kwargs) optimizer_theta = torch.optim.Adam(generative_model.parameters(), **optim_kwargs) iteration = 0 while iteration < num_iterations: for obs in iter(data_loader): log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obs, num_particles) # wake theta optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight( log_weight) wake_theta_loss.backward(retain_graph=True) optimizer_theta.step() # wake phi optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q( log_weight, log_q) wake_phi_loss.backward() optimizer_phi.step() if callback is not None: callback(iteration, wake_theta_loss.item(), wake_phi_loss.item(), elbo.item(), generative_model, inference_network, optimizer_theta, optimizer_phi) iteration += 1 # by this time, we have gone through `iteration` iterations if iteration == num_iterations: break
def train_defensive_wake_wake(delta, generative_model, inference_network, obss_data_loader, num_iterations, num_particles, callback=None): optimizer_phi = torch.optim.Adam(inference_network.parameters()) optimizer_theta = torch.optim.Adam(generative_model.parameters()) obss_iter = iter(obss_data_loader) for iteration in range(num_iterations): # get obss try: obss = next(obss_iter) except StopIteration: obss_iter = iter(obss_data_loader) obss = next(obss_iter) log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obss, num_particles) # wake theta optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight( log_weight) wake_theta_loss.backward(retain_graph=True) optimizer_theta.step() # wake phi optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_phi_loss = losses.get_defensive_wake_phi_loss( generative_model, inference_network, obss, delta, num_particles) wake_phi_loss.backward() optimizer_phi.step() if callback is not None: callback(iteration, wake_theta_loss.item(), wake_phi_loss.item(), elbo.item(), generative_model, inference_network, optimizer_theta, optimizer_phi) return optimizer_theta, optimizer_phi
def get_grads_weird_detach_sleep(seed): util.set_seed(seed) theta_grads_in_one = [] phi_grads_in_one = [] log_weight, log_q = get_log_weight_and_log_q_weird_detach( generative_model, inference_network, obs, num_particles) optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight( log_weight) # optimizer_phi.zero_grad() -> don't zero phi grads # optimizer_theta.zero_grad() wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q( log_weight, log_q) sleep_phi_loss = losses.get_sleep_loss(generative_model, inference_network, num_samples=num_particles) wake_factor = 0.755 phi_loss = wake_factor * wake_phi_loss + (1 - wake_factor) * sleep_phi_loss wake_theta_loss.backward(retain_graph=True) phi_loss.backward() # only get the grads in the end! theta_grads_in_one = [ parameter.grad.clone() for parameter in generative_model.parameters() ] phi_grads_in_one = [ parameter.grad.clone() for parameter in inference_network.parameters() ] # in pyro, we want step to be in a different stage # optimizer_theta.step() # optimizer_phi.step() return theta_grads_in_one, phi_grads_in_one
def train_wake_wake(generative_model, inference_network, true_generative_model, batch_size, num_iterations, num_particles, callback=None): optimizer_phi = torch.optim.Adam(inference_network.parameters()) optimizer_theta = torch.optim.Adam(generative_model.parameters()) for iteration in range(num_iterations): # generate synthetic data obss = [true_generative_model.sample_obs() for _ in range(batch_size)] log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obss, num_particles) # wake theta optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight( log_weight) wake_theta_loss.backward(retain_graph=True) optimizer_theta.step() # wake phi optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q( log_weight, log_q) wake_phi_loss.backward() optimizer_phi.step() if callback is not None: callback(iteration, wake_theta_loss.item(), wake_phi_loss.item(), elbo.item(), generative_model, inference_network, optimizer_theta, optimizer_phi) return optimizer_theta, optimizer_phi
def train_rws( generative_model, inference_network, data_loader, num_iterations, num_particles, true_cluster_cov, test_data_loader, test_num_particles, true_generative_model, checkpoint_path, ): optimizer_phi = torch.optim.Adam(inference_network.parameters()) optimizer_theta = torch.optim.Adam(generative_model.parameters()) ( theta_losses, phi_losses, cluster_cov_distances, test_log_ps, test_log_ps_true, test_kl_qps, test_kl_pqs, test_kl_qps_true, test_kl_pqs_true, train_log_ps, train_log_ps_true, train_kl_qps, train_kl_pqs, train_kl_qps_true, train_kl_pqs_true, reweighted_train_kl_qps, reweighted_train_kl_qps_true, ) = ([], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []) data_loader_iter = iter(data_loader) for iteration in range(num_iterations): # get obs try: obs = next(data_loader_iter) except StopIteration: data_loader_iter = iter(data_loader) obs = next(data_loader_iter) log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obs, num_particles) # wake theta optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss_from_log_weight( log_weight) wake_theta_loss.backward(retain_graph=True) optimizer_theta.step() # wake phi optimizer_phi.zero_grad() optimizer_theta.zero_grad() wake_phi_loss = losses.get_wake_phi_loss_from_log_weight_and_log_q( log_weight, log_q) wake_phi_loss.backward() optimizer_phi.step() theta_losses.append(wake_theta_loss.item()) phi_losses.append(wake_phi_loss.item()) cluster_cov_distances.append( torch.norm(true_cluster_cov - generative_model.get_cluster_cov()).item()) if iteration % 100 == 0: # test every 100 iterations ( test_log_p, test_log_p_true, test_kl_qp, test_kl_pq, test_kl_qp_true, test_kl_pq_true, _, _, ) = models.eval_gen_inf(true_generative_model, generative_model, inference_network, None, test_data_loader) test_log_ps.append(test_log_p) test_log_ps_true.append(test_log_p_true) test_kl_qps.append(test_kl_qp) test_kl_pqs.append(test_kl_pq) test_kl_qps_true.append(test_kl_qp_true) test_kl_pqs_true.append(test_kl_pq_true) ( train_log_p, train_log_p_true, train_kl_qp, train_kl_pq, train_kl_qp_true, train_kl_pq_true, _, _, reweighted_train_kl_qp, reweighted_train_kl_qp_true, ) = models.eval_gen_inf( true_generative_model, generative_model, inference_network, None, data_loader, num_particles=num_particles, reweighted_kl=True, ) train_log_ps.append(train_log_p) train_log_ps_true.append(train_log_p_true) train_kl_qps.append(train_kl_qp) train_kl_pqs.append(train_kl_pq) train_kl_qps_true.append(train_kl_qp_true) train_kl_pqs_true.append(train_kl_pq_true) reweighted_train_kl_qps.append(reweighted_train_kl_qp) reweighted_train_kl_qps_true.append(reweighted_train_kl_qp_true) util.save_checkpoint( checkpoint_path, generative_model, inference_network, theta_losses, phi_losses, cluster_cov_distances, test_log_ps, test_log_ps_true, test_kl_qps, test_kl_pqs, test_kl_qps_true, test_kl_pqs_true, train_log_ps, train_log_ps_true, train_kl_qps, train_kl_pqs, train_kl_qps_true, train_kl_pqs_true, None, None, None, reweighted_train_kl_qps, reweighted_train_kl_qps_true, ) util.print_with_time( "it. {} | theta loss = {:.2f} | phi loss = {:.2f}".format( iteration, wake_theta_loss, wake_phi_loss)) # if iteration % 200 == 0: # z = inference_network.get_latent_dist(obs).sample() # util.save_plot("images/rws/iteration_{}.png".format(iteration), # obs[:3], z[:3]) return ( theta_losses, phi_losses, cluster_cov_distances, test_log_ps, test_log_ps_true, test_kl_qps, test_kl_pqs, test_kl_qps_true, test_kl_pqs_true, train_log_ps, train_log_ps_true, train_kl_qps, train_kl_pqs, train_kl_qps_true, train_kl_pqs_true, None, None, None, reweighted_train_kl_qps, reweighted_train_kl_qps_true, )