class VAE(AbstractVAE): """ Variational auto-encoder model. This is an implementation of the scVI model descibed in [Lopez18]_ Parameters ---------- n_input Number of input genes n_batch Number of batches, if 0, no batch correction is performed. n_labels Number of labels n_hidden Number of nodes per hidden layer n_latent Dimensionality of the latent space n_layers Number of hidden layers used for encoder and decoder NNs n_continuous_cov Number of continuous covarites n_cats_per_cov Number of categories for each extra categorical covariate dropout_rate Dropout rate for neural networks dispersion One of the following * ``'gene'`` - dispersion parameter of NB is constant per gene across cells * ``'gene-batch'`` - dispersion can differ between different batches * ``'gene-label'`` - dispersion can differ between different labels * ``'gene-cell'`` - dispersion can differ for every gene in every cell log_variational Log(data+1) prior to encoding for numerical stability. Not normalization. gene_likelihood One of * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution * ``'poisson'`` - Poisson distribution latent_distribution One of * ``'normal'`` - Isotropic normal * ``'ln'`` - Logistic normal with normal params N(0, 1) encode_covariates Whether to concatenate covariates to expression in encoder deeply_inject_covariates Whether to concatenate covariates into output of hidden layers in encoder/decoder. This option only applies when `n_layers` > 1. The covariates are concatenated to the input of subsequent hidden layers. use_layer_norm Whether to use layer norm in layers use_observed_lib_size Use observed library size for RNA as scaling factor in mean of conditional distribution """ def __init__( self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, gene_likelihood: str = "zinb", latent_distribution: str = "normal", encode_covariates: bool = False, deeply_inject_covariates: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_observed_lib_size: bool = True, ): super().__init__() self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.gene_likelihood = gene_likelihood # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels self.latent_distribution = latent_distribution self.encode_covariates = encode_covariates self.use_observed_lib_size = use_observed_lib_size if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) elif self.dispersion == "gene-cell": pass else: raise ValueError( "dispersion must be one of ['gene', 'gene-batch'," " 'gene-label', 'gene-cell'], but input was " "{}.format(self.dispersion)" ) use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation n_input_encoder = n_input + n_continuous_cov * encode_covariates cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov) encoder_cat_list = cat_list if encode_covariates else None self.z_encoder = Encoder( n_input_encoder, n_latent, n_cat_list=encoder_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, distribution=latent_distribution, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, ) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder( n_input_encoder, 1, n_layers=1, n_cat_list=encoder_cat_list, n_hidden=n_hidden, dropout_rate=dropout_rate, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, ) # decoder goes from n_latent-dimensional space to n_input-d data n_input_decoder = n_latent + n_continuous_cov self.decoder = DecoderSCVI( n_input_decoder, n_input, n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_decoder, use_layer_norm=use_layer_norm_decoder, ) def _get_inference_input(self, tensors): x = tensors[_CONSTANTS.X_KEY] batch_index = tensors[_CONSTANTS.BATCH_KEY] cont_key = _CONSTANTS.CONT_COVS_KEY cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None cat_key = _CONSTANTS.CAT_COVS_KEY cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None input_dict = dict( x=x, batch_index=batch_index, cont_covs=cont_covs, cat_covs=cat_covs ) return input_dict def _get_generative_input(self, tensors, inference_outputs): z = inference_outputs["z"] library = inference_outputs["library"] batch_index = tensors[_CONSTANTS.BATCH_KEY] y = tensors[_CONSTANTS.LABELS_KEY] cont_key = _CONSTANTS.CONT_COVS_KEY cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None cat_key = _CONSTANTS.CAT_COVS_KEY cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None input_dict = { "z": z, "library": library, "batch_index": batch_index, "y": y, "cont_covs": cont_covs, "cat_covs": cat_covs, } return input_dict @auto_move_data def inference(self, x, batch_index, cont_covs=None, cat_covs=None, n_samples=1): """ High level inference method. Runs the inference (encoder) model. """ x_ = x if self.use_observed_lib_size: library = torch.log(x.sum(1)).unsqueeze(1) if self.log_variational: x_ = torch.log(1 + x_) if cont_covs is not None and self.encode_covariates is True: encoder_input = torch.cat((x_, cont_covs), dim=-1) else: encoder_input = x_ if cat_covs is not None and self.encode_covariates is True: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() qz_m, qz_v, z = self.z_encoder(encoder_input, batch_index, *categorical_input) ql_m, ql_v, library_encoded = self.l_encoder( encoder_input, batch_index, *categorical_input ) if not self.use_observed_lib_size: library = library_encoded if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) # when z is normal, untran_z == z untran_z = Normal(qz_m, qz_v.sqrt()).sample() z = self.z_encoder.z_transformation(untran_z) ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1))) ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1))) if self.use_observed_lib_size: library = library.unsqueeze(0).expand( (n_samples, library.size(0), library.size(1)) ) else: library = Normal(ql_m, ql_v.sqrt()).sample() outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, ql_m=ql_m, ql_v=ql_v, library=library) return outputs @auto_move_data def generative( self, z, library, batch_index, cont_covs=None, cat_covs=None, y=None ): """Runs the generative model.""" # TODO: refactor forward function to not rely on y decoder_input = z if cont_covs is None else torch.cat([z, cont_covs], dim=-1) if cat_covs is not None: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, decoder_input, library, batch_index, *categorical_input, y ) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r ) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) return dict( px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout ) def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, ): x = tensors[_CONSTANTS.X_KEY] local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( dim=1 ) if not self.use_observed_lib_size: kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: kl_divergence_l = 0.0 reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) kl_local_for_warmup = kl_divergence_l kl_local_no_warmup = kl_divergence_z weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup loss = torch.mean(reconst_loss + weighted_kl_local) kl_local = dict( kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z ) kl_global = 0.0 return SCVILoss(loss, reconst_loss, kl_local, kl_global) @torch.no_grad() def sample( self, tensors, n_samples=1, library_size=1, ) -> np.ndarray: r""" Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. Parameters ---------- tensors Tensors dict n_samples Number of required samples for each cell library_size Library size to scale scamples to Returns ------- x_new : :py:class:`torch.Tensor` tensor with shape (n_cells, n_genes, n_samples) """ inference_kwargs = dict(n_samples=n_samples) inference_outputs, generative_outputs, = self.forward( tensors, inference_kwargs=inference_kwargs, compute_loss=False, ) px_r = generative_outputs["px_r"] px_rate = generative_outputs["px_rate"] px_dropout = generative_outputs["px_dropout"] if self.gene_likelihood == "poisson": l_train = px_rate l_train = torch.clamp(l_train, max=1e8) dist = torch.distributions.Poisson( l_train ) # Shape : (n_samples, n_cells_batch, n_genes) elif self.gene_likelihood == "nb": dist = NegativeBinomial(mu=px_rate, theta=px_r) elif self.gene_likelihood == "zinb": dist = ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout ) else: raise ValueError( "{} reconstruction error not handled right now".format( self.model.gene_likelihood ) ) if n_samples > 1: exprs = dist.sample().permute( [1, 2, 0] ) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() return exprs.cpu() def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout) -> torch.Tensor: if self.gene_likelihood == "zinb": reconst_loss = ( -ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout ) .log_prob(x) .sum(dim=-1) ) elif self.gene_likelihood == "nb": reconst_loss = ( -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1) ) elif self.gene_likelihood == "poisson": reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1) return reconst_loss @torch.no_grad() def marginal_ll(self, tensors, n_mc_samples): sample_batch = tensors[_CONSTANTS.X_KEY] local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples) for i in range(n_mc_samples): # Distribution parameters and sampled variables inference_outputs, generative_outputs, losses = self.forward(tensors) qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] z = inference_outputs["z"] ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] library = inference_outputs["library"] # Reconstruction Loss reconst_loss = losses.reconstruction_loss # Log-probabilities p_l = Normal(local_l_mean, local_l_var.sqrt()).log_prob(library).sum(dim=-1) p_z = ( Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) .log_prob(z) .sum(dim=-1) ) p_x_zl = -reconst_loss q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) to_sum[:, i] = p_z + p_l + p_x_zl - q_z_x - q_l_x batch_log_lkl = logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) log_lkl = torch.sum(batch_log_lkl).item() return log_lkl
class PEAKVAE(BaseModuleClass): """ Variational auto-encoder model for ATAC-seq data. This is an implementation of the peakVI model descibed in. Parameters ---------- n_input_regions Number of input regions. n_batch Number of batches, if 0, no batch correction is performed. n_hidden Number of nodes per hidden layer. If `None`, defaults to square root of number of regions. n_latent Dimensionality of the latent space. If `None`, defaults to square root of `n_hidden`. n_layers_encoder Number of hidden layers used for encoder NN. n_layers_decoder Number of hidden layers used for decoder NN. dropout_rate Dropout rate for neural networks model_depth Model library size factors or not. region_factors Include region-specific factors in the model use_batch_norm One of the following * ``'encoder'`` - use batch normalization in the encoder only * ``'decoder'`` - use batch normalization in the decoder only * ``'none'`` - do not use batch normalization (default) * ``'both'`` - use batch normalization in both the encoder and decoder use_layer_norm One of the following * ``'encoder'`` - use layer normalization in the encoder only * ``'decoder'`` - use layer normalization in the decoder only * ``'none'`` - do not use layer normalization * ``'both'`` - use layer normalization in both the encoder and decoder (default) latent_distribution which latent distribution to use, options are * ``'normal'`` - Normal distribution (default) * ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax) deeply_inject_covariates Whether to deeply inject covariates into all layers of the decoder. If False (default), covairates will only be included in the input layer. """ def __init__( self, n_input_regions: int, n_batch: int = 0, n_hidden: Optional[int] = None, n_latent: Optional[int] = None, n_layers_encoder: int = 2, n_layers_decoder: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, model_depth: bool = True, region_factors: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", latent_distribution: str = "normal", deeply_inject_covariates: bool = False, encode_covariates: bool = False, ): super().__init__() self.n_input_regions = n_input_regions self.n_hidden = (int(np.sqrt(self.n_input_regions)) if n_hidden is None else n_hidden) self.n_latent = int(np.sqrt( self.n_hidden)) if n_latent is None else n_latent self.n_layers_encoder = n_layers_encoder self.n_layers_decoder = n_layers_decoder self.n_cats_per_cov = n_cats_per_cov self.n_continuous_cov = n_continuous_cov self.model_depth = model_depth self.dropout_rate = dropout_rate self.latent_distribution = latent_distribution self.use_batch_norm_encoder = use_batch_norm in ("encoder", "both") self.use_batch_norm_decoder = use_batch_norm in ("decoder", "both") self.use_layer_norm_encoder = use_layer_norm in ("encoder", "both") self.use_layer_norm_decoder = use_layer_norm in ("decoder", "both") self.deeply_inject_covariates = deeply_inject_covariates self.encode_covariates = encode_covariates cat_list = ([n_batch] + list(n_cats_per_cov) if n_cats_per_cov is not None else []) n_input_encoder = self.n_input_regions + n_continuous_cov * encode_covariates encoder_cat_list = cat_list if encode_covariates else None self.z_encoder = Encoder( n_input=n_input_encoder, n_layers=self.n_layers_encoder, n_output=self.n_latent, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, dropout_rate=self.dropout_rate, activation_fn=torch.nn.LeakyReLU, distribution=self.latent_distribution, var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, ) self.z_decoder = Decoder( n_input=self.n_latent + self.n_continuous_cov, n_output=n_input_regions, n_hidden=self.n_hidden, n_cat_list=cat_list, n_layers=self.n_layers_decoder, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, deep_inject_covariates=self.deeply_inject_covariates, ) self.d_encoder = None if self.model_depth: # Decoder class to avoid variational split self.d_encoder = Decoder( n_input=n_input_encoder, n_output=1, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, n_layers=self.n_layers_encoder, ) self.region_factors = None if region_factors: self.region_factors = torch.nn.Parameter( torch.zeros(self.n_input_regions)) def _get_inference_input(self, tensors): x = tensors[_CONSTANTS.X_KEY] batch_index = tensors[_CONSTANTS.BATCH_KEY] cont_covs = tensors.get(_CONSTANTS.CONT_COVS_KEY) cat_covs = tensors.get(_CONSTANTS.CAT_COVS_KEY) input_dict = dict( x=x, batch_index=batch_index, cont_covs=cont_covs, cat_covs=cat_covs, ) return input_dict def _get_generative_input(self, tensors, inference_outputs, transform_batch=None): z = inference_outputs["z"] qz_m = inference_outputs["qz_m"] batch_index = tensors[_CONSTANTS.BATCH_KEY] cont_covs = tensors.get(_CONSTANTS.CONT_COVS_KEY) cat_covs = tensors.get(_CONSTANTS.CAT_COVS_KEY) if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch input_dict = { "z": z, "qz_m": qz_m, "batch_index": batch_index, "cont_covs": cont_covs, "cat_covs": cat_covs, } return input_dict def get_reconstruction_loss(self, p, d, f, x): rl = torch.nn.BCELoss(reduction="none")(p * d * f, (x > 0).float()).sum(dim=-1) return rl @auto_move_data def inference( self, x, batch_index, cont_covs, cat_covs, n_samples=1, ) -> Dict[str, torch.Tensor]: """Helper function used in forward pass.""" if cat_covs is not None and self.encode_covariates is True: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() if cont_covs is not None and self.encode_covariates is True: encoder_input = torch.cat([x, cont_covs], dim=-1) else: encoder_input = x # if encode_covariates is False, cat_list to init encoder is None, so # batch_index is not used (or categorical_input, but it's empty) qz_m, qz_v, z = self.z_encoder(encoder_input, batch_index, *categorical_input) d = (self.d_encoder(encoder_input, batch_index, *categorical_input) if self.model_depth else 1) if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand( (n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand( (n_samples, qz_v.size(0), qz_v.size(1))) # when z is normal, untran_z == z untran_z = Normal(qz_m, qz_v.sqrt()).sample() z = self.z_encoder.z_transformation(untran_z) return dict(d=d, qz_m=qz_m, qz_v=qz_v, z=z) @auto_move_data def generative( self, z, qz_m, batch_index, cont_covs=None, cat_covs=None, use_z_mean=False, ): """Runs the generative model.""" if cat_covs is not None: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() latent = z if not use_z_mean else qz_m decoder_input = (latent if cont_covs is None else torch.cat( [latent, cont_covs], dim=-1)) p = self.z_decoder(decoder_input, batch_index, *categorical_input) return dict(p=p) def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): x = tensors[_CONSTANTS.X_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] d = inference_outputs["d"] p = generative_outputs["p"] kld = kl_divergence( Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1), ).sum(dim=1) f = torch.sigmoid( self.region_factors) if self.region_factors is not None else 1 rl = self.get_reconstruction_loss(p, d, f, x) loss = (rl.sum() + kld * kl_weight).sum() return LossRecorder(loss, rl, kld, kl_global=0.0)