예제 #1
0
파일: totalvi.py 프로젝트: yynst2/scVI
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        px_: Dict[str, torch.Tensor],
        py_: Dict[str, torch.Tensor],
        pro_batch_mask_minibatch: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute reconstruction loss
        """
        # Reconstruction Loss
        if self.reconstruction_loss_gene == "zinb":
            reconst_loss_gene = (-ZeroInflatedNegativeBinomial(
                mu=px_["rate"], theta=px_["r"],
                zi_logits=px_["dropout"]).log_prob(x).sum(dim=-1))
        else:
            reconst_loss_gene = (-NegativeBinomial(
                mu=px_["rate"], theta=px_["r"]).log_prob(x).sum(dim=-1))

        reconst_loss_protein_full = -log_mixture_nb(
            y, py_["rate_back"], py_["rate_fore"], py_["r"], None,
            py_["mixing"])
        if pro_batch_mask_minibatch is not None:
            temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full)
            temp_pro_loss_full.masked_scatter_(pro_batch_mask_minibatch.bool(),
                                               reconst_loss_protein_full)

            reconst_loss_protein = temp_pro_loss_full.sum(dim=-1)
        else:
            reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1)

        return reconst_loss_gene, reconst_loss_protein
예제 #2
0
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        px_rate: torch.Tensor,
        px_r: torch.Tensor,
        px_dropout: torch.Tensor,
        bernoulli_params: torch.Tensor,
        eps_log: float = 1e-8,
        **kwargs,
    ) -> torch.Tensor:

        # LLs for NB and ZINB
        ll_zinb = torch.log(1.0 - bernoulli_params +
                            eps_log) + ZeroInflatedNegativeBinomial(
                                mu=px_rate, theta=px_r,
                                zi_logits=px_dropout).log_prob(x)
        ll_nb = torch.log(bernoulli_params + eps_log) + NegativeBinomial(
            mu=px_rate, theta=px_r).log_prob(x)

        # Reconstruction loss using a logsumexp-type computation
        ll_max = torch.max(ll_zinb, ll_nb)
        ll_tot = ll_max + torch.log(
            torch.exp(ll_nb - ll_max) + torch.exp(ll_zinb - ll_max))
        reconst_loss = -ll_tot.sum(dim=-1)

        return reconst_loss
예제 #3
0
파일: vae.py 프로젝트: yynst2/scVI
 def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout,
                             **kwargs) -> torch.Tensor:
     # Reconstruction Loss
     if self.reconstruction_loss == "zinb":
         reconst_loss = (-ZeroInflatedNegativeBinomial(
             mu=px_rate, theta=px_r,
             zi_logits=px_dropout).log_prob(x).sum(dim=-1))
     elif self.reconstruction_loss == "nb":
         reconst_loss = (-NegativeBinomial(
             mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1))
     elif self.reconstruction_loss == "poisson":
         reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
     return reconst_loss
예제 #4
0
 def reconstruction_loss(
     self,
     x: torch.Tensor,
     px_rate: torch.Tensor,
     px_r: torch.Tensor,
     px_dropout: torch.Tensor,
     mode: int,
 ) -> torch.Tensor:
     reconstruction_loss = None
     if self.reconstruction_losses[mode] == "zinb":
         reconstruction_loss = (-ZeroInflatedNegativeBinomial(
             mu=px_rate, theta=px_r,
             zi_logits=px_dropout).log_prob(x).sum(dim=-1))
     elif self.reconstruction_losses[mode] == "nb":
         reconstruction_loss = (-NegativeBinomial(
             mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1))
     elif self.reconstruction_losses[mode] == "poisson":
         reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1)
     return reconstruction_loss
예제 #5
0
def test_zinb_distribution():
    theta = 100.0 + torch.rand(size=(2, ))
    mu = 15.0 * torch.ones_like(theta)
    pi = torch.randn_like(theta)
    x = torch.randint_like(mu, high=20)
    log_p_ref = log_zinb_positive(x, mu, theta, pi)

    dist = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi)
    log_p_zinb = dist.log_prob(x)
    assert (log_p_ref - log_p_zinb).abs().max().item() <= 1e-8

    torch.manual_seed(0)
    s1 = dist.sample((100, ))
    assert s1.shape == (100, 2)
    s2 = dist.sample(sample_shape=(4, 3))
    assert s2.shape == (4, 3, 2)

    log_p_ref = log_nb_positive(x, mu, theta)
    dist = NegativeBinomial(mu=mu, theta=theta)
    log_p_nb = dist.log_prob(x)
    assert (log_p_ref - log_p_nb).abs().max().item() <= 1e-8

    s1 = dist.sample((1000, ))
    assert s1.shape == (1000, 2)
    assert (s1.mean(0) - mu).abs().mean() <= 1e0
    assert (s1.std(0) - (mu + mu * mu / theta)**0.5).abs().mean() <= 1e0

    size = (50, 3)
    theta = 100.0 + torch.rand(size=size)
    mu = 15.0 * torch.ones_like(theta)
    pi = torch.randn_like(theta)
    x = torch.randint_like(mu, high=20)
    dist1 = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi)
    dist2 = NegativeBinomial(mu=mu, theta=theta)
    assert dist1.log_prob(x).shape == size
    assert dist2.log_prob(x).shape == size

    with pytest.raises(ValueError):
        ZeroInflatedNegativeBinomial(mu=-mu, theta=theta, zi_logits=pi)
    with pytest.warns(UserWarning):
        dist1.log_prob(-x)  # ensures neg values raise warning
    with pytest.warns(UserWarning):
        dist2.log_prob(0.5 * x)  # ensures float values raise warning