def log_prob(self, ob, z, tau, mu, aggregate=False): """ aggregate = False : return S * B * N aggregate = True : return S * B * K """ sigma = 1. / tau.sqrt() labels = z.argmax(-1) labels_flat = labels.unsqueeze(-1).repeat(1, 1, 1, ob.shape[-1]) mu_expand = torch.gather(mu, 2, labels_flat) sigma_expand = torch.gather(sigma, 2, labels_flat) ll = Normal(mu_expand, sigma_expand).log_prob(ob).sum(-1) # S * B * N if aggregate: ll = ll.sum(-1) # S * B return ll
def train(): train_loss = [] for batch_idx, (x, _) in enumerate(train_loader): start_time = time.time() x = x.to(DEVICE) opt.zero_grad() x_tilde, z_e_x, z_q_x = model(x) z_q_x.retain_grad() loss_recons = F.mse_loss(x_tilde, x) loss_recons.backward(retain_graph=True) # Straight-through estimator z_e_x.backward(z_q_x.grad, retain_graph=True) # Vector quantization objective model.codebook.zero_grad() loss_vq = F.mse_loss(z_q_x, z_e_x.detach()) loss_vq.backward(retain_graph=True) # Commitment objective loss_commit = LAMDA * F.mse_loss(z_e_x, z_q_x.detach()) loss_commit.backward() opt.step() N = x.numel() nll = Normal(x_tilde, torch.ones_like(x_tilde)).log_prob(x) log_px = nll.sum() / N + np.log(128) - np.log(K * 2) log_px /= np.log(2) train_loss.append([log_px.item()] + to_scalar([loss_recons, loss_vq])) if (batch_idx + 1) % PRINT_INTERVAL == 0: print('\tIter [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format( batch_idx * len(x), len(train_loader.dataset), PRINT_INTERVAL * batch_idx / len(train_loader), np.asarray(train_loss)[-PRINT_INTERVAL:].mean(0), time.time() - start_time))
def loss_function(self, *inputs, **kwargs): """loss function described in the paper (eq. (10))""" decoded = inputs[0] encoded = inputs[1] z = inputs[2] x = inputs[3] dataset_size = kwargs['dataset_size'] batch_size = z.size(0) mu, logvar = encoded # compute likelyhood term if self.binary: # likelihood term under Bernolli MLP decoder MLD = F.binary_cross_entropy(decoded, x, reduction='sum').div(x.size(0)) else: # likelihood term under Gaussian MLP decoder mu_o, logvar_o = decoded recon_x_distribution = Normal(loc=mu_o, scale=torch.exp(0.5 * logvar_o)) MLD = -recon_x_distribution.log_prob(x).sum(1).mean() log_q_z_n = Normal(loc=mu, scale=torch.exp( 0.5 * logvar)).log_prob(z).sum(1) # log q(z|n) log_p_z = Normal(loc=torch.zeros_like(z), scale=torch.ones_like(z)).log_prob(z).sum( 1) # p(z) (N(0,I)) # the log(q(z(n_i))|n_j) matrix mat_log_q_z_n = Normal(loc=mu.unsqueeze(dim=0), scale=torch.exp(0.5 * logvar).unsqueeze( dim=0)).log_prob(z.unsqueeze(dim=1)) # compute log q(z) and log prod_j q(z_j) according to sampling method if self.sampling == "mws": # MWS(Minibatch Weighted Sampling) log_q_z = torch.logsumexp( mat_log_q_z_n.sum(2), dim=1, keepdim=False) - math.log( batch_size * dataset_size) log_prod_q_z = ( torch.logsumexp(mat_log_q_z_n, dim=1, keepdim=False) - math.log(batch_size * dataset_size)).sum(1) elif self.sampling == "mss": # MSS(Minibatch Stratified Sampling) log_importance_weights = self.get_log_importance_weight_mat( batch_size, dataset_size).type_as(mat_log_q_z_n) log_q_z = torch.logsumexp(log_importance_weights + mat_log_q_z_n.sum(2), dim=1, keepdim=False) log_prod_q_z = torch.logsumexp( log_importance_weights.unsqueeze(dim=2) + mat_log_q_z_n, dim=1, keepdim=False).sum(1) else: raise NotImplementedError # decomposition index_code_MI = (log_q_z_n - log_q_z).mean() TC = (log_q_z - log_prod_q_z).mean() dim_wise_KL = (log_prod_q_z - log_p_z).mean() # print("MI: {}, TC: {}, KL: {}".format(index_code_MI, TC, dim_wise_KL)) return { "loss": MLD + self.alpha * index_code_MI + self.beta * TC + self.gamma * dim_wise_KL, "MLD": MLD, "index_code_MI": index_code_MI, "TC": TC, "dim_wise_KL": dim_wise_KL }