Esempio n. 1
0
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        px_rate: torch.Tensor,
        px_r: torch.Tensor,
        px_dropout: torch.Tensor,
        bernoulli_params: torch.Tensor,
        eps_log: float = 1e-8,
        **kwargs,
    ) -> torch.Tensor:

        # LLs for NB and ZINB
        ll_zinb = torch.log(1.0 - bernoulli_params +
                            eps_log) + ZeroInflatedNegativeBinomial(
                                mu=px_rate, theta=px_r,
                                zi_logits=px_dropout).log_prob(x)
        ll_nb = torch.log(bernoulli_params + eps_log) + NegativeBinomial(
            mu=px_rate, theta=px_r).log_prob(x)

        # Reconstruction loss using a logsumexp-type computation
        ll_max = torch.max(ll_zinb, ll_nb)
        ll_tot = ll_max + torch.log(
            torch.exp(ll_nb - ll_max) + torch.exp(ll_zinb - ll_max))
        reconst_loss = -ll_tot.sum(dim=-1)

        return reconst_loss
Esempio n. 2
0
 def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout,
                             **kwargs) -> torch.Tensor:
     # Reconstruction Loss
     if self.gene_likelihood == "zinb":
         reconst_loss = (-ZeroInflatedNegativeBinomial(
             mu=px_rate, theta=px_r,
             zi_logits=px_dropout).log_prob(x).sum(dim=-1))
     elif self.gene_likelihood == "nb":
         reconst_loss = (-NegativeBinomial(
             mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1))
     elif self.gene_likelihood == "poisson":
         reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
     return reconst_loss
Esempio n. 3
0
 def reconstruction_loss(
     self,
     x: torch.Tensor,
     px_rate: torch.Tensor,
     px_r: torch.Tensor,
     px_dropout: torch.Tensor,
     mode: int,
 ) -> torch.Tensor:
     reconstruction_loss = None
     if self.gene_likelihoods[mode] == "zinb":
         reconstruction_loss = (-ZeroInflatedNegativeBinomial(
             mu=px_rate, theta=px_r,
             zi_logits=px_dropout).log_prob(x).sum(dim=-1))
     elif self.gene_likelihoods[mode] == "nb":
         reconstruction_loss = (-NegativeBinomial(
             mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1))
     elif self.gene_likelihoods[mode] == "poisson":
         reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1)
     return reconstruction_loss
