Ejemplo n.º 1
0
    def loss(self, X, X_hat, flow):
        """
        Computes the VAE loss objective and collects some training statistics

        X - data tensor, torch.LongTensor(batch_size, num_parameters=155)
        X_hat - data tensor, torch.FloatTensor(batch_size, num_parameters=155, max_value=128)
        flow - the namedtuple returned by TriangularSylvesterFlow
        
        for reference, the namedtuple is ('Flow', ('q_z', 'log_det', 'z_0', 'z_k', 'flow'))
        """
    
        p_z_k = Normal(0,1).log_prob(flow.z_k).sum(-1)
        q_z_0 = flow.q_z.log_prob(flow.z_0).sum(-1)
        kl = (q_z_0-p_z_k-flow.log_det).mean() / flow.z_k.shape[-1]

        beta = sigmoidal_annealing(self.iter, self.beta_temp).item()

        reconstruction_loss = F.cross_entropy(X_hat.transpose(-1, -2), X)
        accuracy = (X_hat.argmax(-1)==X).float().mean()

        loss = reconstruction_loss + self.max_beta * beta * kl

        return loss, {
            'accuracy': accuracy,
            'reconstruction_loss': reconstruction_loss,
            'kl': kl,
            'beta': beta,
            'log_det': flow.log_det.mean(),
            'p_z_k': p_z_k.mean(),
            'q_z_0': q_z_0.mean(),
            # 'iter': self.iter // self.
        }
Ejemplo n.º 2
0
 def forward(self, x):
     pred_result = self.predict(x)
     x = x.unsqueeze(
         0)  # unsqueeze to broadcast input across sample dimension (L)
     log_lik = Normal(pred_result['recon_mu'],
                      pred_result['recon_sigma']).log_prob(x).mean(
                          dim=0)  # average over sample dimension
     log_lik = log_lik.mean(dim=0).sum()
     kl = kl_divergence(pred_result['latent_dist'],
                        self.prior).mean(dim=0).sum()
     loss = kl - log_lik
     return dict(loss=loss, kl=kl, recon_loss=log_lik, **pred_result)
Ejemplo n.º 3
0
def elbo_rvae(data, p_mu, p_sigma, z, q_mu, q_t, model, beta):
    if model._mean_warmup:
        return -Normal(p_mu, p_sigma).log_prob(data).sum(
            -1).mean(), torch.zeros(1), torch.zeros(1)
    else:
        pr_mu, pr_t = model.pr_means, model.pr_t

        log_pxz = Normal(p_mu, p_sigma).log_prob(data).sum(-1)
        log_qzx = log_bm_krn(z, q_mu, q_t, model)
        log_pz = log_bm_krn(z, pr_mu.expand_as(z), pr_t, model)

        KL = log_qzx - log_pz

        return (-log_pxz + beta * KL.abs()).mean(), -log_pxz.mean(), KL.mean()
Ejemplo n.º 4
0
def elbo_vae(data,
             p_mu,
             p_var,
             z,
             q_mu,
             q_var,
             pr_mu,
             pr_var,
             beta,
             vampprior=False):
    log_pxz = Normal(p_mu, p_var.sqrt()).log_prob(data).sum(-1)
    qzx = Normal(q_mu, q_var.sqrt())
    pz = Normal(pr_mu, pr_var)
    if vampprior:
        log_qzx = qzx.log_prob(z).sum(-1)
        log_pz = log_gauss_mix(z, pr_mu, pr_var)

        KL = log_qzx - log_pz
    else:
        KL = kl_divergence(qzx, pz).sum(-1)

    return (-log_pxz + beta * KL).mean(), -log_pxz.mean(), KL.mean()