示例#1
0
    def _run_step(self, batch):
        x, _ = batch
        z_mu, z_log_var = self.encoder(x)

        # we're estimating the KL divergence using sampling
        num_samples = 32

        # expand dims to sample all at once
        # (batch, z_dim) -> (batch, num_samples, z_dim)
        z_mu = z_mu.unsqueeze(1)
        z_mu = shaping.tile(z_mu, 1, num_samples)

        # (batch, z_dim) -> (batch, num_samples, z_dim)
        z_log_var = z_log_var.unsqueeze(1)
        z_log_var = shaping.tile(z_log_var, 1, num_samples)

        # convert to std
        z_std = torch.exp(z_log_var / 2)

        P = self.get_prior(z_mu, z_std)
        Q = self.get_approx_posterior(z_mu, z_std)

        x = x.view(x.size(0), -1)

        loss, recon_loss, kl_div, pxz = self.elbo_loss(x, P, Q, num_samples)

        return loss, recon_loss, kl_div, pxz
示例#2
0
    def elbo_loss(self, x, P, Q, num_samples):
        z = Q.rsample()

        # ----------------------
        # KL divergence loss (using monte carlo sampling)
        # ----------------------
        log_qz = Q.log_prob(z)
        log_pz = P.log_prob(z)

        # (batch, num_samples, z_dim) -> (batch, num_samples)
        kl_div = (log_qz - log_pz).sum(dim=2)

        # we used monte carlo sampling to estimate. average across samples
        # kl_div = kl_div.mean(-1)

        # ----------------------
        # Reconstruction loss
        # ----------------------
        z = z.view(-1, z.size(-1)).contiguous()
        pxz = self.decoder(z)

        pxz = pxz.view(-1, num_samples, pxz.size(-1))
        x = shaping.tile(x.unsqueeze(1), 1, num_samples)

        pxz = torch.sigmoid(pxz)
        recon_loss = F.binary_cross_entropy(pxz, x, reduction="none")

        # sum across dimensions because sum of log probabilities of iid univariate gaussians is the same as
        # multivariate gaussian
        recon_loss = recon_loss.sum(dim=-1)

        # we used monte carlo sampling to estimate. average across samples
        # recon_loss = recon_loss.mean(-1)

        # ELBO = reconstruction + KL
        loss = recon_loss + kl_div

        # average over batch
        loss = loss.mean()
        recon_loss = recon_loss.mean()
        kl_div = kl_div.mean()

        return loss, recon_loss, kl_div, pxz