コード例 #1
0
 def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout):
     # Reconstruction Loss
     if self.reconstruction_loss == "zinb":
         reconst_loss = -log_zinb_positive(x, px_rate, px_r, px_dropout).sum(dim=-1)
     elif self.reconstruction_loss == "nb":
         reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1)
     return reconst_loss
コード例 #2
0
ファイル: vae.py プロジェクト: shunsunsun/HarmonizationSCANVI
 def _reconstruction_loss(self, x, px_rate, px_r, px_dropout):
     # Reconstruction Loss
     if self.reconstruction_loss == 'zinb':
         reconst_loss = -log_zinb_positive(x, px_rate, px_r, px_dropout)
     elif self.reconstruction_loss == 'nb':
         reconst_loss = -log_nb_positive(x, px_rate, px_r)
     return reconst_loss
コード例 #3
0
ファイル: vae_fish.py プロジェクト: Edouard360/scVI
    def _reconstruction_loss(self,
                             x,
                             px_rate,
                             px_r,
                             px_dropout,
                             batch_index,
                             y,
                             mode="scRNA",
                             weighting=1):
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r

        # Reconstruction Loss
        if mode == "scRNA":
            if self.reconstruction_loss == 'zinb':
                reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                                  px_dropout)
            elif self.reconstruction_loss == 'nb':
                reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))

        else:
            if self.reconstruction_loss_fish == 'poisson':
                reconst_loss = -torch.sum(Poisson(px_rate).log_prob(x), dim=1)
            elif self.reconstruction_loss_fish == 'gaussian':
                reconst_loss = -torch.sum(Normal(px_rate, 10).log_prob(x),
                                          dim=1)
        return reconst_loss
