def sample( self, tensors, n_samples=1, library_size=1, ) -> np.ndarray: r""" Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. Parameters ---------- tensors Tensors dict n_samples Number of required samples for each cell library_size Library size to scale scamples to Returns ------- x_new : :py:class:`torch.Tensor` tensor with shape (n_cells, n_genes, n_samples) """ inference_kwargs = dict(n_samples=n_samples) inference_outputs, generative_outputs, = self.forward( tensors, inference_kwargs=inference_kwargs, compute_loss=False, ) px_r = generative_outputs["px_r"] px_rate = generative_outputs["px_rate"] px_dropout = generative_outputs["px_dropout"] if self.gene_likelihood == "poisson": l_train = px_rate l_train = torch.clamp(l_train, max=1e8) dist = torch.distributions.Poisson( l_train) # Shape : (n_samples, n_cells_batch, n_genes) elif self.gene_likelihood == "nb": dist = NegativeBinomial(mu=px_rate, theta=px_r) elif self.gene_likelihood == "zinb": dist = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) else: raise ValueError( "{} reconstruction error not handled right now".format( self.module.gene_likelihood)) if n_samples > 1: exprs = dist.sample().permute( [1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() return exprs.cpu()
def sample( self, tensors, n_samples=1, ) -> np.ndarray: r""" Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. Parameters ---------- tensors Tensors dict n_samples Number of required samples for each cell Returns ------- x_new : :py:class:`torch.Tensor` tensor with shape (n_cells, n_genes, n_samples) """ inference_kwargs = dict(n_samples=n_samples) generative_outputs = self.forward( tensors, inference_kwargs=inference_kwargs, compute_loss=False, )[1] px_r = generative_outputs["px_r"] px_rate = generative_outputs["px_rate"] dist = NegativeBinomial(px_rate, logits=px_r) if n_samples > 1: exprs = dist.sample().permute( [1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() return exprs.cpu()
def sample(self, tensors, n_samples=1): inference_kwargs = dict(n_samples=n_samples) with torch.no_grad(): inference_outputs, generative_outputs, = self.forward( tensors, inference_kwargs=inference_kwargs, compute_loss=False, ) px_ = generative_outputs["px_"] py_ = generative_outputs["py_"] rna_dist = NegativeBinomial(mu=px_["rate"], theta=px_["r"]) protein_dist = NegativeBinomialMixture( mu1=py_["rate_back"], mu2=py_["rate_fore"], theta1=py_["r"], mixture_logits=py_["mixing"], ) rna_sample = rna_dist.sample().cpu() protein_sample = protein_dist.sample().cpu() return rna_sample, protein_sample