def loss(self, X, X_hat, flow): """ Computes the VAE loss objective and collects some training statistics X - data tensor, torch.LongTensor(batch_size, num_parameters=155) X_hat - data tensor, torch.FloatTensor(batch_size, num_parameters=155, max_value=128) flow - the namedtuple returned by TriangularSylvesterFlow for reference, the namedtuple is ('Flow', ('q_z', 'log_det', 'z_0', 'z_k', 'flow')) """ p_z_k = Normal(0,1).log_prob(flow.z_k).sum(-1) q_z_0 = flow.q_z.log_prob(flow.z_0).sum(-1) kl = (q_z_0-p_z_k-flow.log_det).mean() / flow.z_k.shape[-1] beta = sigmoidal_annealing(self.iter, self.beta_temp).item() reconstruction_loss = F.cross_entropy(X_hat.transpose(-1, -2), X) accuracy = (X_hat.argmax(-1)==X).float().mean() loss = reconstruction_loss + self.max_beta * beta * kl return loss, { 'accuracy': accuracy, 'reconstruction_loss': reconstruction_loss, 'kl': kl, 'beta': beta, 'log_det': flow.log_det.mean(), 'p_z_k': p_z_k.mean(), 'q_z_0': q_z_0.mean(), # 'iter': self.iter // self. }
def forward(self, x): pred_result = self.predict(x) x = x.unsqueeze( 0) # unsqueeze to broadcast input across sample dimension (L) log_lik = Normal(pred_result['recon_mu'], pred_result['recon_sigma']).log_prob(x).mean( dim=0) # average over sample dimension log_lik = log_lik.mean(dim=0).sum() kl = kl_divergence(pred_result['latent_dist'], self.prior).mean(dim=0).sum() loss = kl - log_lik return dict(loss=loss, kl=kl, recon_loss=log_lik, **pred_result)
def elbo_rvae(data, p_mu, p_sigma, z, q_mu, q_t, model, beta): if model._mean_warmup: return -Normal(p_mu, p_sigma).log_prob(data).sum( -1).mean(), torch.zeros(1), torch.zeros(1) else: pr_mu, pr_t = model.pr_means, model.pr_t log_pxz = Normal(p_mu, p_sigma).log_prob(data).sum(-1) log_qzx = log_bm_krn(z, q_mu, q_t, model) log_pz = log_bm_krn(z, pr_mu.expand_as(z), pr_t, model) KL = log_qzx - log_pz return (-log_pxz + beta * KL.abs()).mean(), -log_pxz.mean(), KL.mean()
def elbo_vae(data, p_mu, p_var, z, q_mu, q_var, pr_mu, pr_var, beta, vampprior=False): log_pxz = Normal(p_mu, p_var.sqrt()).log_prob(data).sum(-1) qzx = Normal(q_mu, q_var.sqrt()) pz = Normal(pr_mu, pr_var) if vampprior: log_qzx = qzx.log_prob(z).sum(-1) log_pz = log_gauss_mix(z, pr_mu, pr_var) KL = log_qzx - log_pz else: KL = kl_divergence(qzx, pz).sum(-1) return (-log_pxz + beta * KL).mean(), -log_pxz.mean(), KL.mean()