def test_zinb_distribution():
    theta = 100.0 + torch.rand(size=(2, ))
    mu = 15.0 * torch.ones_like(theta)
    pi = torch.randn_like(theta)
    x = torch.randint_like(mu, high=20)
    log_p_ref = log_zinb_positive(x, mu, theta, pi)

    dist = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi)
    log_p_zinb = dist.log_prob(x)
    assert (log_p_ref - log_p_zinb).abs().max().item() <= 1e-8

    torch.manual_seed(0)
    s1 = dist.sample((100, ))
    assert s1.shape == (100, 2)
    s2 = dist.sample(sample_shape=(4, 3))
    assert s2.shape == (4, 3, 2)

    log_p_ref = log_nb_positive(x, mu, theta)
    dist = NegativeBinomial(mu=mu, theta=theta)
    log_p_nb = dist.log_prob(x)
    assert (log_p_ref - log_p_nb).abs().max().item() <= 1e-8

    s1 = dist.sample((1000, ))
    assert s1.shape == (1000, 2)
    assert (s1.mean(0) - mu).abs().mean() <= 1e0
    assert (s1.std(0) - (mu + mu * mu / theta)**0.5).abs().mean() <= 1e0

    size = (50, 3)
    theta = 100.0 + torch.rand(size=size)
    mu = 15.0 * torch.ones_like(theta)
    pi = torch.randn_like(theta)
    x = torch.randint_like(mu, high=20)
    dist1 = ZeroInflatedNegativeBinomial(mu=mu,
                                         theta=theta,
                                         zi_logits=pi,
                                         validate_args=True)
    dist2 = NegativeBinomial(mu=mu, theta=theta, validate_args=True)
    assert dist1.log_prob(x).shape == size
    assert dist2.log_prob(x).shape == size

    with pytest.raises(ValueError):
        ZeroInflatedNegativeBinomial(mu=-mu,
                                     theta=theta,
                                     zi_logits=pi,
                                     validate_args=True)
    with pytest.warns(UserWarning):
        dist1.log_prob(-x)  # ensures neg values raise warning
    with pytest.warns(UserWarning):
        dist2.log_prob(0.5 * x)  # ensures float values raise warning
Example #2
0
    def generative(self, x, size_factor, design_matrix=None):
        # x has shape (n, g)
        delta = torch.exp(self.delta_log)  # (g, c)
        theta_log = F.log_softmax(self.theta_logit, dim=-1)  # (c)

        # compute mean of NegBin - shape (n_cells, n_genes, n_labels)
        n_cells = size_factor.shape[0]
        base_mean = torch.log(size_factor)  # (n, 1)
        base_mean = base_mean.unsqueeze(-1).expand(n_cells, self.n_genes,
                                                   self.n_labels)  # (n, g, c)

        # compute beta (covariate coefficent)
        # design_matrix has shape (n,p)
        if design_matrix is not None:
            covariates = torch.einsum("np,gp->gn", design_matrix,
                                      self.beta)  # (g, n)
            covariates = torch.transpose(covariates, 0,
                                         1).unsqueeze(-1)  # (n, g, 1)
            covariates = covariates.expand(n_cells, self.n_genes,
                                           self.n_labels)
            base_mean = base_mean + covariates

        # base gene expression
        b_g_0 = self.b_g_0.unsqueeze(-1).expand(n_cells, self.n_genes,
                                                self.n_labels)
        delta_rho = delta * self.rho
        delta_rho = delta_rho.expand(n_cells, self.n_genes,
                                     self.n_labels)  # (n, g, c)
        log_mu_ngc = base_mean + delta_rho + b_g_0
        mu_ngc = torch.exp(log_mu_ngc)  # (n, g, c)

        # compute phi of NegBin - shape (n_cells, n_genes, n_labels)
        a = torch.exp(self.log_a)  # (B)
        a = a.expand(n_cells, self.n_genes, self.n_labels, B)
        b_init = 2 * ((self.basis_means[1] - self.basis_means[0])**2)
        b = torch.exp(torch.ones(B, device=x.device) *
                      (-torch.log(b_init)))  # (B)
        b = b.expand(n_cells, self.n_genes, self.n_labels, B)
        mu_ngcb = mu_ngc.unsqueeze(-1).expand(n_cells, self.n_genes,
                                              self.n_labels, B)  # (n, g, c, B)
        basis_means = self.basis_means.expand(n_cells, self.n_genes,
                                              self.n_labels, B)  # (n, g, c, B)
        phi = (  # (n, g, c)
            torch.sum(a * torch.exp(-b * torch.square(mu_ngcb - basis_means)),
                      3) + LOWER_BOUND)

        # compute gamma
        nb_pdf = NegativeBinomial(mu=mu_ngc, theta=phi)
        x_ = x.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels)
        x_log_prob_raw = nb_pdf.log_prob(x_)  # (n, g, c)
        theta_log = theta_log.expand(n_cells, self.n_labels)
        p_x_c = torch.sum(x_log_prob_raw, 1) + theta_log  # (n, c)
        normalizer_over_c = torch.logsumexp(p_x_c, 1)
        normalizer_over_c = normalizer_over_c.unsqueeze(-1).expand(
            n_cells, self.n_labels)
        gamma = torch.exp(p_x_c - normalizer_over_c)  # (n, c)

        return dict(
            mu=mu_ngc,
            phi=phi,
            gamma=gamma,
            p_x_c=p_x_c,
            s=size_factor,
        )