def loss(self, outputs, weight): c_outputs, z_outputs, x_outputs = outputs # 1. Reconstruction loss x, x_mean = x_outputs recon_loss = bernoulli_log_likelihood(x.view(-1, 1, 28, 28), x_mean) recon_loss /= (self.batch_size * self.sample_size) # 2. KL Divergence terms kl = 0 # a) Context divergence c_mean, c_logvar = c_outputs kl_c = kl_diagnormal_stdnormal(c_mean, c_logvar) kl += kl_c # b) Latent divergences qz_params, pz_params = z_outputs shapes = ((self.batch_size, self.sample_size, self.z_dim), (self.batch_size, 1, self.z_dim)) for i in range(self.n_stochastic): args = (qz_params[i][0].view(shapes[0]), qz_params[i][1].view(shapes[0]), pz_params[i][0].view(shapes[1] if i == 0 else shapes[0]), pz_params[i][1].view(shapes[1] if i == 0 else shapes[0])) kl_z = kl_diagnormal_diagnormal(*args) kl += kl_z kl /= (self.batch_size * self.sample_size) # Variational lower bound and weighted loss vlb = recon_loss - kl loss = -((weight * recon_loss) - (kl / weight)) return loss, vlb
def summarize(self, dataset, output_size=6): """ There's some nasty indexing going on here because pytorch doesn't have numpy indexing yet. This will be fixed soon. """ # cast to torch Cuda Variable and reshape dataset = dataset.view(1, self.sample_size, self.n_features) # get approximate posterior over full dataset c_mean_full, c_logvar_full = self.statistic_network(dataset, summarize=True) # iteratively discard until dataset is of required size while dataset.size(1) != output_size: kl_divergences = [] # need KL divergence between full approximate posterior and all # subsets of given size subset_indices = list( combinations(range(dataset.size(1)), dataset.size(1) - 1)) for subset_index in subset_indices: # pull out subset, numpy indexing will make this much easier ix = Variable(torch.LongTensor(subset_index).cuda()) subset = dataset.index_select(1, ix) # calculate approximate posterior over subset c_mean, c_logvar = self.statistic_network(subset, summarize=True) kl = kl_diagnormal_diagnormal(c_mean_full, c_logvar_full, c_mean, c_logvar) kl_divergences.append(kl.data[0]) # determine which sample we want to remove best_index = kl_divergences.index(min(kl_divergences)) # determine which samples to keep to_keep = subset_indices[best_index] to_keep = Variable(torch.LongTensor(to_keep).cuda()) # keep only desired samples dataset = dataset.index_select(1, to_keep) # return pruned dataset return dataset