Exemple #1
0
    def sample(
        self,
        tensors,
        n_samples=1,
        library_size=1,
    ) -> np.ndarray:
        r"""
        Generate observation samples from the posterior predictive distribution.

        The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`.

        Parameters
        ----------
        tensors
            Tensors dict
        n_samples
            Number of required samples for each cell
        library_size
            Library size to scale scamples to

        Returns
        -------
        x_new : :py:class:`torch.Tensor`
            tensor with shape (n_cells, n_genes, n_samples)
        """
        inference_kwargs = dict(n_samples=n_samples)
        inference_outputs, generative_outputs, = self.forward(
            tensors,
            inference_kwargs=inference_kwargs,
            compute_loss=False,
        )

        px_r = generative_outputs["px_r"]
        px_rate = generative_outputs["px_rate"]
        px_dropout = generative_outputs["px_dropout"]

        if self.gene_likelihood == "poisson":
            l_train = px_rate
            l_train = torch.clamp(l_train, max=1e8)
            dist = torch.distributions.Poisson(
                l_train)  # Shape : (n_samples, n_cells_batch, n_genes)
        elif self.gene_likelihood == "nb":
            dist = NegativeBinomial(mu=px_rate, theta=px_r)
        elif self.gene_likelihood == "zinb":
            dist = ZeroInflatedNegativeBinomial(mu=px_rate,
                                                theta=px_r,
                                                zi_logits=px_dropout)
        else:
            raise ValueError(
                "{} reconstruction error not handled right now".format(
                    self.module.gene_likelihood))
        if n_samples > 1:
            exprs = dist.sample().permute(
                [1, 2, 0])  # Shape : (n_cells_batch, n_genes, n_samples)
        else:
            exprs = dist.sample()

        return exprs.cpu()
Exemple #2
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        kl_weight: float = 1.0,
    ):
        x = tensors[REGISTRY_KEYS.X_KEY]
        y = tensors[REGISTRY_KEYS.LABELS_KEY]
        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        px_rate = generative_outputs["px_rate"]
        px_r = generative_outputs["px_r"]

        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)

        reconst_loss = -NegativeBinomial(px_rate,
                                         logits=px_r).log_prob(x).sum(-1)
        scaling_factor = self.ct_weight[y.long()[:, 0]]
        loss = torch.mean(scaling_factor *
                          (reconst_loss + kl_weight * kl_divergence_z))

        return LossRecorder(loss, reconst_loss, kl_divergence_z,
                            torch.tensor(0.0))
Exemple #3
0
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        px_rate: torch.Tensor,
        px_r: torch.Tensor,
        px_dropout: torch.Tensor,
        bernoulli_params: torch.Tensor,
        eps_log: float = 1e-8,
        **kwargs,
    ) -> torch.Tensor:

        # LLs for NB and ZINB
        ll_zinb = torch.log(1.0 - bernoulli_params +
                            eps_log) + ZeroInflatedNegativeBinomial(
                                mu=px_rate, theta=px_r,
                                zi_logits=px_dropout).log_prob(x)
        ll_nb = torch.log(bernoulli_params + eps_log) + NegativeBinomial(
            mu=px_rate, theta=px_r).log_prob(x)

        # Reconstruction loss using a logsumexp-type computation
        ll_max = torch.max(ll_zinb, ll_nb)
        ll_tot = ll_max + torch.log(
            torch.exp(ll_nb - ll_max) + torch.exp(ll_zinb - ll_max))
        reconst_loss = -ll_tot.sum(dim=-1)

        return reconst_loss
