def test_zinb_distribution(): theta = 100.0 + torch.rand(size=(2, )) mu = 15.0 * torch.ones_like(theta) pi = torch.randn_like(theta) x = torch.randint_like(mu, high=20) log_p_ref = log_zinb_positive(x, mu, theta, pi) dist = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi) log_p_zinb = dist.log_prob(x) assert (log_p_ref - log_p_zinb).abs().max().item() <= 1e-8 torch.manual_seed(0) s1 = dist.sample((100, )) assert s1.shape == (100, 2) s2 = dist.sample(sample_shape=(4, 3)) assert s2.shape == (4, 3, 2) log_p_ref = log_nb_positive(x, mu, theta) dist = NegativeBinomial(mu=mu, theta=theta) log_p_nb = dist.log_prob(x) assert (log_p_ref - log_p_nb).abs().max().item() <= 1e-8 s1 = dist.sample((1000, )) assert s1.shape == (1000, 2) assert (s1.mean(0) - mu).abs().mean() <= 1e0 assert (s1.std(0) - (mu + mu * mu / theta)**0.5).abs().mean() <= 1e0 size = (50, 3) theta = 100.0 + torch.rand(size=size) mu = 15.0 * torch.ones_like(theta) pi = torch.randn_like(theta) x = torch.randint_like(mu, high=20) dist1 = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi, validate_args=True) dist2 = NegativeBinomial(mu=mu, theta=theta, validate_args=True) assert dist1.log_prob(x).shape == size assert dist2.log_prob(x).shape == size with pytest.raises(ValueError): ZeroInflatedNegativeBinomial(mu=-mu, theta=theta, zi_logits=pi, validate_args=True) with pytest.warns(UserWarning): dist1.log_prob(-x) # ensures neg values raise warning with pytest.warns(UserWarning): dist2.log_prob(0.5 * x) # ensures float values raise warning
def generative(self, x, size_factor, design_matrix=None): # x has shape (n, g) delta = torch.exp(self.delta_log) # (g, c) theta_log = F.log_softmax(self.theta_logit, dim=-1) # (c) # compute mean of NegBin - shape (n_cells, n_genes, n_labels) n_cells = size_factor.shape[0] base_mean = torch.log(size_factor) # (n, 1) base_mean = base_mean.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels) # (n, g, c) # compute beta (covariate coefficent) # design_matrix has shape (n,p) if design_matrix is not None: covariates = torch.einsum("np,gp->gn", design_matrix, self.beta) # (g, n) covariates = torch.transpose(covariates, 0, 1).unsqueeze(-1) # (n, g, 1) covariates = covariates.expand(n_cells, self.n_genes, self.n_labels) base_mean = base_mean + covariates # base gene expression b_g_0 = self.b_g_0.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels) delta_rho = delta * self.rho delta_rho = delta_rho.expand(n_cells, self.n_genes, self.n_labels) # (n, g, c) log_mu_ngc = base_mean + delta_rho + b_g_0 mu_ngc = torch.exp(log_mu_ngc) # (n, g, c) # compute phi of NegBin - shape (n_cells, n_genes, n_labels) a = torch.exp(self.log_a) # (B) a = a.expand(n_cells, self.n_genes, self.n_labels, B) b_init = 2 * ((self.basis_means[1] - self.basis_means[0])**2) b = torch.exp(torch.ones(B, device=x.device) * (-torch.log(b_init))) # (B) b = b.expand(n_cells, self.n_genes, self.n_labels, B) mu_ngcb = mu_ngc.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels, B) # (n, g, c, B) basis_means = self.basis_means.expand(n_cells, self.n_genes, self.n_labels, B) # (n, g, c, B) phi = ( # (n, g, c) torch.sum(a * torch.exp(-b * torch.square(mu_ngcb - basis_means)), 3) + LOWER_BOUND) # compute gamma nb_pdf = NegativeBinomial(mu=mu_ngc, theta=phi) x_ = x.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels) x_log_prob_raw = nb_pdf.log_prob(x_) # (n, g, c) theta_log = theta_log.expand(n_cells, self.n_labels) p_x_c = torch.sum(x_log_prob_raw, 1) + theta_log # (n, c) normalizer_over_c = torch.logsumexp(p_x_c, 1) normalizer_over_c = normalizer_over_c.unsqueeze(-1).expand( n_cells, self.n_labels) gamma = torch.exp(p_x_c - normalizer_over_c) # (n, c) return dict( mu=mu_ngc, phi=phi, gamma=gamma, p_x_c=p_x_c, s=size_factor, )