Exemple #1
0
 def get_reconstruction_loss(self, x):
     hx = ilr(self.imputer(x), self.Psi)
     z_mean = self.encoder(hx)
     eta = self.decoder(z_mean)
     logp = self.Psi.t() @ eta.t()
     mult_loss = Multinomial(logits=logp.t()).log_prob(x).mean()
     return -mult_loss
Exemple #2
0
 def encode(self, x):
     #B = B.sum(axis=0) + 1
     #B = B.unsqueeze(0)
     #batch_effects = (self.Psi @ B.t()).t()
     hx = ilr(self.imputer(x), self.Psi)
     #hx -= batch_effects  # Subtract out batch effects
     z = self.encoder(hx)
     return z
Exemple #3
0
 def get_reconstruction_loss(self, x, B):
     hx = ilr(self.imputer(x), self.Psi)
     batch_effects = (self.Psi @ B.t()).t()
     hx -= batch_effects  # Subtract out batch effects
     z_mean = self.encoder(hx)
     eta = self.decoder(z_mean)
     eta += batch_effects  # Add batch effects back in
     recon_loss = -self.recon_model_loglik(x, eta)
     return recon_loss
Exemple #4
0
 def get_reconstruction_loss(self, x, B):
     hx = ilr(self.imputer(x), self.Psi)
     batch_effects = (self.Psi @ B.t()).t()
     hx -= batch_effects  # Subtract out batch effects
     z_mean = self.encoder(hx)
     eta = self.decoder(z_mean)
     eta += batch_effects  # Add batch effects back in
     logp = self.Psi.t() @ eta.t()
     mult_loss = Multinomial(logits=logp.t()).log_prob(x).mean()
     return -mult_loss
Exemple #5
0
 def get_reconstruction_loss(self, x):
     x_ = ilr(self.imputer(x), self.Psi)
     if self.use_analytic_elbo:
         return -self.analytic_exp_recon_loss(x_)
     else:
         z_mean = self.encoder(x_)
         eps = torch.normal(torch.zeros_like(z_mean), 1.0)
         z_sample = z_mean + eps * torch.exp(0.5 * self.variational_logvars)
         x_out = self.decoder(z_sample)
         recon_loss = -self.recon_model_loglik(x, x_out)
         return recon_loss
Exemple #6
0
 def forward(self, x, B):
     hx = ilr(self.imputer(x), self.Psi)
     batch_effects = (self.Psi @ B.t()).t()
     hx -= batch_effects  # Subtract out batch effects
     z_mean = self.encoder(hx)
     eps = torch.normal(torch.zeros_like(z_mean), 1.0)
     z_sample = z_mean + eps * torch.exp(0.5 * self.variational_logvars)
     x_out = self.decoder(z_sample)
     x_out += batch_effects  # Add batch effects back in
     kl_div = (
         -self.gaussian_kl(z_mean, self.variational_logvars)).mean(0).sum()
     recon_loss = (-self.recon_model_loglik(x, x_out)).mean(0).sum()
     loss = kl_div + recon_loss
     return loss
Exemple #7
0
 def forward(self, x):
     hx = ilr(self.imputer(x), self.Psi)
     z_mean = self.encoder(hx)
     mu = self.decoder(z_mean)
     W = self.decoder.weight
     # penalties
     D = torch.exp(self.variational_logvars)
     var = torch.exp(self.log_sigma_sq)
     qdist = MultivariateNormalFactorIdentity(mu, var, D, W)
     logp = self.Psi.t() @ self.eta.t()
     prior_loss = Normal(self.zm, self.zI).log_prob(z_mean).mean()
     logit_loss = qdist.log_prob(self.eta).mean()
     mult_loss = Multinomial(logits=logp.t()).log_prob(x).mean()
     loglike = mult_loss + logit_loss + prior_loss
     return -loglike
Exemple #8
0
    def forward(self, x):
        x_ = ilr(self.imputer(x), self.Psi)
        z_mean = self.encoder(x_)

        if not self.use_analytic_elbo:
            eps = torch.normal(torch.zeros_like(z_mean), 1.0)
            z_sample = z_mean + eps * torch.exp(0.5 * self.variational_logvars)
            x_out = self.decoder(z_sample)
            kl_div = (-self.gaussian_kl(
                z_mean, self.variational_logvars)).mean(0).sum()
            recon_loss = (-self.recon_model_loglik(x, x_out)).mean(0).sum()
            loss = kl_div + recon_loss
        else:
            loss = self.analytic_elbo(x_, z_mean)
        return loss
Exemple #9
0
 def encode(self, x):
     hx = ilr(self.imputer(x), self.Psi)
     z = self.encoder(hx)
     return z
Exemple #10
0
 def reset(self, x):
     hx = ilr(self.imputer(x), self.Psi)
     self.eta.data = hx.data