def get_defensive_wake_phi_loss(generative_model, inference_network, obs, delta, num_particles=1): num_mixtures = inference_network.num_mixtures batch_size = len(obs) latent_dist = inference_network.get_latent_dist(obs) latent = inference_network.sample_from_latent_dist(latent_dist, num_particles) latent_uniform = torch.distributions.OneHotCategorical( logits=torch.ones((num_mixtures, ))).sample( (num_particles, batch_size)) catted = torch.cat([x.unsqueeze(0) for x in [latent, latent_uniform]], dim=0) indices = torch.distributions.Bernoulli(probs=torch.tensor(delta)).sample( (num_particles, batch_size)).unsqueeze(0).unsqueeze(-1).expand( 1, num_particles, batch_size, num_mixtures).long() latent_mixture = torch.gather(catted, 0, indices).squeeze(0) log_p = generative_model.get_log_prob(latent_mixture, obs).transpose(0, 1) log_latent = inference_network.get_log_prob_from_latent_dist( latent_dist, latent).transpose(0, 1) log_uniform = torch.ones_like(log_latent) * (-np.log(num_mixtures)) log_q_mixture = util.logaddexp(log_latent + np.log(1 - delta), log_uniform + np.log(delta)) log_q = inference_network.get_log_prob_from_latent_dist( latent_dist, latent_mixture).transpose(0, 1) log_weight = log_p - log_q_mixture normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1) return torch.mean(-torch.sum(normalized_weight.detach() * log_q, dim=1))
def get_wake_phi_loss_from_log_weight_and_log_q(log_weight, log_q): """Returns: loss: scalar that we call .backward() on and step the optimizer. """ normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1) return torch.mean(-torch.sum(normalized_weight.detach() * log_q, dim=1))
def get_thermo_loss_different_samples(generative_model, inference_network, obs, partition=None, num_particles=1, integration='left'): """Thermo loss gradient estimator computed using two set of importance samples. Args: generative_model: models.GenerativeModel object inference_network: models.InferenceNetwork object obs: tensor of shape [batch_size] partition: partition of [0, 1]; tensor of shape [num_partitions + 1] where partition[0] is zero and partition[-1] is one; see https://en.wikipedia.org/wiki/Partition_of_an_interval num_particles: int integration: left, right or trapz Returns: loss: scalar that we call .backward() on and step the optimizer. elbo: average elbo over data """ log_weights, log_ps, log_qs, heated_normalized_weights = [], [], [], [] for _ in range(2): log_weight, log_p, log_q = get_log_weight_log_p_log_q( generative_model, inference_network, obs, num_particles) log_weights.append(log_weight) log_ps.append(log_p) log_qs.append(log_q) heated_log_weight = log_weight.unsqueeze(-1) * partition heated_normalized_weights.append( util.exponentiate_and_normalize(heated_log_weight, dim=1)) w_detached = heated_normalized_weights[0].detach() thermo_logp = partition * log_ps[0].unsqueeze(-1) + \ (1 - partition) * log_qs[0].unsqueeze(-1) wf = heated_normalized_weights[1] * log_weights[1].unsqueeze(-1) if num_particles == 1: correction = 1 else: correction = num_particles / (num_particles - 1) thing_to_add = correction * torch.sum( w_detached * (log_weight.unsqueeze(-1) - torch.sum(wf, dim=1, keepdim=True)).detach() * thermo_logp, dim=1) multiplier = torch.zeros_like(partition) if integration == 'trapz': multiplier[0] = 0.5 * (partition[1] - partition[0]) multiplier[1:-1] = 0.5 * (partition[2:] - partition[0:-2]) multiplier[-1] = 0.5 * (partition[-1] - partition[-2]) elif integration == 'left': multiplier[:-1] = partition[1:] - partition[:-1] elif integration == 'right': multiplier[1:] = partition[1:] - partition[:-1] loss = -torch.mean( torch.sum(multiplier * (thing_to_add + torch.sum(w_detached * log_weight.unsqueeze(-1), dim=1)), dim=1)) log_evidence = torch.logsumexp(log_weight, dim=1) - np.log(num_particles) elbo = torch.mean(log_evidence) return loss, elbo
def get_thermo_loss_from_log_weight_log_p_log_q(log_weight, log_p, log_q, partition, num_particles=1, integration='left'): """Args: log_weight: tensor of shape [batch_size, num_particles] log_p: tensor of shape [batch_size, num_particles] log_q: tensor of shape [batch_size, num_particles] partition: partition of [0, 1]; tensor of shape [num_partitions + 1] where partition[0] is zero and partition[-1] is one; see https://en.wikipedia.org/wiki/Partition_of_an_interval num_particles: int integration: left, right or trapz Returns: loss: scalar that we call .backward() on and step the optimizer. elbo: average elbo over data """ heated_log_weight = log_weight.unsqueeze(-1) * partition heated_normalized_weight = util.exponentiate_and_normalize( heated_log_weight, dim=1) thermo_logp = partition * log_p.unsqueeze(-1) + \ (1 - partition) * log_q.unsqueeze(-1) wf = heated_normalized_weight * log_weight.unsqueeze(-1) w_detached = heated_normalized_weight.detach() if num_particles == 1: correction = 1 else: correction = num_particles / (num_particles - 1) thing_to_add = correction * torch.sum( w_detached * (log_weight.unsqueeze(-1) - torch.sum(wf, dim=1, keepdim=True)).detach() * (thermo_logp - torch.sum(thermo_logp * w_detached, dim=1, keepdim=True)), dim=1) multiplier = torch.zeros_like(partition) if integration == 'trapz': multiplier[0] = 0.5 * (partition[1] - partition[0]) multiplier[1:-1] = 0.5 * (partition[2:] - partition[0:-2]) multiplier[-1] = 0.5 * (partition[-1] - partition[-2]) elif integration == 'left': multiplier[:-1] = partition[1:] - partition[:-1] elif integration == 'right': multiplier[1:] = partition[1:] - partition[:-1] loss = -torch.mean( torch.sum(multiplier * (thing_to_add + torch.sum(w_detached * log_weight.unsqueeze(-1), dim=1)), dim=1)) log_evidence = torch.logsumexp(log_weight, dim=1) - np.log(num_particles) elbo = torch.mean(log_evidence) return loss, elbo
def get_mean_stds(generative_model, inference_network, num_mc_samples, obss, num_particles): vimco_grad = util.OnlineMeanStd() vimco_one_grad = util.OnlineMeanStd() reinforce_grad = util.OnlineMeanStd() reinforce_one_grad = util.OnlineMeanStd() two_grad = util.OnlineMeanStd() log_evidence_stats = util.OnlineMeanStd() log_evidence_grad = util.OnlineMeanStd() wake_phi_loss_grad = util.OnlineMeanStd() log_Q_grad = util.OnlineMeanStd() sleep_loss_grad = util.OnlineMeanStd() for mc_sample_idx in range(num_mc_samples): util.print_with_time('MC sample {}'.format(mc_sample_idx)) log_weight, log_q = losses.get_log_weight_and_log_q( generative_model, inference_network, obss, num_particles) log_evidence = torch.logsumexp(log_weight, dim=1) - \ np.log(num_particles) avg_log_evidence = torch.mean(log_evidence) log_Q = torch.sum(log_q, dim=1) avg_log_Q = torch.mean(log_Q) reinforce_one = torch.mean(log_evidence.detach() * log_Q) reinforce = reinforce_one + avg_log_evidence vimco_one = 0 for i in range(num_particles): log_weight_ = log_weight[:, util.range_except(num_particles, i)] control_variate = torch.logsumexp( torch.cat([log_weight_, torch.mean(log_weight_, dim=1, keepdim=True)], dim=1), dim=1) vimco_one = vimco_one + (log_evidence.detach() - control_variate.detach()) * log_q[:, i] vimco_one = torch.mean(vimco_one) vimco = vimco_one + avg_log_evidence normalized_weight = util.exponentiate_and_normalize(log_weight, dim=1) wake_phi_loss = torch.mean( -torch.sum(normalized_weight.detach() * log_q, dim=1)) inference_network.zero_grad() generative_model.zero_grad() vimco.backward(retain_graph=True) vimco_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() vimco_one.backward(retain_graph=True) vimco_one_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() reinforce.backward(retain_graph=True) reinforce_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() reinforce_one.backward(retain_graph=True) reinforce_one_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() avg_log_evidence.backward(retain_graph=True) two_grad.update([param.grad for param in inference_network.parameters()]) log_evidence_grad.update([param.grad for param in generative_model.parameters()]) inference_network.zero_grad() generative_model.zero_grad() wake_phi_loss.backward(retain_graph=True) wake_phi_loss_grad.update([param.grad for param in inference_network.parameters()]) inference_network.zero_grad() generative_model.zero_grad() avg_log_Q.backward(retain_graph=True) log_Q_grad.update([param.grad for param in inference_network.parameters()]) log_evidence_stats.update([avg_log_evidence.unsqueeze(0)]) sleep_loss = losses.get_sleep_loss( generative_model, inference_network, num_particles * len(obss)) inference_network.zero_grad() generative_model.zero_grad() sleep_loss.backward() sleep_loss_grad.update([param.grad for param in inference_network.parameters()]) return list(map( lambda x: x.avg_of_means_stds(), [vimco_grad, vimco_one_grad, reinforce_grad, reinforce_one_grad, two_grad, log_evidence_stats, log_evidence_grad, wake_phi_loss_grad, log_Q_grad, sleep_loss_grad]))
def train_mws( generative_model, inference_network, data_loader, num_iterations, memory_size, true_cluster_cov, test_data_loader, test_num_particles, true_generative_model, checkpoint_path, reweighted=False, ): optimizer = torch.optim.Adam( itertools.chain(generative_model.parameters(), inference_network.parameters())) ( theta_losses, phi_losses, cluster_cov_distances, test_log_ps, test_log_ps_true, test_kl_qps, test_kl_pqs, test_kl_qps_true, test_kl_pqs_true, train_log_ps, train_log_ps_true, train_kl_qps, train_kl_pqs, train_kl_qps_true, train_kl_pqs_true, train_kl_memory_ps, train_kl_memory_ps_true, ) = ([], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []) memory = {} data_loader_iter = iter(data_loader) for iteration in range(num_iterations): # get obs try: obs = next(data_loader_iter) except StopIteration: data_loader_iter = iter(data_loader) obs = next(data_loader_iter) theta_loss = 0 phi_loss = 0 for single_obs in obs: # key to index memory single_obs_key = tuple(single_obs.tolist()) # populate memory if empty if (single_obs_key not in memory) or len( memory[single_obs_key]) == 0: # batch shape [1] and event shape [num_data] latent_dist = inference_network.get_latent_dist( single_obs.unsqueeze(0)) # HACK while True: # [memory_size, num_data] latent = inference_network.sample_from_latent_dist( latent_dist, memory_size).squeeze(1) # list of M \in {1, ..., memory_size} elements # could be less than memory_size because # sampled elements can be duplicate memory[single_obs_key] = list( set([tuple(x.tolist()) for x in latent])) if len(memory[single_obs_key]) == memory_size: break # WAKE # batch shape [1] and event shape [num_data] latent_dist = inference_network.get_latent_dist( single_obs.unsqueeze(0)) # [1, 1, num_data] -> [num_data] latent = inference_network.sample_from_latent_dist(latent_dist, 1).view(-1) # set (of size memory_size + 1) of tuples (of length num_data) memoized_latent_plus_current_latent = set( memory.get(single_obs_key, []) + [tuple(latent.tolist())]) # [memory_size + 1, 1, num_data] memoized_latent_plus_current_latent_tensor = torch.tensor( list(memoized_latent_plus_current_latent), device=single_obs.device).unsqueeze(1) # [memory_size + 1] log_p_tensor = generative_model.get_log_prob( memoized_latent_plus_current_latent_tensor, single_obs.unsqueeze(0)).squeeze(-1) # this takes the longest # {int: [], ...} log_p = { mem_latent: lp for mem_latent, lp in zip(memoized_latent_plus_current_latent, log_p_tensor) } # update memory. # {float: list of ints} memory[single_obs_key] = sorted( memoized_latent_plus_current_latent, key=log_p.get)[-memory_size:] # REMEMBER # [] if reweighted: memory_log_weight = torch.stack( list(map(log_p.get, memory[single_obs_key]))) # [memory_size] memory_weight_normalized = util.exponentiate_and_normalize( memory_log_weight, dim=0) # [memory_size] memory_latent = torch.tensor( memory[single_obs_key]) # [memory_size, num_data] inference_network_log_prob = inference_network.get_log_prob_from_latent_dist( latent_dist, memory_latent[:, None, :]).squeeze(-1) # [memory_size] theta_loss += -torch.sum( memory_log_weight * memory_weight_normalized.detach()) / len(obs) phi_loss += -torch.sum( inference_network_log_prob * memory_weight_normalized.detach()) / len(obs) else: remembered_latent_id_dist = torch.distributions.Categorical( logits=torch.tensor( list(map(log_p.get, memory[single_obs_key])))) remembered_latent_id = remembered_latent_id_dist.sample() remembered_latent_id_log_prob = remembered_latent_id_dist.log_prob( remembered_latent_id) remembered_latent = memory[single_obs_key][ remembered_latent_id] remembered_latent_tensor = torch.tensor( [remembered_latent], device=single_obs.device) # [] theta_loss += -(log_p.get(remembered_latent) - remembered_latent_id_log_prob.detach()) / len( obs) # [] phi_loss += -inference_network.get_log_prob_from_latent_dist( latent_dist, remembered_latent_tensor).view(()) / len(obs) # SLEEP # TODO optimizer.zero_grad() theta_loss.backward() phi_loss.backward() optimizer.step() theta_losses.append(theta_loss.item()) phi_losses.append(phi_loss.item()) cluster_cov_distances.append( torch.norm(true_cluster_cov - generative_model.get_cluster_cov()).item()) if iteration % 100 == 0: # test every 100 iterations ( test_log_p, test_log_p_true, test_kl_qp, test_kl_pq, test_kl_qp_true, test_kl_pq_true, _, _, ) = models.eval_gen_inf(true_generative_model, generative_model, inference_network, None, test_data_loader) test_log_ps.append(test_log_p) test_log_ps_true.append(test_log_p_true) test_kl_qps.append(test_kl_qp) test_kl_pqs.append(test_kl_pq) test_kl_qps_true.append(test_kl_qp_true) test_kl_pqs_true.append(test_kl_pq_true) ( train_log_p, train_log_p_true, train_kl_qp, train_kl_pq, train_kl_qp_true, train_kl_pq_true, train_kl_memory_p, train_kl_memory_p_true, ) = models.eval_gen_inf(true_generative_model, generative_model, inference_network, memory, data_loader) train_log_ps.append(train_log_p) train_log_ps_true.append(train_log_p_true) train_kl_qps.append(train_kl_qp) train_kl_pqs.append(train_kl_pq) train_kl_qps_true.append(train_kl_qp_true) train_kl_pqs_true.append(train_kl_pq_true) train_kl_memory_ps.append(train_kl_memory_p) train_kl_memory_ps_true.append(train_kl_memory_p_true) util.save_checkpoint( checkpoint_path, generative_model, inference_network, theta_losses, phi_losses, cluster_cov_distances, test_log_ps, test_log_ps_true, test_kl_qps, test_kl_pqs, test_kl_qps_true, test_kl_pqs_true, train_log_ps, train_log_ps_true, train_kl_qps, train_kl_pqs, train_kl_qps_true, train_kl_pqs_true, train_kl_memory_ps, train_kl_memory_ps_true, memory, None, None, ) util.print_with_time( "it. {} | theta loss = {:.2f} | phi loss = {:.2f}".format( iteration, theta_loss, phi_loss)) # if iteration % 200 == 0: # z = inference_network.get_latent_dist(obs).sample() # util.save_plot("images/mws/iteration_{}.png".format(iteration), # obs[:3], z[:3]) return ( theta_losses, phi_losses, cluster_cov_distances, test_log_ps, test_log_ps_true, test_kl_qps, test_kl_pqs, test_kl_qps_true, test_kl_pqs_true, train_log_ps, train_log_ps_true, train_kl_qps, train_kl_pqs, train_kl_qps_true, train_kl_pqs_true, train_kl_memory_ps, train_kl_memory_ps_true, memory, None, None, )