Exemple #4
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        kl_weight: float = 1.0,
        n_obs: int = 1.0,
    ):
        x = tensors[REGISTRY_KEYS.X_KEY]
        px_rate = generative_outputs["px_rate"]
        px_o = generative_outputs["px_o"]
        gamma = generative_outputs["gamma"]

        reconst_loss = -NegativeBinomial(px_rate,
                                         logits=px_o).log_prob(x).sum(-1)

        # eta prior likelihood
        mean = torch.zeros_like(self.eta)
        scale = torch.ones_like(self.eta)
        glo_neg_log_likelihood_prior = -Normal(mean, scale).log_prob(
            self.eta).sum()
        glo_neg_log_likelihood_prior += torch.var(self.beta)

        # gamma prior likelihood
        if self.mean_vprior is None:
            # isotropic normal prior
            mean = torch.zeros_like(gamma)
            scale = torch.ones_like(gamma)
            neg_log_likelihood_prior = (
                -Normal(mean, scale).log_prob(gamma).sum(2).sum(1))
        else:
            # vampprior
            # gamma is of shape n_latent, n_labels, minibatch_size
            gamma = gamma.unsqueeze(1)  # minibatch_size, 1, n_labels, n_latent
            mean_vprior = torch.transpose(self.mean_vprior, 0, 1).unsqueeze(
                0)  # 1, p, n_labels, n_latent
            var_vprior = torch.transpose(self.var_vprior, 0, 1).unsqueeze(
                0)  # 1, p, n_labels, n_latent
            pre_lse = (Normal(mean_vprior,
                              torch.sqrt(var_vprior)).log_prob(gamma).sum(-1)
                       )  # minibatch, p, n_labels
            log_likelihood_prior = torch.logsumexp(pre_lse, 1) - np.log(
                self.p)  # minibatch, n_labels
            neg_log_likelihood_prior = -log_likelihood_prior.sum(
                1)  # minibatch
            # mean_vprior is of shape n_labels, p, n_latent

        loss = (
            n_obs *
            torch.mean(reconst_loss + kl_weight * neg_log_likelihood_prior) +
            glo_neg_log_likelihood_prior)

        return LossRecorder(loss, reconst_loss, neg_log_likelihood_prior,
                            glo_neg_log_likelihood_prior)
Exemple #5
0
 def get_reconstruction_loss_expression(self, x, px_rate, px_r, px_dropout):
     rl = 0.0
     if self.gene_likelihood == "zinb":
         rl = (-ZeroInflatedNegativeBinomial(
             mu=px_rate, theta=px_r,
             zi_logits=px_dropout).log_prob(x).sum(dim=-1))
     elif self.gene_likelihood == "nb":
         rl = -NegativeBinomial(mu=px_rate,
                                theta=px_r).log_prob(x).sum(dim=-1)
     elif self.gene_likelihood == "poisson":
         rl = -Poisson(px_rate).log_prob(x).sum(dim=-1)
     return rl
Exemple #6
0
 def get_reconstruction_loss(self, x, px_rate, px_r,
                             px_dropout) -> torch.Tensor:
     if self.gene_likelihood == "zinb":
         reconst_loss = (-ZeroInflatedNegativeBinomial(
             mu=px_rate, theta=px_r,
             zi_logits=px_dropout).log_prob(x).sum(dim=-1))
     elif self.gene_likelihood == "nb":
         reconst_loss = (-NegativeBinomial(
             mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1))
     elif self.gene_likelihood == "poisson":
         reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
     return reconst_loss
Exemple #7
0
    def sample(
        self,
        tensors,
        n_samples=1,
    ) -> np.ndarray:
        r"""
        Generate observation samples from the posterior predictive distribution.

        The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`.

        Parameters
        ----------
        tensors
            Tensors dict
        n_samples
            Number of required samples for each cell

        Returns
        -------
        x_new : :py:class:`torch.Tensor`
            tensor with shape (n_cells, n_genes, n_samples)
        """
        inference_kwargs = dict(n_samples=n_samples)
        generative_outputs = self.forward(
            tensors,
            inference_kwargs=inference_kwargs,
            compute_loss=False,
        )[1]

        px_r = generative_outputs["px_r"]
        px_rate = generative_outputs["px_rate"]

        dist = NegativeBinomial(px_rate, logits=px_r)
        if n_samples > 1:
            exprs = dist.sample().permute(
                [1, 2, 0])  # Shape : (n_cells_batch, n_genes, n_samples)
        else:
            exprs = dist.sample()

        return exprs.cpu()
