def get_reconstruction_loss(self, x): hx = ilr(self.imputer(x), self.Psi) z_mean = self.encoder(hx) eta = self.decoder(z_mean) logp = self.Psi.t() @ eta.t() mult_loss = Multinomial(logits=logp.t()).log_prob(x).mean() return -mult_loss
def encode(self, x): #B = B.sum(axis=0) + 1 #B = B.unsqueeze(0) #batch_effects = (self.Psi @ B.t()).t() hx = ilr(self.imputer(x), self.Psi) #hx -= batch_effects # Subtract out batch effects z = self.encoder(hx) return z
def get_reconstruction_loss(self, x, B): hx = ilr(self.imputer(x), self.Psi) batch_effects = (self.Psi @ B.t()).t() hx -= batch_effects # Subtract out batch effects z_mean = self.encoder(hx) eta = self.decoder(z_mean) eta += batch_effects # Add batch effects back in recon_loss = -self.recon_model_loglik(x, eta) return recon_loss
def get_reconstruction_loss(self, x, B): hx = ilr(self.imputer(x), self.Psi) batch_effects = (self.Psi @ B.t()).t() hx -= batch_effects # Subtract out batch effects z_mean = self.encoder(hx) eta = self.decoder(z_mean) eta += batch_effects # Add batch effects back in logp = self.Psi.t() @ eta.t() mult_loss = Multinomial(logits=logp.t()).log_prob(x).mean() return -mult_loss
def get_reconstruction_loss(self, x): x_ = ilr(self.imputer(x), self.Psi) if self.use_analytic_elbo: return -self.analytic_exp_recon_loss(x_) else: z_mean = self.encoder(x_) eps = torch.normal(torch.zeros_like(z_mean), 1.0) z_sample = z_mean + eps * torch.exp(0.5 * self.variational_logvars) x_out = self.decoder(z_sample) recon_loss = -self.recon_model_loglik(x, x_out) return recon_loss
def forward(self, x, B): hx = ilr(self.imputer(x), self.Psi) batch_effects = (self.Psi @ B.t()).t() hx -= batch_effects # Subtract out batch effects z_mean = self.encoder(hx) eps = torch.normal(torch.zeros_like(z_mean), 1.0) z_sample = z_mean + eps * torch.exp(0.5 * self.variational_logvars) x_out = self.decoder(z_sample) x_out += batch_effects # Add batch effects back in kl_div = ( -self.gaussian_kl(z_mean, self.variational_logvars)).mean(0).sum() recon_loss = (-self.recon_model_loglik(x, x_out)).mean(0).sum() loss = kl_div + recon_loss return loss
def forward(self, x): hx = ilr(self.imputer(x), self.Psi) z_mean = self.encoder(hx) mu = self.decoder(z_mean) W = self.decoder.weight # penalties D = torch.exp(self.variational_logvars) var = torch.exp(self.log_sigma_sq) qdist = MultivariateNormalFactorIdentity(mu, var, D, W) logp = self.Psi.t() @ self.eta.t() prior_loss = Normal(self.zm, self.zI).log_prob(z_mean).mean() logit_loss = qdist.log_prob(self.eta).mean() mult_loss = Multinomial(logits=logp.t()).log_prob(x).mean() loglike = mult_loss + logit_loss + prior_loss return -loglike
def forward(self, x): x_ = ilr(self.imputer(x), self.Psi) z_mean = self.encoder(x_) if not self.use_analytic_elbo: eps = torch.normal(torch.zeros_like(z_mean), 1.0) z_sample = z_mean + eps * torch.exp(0.5 * self.variational_logvars) x_out = self.decoder(z_sample) kl_div = (-self.gaussian_kl( z_mean, self.variational_logvars)).mean(0).sum() recon_loss = (-self.recon_model_loglik(x, x_out)).mean(0).sum() loss = kl_div + recon_loss else: loss = self.analytic_elbo(x_, z_mean) return loss
def encode(self, x): hx = ilr(self.imputer(x), self.Psi) z = self.encoder(hx) return z
def reset(self, x): hx = ilr(self.imputer(x), self.Psi) self.eta.data = hx.data