def get_reconstruction_loss( self, x: torch.Tensor, px_rate: torch.Tensor, px_r: torch.Tensor, px_dropout: torch.Tensor, bernoulli_params: torch.Tensor, eps_log: float = 1e-8, **kwargs, ) -> torch.Tensor: # LLs for NB and ZINB ll_zinb = torch.log(1.0 - bernoulli_params + eps_log) + ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x) ll_nb = torch.log(bernoulli_params + eps_log) + NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x) # Reconstruction loss using a logsumexp-type computation ll_max = torch.max(ll_zinb, ll_nb) ll_tot = ll_max + torch.log( torch.exp(ll_nb - ll_max) + torch.exp(ll_zinb - ll_max)) reconst_loss = -ll_tot.sum(dim=-1) return reconst_loss
def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs) -> torch.Tensor: # Reconstruction Loss if self.gene_likelihood == "zinb": reconst_loss = (-ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1)) elif self.gene_likelihood == "nb": reconst_loss = (-NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)) elif self.gene_likelihood == "poisson": reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1) return reconst_loss
def reconstruction_loss( self, x: torch.Tensor, px_rate: torch.Tensor, px_r: torch.Tensor, px_dropout: torch.Tensor, mode: int, ) -> torch.Tensor: reconstruction_loss = None if self.gene_likelihoods[mode] == "zinb": reconstruction_loss = (-ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1)) elif self.gene_likelihoods[mode] == "nb": reconstruction_loss = (-NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)) elif self.gene_likelihoods[mode] == "poisson": reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1) return reconstruction_loss
def get_reconstruction_loss( self, x: torch.Tensor, y: torch.Tensor, px_dict: Dict[str, torch.Tensor], py_dict: Dict[str, torch.Tensor], pro_batch_mask_minibatch: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute reconstruction loss.""" px_ = px_dict py_ = py_dict # Reconstruction Loss if self.gene_likelihood == "zinb": reconst_loss_gene = (-ZeroInflatedNegativeBinomial( mu=px_["rate"], theta=px_["r"], zi_logits=px_["dropout"]).log_prob(x).sum(dim=-1)) else: reconst_loss_gene = (-NegativeBinomial( mu=px_["rate"], theta=px_["r"]).log_prob(x).sum(dim=-1)) py_conditional = NegativeBinomialMixture( mu1=py_["rate_back"], mu2=py_["rate_fore"], theta1=py_["r"], mixture_logits=py_["mixing"], ) reconst_loss_protein_full = -py_conditional.log_prob(y) if pro_batch_mask_minibatch is not None: temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full) temp_pro_loss_full.masked_scatter_(pro_batch_mask_minibatch.bool(), reconst_loss_protein_full) reconst_loss_protein = temp_pro_loss_full.sum(dim=-1) else: reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1) return reconst_loss_gene, reconst_loss_protein
def test_zinb_distribution(): theta = 100.0 + torch.rand(size=(2, )) mu = 15.0 * torch.ones_like(theta) pi = torch.randn_like(theta) x = torch.randint_like(mu, high=20) log_p_ref = log_zinb_positive(x, mu, theta, pi) dist = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi) log_p_zinb = dist.log_prob(x) assert (log_p_ref - log_p_zinb).abs().max().item() <= 1e-8 torch.manual_seed(0) s1 = dist.sample((100, )) assert s1.shape == (100, 2) s2 = dist.sample(sample_shape=(4, 3)) assert s2.shape == (4, 3, 2) log_p_ref = log_nb_positive(x, mu, theta) dist = NegativeBinomial(mu=mu, theta=theta) log_p_nb = dist.log_prob(x) assert (log_p_ref - log_p_nb).abs().max().item() <= 1e-8 s1 = dist.sample((1000, )) assert s1.shape == (1000, 2) assert (s1.mean(0) - mu).abs().mean() <= 1e0 assert (s1.std(0) - (mu + mu * mu / theta)**0.5).abs().mean() <= 1e0 size = (50, 3) theta = 100.0 + torch.rand(size=size) mu = 15.0 * torch.ones_like(theta) pi = torch.randn_like(theta) x = torch.randint_like(mu, high=20) dist1 = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi) dist2 = NegativeBinomial(mu=mu, theta=theta) assert dist1.log_prob(x).shape == size assert dist2.log_prob(x).shape == size with pytest.raises(ValueError): ZeroInflatedNegativeBinomial(mu=-mu, theta=theta, zi_logits=pi) with pytest.warns(UserWarning): dist1.log_prob(-x) # ensures neg values raise warning with pytest.warns(UserWarning): dist2.log_prob(0.5 * x) # ensures float values raise warning
def posterior_predictive_sample( self, adata: Optional[AnnData] = None, indices: Optional[Sequence[int]] = None, n_samples: int = 1, gene_list: Optional[Sequence[str]] = None, batch_size: Optional[int] = None, ) -> np.ndarray: r""" Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. Parameters ---------- adata AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the AnnData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. n_samples Number of samples for each cell. gene_list Names of genes of interest. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. Returns ------- x_new : :py:class:`torch.Tensor` tensor with shape (n_cells, n_genes, n_samples) """ if self.model.gene_likelihood not in ["zinb", "nb", "poisson"]: raise ValueError("Invalid gene_likelihood.") adata = self._validate_anndata(adata) scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size) if indices is None: indices = np.arange(adata.n_obs) if gene_list is None: gene_mask = slice(None) else: all_genes = _get_var_names_from_setup_anndata(adata) gene_mask = [ True if gene in gene_list else False for gene in all_genes ] x_new = [] for tensors in scdl: x = tensors[_CONSTANTS.X_KEY] batch_idx = tensors[_CONSTANTS.BATCH_KEY] labels = tensors[_CONSTANTS.LABELS_KEY] outputs = self.model.inference(x, batch_index=batch_idx, y=labels, n_samples=n_samples) px_r = outputs["px_r"] px_rate = outputs["px_rate"] px_dropout = outputs["px_dropout"] if self.model.gene_likelihood == "poisson": l_train = px_rate l_train = torch.clamp(l_train, max=1e8) dist = torch.distributions.Poisson( l_train) # Shape : (n_samples, n_cells_batch, n_genes) elif self.model.gene_likelihood == "nb": dist = NegativeBinomial(mu=px_rate, theta=px_r) elif self.model.gene_likelihood == "zinb": dist = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) else: raise ValueError( "{} reconstruction error not handled right now".format( self.model.gene_likelihood)) if n_samples > 1: exprs = dist.sample().permute( [1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() if gene_list is not None: exprs = exprs[:, gene_mask, ...] x_new.append(exprs.cpu()) x_new = torch.cat(x_new) # Shape (n_cells, n_genes, n_samples) return x_new.numpy()