コード例 #4
0
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        px_: Dict[str, torch.Tensor],
        py_: Dict[str, torch.Tensor],
        pro_batch_mask_minibatch: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute reconstruction loss
        """
        # Reconstruction Loss
        if self.reconstruction_loss_gene == "zinb":
            reconst_loss_gene = -log_zinb_positive(x, px_["rate"], px_["r"],
                                                   px_["dropout"]).sum(dim=-1)
        else:
            reconst_loss_gene = -log_nb_positive(x, px_["rate"],
                                                 px_["r"]).sum(dim=-1)

        reconst_loss_protein_full = -log_mixture_nb(
            y, py_["rate_back"], py_["rate_fore"], py_["r"], None,
            py_["mixing"])
        if pro_batch_mask_minibatch is not None:
            temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full)
            temp_pro_loss_full.masked_scatter_(pro_batch_mask_minibatch.bool(),
                                               reconst_loss_protein_full)

            reconst_loss_protein = temp_pro_loss_full.sum(dim=-1)
        else:
            reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1)

        return reconst_loss_gene, reconst_loss_protein
コード例 #5
0
ファイル: autozivae.py プロジェクト: maichmueller/scVI_ma
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        px_rate: torch.Tensor,
        px_r: torch.Tensor,
        px_dropout: torch.Tensor,
        bernoulli_params: torch.Tensor,
        eps_log: float = 1e-8,
        **kwargs,
    ) -> torch.Tensor:

        # LLs for NB and ZINB
        ll_zinb = torch.log(1.0 - bernoulli_params + eps_log) + log_zinb_positive(
            x, px_rate, px_r, px_dropout
        )
        ll_nb = torch.log(bernoulli_params + eps_log) + log_nb_positive(
            x, px_rate, px_r
        )

        # Reconstruction loss using a logsumexp-type computation
        ll_max = torch.max(ll_zinb, ll_nb)
        ll_tot = ll_max + torch.log(
            torch.exp(ll_nb - ll_max) + torch.exp(ll_zinb - ll_max)
        )
        reconst_loss = -ll_tot.sum(dim=-1)

        return reconst_loss
コード例 #6
0
    def get_reconstruction_loss(self,
                                x,
                                px_rate,
                                px_r,
                                px_dropout,
                                mode="scRNA",
                                weighting=1):

        # Reconstruction Loss
        if mode == "scRNA":
            if self.reconstruction_loss == 'zinb':
                reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                                  px_dropout)
            elif self.reconstruction_loss == 'nb':
                reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))

        else:
            if self.reconstruction_loss_fish == 'poisson':
                reconst_loss = -torch.sum(Poisson(px_rate).log_prob(
                    x[:, self.indexes_to_keep]),
                                          dim=1)
            elif self.reconstruction_loss_fish == 'gaussian':
                reconst_loss = -torch.sum(Normal(px_rate, 10).log_prob(
                    x[:, self.indexes_to_keep]),
                                          dim=1)
        return reconst_loss
コード例 #7
0
 def get_reconstruction_atac_loss(self, x, mu, dispersion, dropout, type = "zinb", **kwargs):
     if type == "zinb":
         reconst_loss = -log_zinb_positive(x, mu, dispersion, dropout).sum(dim=-1)
     elif type == "zip":
         reconst_loss = - log_zip_positive(x, mu, dropout).sum(dim=-1)
     else:
         reconst_loss = - binary_cross_entropy(x, mu).sum(dim=-1)
     return reconst_loss
コード例 #8
0
 def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs):
     # Reconstruction Loss
     if self.reconstruction_loss == "zinb":
         reconst_loss = -log_zinb_positive(x, px_rate, px_r, px_dropout).sum(dim=-1)
     elif self.reconstruction_loss == "nb":
         reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1)
     elif self.reconstruction_loss == "poisson":
         reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
     return reconst_loss
コード例 #9
0
ファイル: distributions.py プロジェクト: yynst2/scVI
 def log_prob(self, value: torch.Tensor) -> torch.Tensor:
     try:
         self._validate_sample(value)
     except ValueError:
         warnings.warn(
             "The value argument must be within the support of the distribution",
             UserWarning,
         )
     return log_zinb_positive(value,
                              self.mu,
                              self.theta,
                              self.zi_logits,
                              eps=1e-08)
コード例 #10
0
ファイル: vae.py プロジェクト: zhuy16/scVI
 def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout,
                             **kwargs) -> torch.Tensor:
     """Return the reconstruction loss (for a minibatch)
     """
     # Reconstruction Loss
     if self.reconstruction_loss == "zinb":
         reconst_loss = -log_zinb_positive(x, px_rate, px_r,
                                           px_dropout).sum(dim=-1)
     elif self.reconstruction_loss == "nb":
         reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1)
     elif self.reconstruction_loss == "poisson":
         reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
     return reconst_loss
コード例 #11
0
ファイル: jvae.py プロジェクト: zihao12/scVI
 def reconstruction_loss(
     self,
     x: torch.Tensor,
     px_rate: torch.Tensor,
     px_r: torch.Tensor,
     px_dropout: torch.Tensor,
     mode: int,
 ) -> torch.Tensor:
     reconstruction_loss = None
     if self.reconstruction_losses[mode] == "zinb":
         reconstruction_loss = -log_zinb_positive(x, px_rate, px_r,
                                                  px_dropout)
     elif self.reconstruction_losses[mode] == "nb":
         reconstruction_loss = -log_nb_positive(x, px_rate, px_r)
     elif self.reconstruction_losses[mode] == "poisson":
         reconstruction_loss = -torch.sum(Poisson(px_rate).log_prob(x),
                                          dim=1)
     return reconstruction_loss
コード例 #12
0
    def _reconstruction_loss(self, x, px_rate, px_r, px_dropout, batch_index,
                             y):
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r

        # Reconstruction Loss
        if self.reconstruction_loss == 'zinb':
            reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                              px_dropout)
        elif self.reconstruction_loss == 'nb':
            reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))
        return reconst_loss
コード例 #13
0
ファイル: test_scvi.py プロジェクト: shaoxin0801/scVI
def test_zinb_distribution():
    theta = 100.0 + torch.rand(size=(2, ))
    mu = 15.0 * torch.ones_like(theta)
    pi = torch.randn_like(theta)
    x = torch.randint_like(mu, high=20)
    log_p_ref = log_zinb_positive(x, mu, theta, pi)

    dist = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi)
    log_p_zinb = dist.log_prob(x)
    assert (log_p_ref - log_p_zinb).abs().max().item() <= 1e-8

    torch.manual_seed(0)
    s1 = dist.sample((100, ))
    assert s1.shape == (100, 2)
    s2 = dist.sample(sample_shape=(4, 3))
    assert s2.shape == (4, 3, 2)

    log_p_ref = log_nb_positive(x, mu, theta)
    dist = NegativeBinomial(mu=mu, theta=theta)
    log_p_nb = dist.log_prob(x)
    assert (log_p_ref - log_p_nb).abs().max().item() <= 1e-8

    s1 = dist.sample((1000, ))
    assert s1.shape == (1000, 2)
    assert (s1.mean(0) - mu).abs().mean() <= 1e0
    assert (s1.std(0) - (mu + mu * mu / theta)**0.5).abs().mean() <= 1e0

    size = (50, 3)
    theta = 100.0 + torch.rand(size=size)
    mu = 15.0 * torch.ones_like(theta)
    pi = torch.randn_like(theta)
    x = torch.randint_like(mu, high=20)
    dist1 = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi)
    dist2 = NegativeBinomial(mu=mu, theta=theta)
    assert dist1.log_prob(x).shape == size
    assert dist2.log_prob(x).shape == size

    with pytest.raises(ValueError):
        ZeroInflatedNegativeBinomial(mu=-mu, theta=theta, zi_logits=pi)
    with pytest.warns(UserWarning):
        dist1.log_prob(-x)  # ensures neg values raise warning
    with pytest.warns(UserWarning):
        dist2.log_prob(0.5 * x)  # ensures float values raise warning
コード例 #14
0
ファイル: totalvi.py プロジェクト: maichmueller/scVI_ma
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        px_: Dict[str, torch.Tensor],
        py_: Dict[str, torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Reconstruction Loss
        if self.reconstruction_loss_gene == "zinb":
            reconst_loss_gene = -log_zinb_positive(
                x, px_["rate"], px_["r"], px_["dropout"]
            ).sum(dim=-1)
        else:
            reconst_loss_gene = -log_nb_positive(x, px_["rate"], px_["r"]).sum(dim=-1)

        reconst_loss_protein = -log_mixture_nb(
            y, py_["rate_back"], py_["rate_fore"], py_["r"], None, py_["mixing"]
        ).sum(dim=-1)

        return reconst_loss_gene, reconst_loss_protein
コード例 #15
0
    nb = NegativeBinomialDisp(loc=mean, disp=disp_row)
    llk1 = tf.reduce_sum(nb.log_prob(x), axis=1).numpy()
    llk2 = log_nb_positive(x=torch.Tensor(x),
                           mu=torch.Tensor(mean),
                           theta=torch.Tensor(disp_row)).numpy()
    print(np.all(np.isclose(llk1, llk2)))
except:
    print("NOT POSSIBLE TO BROADCAST the first dimension")

# all disp available
nb = NegativeBinomialDisp(loc=mean, disp=disp)
llk1 = tf.reduce_sum(nb.log_prob(x), axis=1).numpy()
llk2 = log_nb_positive(x=torch.Tensor(x),
                       mu=torch.Tensor(mean),
                       theta=torch.Tensor(disp)).numpy()
print(np.all(np.isclose(llk1, llk2)))

s1 = nb.sample().numpy()
s2 = torch_nb(mean, disp).numpy()
print(describe(s1))
print(describe(s2))

zinb = ZeroInflated(nb, probs=pi)
llk1 = tf.reduce_sum(zinb.log_prob(x), axis=1).numpy()
llk2 = log_zinb_positive(x=torch.Tensor(x),
                         mu=torch.Tensor(mean),
                         theta=torch.Tensor(disp),
                         pi=torch.Tensor(pi)).numpy()
print(llk1)
print(llk2)