Exemple #8
0
    def sample(self, tensors, n_samples=1):
        inference_kwargs = dict(n_samples=n_samples)
        with torch.no_grad():
            inference_outputs, generative_outputs, = self.forward(
                tensors,
                inference_kwargs=inference_kwargs,
                compute_loss=False,
            )

        px_ = generative_outputs["px_"]
        py_ = generative_outputs["py_"]

        rna_dist = NegativeBinomial(mu=px_["rate"], theta=px_["r"])
        protein_dist = NegativeBinomialMixture(
            mu1=py_["rate_back"],
            mu2=py_["rate_fore"],
            theta1=py_["r"],
            mixture_logits=py_["mixing"],
        )
        rna_sample = rna_dist.sample().cpu()
        protein_sample = protein_dist.sample().cpu()

        return rna_sample, protein_sample
Exemple #9
0
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        px_dict: Dict[str, torch.Tensor],
        py_dict: Dict[str, torch.Tensor],
        pro_batch_mask_minibatch: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute reconstruction loss."""
        px_ = px_dict
        py_ = py_dict
        # Reconstruction Loss
        if self.gene_likelihood == "zinb":
            reconst_loss_gene = (
                -ZeroInflatedNegativeBinomial(
                    mu=px_["rate"], theta=px_["r"], zi_logits=px_["dropout"]
                )
                .log_prob(x)
                .sum(dim=-1)
            )
        else:
            reconst_loss_gene = (
                -NegativeBinomial(mu=px_["rate"], theta=px_["r"])
                .log_prob(x)
                .sum(dim=-1)
            )

        py_conditional = NegativeBinomialMixture(
            mu1=py_["rate_back"],
            mu2=py_["rate_fore"],
            theta1=py_["r"],
            mixture_logits=py_["mixing"],
        )
        reconst_loss_protein_full = -py_conditional.log_prob(y)
        if pro_batch_mask_minibatch is not None:
            temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full)
            temp_pro_loss_full.masked_scatter_(
                pro_batch_mask_minibatch.bool(), reconst_loss_protein_full
            )

            reconst_loss_protein = temp_pro_loss_full.sum(dim=-1)
        else:
            reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1)

        return reconst_loss_gene, reconst_loss_protein
Exemple #10
0
 def reconstruction_loss(
     self,
     x: torch.Tensor,
     px_rate: torch.Tensor,
     px_r: torch.Tensor,
     px_dropout: torch.Tensor,
     mode: int,
 ) -> torch.Tensor:
     reconstruction_loss = None
     if self.gene_likelihoods[mode] == "zinb":
         reconstruction_loss = (-ZeroInflatedNegativeBinomial(
             mu=px_rate, theta=px_r,
             zi_logits=px_dropout).log_prob(x).sum(dim=-1))
     elif self.gene_likelihoods[mode] == "nb":
         reconstruction_loss = (-NegativeBinomial(
             mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1))
     elif self.gene_likelihoods[mode] == "poisson":
         reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1)
     return reconstruction_loss
def test_zinb_distribution():
    theta = 100.0 + torch.rand(size=(2, ))
    mu = 15.0 * torch.ones_like(theta)
    pi = torch.randn_like(theta)
    x = torch.randint_like(mu, high=20)
    log_p_ref = log_zinb_positive(x, mu, theta, pi)

    dist = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi)
    log_p_zinb = dist.log_prob(x)
    assert (log_p_ref - log_p_zinb).abs().max().item() <= 1e-8

    torch.manual_seed(0)
    s1 = dist.sample((100, ))
    assert s1.shape == (100, 2)
    s2 = dist.sample(sample_shape=(4, 3))
    assert s2.shape == (4, 3, 2)

    log_p_ref = log_nb_positive(x, mu, theta)
    dist = NegativeBinomial(mu=mu, theta=theta)
    log_p_nb = dist.log_prob(x)
    assert (log_p_ref - log_p_nb).abs().max().item() <= 1e-8

    s1 = dist.sample((1000, ))
    assert s1.shape == (1000, 2)
    assert (s1.mean(0) - mu).abs().mean() <= 1e0
    assert (s1.std(0) - (mu + mu * mu / theta)**0.5).abs().mean() <= 1e0

    size = (50, 3)
    theta = 100.0 + torch.rand(size=size)
    mu = 15.0 * torch.ones_like(theta)
    pi = torch.randn_like(theta)
    x = torch.randint_like(mu, high=20)
    dist1 = ZeroInflatedNegativeBinomial(mu=mu,
                                         theta=theta,
                                         zi_logits=pi,
                                         validate_args=True)
    dist2 = NegativeBinomial(mu=mu, theta=theta, validate_args=True)
    assert dist1.log_prob(x).shape == size
    assert dist2.log_prob(x).shape == size

    with pytest.raises(ValueError):
        ZeroInflatedNegativeBinomial(mu=-mu,
                                     theta=theta,
                                     zi_logits=pi,
                                     validate_args=True)
    with pytest.warns(UserWarning):
        dist1.log_prob(-x)  # ensures neg values raise warning
    with pytest.warns(UserWarning):
        dist2.log_prob(0.5 * x)  # ensures float values raise warning
Exemple #12
0
    def generative(self, x, size_factor, design_matrix=None):
        # x has shape (n, g)
        delta = torch.exp(self.delta_log)  # (g, c)
        theta_log = F.log_softmax(self.theta_logit, dim=-1)  # (c)

        # compute mean of NegBin - shape (n_cells, n_genes, n_labels)
        n_cells = size_factor.shape[0]
        base_mean = torch.log(size_factor)  # (n, 1)
        base_mean = base_mean.unsqueeze(-1).expand(n_cells, self.n_genes,
                                                   self.n_labels)  # (n, g, c)

        # compute beta (covariate coefficent)
        # design_matrix has shape (n,p)
        if design_matrix is not None:
            covariates = torch.einsum("np,gp->gn", design_matrix,
                                      self.beta)  # (g, n)
            covariates = torch.transpose(covariates, 0,
                                         1).unsqueeze(-1)  # (n, g, 1)
            covariates = covariates.expand(n_cells, self.n_genes,
                                           self.n_labels)
            base_mean = base_mean + covariates

        # base gene expression
        b_g_0 = self.b_g_0.unsqueeze(-1).expand(n_cells, self.n_genes,
                                                self.n_labels)
        delta_rho = delta * self.rho
        delta_rho = delta_rho.expand(n_cells, self.n_genes,
                                     self.n_labels)  # (n, g, c)
        log_mu_ngc = base_mean + delta_rho + b_g_0
        mu_ngc = torch.exp(log_mu_ngc)  # (n, g, c)

        # compute phi of NegBin - shape (n_cells, n_genes, n_labels)
        a = torch.exp(self.log_a)  # (B)
        a = a.expand(n_cells, self.n_genes, self.n_labels, B)
        b_init = 2 * ((self.basis_means[1] - self.basis_means[0])**2)
        b = torch.exp(torch.ones(B, device=x.device) *
                      (-torch.log(b_init)))  # (B)
        b = b.expand(n_cells, self.n_genes, self.n_labels, B)
        mu_ngcb = mu_ngc.unsqueeze(-1).expand(n_cells, self.n_genes,
                                              self.n_labels, B)  # (n, g, c, B)
        basis_means = self.basis_means.expand(n_cells, self.n_genes,
                                              self.n_labels, B)  # (n, g, c, B)
        phi = (  # (n, g, c)
            torch.sum(a * torch.exp(-b * torch.square(mu_ngcb - basis_means)),
                      3) + LOWER_BOUND)

        # compute gamma
        nb_pdf = NegativeBinomial(mu=mu_ngc, theta=phi)
        x_ = x.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels)
        x_log_prob_raw = nb_pdf.log_prob(x_)  # (n, g, c)
        theta_log = theta_log.expand(n_cells, self.n_labels)
        p_x_c = torch.sum(x_log_prob_raw, 1) + theta_log  # (n, c)
        normalizer_over_c = torch.logsumexp(p_x_c, 1)
        normalizer_over_c = normalizer_over_c.unsqueeze(-1).expand(
            n_cells, self.n_labels)
        gamma = torch.exp(p_x_c - normalizer_over_c)  # (n, c)

        return dict(
            mu=mu_ngc,
            phi=phi,
            gamma=gamma,
            p_x_c=p_x_c,
            s=size_factor,
        )