Esempio n. 4
0
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        px_dict: Dict[str, torch.Tensor],
        py_dict: Dict[str, torch.Tensor],
        pro_batch_mask_minibatch: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute reconstruction loss."""
        px_ = px_dict
        py_ = py_dict
        # Reconstruction Loss
        if self.gene_likelihood == "zinb":
            reconst_loss_gene = (-ZeroInflatedNegativeBinomial(
                mu=px_["rate"], theta=px_["r"],
                zi_logits=px_["dropout"]).log_prob(x).sum(dim=-1))
        else:
            reconst_loss_gene = (-NegativeBinomial(
                mu=px_["rate"], theta=px_["r"]).log_prob(x).sum(dim=-1))

        py_conditional = NegativeBinomialMixture(
            mu1=py_["rate_back"],
            mu2=py_["rate_fore"],
            theta1=py_["r"],
            mixture_logits=py_["mixing"],
        )
        reconst_loss_protein_full = -py_conditional.log_prob(y)
        if pro_batch_mask_minibatch is not None:
            temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full)
            temp_pro_loss_full.masked_scatter_(pro_batch_mask_minibatch.bool(),
                                               reconst_loss_protein_full)

            reconst_loss_protein = temp_pro_loss_full.sum(dim=-1)
        else:
            reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1)

        return reconst_loss_gene, reconst_loss_protein
Esempio n. 5
0
def test_zinb_distribution():
    theta = 100.0 + torch.rand(size=(2, ))
    mu = 15.0 * torch.ones_like(theta)
    pi = torch.randn_like(theta)
    x = torch.randint_like(mu, high=20)
    log_p_ref = log_zinb_positive(x, mu, theta, pi)

    dist = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi)
    log_p_zinb = dist.log_prob(x)
    assert (log_p_ref - log_p_zinb).abs().max().item() <= 1e-8

    torch.manual_seed(0)
    s1 = dist.sample((100, ))
    assert s1.shape == (100, 2)
    s2 = dist.sample(sample_shape=(4, 3))
    assert s2.shape == (4, 3, 2)

    log_p_ref = log_nb_positive(x, mu, theta)
    dist = NegativeBinomial(mu=mu, theta=theta)
    log_p_nb = dist.log_prob(x)
    assert (log_p_ref - log_p_nb).abs().max().item() <= 1e-8

    s1 = dist.sample((1000, ))
    assert s1.shape == (1000, 2)
    assert (s1.mean(0) - mu).abs().mean() <= 1e0
    assert (s1.std(0) - (mu + mu * mu / theta)**0.5).abs().mean() <= 1e0

    size = (50, 3)
    theta = 100.0 + torch.rand(size=size)
    mu = 15.0 * torch.ones_like(theta)
    pi = torch.randn_like(theta)
    x = torch.randint_like(mu, high=20)
    dist1 = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi)
    dist2 = NegativeBinomial(mu=mu, theta=theta)
    assert dist1.log_prob(x).shape == size
    assert dist2.log_prob(x).shape == size

    with pytest.raises(ValueError):
        ZeroInflatedNegativeBinomial(mu=-mu, theta=theta, zi_logits=pi)
    with pytest.warns(UserWarning):
        dist1.log_prob(-x)  # ensures neg values raise warning
    with pytest.warns(UserWarning):
        dist2.log_prob(0.5 * x)  # ensures float values raise warning
Esempio n. 6
0
    def posterior_predictive_sample(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        n_samples: int = 1,
        gene_list: Optional[Sequence[str]] = None,
        batch_size: Optional[int] = None,
    ) -> np.ndarray:
        r"""
        Generate observation samples from the posterior predictive distribution.

        The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        n_samples
            Number of samples for each cell.
        gene_list
            Names of genes of interest.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

        Returns
        -------
        x_new : :py:class:`torch.Tensor`
            tensor with shape (n_cells, n_genes, n_samples)
        """
        if self.model.gene_likelihood not in ["zinb", "nb", "poisson"]:
            raise ValueError("Invalid gene_likelihood.")

        adata = self._validate_anndata(adata)
        scdl = self._make_scvi_dl(adata=adata,
                                  indices=indices,
                                  batch_size=batch_size)

        if indices is None:
            indices = np.arange(adata.n_obs)

        if gene_list is None:
            gene_mask = slice(None)
        else:
            all_genes = _get_var_names_from_setup_anndata(adata)
            gene_mask = [
                True if gene in gene_list else False for gene in all_genes
            ]

        x_new = []
        for tensors in scdl:
            x = tensors[_CONSTANTS.X_KEY]
            batch_idx = tensors[_CONSTANTS.BATCH_KEY]
            labels = tensors[_CONSTANTS.LABELS_KEY]
            outputs = self.model.inference(x,
                                           batch_index=batch_idx,
                                           y=labels,
                                           n_samples=n_samples)
            px_r = outputs["px_r"]
            px_rate = outputs["px_rate"]
            px_dropout = outputs["px_dropout"]

            if self.model.gene_likelihood == "poisson":
                l_train = px_rate
                l_train = torch.clamp(l_train, max=1e8)
                dist = torch.distributions.Poisson(
                    l_train)  # Shape : (n_samples, n_cells_batch, n_genes)
            elif self.model.gene_likelihood == "nb":
                dist = NegativeBinomial(mu=px_rate, theta=px_r)
            elif self.model.gene_likelihood == "zinb":
                dist = ZeroInflatedNegativeBinomial(mu=px_rate,
                                                    theta=px_r,
                                                    zi_logits=px_dropout)
            else:
                raise ValueError(
                    "{} reconstruction error not handled right now".format(
                        self.model.gene_likelihood))
            if n_samples > 1:
                exprs = dist.sample().permute(
                    [1, 2, 0])  # Shape : (n_cells_batch, n_genes, n_samples)
            else:
                exprs = dist.sample()

            if gene_list is not None:
                exprs = exprs[:, gene_mask, ...]

            x_new.append(exprs.cpu())
        x_new = torch.cat(x_new)  # Shape (n_cells, n_genes, n_samples)

        return x_new.numpy()