def __init__(self, n_input: int, n_latent: int, n_hidden: int, n_layers: int): super().__init__() self.n_input = n_input self.n_latent = n_latent self.epsilon = 5.0e-3 # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.encoder = Encoder( n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0.1, ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI( n_latent, n_input, n_layers=n_layers, n_hidden=n_hidden, ) # This gene-level parameter modulates the variance of the observation distribution self.px_r = torch.nn.Parameter(torch.ones(self.n_input))
def __init__( self, n_input, n_batch, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, y_prior=None, dispersion="gene", log_variational=True, gene_likelihood="zinb", ): super().__init__( n_input, n_batch, n_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, log_variational=log_variational, gene_likelihood=gene_likelihood, use_observed_lib_size=False, ) self.z_encoder = Encoder( n_input, n_latent, n_cat_list=[n_batch, n_labels], n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, ) self.decoder = DecoderSCVI( n_latent, n_input, n_cat_list=[n_batch, n_labels], n_layers=n_layers, n_hidden=n_hidden, ) self.y_prior = torch.nn.Parameter( y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False, ) self.classifier = Classifier(n_input, n_hidden, n_labels, n_layers=n_layers, dropout_rate=dropout_rate)
def __init__( self, n_input_regions: int = 0, n_input_genes: int = 0, n_batch: int = 0, n_labels: int = 0, gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", n_hidden: Optional[int] = None, n_latent: Optional[int] = None, n_layers_encoder: int = 2, n_layers_decoder: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, region_factors: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", latent_distribution: str = "normal", deeply_inject_covariates: bool = False, encode_covariates: bool = False, ): super().__init__() # INIT PARAMS self.n_input_regions = n_input_regions self.n_input_genes = n_input_genes self.n_hidden = (int(np.sqrt(self.n_input_regions + self.n_input_genes)) if n_hidden is None else n_hidden) self.n_batch = n_batch self.n_labels = n_labels self.gene_likelihood = gene_likelihood self.latent_distribution = latent_distribution self.n_latent = int(np.sqrt( self.n_hidden)) if n_latent is None else n_latent self.n_layers_encoder = n_layers_encoder self.n_layers_decoder = n_layers_decoder self.n_cats_per_cov = n_cats_per_cov self.n_continuous_cov = n_continuous_cov self.dropout_rate = dropout_rate self.use_batch_norm_encoder = use_batch_norm in ("encoder", "both") self.use_batch_norm_decoder = use_batch_norm in ("decoder", "both") self.use_layer_norm_encoder = use_layer_norm in ("encoder", "both") self.use_layer_norm_decoder = use_layer_norm in ("decoder", "both") self.encode_covariates = encode_covariates self.deeply_inject_covariates = deeply_inject_covariates cat_list = ([n_batch] + list(n_cats_per_cov) if n_cats_per_cov is not None else []) n_input_encoder_acc = (self.n_input_regions + n_continuous_cov * encode_covariates) n_input_encoder_exp = self.n_input_genes + n_continuous_cov * encode_covariates encoder_cat_list = cat_list if encode_covariates else None ## accessibility encoder self.z_encoder_accessibility = Encoder( n_input=n_input_encoder_acc, n_layers=self.n_layers_encoder, n_output=self.n_latent, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, dropout_rate=self.dropout_rate, activation_fn=torch.nn.LeakyReLU, distribution=self.latent_distribution, var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, ) ## expression encoder self.z_encoder_expression = Encoder( n_input=n_input_encoder_exp, n_layers=self.n_layers_encoder, n_output=self.n_latent, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, dropout_rate=self.dropout_rate, activation_fn=torch.nn.LeakyReLU, distribution=self.latent_distribution, var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, ) # expression decoder self.z_decoder_expression = DecoderSCVI( self.n_latent + self.n_continuous_cov, n_input_genes, n_cat_list=cat_list, n_layers=n_layers_decoder, n_hidden=self.n_hidden, inject_covariates=self.deeply_inject_covariates, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, ) # accessibility decoder self.z_decoder_accessibility = DecoderPeakVI( n_input=self.n_latent + self.n_continuous_cov, n_output=n_input_regions, n_hidden=self.n_hidden, n_cat_list=cat_list, n_layers=self.n_layers_decoder, use_batch_norm=self.use_batch_norm_decoder, use_layer_norm=self.use_layer_norm_decoder, deep_inject_covariates=self.deeply_inject_covariates, ) ## accessibility region-specific factors self.region_factors = None if region_factors: self.region_factors = torch.nn.Parameter( torch.zeros(self.n_input_regions)) ## expression dispersion parameters self.px_r = torch.nn.Parameter(torch.randn(n_input_genes)) ## expression library size encoder self.l_encoder_expression = LibrarySizeEncoder( n_input_encoder_exp, n_cat_list=encoder_cat_list, n_layers=self.n_layers_encoder, n_hidden=self.n_hidden, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, deep_inject_covariates=self.deeply_inject_covariates, ) ## accessibility library size encoder self.l_encoder_accessibility = DecoderPeakVI( n_input=n_input_encoder_acc, n_output=1, n_hidden=self.n_hidden, n_cat_list=encoder_cat_list, n_layers=self.n_layers_encoder, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, deep_inject_covariates=self.deeply_inject_covariates, )
def __init__( self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, gene_likelihood: str = "zinb", latent_distribution: str = "normal", encode_covariates: bool = False, deeply_inject_covariates: bool = True, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_size_factor_key: bool = False, use_observed_lib_size: bool = True, library_log_means: Optional[np.ndarray] = None, library_log_vars: Optional[np.ndarray] = None, var_activation: Optional[Callable] = None, ): super().__init__() self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.gene_likelihood = gene_likelihood # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels self.latent_distribution = latent_distribution self.encode_covariates = encode_covariates self.use_size_factor_key = use_size_factor_key self.use_observed_lib_size = use_size_factor_key or use_observed_lib_size if not self.use_observed_lib_size: if library_log_means is None or library_log_means is None: raise ValueError( "If not using observed_lib_size, " "must provide library_log_means and library_log_vars." ) self.register_buffer( "library_log_means", torch.from_numpy(library_log_means).float() ) self.register_buffer( "library_log_vars", torch.from_numpy(library_log_vars).float() ) if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) elif self.dispersion == "gene-cell": pass else: raise ValueError( "dispersion must be one of ['gene', 'gene-batch'," " 'gene-label', 'gene-cell'], but input was " "{}.format(self.dispersion)" ) use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation n_input_encoder = n_input + n_continuous_cov * encode_covariates cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov) encoder_cat_list = cat_list if encode_covariates else None self.z_encoder = Encoder( n_input_encoder, n_latent, n_cat_list=encoder_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, distribution=latent_distribution, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, var_activation=var_activation, ) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder( n_input_encoder, 1, n_layers=1, n_cat_list=encoder_cat_list, n_hidden=n_hidden, dropout_rate=dropout_rate, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, var_activation=var_activation, ) # decoder goes from n_latent-dimensional space to n_input-d data n_input_decoder = n_latent + n_continuous_cov self.decoder = DecoderSCVI( n_input_decoder, n_input, n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_decoder, use_layer_norm=use_layer_norm_decoder, scale_activation="softplus" if use_size_factor_key else "softmax", )