def get_reconstruction_loss( self, x: torch.Tensor, y: torch.Tensor, px_: Dict[str, torch.Tensor], py_: Dict[str, torch.Tensor], pro_batch_mask_minibatch: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute reconstruction loss """ # Reconstruction Loss if self.reconstruction_loss_gene == "zinb": reconst_loss_gene = (-ZeroInflatedNegativeBinomial( mu=px_["rate"], theta=px_["r"], zi_logits=px_["dropout"]).log_prob(x).sum(dim=-1)) else: reconst_loss_gene = (-NegativeBinomial( mu=px_["rate"], theta=px_["r"]).log_prob(x).sum(dim=-1)) reconst_loss_protein_full = -log_mixture_nb( y, py_["rate_back"], py_["rate_fore"], py_["r"], None, py_["mixing"]) if pro_batch_mask_minibatch is not None: temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full) temp_pro_loss_full.masked_scatter_(pro_batch_mask_minibatch.bool(), reconst_loss_protein_full) reconst_loss_protein = temp_pro_loss_full.sum(dim=-1) else: reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1) return reconst_loss_gene, reconst_loss_protein
def get_reconstruction_loss( self, x: torch.Tensor, px_rate: torch.Tensor, px_r: torch.Tensor, px_dropout: torch.Tensor, bernoulli_params: torch.Tensor, eps_log: float = 1e-8, **kwargs, ) -> torch.Tensor: # LLs for NB and ZINB ll_zinb = torch.log(1.0 - bernoulli_params + eps_log) + ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x) ll_nb = torch.log(bernoulli_params + eps_log) + NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x) # Reconstruction loss using a logsumexp-type computation ll_max = torch.max(ll_zinb, ll_nb) ll_tot = ll_max + torch.log( torch.exp(ll_nb - ll_max) + torch.exp(ll_zinb - ll_max)) reconst_loss = -ll_tot.sum(dim=-1) return reconst_loss
def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs) -> torch.Tensor: # Reconstruction Loss if self.reconstruction_loss == "zinb": reconst_loss = (-ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1)) elif self.reconstruction_loss == "nb": reconst_loss = (-NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)) elif self.reconstruction_loss == "poisson": reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1) return reconst_loss
def reconstruction_loss( self, x: torch.Tensor, px_rate: torch.Tensor, px_r: torch.Tensor, px_dropout: torch.Tensor, mode: int, ) -> torch.Tensor: reconstruction_loss = None if self.reconstruction_losses[mode] == "zinb": reconstruction_loss = (-ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1)) elif self.reconstruction_losses[mode] == "nb": reconstruction_loss = (-NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)) elif self.reconstruction_losses[mode] == "poisson": reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1) return reconstruction_loss
def test_zinb_distribution(): theta = 100.0 + torch.rand(size=(2, )) mu = 15.0 * torch.ones_like(theta) pi = torch.randn_like(theta) x = torch.randint_like(mu, high=20) log_p_ref = log_zinb_positive(x, mu, theta, pi) dist = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi) log_p_zinb = dist.log_prob(x) assert (log_p_ref - log_p_zinb).abs().max().item() <= 1e-8 torch.manual_seed(0) s1 = dist.sample((100, )) assert s1.shape == (100, 2) s2 = dist.sample(sample_shape=(4, 3)) assert s2.shape == (4, 3, 2) log_p_ref = log_nb_positive(x, mu, theta) dist = NegativeBinomial(mu=mu, theta=theta) log_p_nb = dist.log_prob(x) assert (log_p_ref - log_p_nb).abs().max().item() <= 1e-8 s1 = dist.sample((1000, )) assert s1.shape == (1000, 2) assert (s1.mean(0) - mu).abs().mean() <= 1e0 assert (s1.std(0) - (mu + mu * mu / theta)**0.5).abs().mean() <= 1e0 size = (50, 3) theta = 100.0 + torch.rand(size=size) mu = 15.0 * torch.ones_like(theta) pi = torch.randn_like(theta) x = torch.randint_like(mu, high=20) dist1 = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi) dist2 = NegativeBinomial(mu=mu, theta=theta) assert dist1.log_prob(x).shape == size assert dist2.log_prob(x).shape == size with pytest.raises(ValueError): ZeroInflatedNegativeBinomial(mu=-mu, theta=theta, zi_logits=pi) with pytest.warns(UserWarning): dist1.log_prob(-x) # ensures neg values raise warning with pytest.warns(UserWarning): dist2.log_prob(0.5 * x) # ensures float values raise warning