def __init__( self, n_input: int, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 5, n_layers: int = 2, dropout_rate: float = 0.1, log_variational: bool = True, ct_weight: np.ndarray = None, **module_kwargs, ): super().__init__() self.dispersion = "gene" self.n_latent = n_latent self.n_layers = n_layers self.n_hidden = n_hidden self.log_variational = log_variational self.gene_likelihood = "nb" self.latent_distribution = "normal" # Automatically deactivate if useless self.n_batch = 0 self.n_labels = n_labels # gene dispersion self.px_r = torch.nn.Parameter(torch.randn(n_input)) # z encoder goes from the n_input-dimensional data to an n_latent-d self.z_encoder = Encoder( n_input, n_latent, n_cat_list=[n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, inject_covariates=True, use_batch_norm=False, use_layer_norm=True, ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = FCLayers( n_in=n_latent, n_out=n_hidden, n_cat_list=[n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0, inject_covariates=True, use_batch_norm=False, use_layer_norm=True, ) self.px_decoder = torch.nn.Sequential( torch.nn.Linear(n_hidden, n_input), torch.nn.Softplus()) if ct_weight is not None: ct_weight = torch.tensor(ct_weight, dtype=torch.float32) else: ct_weight = torch.ones((self.n_labels, ), dtype=torch.float32) self.register_buffer("ct_weight", ct_weight)
def __init__( self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers_encoder: int = 1, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, gene_likelihood: str = "nb", use_batch_norm: bool = True, bias: bool = False, latent_distribution: str = "normal", **vae_kwargs, ): super().__init__( n_input=n_input, n_batch=n_batch, n_labels=n_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers_encoder, dropout_rate=dropout_rate, dispersion=dispersion, log_variational=log_variational, gene_likelihood=gene_likelihood, latent_distribution=latent_distribution, use_observed_lib_size=False, **vae_kwargs, ) self.use_batch_norm = use_batch_norm self.z_encoder = Encoder( n_input, n_latent, n_layers=n_layers_encoder, n_hidden=n_hidden, dropout_rate=dropout_rate, distribution=latent_distribution, use_batch_norm=True, use_layer_norm=False, ) self.l_encoder = Encoder( n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, use_batch_norm=True, use_layer_norm=False, ) self.decoder = LinearDecoderSCVI( n_latent, n_input, n_cat_list=[n_batch], use_batch_norm=use_batch_norm, use_layer_norm=False, bias=bias, )
def __init__( self, n_input: int, gmv_mask: np.ndarray, n_batch: int = 0, n_hidden: int = 600, n_layers: int = 3, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, z_dropout: float = 0, gene_likelihood: str = "zinb", encode_covariates: bool = False ): super().__init__(n_input=n_input) self.mask = gmv_mask self.n_genes = gmv_mask.shape[0] self.n_gmvs = gmv_mask.shape[1] self.n_batch = n_batch self.gene_likelihood = gene_likelihood n_input_encoder = self.n_genes + n_continuous_cov * encode_covariates cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov) encoder_cat_list = cat_list if encode_covariates else None self.latent_distribution = "normal" # Model parameters (gene dispersions) self.px_r = torch.nn.Parameter(torch.randn(self.n_genes)) # GMV activity encoder self.z_encoder = Encoder( n_input_encoder, self.n_gmvs, n_cat_list = [n_batch] if encode_covariates else None, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, use_batch_norm=True ) # Add dropout to mean and var of z_encoder self.z_encoder.mean_encoder = torch.nn.Sequential(self.z_encoder.mean_encoder, torch.nn.Dropout(p=z_dropout)) self.z_encoder.var_encoder = torch.nn.Sequential(self.z_encoder.var_encoder, torch.nn.Dropout(p=z_dropout)) # Scaling factor encoder self.l_encoder = Encoder( n_input_encoder, 1, n_cat_list=[n_batch] if encode_covariates else None, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, use_batch_norm=True ) # Sparse decoder to decode GMV activities # TO DO: Add continuous covariates to dim self.decoder = DecoderVEGACount( self.mask.T, n_cat_list = cat_list, n_continuous_cov = n_continuous_cov, use_batch_norm = False, use_layer_norm = False, bias = False )
def __init__( self, n_input: int, n_batch: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, ): super().__init__() self.n_latent = n_latent self.n_batch = n_batch # this is needed to comply with some requirement of the VAEMixin class self.latent_distribution = "normal" # setup the parameters of your generative model, as well as your inference model # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.z_encoder = Encoder( n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = BinaryDecoder( n_latent, n_input, n_layers=n_layers, n_hidden=n_hidden, )
def __init__(self, n_input: int, n_latent: int, n_hidden: int, n_layers: int): super().__init__() self.n_input = n_input self.n_latent = n_latent self.epsilon = 5.0e-3 # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.encoder = Encoder( n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0.1, ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI( n_latent, n_input, n_layers=n_layers, n_hidden=n_hidden, ) # This gene-level parameter modulates the variance of the observation distribution self.px_r = torch.nn.Parameter(torch.ones(self.n_input))
def __init__( self, n_input, n_batch, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, y_prior=None, dispersion="gene", log_variational=True, gene_likelihood="zinb", ): super().__init__( n_input, n_batch, n_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, log_variational=log_variational, gene_likelihood=gene_likelihood, use_observed_lib_size=False, ) self.z_encoder = Encoder( n_input, n_latent, n_cat_list=[n_batch, n_labels], n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, ) self.decoder = DecoderSCVI( n_latent, n_input, n_cat_list=[n_batch, n_labels], n_layers=n_layers, n_hidden=n_hidden, ) self.y_prior = torch.nn.Parameter( y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False, ) self.classifier = Classifier(n_input, n_hidden, n_labels, n_layers=n_layers, dropout_rate=dropout_rate)
def __init__( self, n_input: int, n_batch: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, ): super().__init__() self.n_latent = n_latent self.n_batch = n_batch self.kl_factor = 1.0 print("USING KL FACTOR:", self.kl_factor) # this is needed to comply with some requirement of the VAEMixin class self.latent_distribution = "normal" # setup the parameters of your generative model, as well as your inference model self.px_r = torch.nn.Parameter(torch.randn(n_input)) # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.z_encoder = Encoder( n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder( n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = LinearDecoderSCVI(n_input=n_latent, n_output=n_input)
def __init__(self, n_input: int, n_topics: int, n_hidden: int): super().__init__(_AMORTIZED_LDA_PYRO_MODULE_NAME) self.n_input = n_input self.n_topics = n_topics self.n_hidden = n_hidden # Populated by PyroTrainingPlan. self.n_obs = None self.encoder = Encoder(n_input, n_topics, distribution="ln") ( topic_feature_posterior_mu, topic_feature_posterior_sigma, ) = logistic_normal_approximation(torch.ones(self.n_input)) self.topic_feature_posterior_mu = torch.nn.Parameter( topic_feature_posterior_mu.repeat(self.n_topics, 1)) self.unconstrained_topic_feature_posterior_sigma = torch.nn.Parameter( topic_feature_posterior_sigma.repeat(self.n_topics, 1))
def __init__( self, n_input: int, n_hidden: int = 800, n_latent: int = 10, n_layers: int = 2, dropout_rate: float = 0.1, log_variational: bool = False, latent_distribution: str = "normal", use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", kl_weight: float = 0.00005, ): super().__init__() self.n_layers = n_layers self.n_latent = n_latent self.log_variational = log_variational self.latent_distribution = "normal" self.kl_weight = kl_weight use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" self.z_encoder = Encoder( n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, distribution=latent_distribution, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, activation_fn=torch.nn.LeakyReLU, ) n_input_decoder = n_latent self.decoder = DecoderSCGEN( n_input_decoder, n_input, n_layers=n_layers, n_hidden=n_hidden, activation_fn=torch.nn.LeakyReLU, dropout_rate=dropout_rate, )
class MULTIVAE(BaseModuleClass): """ Variational auto-encoder model for joint paired + unpaired RNA-seq and ATAC-seq data. Parameters ---------- n_input_regions Number of input regions. n_input_genes Number of input genes. n_batch Number of batches, if 0, no batch correction is performed. n_labels Number of labels, if 0, all cells are assumed to have the same label gene_likelihood The distribution to use for gene expression data. One of the following * ``'zinb'`` - Zero-Inflated Negative Binomial * ``'nb'`` - Negative Binomial * ``'poisson'`` - Poisson n_hidden Number of nodes per hidden layer. If `None`, defaults to square root of number of regions. n_latent Dimensionality of the latent space. If `None`, defaults to square root of `n_hidden`. n_layers_encoder Number of hidden layers used for encoder NN. n_layers_decoder Number of hidden layers used for decoder NN. dropout_rate Dropout rate for neural networks region_factors Include region-specific factors in the model use_batch_norm One of the following * ``'encoder'`` - use batch normalization in the encoder only * ``'decoder'`` - use batch normalization in the decoder only * ``'none'`` - do not use batch normalization * ``'both'`` - use batch normalization in both the encoder and decoder use_layer_norm One of the following * ``'encoder'`` - use layer normalization in the encoder only * ``'decoder'`` - use layer normalization in the decoder only * ``'none'`` - do not use layer normalization * ``'both'`` - use layer normalization in both the encoder and decoder latent_distribution which latent distribution to use, options are * ``'normal'`` - Normal distribution * ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax) deeply_inject_covariates Whether to deeply inject covariates into all layers of the decoder. If False, covariates will only be included in the input layer. encode_covariates If True, include covariates in the input to the encoder. """ ## TODO: replace n_input_regions and n_input_genes with a gene/region mask (we don't dictate which comes forst or that they're even contiguous) def __init__( self, n_input_regions: int = 0, n_input_genes: int = 0, n_batch: int = 0, n_labels: int = 0, gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", n_hidden: Optional[int] = None, n_latent: Optional[int] = None, n_layers_encoder: int = 2, n_layers_decoder: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, region_factors: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", latent_distribution: str = "normal", deeply_inject_covariates: bool = False, encode_covariates: bool = False, ): super().__init__() # INIT PARAMS self.n_input_regions = n_input_regions self.n_input_genes = n_input_genes self.n_hidden = (int(np.sqrt(self.n_input_regions + self.n_input_genes)) if n_hidden is None else n_hidden) self.n_batch = n_batch self.n_labels = n_labels self.gene_likelihood = gene_likelihood self.latent_distribution = latent_distribution self.n_latent = int(np.sqrt( self.n_hidden)) if n_latent is None else n_latent self.n_layers_encoder = n_layers_encoder self.n_layers_decoder = n_layers_decoder self.n_cats_per_cov = n_cats_per_cov self.n_continuous_cov = n_continuous_cov self.dropout_rate = dropout_rate self.use_batch_norm_encoder = use_batch_norm in ("encoder", "both") self.use_batch_norm_decoder = use_batch_norm in ("decoder", "both") self.use_layer_norm_encoder = use_layer_norm in ("encoder", "both") self.use_layer_norm_decoder = use_layer_norm in ("decoder", "both") self.encode_covariates = encode_covariates self.deeply_inject_covariates = deeply_inject_covariates cat_list = ([n_batch] + list(n_cats_per_cov) if n_cats_per_cov is not None else []) n_input_encoder_acc = (self.n_input_regions + n_continuous_cov * encode_covariates) n_input_encoder_exp = self.n_input_genes + n_continuous_cov * encode_covariates encoder_cat_list = cat_list if encode_covariates else None ## accessibility encoder self.z_encoder_accessibility = Encoder( n_input=n_input_encoder_acc, n_layers=self.n_layers_encoder, n_output=self.n_latent, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, dropout_rate=self.dropout_rate, activation_fn=torch.nn.LeakyReLU, distribution=self.latent_distribution, var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, ) ## expression encoder self.z_encoder_expression = Encoder( n_input=n_input_encoder_exp, n_layers=self.n_layers_encoder, n_output=self.n_latent, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, dropout_rate=self.dropout_rate, activation_fn=torch.nn.LeakyReLU, distribution=self.latent_distribution, var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, ) # expression decoder self.z_decoder_expression = DecoderSCVI( self.n_latent + self.n_continuous_cov, n_input_genes, n_cat_list=cat_list, n_layers=n_layers_decoder, n_hidden=self.n_hidden, inject_covariates=self.deeply_inject_covariates, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, ) # accessibility decoder self.z_decoder_accessibility = DecoderPeakVI( n_input=self.n_latent + self.n_continuous_cov, n_output=n_input_regions, n_hidden=self.n_hidden, n_cat_list=cat_list, n_layers=self.n_layers_decoder, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, deep_inject_covariates=self.deeply_inject_covariates, ) ## accessibility region-specific factors self.region_factors = None if region_factors: self.region_factors = torch.nn.Parameter( torch.zeros(self.n_input_regions)) ## expression dispersion parameters self.px_r = torch.nn.Parameter(torch.randn(n_input_genes)) ## expression library size encoder self.l_encoder_expression = LibrarySizeEncoder( n_input_encoder_exp, n_cat_list=encoder_cat_list, n_layers=self.n_layers_encoder, n_hidden=self.n_hidden, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, deep_inject_covariates=self.deeply_inject_covariates, ) ## accessibility library size encoder self.l_encoder_accessibility = DecoderPeakVI( n_input=n_input_encoder_acc, n_output=1, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, n_layers=self.n_layers_encoder, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, deep_inject_covariates=self.deeply_inject_covariates, ) def _get_inference_input(self, tensors): x = tensors[_CONSTANTS.X_KEY] batch_index = tensors[_CONSTANTS.BATCH_KEY] cont_covs = tensors.get(_CONSTANTS.CONT_COVS_KEY) cat_covs = tensors.get(_CONSTANTS.CAT_COVS_KEY) input_dict = dict( x=x, batch_index=batch_index, cont_covs=cont_covs, cat_covs=cat_covs, ) return input_dict @auto_move_data def inference( self, x, batch_index, cont_covs, cat_covs, n_samples=1, ) -> Dict[str, torch.Tensor]: # Get Data and Additional Covs x_rna = x[:, :self.n_input_genes] x_chr = x[:, self.n_input_genes:] mask_expr = x_rna.sum(dim=1) > 0 mask_acc = x_chr.sum(dim=1) > 0 if cont_covs is not None and self.encode_covariates: encoder_input_expression = torch.cat((x_rna, cont_covs), dim=-1) encoder_input_accessibility = torch.cat((x_chr, cont_covs), dim=-1) else: encoder_input_expression = x_rna encoder_input_accessibility = x_chr if cat_covs is not None and self.encode_covariates: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() # Z Encoders qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility( encoder_input_accessibility, batch_index, *categorical_input) qzm_expr, qzv_expr, z_expr = self.z_encoder_expression( encoder_input_expression, batch_index, *categorical_input) # L encoders libsize_expr = self.l_encoder_expression(encoder_input_expression, batch_index, *categorical_input) libsize_acc = self.l_encoder_accessibility(encoder_input_accessibility, batch_index, *categorical_input) # ReFormat Outputs if n_samples > 1: qzm_acc = qzm_acc.unsqueeze(0).expand( (n_samples, qzm_acc.size(0), qzm_acc.size(1))) qzv_acc = qzv_acc.unsqueeze(0).expand( (n_samples, qzv_acc.size(0), qzv_acc.size(1))) untran_za = Normal(qzm_acc, qzv_acc.sqrt()).sample() z_acc = self.z_encoder_accessibility.z_transformation(untran_za) qzm_expr = qzm_expr.unsqueeze(0).expand( (n_samples, qzm_expr.size(0), qzm_expr.size(1))) qzv_expr = qzv_expr.unsqueeze(0).expand( (n_samples, qzv_expr.size(0), qzv_expr.size(1))) untran_zr = Normal(qzm_expr, qzv_expr.sqrt()).sample() z_expr = self.z_encoder_expression.z_transformation(untran_zr) libsize_expr = libsize_expr.unsqueeze(0).expand( (n_samples, libsize_expr.size(0), libsize_expr.size(1))) libsize_acc = libsize_acc.unsqueeze(0).expand( (n_samples, libsize_acc.size(0), libsize_acc.size(1))) ## Sample from the average distribution qzp_m = (qzm_acc + qzm_expr) / 2 qzp_v = (qzv_acc + qzv_expr) / (2**0.5) zp = Normal(qzp_m, qzp_v.sqrt()).rsample() ## choose the correct latent representation based on the modality qz_m = self._mix_modalities(qzp_m, qzm_expr, qzm_acc, mask_expr, mask_acc) qz_v = self._mix_modalities(qzp_v, qzv_expr, qzv_acc, mask_expr, mask_acc) z = self._mix_modalities(zp, z_expr, z_acc, mask_expr, mask_acc) outputs = dict( z=z, qz_m=qz_m, qz_v=qz_v, z_expr=z_expr, qzm_expr=qzm_expr, qzv_expr=qzv_expr, z_acc=z_acc, qzm_acc=qzm_acc, qzv_acc=qzv_acc, libsize_expr=libsize_expr, libsize_acc=libsize_acc, ) return outputs def _get_generative_input(self, tensors, inference_outputs, transform_batch=None): z = inference_outputs["z"] qz_m = inference_outputs["qz_m"] libsize_expr = inference_outputs["libsize_expr"] labels = tensors[_CONSTANTS.LABELS_KEY] batch_index = tensors[_CONSTANTS.BATCH_KEY] cont_key = _CONSTANTS.CONT_COVS_KEY cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None cat_key = _CONSTANTS.CAT_COVS_KEY cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch input_dict = dict( z=z, qz_m=qz_m, batch_index=batch_index, cont_covs=cont_covs, cat_covs=cat_covs, libsize_expr=libsize_expr, labels=labels, ) return input_dict @auto_move_data def generative( self, z, qz_m, batch_index, cont_covs=None, cat_covs=None, libsize_expr=None, labels=None, use_z_mean=False, ): """Runs the generative model.""" if cat_covs is not None: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() latent = z if not use_z_mean else qz_m decoder_input = (latent if cont_covs is None else torch.cat( [latent, cont_covs], dim=-1)) # Accessibility Decoder p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) # Expression Decoder px_scale, _, px_rate, px_dropout = self.z_decoder_expression( "gene", decoder_input, libsize_expr, batch_index, *categorical_input, labels) return dict( p=p, px_scale=px_scale, px_r=torch.exp(self.px_r), px_rate=px_rate, px_dropout=px_dropout, ) def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): # Get the data x = tensors[_CONSTANTS.X_KEY] x_rna = x[:, :self.n_input_genes] x_chr = x[:, self.n_input_genes:] mask_expr = x_rna.sum(dim=1) > 0 mask_acc = x_chr.sum(dim=1) > 0 # Compute Accessibility loss x_accessibility = x[:, self.n_input_genes:] p = generative_outputs["p"] libsize_acc = inference_outputs["libsize_acc"] rl_accessibility = self.get_reconstruction_loss_accessibility( x_accessibility, p, libsize_acc) # Compute Expression loss px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] x_expression = x[:, :self.n_input_genes] rl_expression = self.get_reconstruction_loss_expression( x_expression, px_rate, px_r, px_dropout) # mix losses to get the correct loss for each cell recon_loss = self._mix_modalities( rl_accessibility + rl_expression, # paired rl_expression, # expression rl_accessibility, # accessibility mask_expr, mask_acc, ) # Compute KLD between Z and N(0,I) qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] kl_div_z = kld( Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1), ).sum(dim=1) # Compute KLD between distributions for paired data qzm_expr = inference_outputs["qzm_expr"] qzv_expr = inference_outputs["qzv_expr"] qzm_acc = inference_outputs["qzm_acc"] qzv_acc = inference_outputs["qzv_acc"] kld_paired = kld(Normal(qzm_expr, torch.sqrt(qzv_expr)), Normal(qzm_acc, torch.sqrt(qzv_acc))) + kld( Normal(qzm_acc, torch.sqrt(qzv_acc)), Normal(qzm_expr, torch.sqrt(qzv_expr))) kld_paired = torch.where( torch.logical_and(mask_acc, mask_expr), kld_paired.T, torch.zeros_like(kld_paired).T, ).sum(dim=0) # KL WARMUP kl_local_for_warmup = kl_div_z weighted_kl_local = kl_weight * kl_local_for_warmup # PENALTY # distance_penalty = kl_weight * torch.pow(z_acc - z_expr, 2).sum(dim=1) # TOTAL LOSS loss = torch.mean(recon_loss + weighted_kl_local + kld_paired) kl_local = dict(kl_divergence_z=kl_div_z) kl_global = torch.tensor(0.0) return LossRecorder(loss, recon_loss, kl_local, kl_global) def get_reconstruction_loss_expression(self, x, px_rate, px_r, px_dropout): rl = 0.0 if self.gene_likelihood == "zinb": rl = (-ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1)) elif self.gene_likelihood == "nb": rl = -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1) elif self.gene_likelihood == "poisson": rl = -Poisson(px_rate).log_prob(x).sum(dim=-1) return rl def get_reconstruction_loss_accessibility(self, x, p, d): f = torch.sigmoid( self.region_factors) if self.region_factors is not None else 1 return torch.nn.BCELoss(reduction="none")(p * d * f, (x > 0).float()).sum(dim=-1) @staticmethod def _mix_modalities(x_paired, x_expr, x_acc, mask_expr, mask_acc): """ Mixes modality-specific vectors according to the modality masks. in positions where both `mask_expr` and `mask_acc` are True (corresponding to cell for which both expression and accessibility data is available), values from `x_paired` will be used. If only `mask_expr` is True, use values from `x_expr`, and if only `mask_acc` is True, use values from `x_acc`. Parameters ---------- x_paired the values for paired cells (both modalities available), will be used in positions where both `mask_expr` and `mask_acc` are True. x_expr the values for expression-only cells, will be used in positions where only `mask_expr` is True. x_acc the values for accessibility-only cells, will be used on positions where only `mask_acc` is True. mask_expr the expression mask, indicating which cells have expression data mask_acc the accessibility mask, indicating which cells have accessibility data """ x = torch.where(mask_expr.T, x_expr.T, x_acc.T).T x = torch.where(torch.logical_and(mask_acc, mask_expr), x_paired.T, x.T).T return x
def __init__( self, n_input_regions: int = 0, n_input_genes: int = 0, n_batch: int = 0, n_labels: int = 0, gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", n_hidden: Optional[int] = None, n_latent: Optional[int] = None, n_layers_encoder: int = 2, n_layers_decoder: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, region_factors: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", latent_distribution: str = "normal", deeply_inject_covariates: bool = False, encode_covariates: bool = False, ): super().__init__() # INIT PARAMS self.n_input_regions = n_input_regions self.n_input_genes = n_input_genes self.n_hidden = (int(np.sqrt(self.n_input_regions + self.n_input_genes)) if n_hidden is None else n_hidden) self.n_batch = n_batch self.n_labels = n_labels self.gene_likelihood = gene_likelihood self.latent_distribution = latent_distribution self.n_latent = int(np.sqrt( self.n_hidden)) if n_latent is None else n_latent self.n_layers_encoder = n_layers_encoder self.n_layers_decoder = n_layers_decoder self.n_cats_per_cov = n_cats_per_cov self.n_continuous_cov = n_continuous_cov self.dropout_rate = dropout_rate self.use_batch_norm_encoder = use_batch_norm in ("encoder", "both") self.use_batch_norm_decoder = use_batch_norm in ("decoder", "both") self.use_layer_norm_encoder = use_layer_norm in ("encoder", "both") self.use_layer_norm_decoder = use_layer_norm in ("decoder", "both") self.encode_covariates = encode_covariates self.deeply_inject_covariates = deeply_inject_covariates cat_list = ([n_batch] + list(n_cats_per_cov) if n_cats_per_cov is not None else []) n_input_encoder_acc = (self.n_input_regions + n_continuous_cov * encode_covariates) n_input_encoder_exp = self.n_input_genes + n_continuous_cov * encode_covariates encoder_cat_list = cat_list if encode_covariates else None ## accessibility encoder self.z_encoder_accessibility = Encoder( n_input=n_input_encoder_acc, n_layers=self.n_layers_encoder, n_output=self.n_latent, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, dropout_rate=self.dropout_rate, activation_fn=torch.nn.LeakyReLU, distribution=self.latent_distribution, var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, ) ## expression encoder self.z_encoder_expression = Encoder( n_input=n_input_encoder_exp, n_layers=self.n_layers_encoder, n_output=self.n_latent, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, dropout_rate=self.dropout_rate, activation_fn=torch.nn.LeakyReLU, distribution=self.latent_distribution, var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, ) # expression decoder self.z_decoder_expression = DecoderSCVI( self.n_latent + self.n_continuous_cov, n_input_genes, n_cat_list=cat_list, n_layers=n_layers_decoder, n_hidden=self.n_hidden, inject_covariates=self.deeply_inject_covariates, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, ) # accessibility decoder self.z_decoder_accessibility = DecoderPeakVI( n_input=self.n_latent + self.n_continuous_cov, n_output=n_input_regions, n_hidden=self.n_hidden, n_cat_list=cat_list, n_layers=self.n_layers_decoder, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, deep_inject_covariates=self.deeply_inject_covariates, ) ## accessibility region-specific factors self.region_factors = None if region_factors: self.region_factors = torch.nn.Parameter( torch.zeros(self.n_input_regions)) ## expression dispersion parameters self.px_r = torch.nn.Parameter(torch.randn(n_input_genes)) ## expression library size encoder self.l_encoder_expression = LibrarySizeEncoder( n_input_encoder_exp, n_cat_list=encoder_cat_list, n_layers=self.n_layers_encoder, n_hidden=self.n_hidden, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, deep_inject_covariates=self.deeply_inject_covariates, ) ## accessibility library size encoder self.l_encoder_accessibility = DecoderPeakVI( n_input=n_input_encoder_acc, n_output=1, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, n_layers=self.n_layers_encoder, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, deep_inject_covariates=self.deeply_inject_covariates, )
def __init__(self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, gene_likelihood: str = "zinb", y_prior=None, labels_groups: Sequence[int] = None, use_labels_groups: bool = False, classifier_parameters: dict = dict(), use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", **vae_kwargs): super().__init__(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, n_continuous_cov=n_continuous_cov, n_cats_per_cov=n_cats_per_cov, dropout_rate=dropout_rate, n_batch=n_batch, dispersion=dispersion, log_variational=log_variational, gene_likelihood=gene_likelihood, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm, **vae_kwargs) use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" self.n_labels = n_labels # Classifier takes n_latent as input cls_parameters = { "n_layers": n_layers, "n_hidden": n_hidden, "dropout_rate": dropout_rate, } cls_parameters.update(classifier_parameters) self.classifier = Classifier(n_latent, n_labels=n_labels, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, **cls_parameters) self.encoder_z2_z1 = Encoder( n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, ) self.decoder_z1_z2 = Decoder( n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers, n_hidden=n_hidden, use_batch_norm=use_batch_norm_decoder, use_layer_norm=use_layer_norm_decoder, ) self.y_prior = torch.nn.Parameter( y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False, ) self.use_labels_groups = use_labels_groups self.labels_groups = (np.array(labels_groups) if labels_groups is not None else None) if self.use_labels_groups: if labels_groups is None: raise ValueError("Specify label groups") unique_groups = np.unique(self.labels_groups) self.n_groups = len(unique_groups) if not (unique_groups == np.arange(self.n_groups)).all(): raise ValueError() self.classifier_groups = Classifier(n_latent, n_hidden, self.n_groups, n_layers, dropout_rate) self.groups_index = torch.nn.ParameterList([ torch.nn.Parameter( torch.tensor( (self.labels_groups == i).astype(np.uint8), dtype=torch.uint8, ), requires_grad=False, ) for i in range(self.n_groups) ])
class VAEC(BaseModuleClass): """ Conditional Variational auto-encoder model. This is an implementation of the CondSCVI model Parameters ---------- n_input Number of input genes n_labels Number of labels n_hidden Number of nodes per hidden layer n_latent Dimensionality of the latent space n_layers Number of hidden layers used for encoder and decoder NNs dropout_rate Dropout rate for the encoder neural network log_variational Log(data+1) prior to encoding for numerical stability. Not normalization. """ def __init__( self, n_input: int, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 5, n_layers: int = 2, dropout_rate: float = 0.1, log_variational: bool = True, ct_weight: np.ndarray = None, **module_kwargs, ): super().__init__() self.dispersion = "gene" self.n_latent = n_latent self.n_layers = n_layers self.n_hidden = n_hidden self.log_variational = log_variational self.gene_likelihood = "nb" self.latent_distribution = "normal" # Automatically deactivate if useless self.n_batch = 0 self.n_labels = n_labels # gene dispersion self.px_r = torch.nn.Parameter(torch.randn(n_input)) # z encoder goes from the n_input-dimensional data to an n_latent-d self.z_encoder = Encoder( n_input, n_latent, n_cat_list=[n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, inject_covariates=True, use_batch_norm=False, use_layer_norm=True, ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = FCLayers( n_in=n_latent, n_out=n_hidden, n_cat_list=[n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0, inject_covariates=True, use_batch_norm=False, use_layer_norm=True, ) self.px_decoder = torch.nn.Sequential( torch.nn.Linear(n_hidden, n_input), torch.nn.Softplus()) if ct_weight is not None: ct_weight = torch.tensor(ct_weight, dtype=torch.float32) else: ct_weight = torch.ones((self.n_labels, ), dtype=torch.float32) self.register_buffer("ct_weight", ct_weight) def _get_inference_input(self, tensors): x = tensors[REGISTRY_KEYS.X_KEY] y = tensors[REGISTRY_KEYS.LABELS_KEY] input_dict = dict( x=x, y=y, ) return input_dict def _get_generative_input(self, tensors, inference_outputs): z = inference_outputs["z"] library = inference_outputs["library"] y = tensors[REGISTRY_KEYS.LABELS_KEY] input_dict = { "z": z, "library": library, "y": y, } return input_dict @auto_move_data def inference(self, x, y, n_samples=1): """ High level inference method. Runs the inference (encoder) model. """ x_ = x library = torch.log(x.sum(1)).unsqueeze(1) if self.log_variational: x_ = torch.log(1 + x_) qz_m, qz_v, z = self.z_encoder(x_, y) if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand( (n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand( (n_samples, qz_v.size(0), qz_v.size(1))) # when z is normal, untran_z == z untran_z = Normal(qz_m, qz_v.sqrt()).sample() z = self.z_encoder.z_transformation(untran_z) library = library.unsqueeze(0).expand( (n_samples, library.size(0), library.size(1))) outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, library=library) return outputs @auto_move_data def generative(self, z, library, y): """Runs the generative model.""" h = self.decoder(z, y) px_scale = self.px_decoder(h) px_rate = library * px_scale return dict(px_scale=px_scale, px_r=self.px_r, px_rate=px_rate) def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, ): x = tensors[REGISTRY_KEYS.X_KEY] y = tensors[REGISTRY_KEYS.LABELS_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) reconst_loss = -NegativeBinomial(px_rate, logits=px_r).log_prob(x).sum(-1) scaling_factor = self.ct_weight[y.long()[:, 0]] loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) return LossRecorder(loss, reconst_loss, kl_divergence_z, torch.tensor(0.0)) @torch.no_grad() def sample( self, tensors, n_samples=1, ) -> np.ndarray: r""" Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. Parameters ---------- tensors Tensors dict n_samples Number of required samples for each cell Returns ------- x_new : :py:class:`torch.Tensor` tensor with shape (n_cells, n_genes, n_samples) """ inference_kwargs = dict(n_samples=n_samples) generative_outputs = self.forward( tensors, inference_kwargs=inference_kwargs, compute_loss=False, )[1] px_r = generative_outputs["px_r"] px_rate = generative_outputs["px_rate"] dist = NegativeBinomial(px_rate, logits=px_r) if n_samples > 1: exprs = dist.sample().permute( [1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() return exprs.cpu()
def __init__( self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, gene_likelihood: str = "zinb", latent_distribution: str = "normal", encode_covariates: bool = False, deeply_inject_covariates: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_size_factor_key: bool = False, use_observed_lib_size: bool = True, library_log_means: Optional[np.ndarray] = None, library_log_vars: Optional[np.ndarray] = None, var_activation: Optional[Callable] = None, ): super().__init__() self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.gene_likelihood = gene_likelihood # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels self.latent_distribution = latent_distribution self.encode_covariates = encode_covariates self.use_size_factor_key = use_size_factor_key self.use_observed_lib_size = use_size_factor_key or use_observed_lib_size if not self.use_observed_lib_size: if library_log_means is None or library_log_means is None: raise ValueError( "If not using observed_lib_size, " "must provide library_log_means and library_log_vars." ) self.register_buffer( "library_log_means", torch.from_numpy(library_log_means).float() ) self.register_buffer( "library_log_vars", torch.from_numpy(library_log_vars).float() ) if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) elif self.dispersion == "gene-cell": pass else: raise ValueError( "dispersion must be one of ['gene', 'gene-batch'," " 'gene-label', 'gene-cell'], but input was " "{}.format(self.dispersion)" ) use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation n_input_encoder = n_input + n_continuous_cov * encode_covariates cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov) encoder_cat_list = cat_list if encode_covariates else None self.z_encoder = Encoder( n_input_encoder, n_latent, n_cat_list=encoder_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, distribution=latent_distribution, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, var_activation=var_activation, ) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder( n_input_encoder, 1, n_layers=1, n_cat_list=encoder_cat_list, n_hidden=n_hidden, dropout_rate=dropout_rate, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, var_activation=var_activation, ) # decoder goes from n_latent-dimensional space to n_input-d data n_input_decoder = n_latent + n_continuous_cov self.decoder = DecoderSCVI( n_input_decoder, n_input, n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_decoder, use_layer_norm=use_layer_norm_decoder, scale_activation="softplus" if use_size_factor_key else "softmax", )
def __init__( self, dim_input_list: List[int], total_genes: int, indices_mappings: List[Union[np.ndarray, slice]], gene_likelihoods: List[str], model_library_bools: List[bool], library_log_means: List[Optional[np.ndarray]], library_log_vars: List[Optional[np.ndarray]], n_latent: int = 10, n_layers_encoder_individual: int = 1, n_layers_encoder_shared: int = 1, dim_hidden_encoder: int = 64, n_layers_decoder_individual: int = 0, n_layers_decoder_shared: int = 0, dim_hidden_decoder_individual: int = 64, dim_hidden_decoder_shared: int = 64, dropout_rate_encoder: float = 0.2, dropout_rate_decoder: float = 0.2, n_batch: int = 0, n_labels: int = 0, dispersion: str = "gene-batch", log_variational: bool = True, ): super().__init__() self.n_input_list = dim_input_list self.total_genes = total_genes self.indices_mappings = indices_mappings self.gene_likelihoods = gene_likelihoods self.model_library_bools = model_library_bools for mode in range(len(dim_input_list)): if self.model_library_bools[mode]: self.register_buffer( f"library_log_means_{mode}", torch.from_numpy(library_log_means[mode]).float(), ) self.register_buffer( f"library_log_vars_{mode}", torch.from_numpy(library_log_vars[mode]).float(), ) self.n_latent = n_latent self.n_batch = n_batch self.n_labels = n_labels self.dispersion = dispersion self.log_variational = log_variational self.z_encoder = MultiEncoder( n_heads=len(dim_input_list), n_input_list=dim_input_list, n_output=self.n_latent, n_hidden=dim_hidden_encoder, n_layers_individual=n_layers_encoder_individual, n_layers_shared=n_layers_encoder_shared, dropout_rate=dropout_rate_encoder, ) self.l_encoders = ModuleList( [ Encoder( self.n_input_list[i], 1, n_layers=1, dropout_rate=dropout_rate_encoder, ) if self.model_library_bools[i] else None for i in range(len(self.n_input_list)) ] ) self.decoder = MultiDecoder( self.n_latent, self.total_genes, n_hidden_conditioned=dim_hidden_decoder_individual, n_hidden_shared=dim_hidden_decoder_shared, n_layers_conditioned=n_layers_decoder_individual, n_layers_shared=n_layers_decoder_shared, n_cat_list=[self.n_batch], dropout_rate=dropout_rate_decoder, ) if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes, n_labels)) else: # gene-cell pass
class VAE(BaseModuleClass): """ Variational auto-encoder model. This is an implementation of the scVI model described in [Lopez18]_ Parameters ---------- n_input Number of input genes n_batch Number of batches, if 0, no batch correction is performed. n_labels Number of labels n_hidden Number of nodes per hidden layer n_latent Dimensionality of the latent space n_layers Number of hidden layers used for encoder and decoder NNs n_continuous_cov Number of continuous covarites n_cats_per_cov Number of categories for each extra categorical covariate dropout_rate Dropout rate for neural networks dispersion One of the following * ``'gene'`` - dispersion parameter of NB is constant per gene across cells * ``'gene-batch'`` - dispersion can differ between different batches * ``'gene-label'`` - dispersion can differ between different labels * ``'gene-cell'`` - dispersion can differ for every gene in every cell log_variational Log(data+1) prior to encoding for numerical stability. Not normalization. gene_likelihood One of * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution * ``'poisson'`` - Poisson distribution latent_distribution One of * ``'normal'`` - Isotropic normal * ``'ln'`` - Logistic normal with normal params N(0, 1) encode_covariates Whether to concatenate covariates to expression in encoder deeply_inject_covariates Whether to concatenate covariates into output of hidden layers in encoder/decoder. This option only applies when `n_layers` > 1. The covariates are concatenated to the input of subsequent hidden layers. use_layer_norm Whether to use layer norm in layers use_size_factor_key Use size_factor AnnDataField defined by the user as scaling factor in mean of conditional distribution. Takes priority over `use_observed_lib_size`. use_observed_lib_size Use observed library size for RNA as scaling factor in mean of conditional distribution library_log_means 1 x n_batch array of means of the log library sizes. Parameterizes prior on library size if not using observed library size. library_log_vars 1 x n_batch array of variances of the log library sizes. Parameterizes prior on library size if not using observed library size. var_activation Callable used to ensure positivity of the variational distributions' variance. When `None`, defaults to `torch.exp`. """ def __init__( self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, gene_likelihood: str = "zinb", latent_distribution: str = "normal", encode_covariates: bool = False, deeply_inject_covariates: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_size_factor_key: bool = False, use_observed_lib_size: bool = True, library_log_means: Optional[np.ndarray] = None, library_log_vars: Optional[np.ndarray] = None, var_activation: Optional[Callable] = None, ): super().__init__() self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.gene_likelihood = gene_likelihood # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels self.latent_distribution = latent_distribution self.encode_covariates = encode_covariates self.use_size_factor_key = use_size_factor_key self.use_observed_lib_size = use_size_factor_key or use_observed_lib_size if not self.use_observed_lib_size: if library_log_means is None or library_log_means is None: raise ValueError( "If not using observed_lib_size, " "must provide library_log_means and library_log_vars." ) self.register_buffer( "library_log_means", torch.from_numpy(library_log_means).float() ) self.register_buffer( "library_log_vars", torch.from_numpy(library_log_vars).float() ) if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) elif self.dispersion == "gene-cell": pass else: raise ValueError( "dispersion must be one of ['gene', 'gene-batch'," " 'gene-label', 'gene-cell'], but input was " "{}.format(self.dispersion)" ) use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation n_input_encoder = n_input + n_continuous_cov * encode_covariates cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov) encoder_cat_list = cat_list if encode_covariates else None self.z_encoder = Encoder( n_input_encoder, n_latent, n_cat_list=encoder_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, distribution=latent_distribution, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, var_activation=var_activation, ) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder( n_input_encoder, 1, n_layers=1, n_cat_list=encoder_cat_list, n_hidden=n_hidden, dropout_rate=dropout_rate, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, var_activation=var_activation, ) # decoder goes from n_latent-dimensional space to n_input-d data n_input_decoder = n_latent + n_continuous_cov self.decoder = DecoderSCVI( n_input_decoder, n_input, n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_decoder, use_layer_norm=use_layer_norm_decoder, scale_activation="softplus" if use_size_factor_key else "softmax", ) def _get_inference_input(self, tensors): x = tensors[REGISTRY_KEYS.X_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] cont_key = REGISTRY_KEYS.CONT_COVS_KEY cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None cat_key = REGISTRY_KEYS.CAT_COVS_KEY cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None input_dict = dict( x=x, batch_index=batch_index, cont_covs=cont_covs, cat_covs=cat_covs ) return input_dict def _get_generative_input(self, tensors, inference_outputs): z = inference_outputs["z"] library = inference_outputs["library"] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] y = tensors[REGISTRY_KEYS.LABELS_KEY] cont_key = REGISTRY_KEYS.CONT_COVS_KEY cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None cat_key = REGISTRY_KEYS.CAT_COVS_KEY cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY size_factor = ( torch.log(tensors[size_factor_key]) if size_factor_key in tensors.keys() else None ) input_dict = dict( z=z, library=library, batch_index=batch_index, y=y, cont_covs=cont_covs, cat_covs=cat_covs, size_factor=size_factor, ) return input_dict def _compute_local_library_params(self, batch_index): """ Computes local library parameters. Compute two tensors of shape (batch_index.shape[0], 1) where each element corresponds to the mean and variances, respectively, of the log library sizes in the batch the cell corresponds to. """ n_batch = self.library_log_means.shape[1] local_library_log_means = F.linear( one_hot(batch_index, n_batch), self.library_log_means ) local_library_log_vars = F.linear( one_hot(batch_index, n_batch), self.library_log_vars ) return local_library_log_means, local_library_log_vars @auto_move_data def inference(self, x, batch_index, cont_covs=None, cat_covs=None, n_samples=1): """ High level inference method. Runs the inference (encoder) model. """ x_ = x if self.use_observed_lib_size: library = torch.log(x.sum(1)).unsqueeze(1) if self.log_variational: x_ = torch.log(1 + x_) if cont_covs is not None and self.encode_covariates is True: encoder_input = torch.cat((x_, cont_covs), dim=-1) else: encoder_input = x_ if cat_covs is not None and self.encode_covariates is True: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() qz_m, qz_v, z = self.z_encoder(encoder_input, batch_index, *categorical_input) ql_m, ql_v = None, None if not self.use_observed_lib_size: ql_m, ql_v, library_encoded = self.l_encoder( encoder_input, batch_index, *categorical_input ) library = library_encoded if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) # when z is normal, untran_z == z untran_z = Normal(qz_m, qz_v.sqrt()).sample() z = self.z_encoder.z_transformation(untran_z) if self.use_observed_lib_size: library = library.unsqueeze(0).expand( (n_samples, library.size(0), library.size(1)) ) else: ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1))) ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1))) library = Normal(ql_m, ql_v.sqrt()).sample() outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, ql_m=ql_m, ql_v=ql_v, library=library) return outputs @auto_move_data def generative( self, z, library, batch_index, cont_covs=None, cat_covs=None, size_factor=None, y=None, transform_batch=None, ): """Runs the generative model.""" # TODO: refactor forward function to not rely on y decoder_input = z if cont_covs is None else torch.cat([z, cont_covs], dim=-1) if cat_covs is not None: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch if not self.use_size_factor_key: size_factor = library px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, decoder_input, size_factor, batch_index, *categorical_input, y, ) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r ) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) return dict( px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout ) def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, ): x = tensors[REGISTRY_KEYS.X_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, qz_v.sqrt()), Normal(mean, scale)).sum(dim=1) if not self.use_observed_lib_size: ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] ( local_library_log_means, local_library_log_vars, ) = self._compute_local_library_params(batch_index) kl_divergence_l = kl( Normal(ql_m, ql_v.sqrt()), Normal(local_library_log_means, local_library_log_vars.sqrt()), ).sum(dim=1) else: kl_divergence_l = 0.0 reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) kl_local_for_warmup = kl_divergence_z kl_local_no_warmup = kl_divergence_l weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup loss = torch.mean(reconst_loss + weighted_kl_local) kl_local = dict( kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z ) kl_global = torch.tensor(0.0) return LossRecorder(loss, reconst_loss, kl_local, kl_global) @torch.no_grad() def sample( self, tensors, n_samples=1, library_size=1, ) -> np.ndarray: r""" Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. Parameters ---------- tensors Tensors dict n_samples Number of required samples for each cell library_size Library size to scale scamples to Returns ------- x_new : :py:class:`torch.Tensor` tensor with shape (n_cells, n_genes, n_samples) """ inference_kwargs = dict(n_samples=n_samples) inference_outputs, generative_outputs, = self.forward( tensors, inference_kwargs=inference_kwargs, compute_loss=False, ) px_r = generative_outputs["px_r"] px_rate = generative_outputs["px_rate"] px_dropout = generative_outputs["px_dropout"] if self.gene_likelihood == "poisson": l_train = px_rate l_train = torch.clamp(l_train, max=1e8) dist = torch.distributions.Poisson( l_train ) # Shape : (n_samples, n_cells_batch, n_genes) elif self.gene_likelihood == "nb": dist = NegativeBinomial(mu=px_rate, theta=px_r) elif self.gene_likelihood == "zinb": dist = ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout ) else: raise ValueError( "{} reconstruction error not handled right now".format( self.module.gene_likelihood ) ) if n_samples > 1: exprs = dist.sample().permute( [1, 2, 0] ) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() return exprs.cpu() def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout) -> torch.Tensor: if self.gene_likelihood == "zinb": reconst_loss = ( -ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout ) .log_prob(x) .sum(dim=-1) ) elif self.gene_likelihood == "nb": reconst_loss = ( -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1) ) elif self.gene_likelihood == "poisson": reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1) return reconst_loss @torch.no_grad() @auto_move_data def marginal_ll(self, tensors, n_mc_samples): sample_batch = tensors[REGISTRY_KEYS.X_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples) for i in range(n_mc_samples): # Distribution parameters and sampled variables inference_outputs, _, losses = self.forward(tensors) qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] z = inference_outputs["z"] library = inference_outputs["library"] # Reconstruction Loss reconst_loss = losses.reconstruction_loss # Log-probabilities log_prob_sum = torch.zeros(qz_m.shape[0]).to(self.device) p_z = ( Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) .log_prob(z) .sum(dim=-1) ) p_x_zl = -reconst_loss log_prob_sum += p_z + p_x_zl q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) log_prob_sum -= q_z_x if not self.use_observed_lib_size: ( local_library_log_means, local_library_log_vars, ) = self._compute_local_library_params(batch_index) p_l = ( Normal(local_library_log_means, local_library_log_vars.sqrt()) .log_prob(library) .sum(dim=-1) ) ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) log_prob_sum += p_l - q_l_x to_sum[:, i] = log_prob_sum batch_log_lkl = logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) log_lkl = torch.sum(batch_log_lkl).item() return log_lkl
def __init__( self, n_input_regions: int, n_batch: int = 0, n_hidden: Optional[int] = None, n_latent: Optional[int] = None, n_layers_encoder: int = 2, n_layers_decoder: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, model_depth: bool = True, region_factors: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", latent_distribution: str = "normal", deeply_inject_covariates: bool = False, encode_covariates: bool = False, ): super().__init__() self.n_input_regions = n_input_regions self.n_hidden = (int(np.sqrt(self.n_input_regions)) if n_hidden is None else n_hidden) self.n_latent = int(np.sqrt( self.n_hidden)) if n_latent is None else n_latent self.n_layers_encoder = n_layers_encoder self.n_layers_decoder = n_layers_decoder self.n_cats_per_cov = n_cats_per_cov self.n_continuous_cov = n_continuous_cov self.model_depth = model_depth self.dropout_rate = dropout_rate self.latent_distribution = latent_distribution self.use_batch_norm_encoder = use_batch_norm in ("encoder", "both") self.use_batch_norm_decoder = use_batch_norm in ("decoder", "both") self.use_layer_norm_encoder = use_layer_norm in ("encoder", "both") self.use_layer_norm_decoder = use_layer_norm in ("decoder", "both") self.deeply_inject_covariates = deeply_inject_covariates self.encode_covariates = encode_covariates cat_list = ([n_batch] + list(n_cats_per_cov) if n_cats_per_cov is not None else []) n_input_encoder = self.n_input_regions + n_continuous_cov * encode_covariates encoder_cat_list = cat_list if encode_covariates else None self.z_encoder = Encoder( n_input=n_input_encoder, n_layers=self.n_layers_encoder, n_output=self.n_latent, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, dropout_rate=self.dropout_rate, activation_fn=torch.nn.LeakyReLU, distribution=self.latent_distribution, var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, ) self.z_decoder = Decoder( n_input=self.n_latent + self.n_continuous_cov, n_output=n_input_regions, n_hidden=self.n_hidden, n_cat_list=cat_list, n_layers=self.n_layers_decoder, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, deep_inject_covariates=self.deeply_inject_covariates, ) self.d_encoder = None if self.model_depth: # Decoder class to avoid variational split self.d_encoder = Decoder( n_input=n_input_encoder, n_output=1, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, n_layers=self.n_layers_encoder, ) self.region_factors = None if region_factors: self.region_factors = torch.nn.Parameter( torch.zeros(self.n_input_regions))
class PEAKVAE(BaseModuleClass): """ Variational auto-encoder model for ATAC-seq data. This is an implementation of the peakVI model descibed in. Parameters ---------- n_input_regions Number of input regions. n_batch Number of batches, if 0, no batch correction is performed. n_hidden Number of nodes per hidden layer. If `None`, defaults to square root of number of regions. n_latent Dimensionality of the latent space. If `None`, defaults to square root of `n_hidden`. n_layers_encoder Number of hidden layers used for encoder NN. n_layers_decoder Number of hidden layers used for decoder NN. dropout_rate Dropout rate for neural networks model_depth Model library size factors or not. region_factors Include region-specific factors in the model use_batch_norm One of the following * ``'encoder'`` - use batch normalization in the encoder only * ``'decoder'`` - use batch normalization in the decoder only * ``'none'`` - do not use batch normalization (default) * ``'both'`` - use batch normalization in both the encoder and decoder use_layer_norm One of the following * ``'encoder'`` - use layer normalization in the encoder only * ``'decoder'`` - use layer normalization in the decoder only * ``'none'`` - do not use layer normalization * ``'both'`` - use layer normalization in both the encoder and decoder (default) latent_distribution which latent distribution to use, options are * ``'normal'`` - Normal distribution (default) * ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax) deeply_inject_covariates Whether to deeply inject covariates into all layers of the decoder. If False (default), covairates will only be included in the input layer. """ def __init__( self, n_input_regions: int, n_batch: int = 0, n_hidden: Optional[int] = None, n_latent: Optional[int] = None, n_layers_encoder: int = 2, n_layers_decoder: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, model_depth: bool = True, region_factors: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", latent_distribution: str = "normal", deeply_inject_covariates: bool = False, encode_covariates: bool = False, ): super().__init__() self.n_input_regions = n_input_regions self.n_hidden = (int(np.sqrt(self.n_input_regions)) if n_hidden is None else n_hidden) self.n_latent = int(np.sqrt( self.n_hidden)) if n_latent is None else n_latent self.n_layers_encoder = n_layers_encoder self.n_layers_decoder = n_layers_decoder self.n_cats_per_cov = n_cats_per_cov self.n_continuous_cov = n_continuous_cov self.model_depth = model_depth self.dropout_rate = dropout_rate self.latent_distribution = latent_distribution self.use_batch_norm_encoder = use_batch_norm in ("encoder", "both") self.use_batch_norm_decoder = use_batch_norm in ("decoder", "both") self.use_layer_norm_encoder = use_layer_norm in ("encoder", "both") self.use_layer_norm_decoder = use_layer_norm in ("decoder", "both") self.deeply_inject_covariates = deeply_inject_covariates self.encode_covariates = encode_covariates cat_list = ([n_batch] + list(n_cats_per_cov) if n_cats_per_cov is not None else []) n_input_encoder = self.n_input_regions + n_continuous_cov * encode_covariates encoder_cat_list = cat_list if encode_covariates else None self.z_encoder = Encoder( n_input=n_input_encoder, n_layers=self.n_layers_encoder, n_output=self.n_latent, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, dropout_rate=self.dropout_rate, activation_fn=torch.nn.LeakyReLU, distribution=self.latent_distribution, var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, ) self.z_decoder = Decoder( n_input=self.n_latent + self.n_continuous_cov, n_output=n_input_regions, n_hidden=self.n_hidden, n_cat_list=cat_list, n_layers=self.n_layers_decoder, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, deep_inject_covariates=self.deeply_inject_covariates, ) self.d_encoder = None if self.model_depth: # Decoder class to avoid variational split self.d_encoder = Decoder( n_input=n_input_encoder, n_output=1, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, n_layers=self.n_layers_encoder, ) self.region_factors = None if region_factors: self.region_factors = torch.nn.Parameter( torch.zeros(self.n_input_regions)) def _get_inference_input(self, tensors): x = tensors[_CONSTANTS.X_KEY] batch_index = tensors[_CONSTANTS.BATCH_KEY] cont_covs = tensors.get(_CONSTANTS.CONT_COVS_KEY) cat_covs = tensors.get(_CONSTANTS.CAT_COVS_KEY) input_dict = dict( x=x, batch_index=batch_index, cont_covs=cont_covs, cat_covs=cat_covs, ) return input_dict def _get_generative_input(self, tensors, inference_outputs, transform_batch=None): z = inference_outputs["z"] qz_m = inference_outputs["qz_m"] batch_index = tensors[_CONSTANTS.BATCH_KEY] cont_covs = tensors.get(_CONSTANTS.CONT_COVS_KEY) cat_covs = tensors.get(_CONSTANTS.CAT_COVS_KEY) if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch input_dict = { "z": z, "qz_m": qz_m, "batch_index": batch_index, "cont_covs": cont_covs, "cat_covs": cat_covs, } return input_dict def get_reconstruction_loss(self, p, d, f, x): rl = torch.nn.BCELoss(reduction="none")(p * d * f, (x > 0).float()).sum(dim=-1) return rl @auto_move_data def inference( self, x, batch_index, cont_covs, cat_covs, n_samples=1, ) -> Dict[str, torch.Tensor]: """Helper function used in forward pass.""" if cat_covs is not None and self.encode_covariates is True: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() if cont_covs is not None and self.encode_covariates is True: encoder_input = torch.cat([x, cont_covs], dim=-1) else: encoder_input = x # if encode_covariates is False, cat_list to init encoder is None, so # batch_index is not used (or categorical_input, but it's empty) qz_m, qz_v, z = self.z_encoder(encoder_input, batch_index, *categorical_input) d = (self.d_encoder(encoder_input, batch_index, *categorical_input) if self.model_depth else 1) if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand( (n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand( (n_samples, qz_v.size(0), qz_v.size(1))) # when z is normal, untran_z == z untran_z = Normal(qz_m, qz_v.sqrt()).sample() z = self.z_encoder.z_transformation(untran_z) return dict(d=d, qz_m=qz_m, qz_v=qz_v, z=z) @auto_move_data def generative( self, z, qz_m, batch_index, cont_covs=None, cat_covs=None, use_z_mean=False, ): """Runs the generative model.""" if cat_covs is not None: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() latent = z if not use_z_mean else qz_m decoder_input = (latent if cont_covs is None else torch.cat( [latent, cont_covs], dim=-1)) p = self.z_decoder(decoder_input, batch_index, *categorical_input) return dict(p=p) def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): x = tensors[_CONSTANTS.X_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] d = inference_outputs["d"] p = generative_outputs["p"] kld = kl_divergence( Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1), ).sum(dim=1) f = torch.sigmoid( self.region_factors) if self.region_factors is not None else 1 rl = self.get_reconstruction_loss(p, d, f, x) loss = (rl.sum() + kld * kl_weight).sum() return LossRecorder(loss, rl, kld, kl_global=0.0)