def _run_step(self, batch): x, _ = batch z_mu, z_log_var = self.encoder(x) # we're estimating the KL divergence using sampling num_samples = 32 # expand dims to sample all at once # (batch, z_dim) -> (batch, num_samples, z_dim) z_mu = z_mu.unsqueeze(1) z_mu = shaping.tile(z_mu, 1, num_samples) # (batch, z_dim) -> (batch, num_samples, z_dim) z_log_var = z_log_var.unsqueeze(1) z_log_var = shaping.tile(z_log_var, 1, num_samples) # convert to std z_std = torch.exp(z_log_var / 2) P = self.get_prior(z_mu, z_std) Q = self.get_approx_posterior(z_mu, z_std) x = x.view(x.size(0), -1) loss, recon_loss, kl_div, pxz = self.elbo_loss(x, P, Q, num_samples) return loss, recon_loss, kl_div, pxz
def elbo_loss(self, x, P, Q, num_samples): z = Q.rsample() # ---------------------- # KL divergence loss (using monte carlo sampling) # ---------------------- log_qz = Q.log_prob(z) log_pz = P.log_prob(z) # (batch, num_samples, z_dim) -> (batch, num_samples) kl_div = (log_qz - log_pz).sum(dim=2) # we used monte carlo sampling to estimate. average across samples # kl_div = kl_div.mean(-1) # ---------------------- # Reconstruction loss # ---------------------- z = z.view(-1, z.size(-1)).contiguous() pxz = self.decoder(z) pxz = pxz.view(-1, num_samples, pxz.size(-1)) x = shaping.tile(x.unsqueeze(1), 1, num_samples) pxz = torch.sigmoid(pxz) recon_loss = F.binary_cross_entropy(pxz, x, reduction="none") # sum across dimensions because sum of log probabilities of iid univariate gaussians is the same as # multivariate gaussian recon_loss = recon_loss.sum(dim=-1) # we used monte carlo sampling to estimate. average across samples # recon_loss = recon_loss.mean(-1) # ELBO = reconstruction + KL loss = recon_loss + kl_div # average over batch loss = loss.mean() recon_loss = recon_loss.mean() kl_div = kl_div.mean() return loss, recon_loss, kl_div, pxz