class CVAE(VAE): r"""A conditional variational autoencoder model, Args: :n_input: Number of input genes. :n_batch: Default: ``0``. :n_labels: Default: ``0``. :n_hidden: Number of hidden. Default: ``128``. :n_latent: Default: ``1``. :n_layers: Number of layers. Default: ``1``. :dropout_rate: Default: ``0.1``. :dispersion: Default: ``"gene"``. :log_variational: Default: ``True``. :reconstruction_loss: Default: ``"zinb"``. :y_prior: Default: None, but will be initialized to uniform probability over the cell types if not specified Examples: >>> gene_dataset = CortexDataset() >>> vaec = VAEC(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, ... n_labels=gene_dataset.n_labels) >>> gene_dataset = SyntheticDataset(n_labels=3) >>> vaec = VAEC(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, ... n_labels=3, y_prior=torch.tensor([[0.1,0.5,0.4]])) """ def __init__(self, n_input, n_batch, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, y_prior=None, dispersion="gene", log_variational=True, reconstruction_loss="zinb"): super(CVAE, self).__init__(n_input, n_batch, n_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, log_variational=log_variational, reconstruction_loss=reconstruction_loss) self.z_encoder = Encoder(n_input, n_latent, n_cat_list=[n_batch, n_labels], n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch, n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) self.y_prior = torch.nn.Parameter( y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False ) def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1): qz_m, qz_v, z = self.z_encoder(torch.log(1 + x), batch_index, y) qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) z = Normal(qz_m, qz_v).sample() px = self.decoder.px_decoder(z, batch_index, y) # y only used in VAEC - won't work for batch index not None px_scale = self.decoder.px_scale_decoder(px) return px_scale def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): # Prepare for sampling x_ = torch.log(1 + x) ql_m, ql_v, library = self.l_encoder(x_) # Enumerate choices of label ys, xs, library_s, batch_index_s = ( broadcast_labels( y, x, library, batch_index, n_broadcast=self.n_labels ) ) if self.log_variational: xs_ = torch.log(1 + xs) # Sampling qz_m, qz_v, zs = self.z_encoder(xs_, batch_index_s, ys) px_scale, px_r, px_rate, px_dropout = self.decoder(self.dispersion, zs, library_s, batch_index_s, ys) reconst_loss = self._reconstruction_loss(xs, px_rate, px_r, px_dropout, batch_index_s, ys) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) return reconst_loss, kl_divergence_z + kl_divergence_l
class VAEC(nn.Module): def __init__(self, n_input, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion="gene", log_variational=True, reconstruction_loss="zinb", n_batch=0, y_prior=None, use_cuda=False): super(VAEC, self).__init__() self.dispersion = dispersion self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically desactivate if useless self.n_batch = 0 if n_batch == 1 else n_batch self.n_labels = 0 if n_labels == 1 else n_labels if self.n_labels == 0: raise ValueError("VAEC is only implemented for > 1 label dataset") if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input, )) self.z_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, n_cat=n_labels) self.l_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=1, n_layers=1, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, n_batch=n_batch, n_labels=n_labels) self.y_prior = y_prior if y_prior is not None else ( 1 / n_labels) * torch.ones(n_labels) self.classifier = Classifier(n_input, n_hidden, n_labels, n_layers=n_layers, dropout_rate=dropout_rate) self.use_cuda = use_cuda and torch.cuda.is_available() if self.use_cuda: self.cuda() self.y_prior = self.y_prior.cuda() def classify(self, x): x = torch.log(1 + x) return self.classifier(x) def sample_from_posterior_z(self, x, y): x = torch.log(1 + x) # Here we compute as little as possible to have q(z|x) qz_m, qz_v, z = self.z_encoder(x, y) return z def sample_from_posterior_l(self, x): x = torch.log(1 + x) # Here we compute as little as possible to have q(z|x) ql_m, ql_v, library = self.l_encoder(x) return library def get_sample_scale(self, x, y=None, batch_index=None): x = torch.log(1 + x) z = self.sample_from_posterior_z(x, y) px = self.decoder.px_decoder(z, batch_index, y) px_scale = self.decoder.px_scale_decoder(px) return px_scale def get_sample_rate(self, x, y=None, batch_index=None): x = torch.log(1 + x) z = self.sample_from_posterior_z(x, y) library = self.sample_from_posterior_l(x) px = self.decoder.px_decoder(z, batch_index, y) return self.decoder.px_scale_decoder(px) * torch.exp(library) def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): is_labelled = False if y is None else True # Prepare for sampling xs, ys = (x, y) # Enumerate choices of label if not is_labelled: ys = enumerate_discrete(xs, self.n_labels) xs = xs.repeat(self.n_labels, 1) if batch_index is not None: batch_index = batch_index.repeat(self.n_labels, 1) local_l_var = local_l_var.repeat(self.n_labels, 1) local_l_mean = local_l_mean.repeat(self.n_labels, 1) else: ys = one_hot(ys, self.n_labels) xs_ = xs if self.log_variational: xs_ = torch.log(1 + xs_) # Sampling qz_m, qz_v, z = self.z_encoder(xs_, ys) ql_m, ql_v, library = self.l_encoder(xs_) if self.dispersion == "gene-cell": px_scale, self.px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index, y=ys) elif self.dispersion == "gene": px_scale, px_rate, px_dropout = self.decoder(self.dispersion, z, library, batch_index, y=ys) # Reconstruction Loss if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(xs, px_rate, torch.exp( self.px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(xs, px_rate, torch.exp(self.px_r)) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l if is_labelled: return reconst_loss, kl_divergence reconst_loss = reconst_loss.view(self.n_labels, -1) kl_divergence = kl_divergence.view(self.n_labels, -1) if self.log_variational: x_ = torch.log(1 + x) probs = self.classifier(x_) reconst_loss = (reconst_loss.t() * probs).sum(dim=1) kl_divergence = (kl_divergence.t() * probs).sum(dim=1) kl_divergence += kl(Multinomial(probs=probs), Multinomial(probs=self.y_prior)) return reconst_loss, kl_divergence
class VAEF(VAE): r"""Variational auto-encoder model. Args: :n_input: Number of input genes for scRNA-seq data. :n_input_fish: Number of input genes for smFISH data :n_batch: Default: ``0``. :n_labels: Default: ``0``. :n_hidden: Number of hidden. Default: ``128``. :n_latent: Default: ``1``. :n_layers: Number of layers. Default: ``1``. :dropout_rate: Default: ``0.1``. :dispersion: Default: ``"gene"``. :log_variational: Default: ``True``. :reconstruction_loss: Default: ``"zinb"``. :reconstruction_loss_fish: Default: ``"poisson"``. Examples: >>> gene_dataset_seq = CortexDataset() >>> gene_dataset_fish = SmfishDataset() >>> vae = VAE(gene_dataset_seq.nb_genes, gene_dataset_fish.nb_genes, ... n_labels=gene_dataset.n_labels, use_cuda=True ) """ def __init__(self, n_input, indexes_fish_train=None, n_batch=0, n_labels=0, n_hidden=128, n_latent=10, n_layers=1, n_layers_decoder=1, dropout_rate=0.3, dispersion="gene", log_variational=True, reconstruction_loss="zinb", reconstruction_loss_fish="poisson", model_library=False): super(VAEF, self).__init__(n_input, dispersion=dispersion, n_latent=n_hidden, n_hidden=n_hidden, log_variational=log_variational, dropout_rate=dropout_rate, n_layers=1, reconstruction_loss=reconstruction_loss, n_batch=n_batch, n_labels=n_labels) self.n_input = n_input self.n_input_fish = len(indexes_fish_train) self.indexes_to_keep = indexes_fish_train self.reconstruction_loss_fish = reconstruction_loss_fish self.model_library = model_library self.n_latent = n_latent # First layer of the encoder isn't shared self.z_encoder_fish = Encoder(self.n_input_fish, n_hidden, n_hidden=n_hidden, n_layers=1, dropout_rate=dropout_rate) # The last layers of the encoder are shared self.z_final_encoder = Encoder(n_hidden, n_latent, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate) self.l_encoder_fish = Encoder(self.n_input_fish, 1, n_hidden=n_hidden, n_layers=1, dropout_rate=dropout_rate) self.l_encoder = Encoder(n_input, 1, n_hidden=n_hidden, n_layers=1, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_layers=n_layers_decoder, n_hidden=n_hidden, n_cat_list=[n_batch], dropout_rate=dropout_rate) self.classifier = Classifier(n_latent, n_labels=n_labels, n_hidden=128, n_layers=3) def get_latents(self, x, y=None): return [self.sample_from_posterior_z(x, y)] def sample_from_posterior_z(self, x, y=None, mode="scRNA"): x = torch.log(1 + x) # First layer isn't shared if mode == "scRNA": z, _, _ = self.z_encoder(x) elif mode == "smFISH": z, _, _ = self.z_encoder_fish(x[:, self.indexes_to_keep]) # The last layers of the encoder are shared qz_m, qz_v, z = self.z_final_encoder(z) if not self.training: z = qz_m return z def sample_from_posterior_l(self, x, mode="scRNA"): x = torch.log(1 + x) if mode == "scRNA": ql_m, ql_v, library = self.l_encoder(x) elif mode == "smFISH": ql_m, ql_v, library = self.l_encoder_fish(x) return library def get_sample_scale(self, x, mode="scRNA", batch_index=None, y=None): z = self.sample_from_posterior_z(x, y, mode) # y only used in VAEC px = self.decoder.px_decoder(z, batch_index, y) # y only used in VAEC px_scale = self.decoder.px_scale_decoder(px) return px_scale def get_sample_rate(self, x, y=None, mode="scRNA"): if mode == "scRNA": library = torch.log(torch.sum(x, dim=1)).view(-1, 1) batch_index = torch.zeros_like(library) else: library = torch.log(torch.sum(x[:, self.indexes_to_keep], dim=1)).view(-1, 1) batch_index = torch.ones_like(library) if self.model_library: library = self.sample_from_posterior_l(x, mode=mode) px_scale = self.get_sample_scale(x, batch_index=batch_index, y=y) return px_scale * torch.exp(library) def get_sample_rate_fish(self, x, y=None): library = torch.log(torch.sum(x[:, self.indexes_to_keep], dim=1)).view(-1, 1) batch_index = torch.ones_like(library) if self.model_library: library = self.sample_from_posterior_l(x, mode="smFISH") px_scale = self.get_sample_scale(x, mode="smFISH", batch_index=batch_index, y=y) px_scale = px_scale[:, self.indexes_to_keep] / torch.sum( px_scale[:, self.indexes_to_keep], dim=1).view(-1, 1) return px_scale * torch.exp(library) def classify(self, x, mode="scRNA"): z = self.sample_from_posterior_z(x, mode) return self.classifier(z) def _reconstruction_loss(self, x, px_rate, px_r, px_dropout, batch_index, y, mode="scRNA", weighting=1): if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r # Reconstruction Loss if mode == "scRNA": if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r)) else: if self.reconstruction_loss_fish == 'poisson': reconst_loss = -torch.sum(Poisson(px_rate).log_prob(x), dim=1) elif self.reconstruction_loss_fish == 'gaussian': reconst_loss = -torch.sum(Normal(px_rate, 10).log_prob(x), dim=1) return reconst_loss def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None, mode="scRNA", weighting=1): x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling if mode == "scRNA": qz_m, qz_v, z = self.z_encoder(x_) library = torch.log(torch.sum(x, dim=1)).view(-1, 1) batch_index = torch.zeros_like(library) if mode == "smFISH": qz_m, qz_v, z = self.z_encoder_fish(x_[:, self.indexes_to_keep]) library = torch.log(torch.sum(x[:, self.indexes_to_keep], dim=1)).view(-1, 1) batch_index = torch.ones_like(library) if self.model_library: if mode == "scRNA": ql_m, ql_v, library = self.l_encoder(x_) elif mode == "smFISH": ql_m, ql_v, library = self.l_encoder_fish( x_[:, self.indexes_to_keep]) qz_m, qz_v, z = self.z_final_encoder(z) px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) # rescaling the expected frequencies if mode == "smFISH": if self.model_library: px_rate = px_scale[:, self.indexes_to_keep] * torch.exp(library) reconst_loss = self._reconstruction_loss( x[:, self.indexes_to_keep], px_rate, px_r, px_dropout, batch_index, y, mode) else: px_scale = px_scale[:, self.indexes_to_keep] / torch.sum( px_scale[:, self.indexes_to_keep], dim=1).view(-1, 1) px_rate = px_scale * torch.exp(library) reconst_loss = self._reconstruction_loss( x[:, self.indexes_to_keep], px_rate, px_r, px_dropout, batch_index, y, mode) else: reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout, batch_index, y, mode, weighting) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) if self.model_library: kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l else: kl_divergence = kl_divergence_z return reconst_loss, kl_divergence
class VAE(nn.Module): r"""Variational auto-encoder model. Args: :n_input: Number of input genes. :n_batch: Default: ``0``. :n_labels: Default: ``0``. :n_hidden: Number of hidden. Default: ``128``. :n_latent: Default: ``1``. :n_layers: Number of layers. Default: ``1``. :dropout_rate: Default: ``0.1``. :dispersion: Default: ``"gene"``. :log_variational: Default: ``True``. :reconstruction_loss: Default: ``"zinb"``. Examples: >>> gene_dataset = CortexDataset() >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, ... n_labels=gene_dataset.n_labels, use_cuda=True ) """ def __init__(self, n_input, n_batch=0, n_labels=0, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion="gene", log_variational=True, reconstruction_loss="zinb"): super(VAE, self).__init__() self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically desactivate if useless self.n_batch = n_batch self.n_labels = n_labels self.n_latent_layers = 1 if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input, )) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) else: # gene-cell pass self.z_encoder = Encoder(n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) self.l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) def get_latents(self, x, y=None): return [self.sample_from_posterior_z(x, y)] def sample_from_posterior_z(self, x, y=None): x = torch.log(1 + x) qz_m, qz_v, z = self.z_encoder(x, y) # y only used in VAEC if not self.training: z = qz_m return z def sample_from_posterior_l(self, x): x = torch.log(1 + x) ql_m, ql_v, library = self.l_encoder(x) if not self.training: library = ql_m return library def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1): qz_m, qz_v, z = self.z_encoder(torch.log(1 + x), y) qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) z = Normal(qz_m, qz_v).sample() px = self.decoder.px_decoder(z, batch_index, y) # y only used in VAEC - won't work for batch index not None px_scale = self.decoder.px_scale_decoder(px) return px_scale def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1): ql_m, ql_v, library = self.l_encoder(torch.log(1 + x), y) ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1))) ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1))) library = Normal(ql_m, ql_v).sample() px_scale = self.get_sample_scale(x, batch_index=batch_index, y=y, n_samples=n_samples) return px_scale * torch.exp(library) def _reconstruction_loss(self, x, px_rate, px_r, px_dropout, batch_index, y): if self.dispersion == "gene-label": px_r = F.linear(one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r # Reconstruction Loss if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r)) return reconst_loss def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): # Parameters for z latent distribution x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling qz_m, qz_v, z = self.z_encoder(x_) ql_m, ql_v, library = self.l_encoder(x_) px_scale, px_r, px_rate, px_dropout = self.decoder(self.dispersion, z, library, batch_index) reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout, batch_index, y) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l return reconst_loss, kl_divergence
class VAEF(VAE): r"""Variational auto-encoder model. Args: :n_input: Number of input genes for scRNA-seq data. :n_input_fish: Number of input genes for smFISH data :n_batch: Default: ``0``. :n_labels: Default: ``0``. :n_hidden: Number of hidden. Default: ``128``. :n_latent: Default: ``1``. :n_layers: Number of layers. Default: ``1``. :dropout_rate: Default: ``0.1``. :dispersion: Default: ``"gene"``. :log_variational: Default: ``True``. :reconstruction_loss: Default: ``"zinb"``. :reconstruction_loss_fish: Default: ``"poisson"``. Examples: >>> gene_dataset_seq = CortexDataset() >>> gene_dataset_fish = SmfishDataset() >>> vae = VAE(gene_dataset_seq.nb_genes, gene_dataset_fish.nb_genes, ... n_labels=gene_dataset.n_labels, use_cuda=True ) """ def __init__(self, n_input, indexes_fish_train=None, n_batch=0, n_labels=0, n_hidden=128, n_latent=10, n_layers=1, n_layers_decoder=1, dropout_rate=0.3, dispersion="gene", log_variational=True, reconstruction_loss="zinb", reconstruction_loss_fish="poisson", model_library=False): super().__init__(n_input, dispersion=dispersion, n_latent=n_hidden, n_hidden=n_hidden, log_variational=log_variational, dropout_rate=dropout_rate, n_layers=1, reconstruction_loss=reconstruction_loss, n_batch=n_batch, n_labels=n_labels) self.n_input = n_input self.n_input_fish = len(indexes_fish_train) self.indexes_to_keep = indexes_fish_train self.reconstruction_loss_fish = reconstruction_loss_fish self.model_library = model_library self.n_latent = n_latent # First layer of the encoder isn't shared self.z_encoder_fish = Encoder(self.n_input_fish, n_hidden, n_hidden=n_hidden, n_layers=1, dropout_rate=dropout_rate) # The last layers of the encoder are shared self.z_final_encoder = Encoder(n_hidden, n_latent, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate) self.l_encoder_fish = Encoder(self.n_input_fish, 1, n_hidden=n_hidden, n_layers=1, dropout_rate=dropout_rate) self.l_encoder = Encoder(n_input, 1, n_hidden=n_hidden, n_layers=1, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_layers=n_layers_decoder, n_hidden=n_hidden, n_cat_list=[n_batch]) self.classifier = Classifier(n_latent, n_labels=n_labels, n_hidden=128, n_layers=3) def get_latents(self, x, y=None): r""" returns the result of ``sample_from_posterior_z`` inside a list :param x: tensor of values with shape ``(batch_size, n_input)`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :return: one element list of tensor :rtype: list of :py:class:`torch.Tensor` """ return [self.sample_from_posterior_z(x, y)] def sample_from_posterior_z(self, x, y=None, mode="scRNA"): r""" samples the tensor of latent values from the posterior #doesn't really sample, returns the mean of the posterior distribution :param x: tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :param mode: string that indicates the type of data we analyse :return: tensor of shape ``(batch_size, n_latent)`` :rtype: :py:class:`torch.Tensor` """ x = torch.log(1 + x) # First layer isn't shared if mode == "scRNA": z, _, _ = self.z_encoder(x) elif mode == "smFISH": z, _, _ = self.z_encoder_fish(x[:, self.indexes_to_keep]) # The last layers of the encoder are shared qz_m, qz_v, z = self.z_final_encoder(z) if not self.training: z = qz_m return z def sample_from_posterior_l(self, x, mode="scRNA"): r""" samples the tensor of library sizes from the posterior #doesn't really sample, returns the tensor of the means of the posterior distribution :param x: tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :param mode: string that indicates the type of data we analyse :return: tensor of shape ``(batch_size, 1)`` :rtype: :py:class:`torch.Tensor` """ x = torch.log(1 + x) if mode == "scRNA": ql_m, ql_v, library = self.l_encoder(x) elif mode == "smFISH": ql_m, ql_v, library = self.l_encoder_fish(x) return library def get_sample_scale(self, x, mode="scRNA", batch_index=None, y=None): r"""Returns the tensor of predicted frequencies of expression :param x: tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode :param mode: string that indicates the type of data we analyse :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :return: tensor of predicted frequencies of expression with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode :rtype: :py:class:`torch.Tensor` """ z = self.sample_from_posterior_z(x, y, mode) # y only used in VAEC px = self.decoder.px_decoder(z, batch_index, y) # y only used in VAEC px_scale = self.decoder.px_scale_decoder(px) return px_scale def get_sample_rate(self, x, y=None, mode="scRNA"): r"""Returns the tensor of means of the negative binomial distribution :param x: tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :param mode: string that indicates the type of data we analyse :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode :rtype: :py:class:`torch.Tensor` """ if mode == "scRNA": library = torch.log(torch.sum(x, dim=1)).view(-1, 1) batch_index = torch.zeros_like(library) else: library = torch.log(torch.sum(x[:, self.indexes_to_keep], dim=1)).view(-1, 1) batch_index = torch.ones_like(library) if self.model_library: library = self.sample_from_posterior_l(x, mode=mode) px_scale = self.get_sample_scale(x, batch_index=batch_index, y=y) return px_scale * torch.exp(library) def get_sample_rate_fish(self, x, y=None): r"""Returns the tensor of means of the negative binomial distribution :param x: tensor of values with shape ``(batch_size, n_input_fish)`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input_fish)`` :rtype: :py:class:`torch.Tensor` """ library = torch.log(torch.sum(x[:, self.indexes_to_keep], dim=1)).view(-1, 1) batch_index = torch.ones_like(library) if self.model_library: library = self.sample_from_posterior_l(x, mode="smFISH") px_scale = self.get_sample_scale(x, mode="smFISH", batch_index=batch_index, y=y) px_scale = px_scale[:, self.indexes_to_keep] / torch.sum( px_scale[:, self.indexes_to_keep], dim=1).view(-1, 1) return px_scale * torch.exp(library) def classify(self, x, mode="scRNA"): r"""Classifies the cells based on their latent representation #for each cell, it gives the probability distribution over the different labels :param x: tensor of values with shape (batch_size, n_input) or ``(batch_size, n_input_fish)`` depending on the mode :param mode: string that indicates the type of data we analyse :return: tensor of probabilities with shape``(batch_size, n_labels)`` :rtype: :py:class:`torch.Tensor` """ z = self.sample_from_posterior_z(x, mode) return self.classifier(z) def _reconstruction_loss(self, x, px_rate, px_r, px_dropout, batch_index, y, mode="scRNA", weighting=1): if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r # Reconstruction Loss if mode == "scRNA": if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r)) else: if self.reconstruction_loss_fish == 'poisson': reconst_loss = -torch.sum(Poisson(px_rate).log_prob(x), dim=1) elif self.reconstruction_loss_fish == 'gaussian': reconst_loss = -torch.sum(Normal(px_rate, 10).log_prob(x), dim=1) return reconst_loss def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None, mode="scRNA", weighting=1): r""" Returns the reconstruction loss and the Kullback divergences :param x: tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode :param local_l_mean: tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) :param local_l_var: tensor of variances of the prior distribution of latent variable l with shape (batch_size, 1) :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape (batch_size, n_labels) :param mode: string that indicates the type of data we analyse :param weighting: used in none of these methods :return: the reconstruction loss and the Kullback divergences :rtype: 2-tuple of :py:class:`torch.FloatTensor` """ x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling if mode == "scRNA": qz_m, qz_v, z = self.z_encoder(x_) library = torch.log(torch.sum(x, dim=1)).view(-1, 1) batch_index = torch.zeros_like(library) if mode == "smFISH": qz_m, qz_v, z = self.z_encoder_fish(x_[:, self.indexes_to_keep]) library = torch.log(torch.sum(x[:, self.indexes_to_keep], dim=1)).view(-1, 1) batch_index = torch.ones_like(library) if self.model_library: if mode == "scRNA": ql_m, ql_v, library = self.l_encoder(x_) elif mode == "smFISH": ql_m, ql_v, library = self.l_encoder_fish( x_[:, self.indexes_to_keep]) qz_m, qz_v, z = self.z_final_encoder(z) px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) # rescaling the expected frequencies if mode == "smFISH": if self.model_library: px_rate = px_scale[:, self.indexes_to_keep] * torch.exp(library) reconst_loss = self._reconstruction_loss( x[:, self.indexes_to_keep], px_rate, px_r, px_dropout, batch_index, y, mode) else: px_scale = px_scale[:, self.indexes_to_keep] / torch.sum( px_scale[:, self.indexes_to_keep], dim=1).view(-1, 1) px_rate = px_scale * torch.exp(library) reconst_loss = self._reconstruction_loss( x[:, self.indexes_to_keep], px_rate, px_r, px_dropout, batch_index, y, mode) else: reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout, batch_index, y, mode, weighting) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) if self.model_library: kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l else: kl_divergence = kl_divergence_z return reconst_loss, kl_divergence
class VAE(nn.Module): r"""Variational auto-encoder model. :param n_input: Number of input genes :param n_batch: Number of batches :param n_labels: Number of labels :param n_hidden: Number of nodes per hidden layer :param n_latent: Dimensionality of the latent space :param n_layers: Number of hidden layers used for encoder and decoder NNs :param dropout_rate: Dropout rate for neural networks :param dispersion: One of the following * ``'gene'`` - dispersion parameter of NB is constant per gene across cells * ``'gene-batch'`` - dispersion can differ between different batches * ``'gene-label'`` - dispersion can differ between different labels * ``'gene-cell'`` - dispersion can differ for every gene in every cell :param log_variational: Log variational distribution :param reconstruction_loss: One of * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution Examples: >>> gene_dataset = CortexDataset() >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, ... n_labels=gene_dataset.n_labels, use_cuda=True ) """ def __init__(self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, reconstruction_loss: str = "zinb"): super(VAE, self).__init__() self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels self.n_latent_layers = 1 # not sure what this is for, no usages? if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input, )) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) else: # gene-cell pass # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.z_encoder = Encoder(n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) def get_latents(self, x, y=None): r""" returns the result of ``sample_from_posterior_z`` inside a list :param x: tensor of values with shape ``(batch_size, n_input)`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :return: one element list of tensor :rtype: list of :py:class:`torch.Tensor` """ return [self.sample_from_posterior_z(x, y)] def sample_from_posterior_z(self, x, y=None): r""" samples the tensor of latent values from the posterior #doesn't really sample, returns the means of the posterior distribution :param x: tensor of values with shape ``(batch_size, n_input)`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :return: tensor of shape ``(batch_size, n_latent)`` :rtype: :py:class:`torch.Tensor` """ x = torch.log(1 + x) qz_m, qz_v, z = self.z_encoder(x, y) # y only used in VAEC if not self.training: z = qz_m return z def sample_from_posterior_l(self, x): r""" samples the tensor of library sizes from the posterior #doesn't really sample, returns the tensor of the means of the posterior distribution :param x: tensor of values with shape ``(batch_size, n_input)`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :return: tensor of shape ``(batch_size, 1)`` :rtype: :py:class:`torch.Tensor` """ x = torch.log(1 + x) ql_m, ql_v, library = self.l_encoder(x) if not self.training: library = ql_m return library def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1): r"""Returns the tensor of predicted frequencies of expression :param x: tensor of values with shape ``(batch_size, n_input)`` :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :param n_samples: number of samples :return: tensor of predicted frequencies of expression with shape ``(batch_size, n_input)`` :rtype: :py:class:`torch.Tensor` """ qz_m, qz_v, z = self.z_encoder(torch.log(1 + x), y) qz_m = qz_m.unsqueeze(0).expand( (n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand( (n_samples, qz_v.size(0), qz_v.size(1))) z = Normal(qz_m, qz_v).sample() px = self.decoder.px_decoder( z, batch_index, y) # y only used in VAEC - won't work for batch index not None px_scale = self.decoder.px_scale_decoder(px) return px_scale def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1): r"""Returns the tensor of means of the negative binomial distribution :param x: tensor of values with shape ``(batch_size, n_input)`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param n_samples: number of samples :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)`` :rtype: :py:class:`torch.Tensor` """ ql_m, ql_v, library = self.l_encoder(torch.log(1 + x), y) ql_m = ql_m.unsqueeze(0).expand( (n_samples, ql_m.size(0), ql_m.size(1))) ql_v = ql_v.unsqueeze(0).expand( (n_samples, ql_v.size(0), ql_v.size(1))) library = Normal(ql_m, ql_v).sample() px_scale = self.get_sample_scale(x, batch_index=batch_index, y=y, n_samples=n_samples) return px_scale * torch.exp(library) def _reconstruction_loss(self, x, px_rate, px_r, px_dropout, batch_index, y): if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r # Reconstruction Loss if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r)) return reconst_loss def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): r""" Returns the reconstruction loss and the Kullback divergences :param x: tensor of values with shape (batch_size, n_input) :param local_l_mean: tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) :param local_l_var: tensor of variancess of the prior distribution of latent variable l with shape (batch_size, 1) :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape (batch_size, n_labels) :return: the reconstruction loss and the Kullback divergences :rtype: 2-tuple of :py:class:`torch.FloatTensor` """ # Parameters for z latent distribution x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling qz_m, qz_v, z = self.z_encoder(x_) ql_m, ql_v, library = self.l_encoder(x_) px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout, batch_index, y) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l return reconst_loss, kl_divergence
class SVAEC(nn.Module): ''' "Stacked" variational autoencoder for classification - SVAEC (from the stacked generative model M1 + M2) ''' def __init__(self, n_input, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, n_batch=0, y_prior=None, use_cuda=False): super(SVAEC, self).__init__() self.n_labels = n_labels self.n_input = n_input self.y_prior = y_prior if y_prior is not None else ( 1 / self.n_labels) * torch.ones(self.n_labels) # Automatically desactivate if useless self.n_batch = 0 if n_batch == 1 else n_batch self.z_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate) self.l_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=1, n_layers=1, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, n_batch=n_batch) self.dispersion = 'gene' self.px_r = torch.nn.Parameter(torch.randn(n_input, )) # Classifier takes n_latent as input self.classifier = Classifier(n_latent, n_hidden, self.n_labels, n_layers, dropout_rate) self.encoder_z2_z1 = Encoder(n_input=n_latent, n_cat=self.n_labels, n_latent=n_latent, n_layers=n_layers) self.decoder_z1_z2 = Decoder(n_latent, n_latent, n_cat=self.n_labels, n_layers=n_layers) self.use_cuda = use_cuda and torch.cuda.is_available() if self.use_cuda: self.cuda() self.y_prior = self.y_prior.cuda() def classify(self, x): x_ = torch.log(1 + x) qz_m, _, z = self.z_encoder(x_) if self.training: return self.classifier(z) else: return self.classifier(qz_m) def sample_from_posterior_z(self, x, y=None): x = torch.log(1 + x) # Here we compute as little as possible to have q(z|x) qz_m, qz_v, z = self.z_encoder(x) return z def sample_from_posterior_l(self, x): x = torch.log(1 + x) # Here we compute as little as possible to have q(z|x) ql_m, ql_v, library = self.l_encoder(x) return library def get_sample_scale(self, x, y=None, batch_index=None): x = torch.log(1 + x) z = self.sample_from_posterior_z(x, y) px = self.decoder.px_decoder(z, batch_index, y) px_scale = self.decoder.px_scale_decoder(px) return px_scale def get_sample_rate(self, x, y=None, batch_index=None): x = torch.log(1 + x) z = self.sample_from_posterior_z(x) library = self.sample_from_posterior_l(x) px = self.decoder.px_decoder(z, batch_index, y) return self.decoder.px_scale_decoder(px) * torch.exp(library) def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): is_labelled = False if y is None else True xs, ys = (x, y) xs_ = torch.log(1 + xs) qz1_m, qz1_v, z1_ = self.z_encoder(xs_) z1 = z1_ # Enumerate choices of label if not is_labelled: ys = enumerate_discrete(xs, self.n_labels) xs = xs.repeat(self.n_labels, 1) if batch_index is not None: batch_index = batch_index.repeat(self.n_labels, 1) local_l_var = local_l_var.repeat(self.n_labels, 1) local_l_mean = local_l_mean.repeat(self.n_labels, 1) qz1_m = qz1_m.repeat(self.n_labels, 1) qz1_v = qz1_v.repeat(self.n_labels, 1) z1 = z1.repeat(self.n_labels, 1) else: ys = one_hot(ys, self.n_labels) xs_ = torch.log(1 + xs) qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1, ys) pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) # Sampling ql_m, ql_v, library = self.l_encoder(xs_) # let's keep that ind. of y px_scale, px_rate, px_dropout = self.decoder(self.dispersion, z1, library, batch_index) reconst_loss = -log_zinb_positive(xs, px_rate, torch.exp(self.px_r), px_dropout) # KL Divergence mean = torch.zeros_like(qz2_m) scale = torch.ones_like(qz2_v) kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale)).sum(dim=1) loss_z1 = (-Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1) + Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z2 + loss_z1 + kl_divergence_l if is_labelled: return reconst_loss, kl_divergence reconst_loss = reconst_loss.view(self.n_labels, -1) kl_divergence = kl_divergence.view(self.n_labels, -1) probs = self.classifier(z1_) reconst_loss = (reconst_loss.t() * probs).sum(dim=1) kl_divergence = (kl_divergence.t() * probs).sum(dim=1) kl_divergence += kl(Multinomial(probs=probs), Multinomial(probs=self.y_prior)) return reconst_loss, kl_divergence
class VAE(nn.Module): def __init__(self, n_input, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion="gene", log_variational=True, reconstruction_loss="zinb", n_batch=0, n_labels=0, use_cuda=False): super(VAE, self).__init__() self.dispersion = dispersion self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically desactivate if useless self.n_batch = 0 if n_batch == 1 else n_batch self.n_labels = n_labels if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input, )) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) else: # gene-cell pass self.z_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate) self.l_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=1, n_layers=1, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, n_batch=n_batch) self.use_cuda = use_cuda and torch.cuda.is_available() if self.use_cuda: self.cuda() def sample_from_posterior_z(self, x, y=None): x = torch.log(1 + x) # Here we compute as little as possible to have q(z|x) qz_m, qz_v, z = self.z_encoder(x) return z def sample_from_posterior_l(self, x): x = torch.log(1 + x) # Here we compute as little as possible to have q(z|x) ql_m, ql_v, library = self.l_encoder(x) return library def get_sample_scale(self, x, y=None, batch_index=None): x = torch.log(1 + x) z = self.sample_from_posterior_z(x) px = self.decoder.px_decoder(z, batch_index) px_scale = self.decoder.px_scale_decoder(px) return px_scale def get_sample_rate(self, x, y=None, batch_index=None): x = torch.log(1 + x) z = self.sample_from_posterior_z(x) library = self.sample_from_posterior_l(x) px = self.decoder.px_decoder(z, batch_index) return self.decoder.px_scale_decoder(px) * torch.exp(library) def sample(self, z): return self.px_scale_decoder(z) def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): # same signature as loss # Parameters for z latent distribution x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling qz_m, qz_v, z = self.z_encoder(x_) ql_m, ql_v, library = self.l_encoder(x_) if self.dispersion == "gene-cell": px_scale, self.px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) else: # self.dispersion == "gene", "gene-batch", "gene-label" px_scale, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) else: px_r = self.px_r # Reconstruction Loss if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r)) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l return reconst_loss, kl_divergence