def sample( self, tensors, n_samples=1, library_size=1, ) -> np.ndarray: r""" Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. Parameters ---------- tensors Tensors dict n_samples Number of required samples for each cell library_size Library size to scale scamples to Returns ------- x_new : :py:class:`torch.Tensor` tensor with shape (n_cells, n_genes, n_samples) """ inference_kwargs = dict(n_samples=n_samples) inference_outputs, generative_outputs, = self.forward( tensors, inference_kwargs=inference_kwargs, compute_loss=False, ) px_r = generative_outputs["px_r"] px_rate = generative_outputs["px_rate"] px_dropout = generative_outputs["px_dropout"] if self.gene_likelihood == "poisson": l_train = px_rate l_train = torch.clamp(l_train, max=1e8) dist = torch.distributions.Poisson( l_train) # Shape : (n_samples, n_cells_batch, n_genes) elif self.gene_likelihood == "nb": dist = NegativeBinomial(mu=px_rate, theta=px_r) elif self.gene_likelihood == "zinb": dist = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) else: raise ValueError( "{} reconstruction error not handled right now".format( self.module.gene_likelihood)) if n_samples > 1: exprs = dist.sample().permute( [1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() return exprs.cpu()
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, ): x = tensors[REGISTRY_KEYS.X_KEY] y = tensors[REGISTRY_KEYS.LABELS_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) reconst_loss = -NegativeBinomial(px_rate, logits=px_r).log_prob(x).sum(-1) scaling_factor = self.ct_weight[y.long()[:, 0]] loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) return LossRecorder(loss, reconst_loss, kl_divergence_z, torch.tensor(0.0))
def get_reconstruction_loss( self, x: torch.Tensor, px_rate: torch.Tensor, px_r: torch.Tensor, px_dropout: torch.Tensor, bernoulli_params: torch.Tensor, eps_log: float = 1e-8, **kwargs, ) -> torch.Tensor: # LLs for NB and ZINB ll_zinb = torch.log(1.0 - bernoulli_params + eps_log) + ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x) ll_nb = torch.log(bernoulli_params + eps_log) + NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x) # Reconstruction loss using a logsumexp-type computation ll_max = torch.max(ll_zinb, ll_nb) ll_tot = ll_max + torch.log( torch.exp(ll_nb - ll_max) + torch.exp(ll_zinb - ll_max)) reconst_loss = -ll_tot.sum(dim=-1) return reconst_loss
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, n_obs: int = 1.0, ): x = tensors[REGISTRY_KEYS.X_KEY] px_rate = generative_outputs["px_rate"] px_o = generative_outputs["px_o"] gamma = generative_outputs["gamma"] reconst_loss = -NegativeBinomial(px_rate, logits=px_o).log_prob(x).sum(-1) # eta prior likelihood mean = torch.zeros_like(self.eta) scale = torch.ones_like(self.eta) glo_neg_log_likelihood_prior = -Normal(mean, scale).log_prob( self.eta).sum() glo_neg_log_likelihood_prior += torch.var(self.beta) # gamma prior likelihood if self.mean_vprior is None: # isotropic normal prior mean = torch.zeros_like(gamma) scale = torch.ones_like(gamma) neg_log_likelihood_prior = ( -Normal(mean, scale).log_prob(gamma).sum(2).sum(1)) else: # vampprior # gamma is of shape n_latent, n_labels, minibatch_size gamma = gamma.unsqueeze(1) # minibatch_size, 1, n_labels, n_latent mean_vprior = torch.transpose(self.mean_vprior, 0, 1).unsqueeze( 0) # 1, p, n_labels, n_latent var_vprior = torch.transpose(self.var_vprior, 0, 1).unsqueeze( 0) # 1, p, n_labels, n_latent pre_lse = (Normal(mean_vprior, torch.sqrt(var_vprior)).log_prob(gamma).sum(-1) ) # minibatch, p, n_labels log_likelihood_prior = torch.logsumexp(pre_lse, 1) - np.log( self.p) # minibatch, n_labels neg_log_likelihood_prior = -log_likelihood_prior.sum( 1) # minibatch # mean_vprior is of shape n_labels, p, n_latent loss = ( n_obs * torch.mean(reconst_loss + kl_weight * neg_log_likelihood_prior) + glo_neg_log_likelihood_prior) return LossRecorder(loss, reconst_loss, neg_log_likelihood_prior, glo_neg_log_likelihood_prior)
def get_reconstruction_loss_expression(self, x, px_rate, px_r, px_dropout): rl = 0.0 if self.gene_likelihood == "zinb": rl = (-ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1)) elif self.gene_likelihood == "nb": rl = -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1) elif self.gene_likelihood == "poisson": rl = -Poisson(px_rate).log_prob(x).sum(dim=-1) return rl
def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout) -> torch.Tensor: if self.gene_likelihood == "zinb": reconst_loss = (-ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1)) elif self.gene_likelihood == "nb": reconst_loss = (-NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)) elif self.gene_likelihood == "poisson": reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1) return reconst_loss
def sample( self, tensors, n_samples=1, ) -> np.ndarray: r""" Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. Parameters ---------- tensors Tensors dict n_samples Number of required samples for each cell Returns ------- x_new : :py:class:`torch.Tensor` tensor with shape (n_cells, n_genes, n_samples) """ inference_kwargs = dict(n_samples=n_samples) generative_outputs = self.forward( tensors, inference_kwargs=inference_kwargs, compute_loss=False, )[1] px_r = generative_outputs["px_r"] px_rate = generative_outputs["px_rate"] dist = NegativeBinomial(px_rate, logits=px_r) if n_samples > 1: exprs = dist.sample().permute( [1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() return exprs.cpu()
def sample(self, tensors, n_samples=1): inference_kwargs = dict(n_samples=n_samples) with torch.no_grad(): inference_outputs, generative_outputs, = self.forward( tensors, inference_kwargs=inference_kwargs, compute_loss=False, ) px_ = generative_outputs["px_"] py_ = generative_outputs["py_"] rna_dist = NegativeBinomial(mu=px_["rate"], theta=px_["r"]) protein_dist = NegativeBinomialMixture( mu1=py_["rate_back"], mu2=py_["rate_fore"], theta1=py_["r"], mixture_logits=py_["mixing"], ) rna_sample = rna_dist.sample().cpu() protein_sample = protein_dist.sample().cpu() return rna_sample, protein_sample
def get_reconstruction_loss( self, x: torch.Tensor, y: torch.Tensor, px_dict: Dict[str, torch.Tensor], py_dict: Dict[str, torch.Tensor], pro_batch_mask_minibatch: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute reconstruction loss.""" px_ = px_dict py_ = py_dict # Reconstruction Loss if self.gene_likelihood == "zinb": reconst_loss_gene = ( -ZeroInflatedNegativeBinomial( mu=px_["rate"], theta=px_["r"], zi_logits=px_["dropout"] ) .log_prob(x) .sum(dim=-1) ) else: reconst_loss_gene = ( -NegativeBinomial(mu=px_["rate"], theta=px_["r"]) .log_prob(x) .sum(dim=-1) ) py_conditional = NegativeBinomialMixture( mu1=py_["rate_back"], mu2=py_["rate_fore"], theta1=py_["r"], mixture_logits=py_["mixing"], ) reconst_loss_protein_full = -py_conditional.log_prob(y) if pro_batch_mask_minibatch is not None: temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full) temp_pro_loss_full.masked_scatter_( pro_batch_mask_minibatch.bool(), reconst_loss_protein_full ) reconst_loss_protein = temp_pro_loss_full.sum(dim=-1) else: reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1) return reconst_loss_gene, reconst_loss_protein
def reconstruction_loss( self, x: torch.Tensor, px_rate: torch.Tensor, px_r: torch.Tensor, px_dropout: torch.Tensor, mode: int, ) -> torch.Tensor: reconstruction_loss = None if self.gene_likelihoods[mode] == "zinb": reconstruction_loss = (-ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1)) elif self.gene_likelihoods[mode] == "nb": reconstruction_loss = (-NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)) elif self.gene_likelihoods[mode] == "poisson": reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1) return reconstruction_loss
def test_zinb_distribution(): theta = 100.0 + torch.rand(size=(2, )) mu = 15.0 * torch.ones_like(theta) pi = torch.randn_like(theta) x = torch.randint_like(mu, high=20) log_p_ref = log_zinb_positive(x, mu, theta, pi) dist = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi) log_p_zinb = dist.log_prob(x) assert (log_p_ref - log_p_zinb).abs().max().item() <= 1e-8 torch.manual_seed(0) s1 = dist.sample((100, )) assert s1.shape == (100, 2) s2 = dist.sample(sample_shape=(4, 3)) assert s2.shape == (4, 3, 2) log_p_ref = log_nb_positive(x, mu, theta) dist = NegativeBinomial(mu=mu, theta=theta) log_p_nb = dist.log_prob(x) assert (log_p_ref - log_p_nb).abs().max().item() <= 1e-8 s1 = dist.sample((1000, )) assert s1.shape == (1000, 2) assert (s1.mean(0) - mu).abs().mean() <= 1e0 assert (s1.std(0) - (mu + mu * mu / theta)**0.5).abs().mean() <= 1e0 size = (50, 3) theta = 100.0 + torch.rand(size=size) mu = 15.0 * torch.ones_like(theta) pi = torch.randn_like(theta) x = torch.randint_like(mu, high=20) dist1 = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi, validate_args=True) dist2 = NegativeBinomial(mu=mu, theta=theta, validate_args=True) assert dist1.log_prob(x).shape == size assert dist2.log_prob(x).shape == size with pytest.raises(ValueError): ZeroInflatedNegativeBinomial(mu=-mu, theta=theta, zi_logits=pi, validate_args=True) with pytest.warns(UserWarning): dist1.log_prob(-x) # ensures neg values raise warning with pytest.warns(UserWarning): dist2.log_prob(0.5 * x) # ensures float values raise warning
def generative(self, x, size_factor, design_matrix=None): # x has shape (n, g) delta = torch.exp(self.delta_log) # (g, c) theta_log = F.log_softmax(self.theta_logit, dim=-1) # (c) # compute mean of NegBin - shape (n_cells, n_genes, n_labels) n_cells = size_factor.shape[0] base_mean = torch.log(size_factor) # (n, 1) base_mean = base_mean.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels) # (n, g, c) # compute beta (covariate coefficent) # design_matrix has shape (n,p) if design_matrix is not None: covariates = torch.einsum("np,gp->gn", design_matrix, self.beta) # (g, n) covariates = torch.transpose(covariates, 0, 1).unsqueeze(-1) # (n, g, 1) covariates = covariates.expand(n_cells, self.n_genes, self.n_labels) base_mean = base_mean + covariates # base gene expression b_g_0 = self.b_g_0.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels) delta_rho = delta * self.rho delta_rho = delta_rho.expand(n_cells, self.n_genes, self.n_labels) # (n, g, c) log_mu_ngc = base_mean + delta_rho + b_g_0 mu_ngc = torch.exp(log_mu_ngc) # (n, g, c) # compute phi of NegBin - shape (n_cells, n_genes, n_labels) a = torch.exp(self.log_a) # (B) a = a.expand(n_cells, self.n_genes, self.n_labels, B) b_init = 2 * ((self.basis_means[1] - self.basis_means[0])**2) b = torch.exp(torch.ones(B, device=x.device) * (-torch.log(b_init))) # (B) b = b.expand(n_cells, self.n_genes, self.n_labels, B) mu_ngcb = mu_ngc.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels, B) # (n, g, c, B) basis_means = self.basis_means.expand(n_cells, self.n_genes, self.n_labels, B) # (n, g, c, B) phi = ( # (n, g, c) torch.sum(a * torch.exp(-b * torch.square(mu_ngcb - basis_means)), 3) + LOWER_BOUND) # compute gamma nb_pdf = NegativeBinomial(mu=mu_ngc, theta=phi) x_ = x.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels) x_log_prob_raw = nb_pdf.log_prob(x_) # (n, g, c) theta_log = theta_log.expand(n_cells, self.n_labels) p_x_c = torch.sum(x_log_prob_raw, 1) + theta_log # (n, c) normalizer_over_c = torch.logsumexp(p_x_c, 1) normalizer_over_c = normalizer_over_c.unsqueeze(-1).expand( n_cells, self.n_labels) gamma = torch.exp(p_x_c - normalizer_over_c) # (n, c) return dict( mu=mu_ngc, phi=phi, gamma=gamma, p_x_c=p_x_c, s=size_factor, )