def __init__(self, mask: numpy.ndarray, n_cat_list: Iterable[int] = None, regularizer: str = 'mask', positive_decoder: bool = True, reg_kwargs=None): super(DecoderVEGA, self).__init__() self.n_input = mask.shape[0] self.n_output = mask.shape[1] self.reg_method = regularizer if reg_kwargs and (reg_kwargs.get('d', None) is None): reg_kwargs['d'] = ~mask.T.astype(bool) if reg_kwargs is None: reg_kwargs = {} if regularizer == 'mask': print('Using masked decoder', flush=True) self.decoder = SparseLayer(mask, n_cat_list=n_cat_list, use_batch_norm=False, use_layer_norm=False, bias=True, dropout_rate=0) elif regularizer == 'gelnet': print('Using GelNet-regularized decoder', flush=True) self.decoder = FCLayers(n_in=self.n_input, n_out=self.n_output, n_layers=1, use_batch_norm=False, use_activation=False, use_layer_norm=False, bias=True, dropout_rate=0) self.regularizer = GelNet(**reg_kwargs) elif regularizer == 'l1': print('Using L1-regularized decoder', flush=True) self.decoder = FCLayers(n_in=self.n_input, n_out=self.n_output, n_layers=1, use_batch_norm=False, use_activation=False, use_layer_norm=False, bias=True, dropout_rate=0) self.regularizer = LassoRegularizer(**reg_kwargs) else: raise ValueError( "Regularizer not recognized. Choose one of ['mask', 'gelnet', 'l1']" )
def __init__( self, n_input: int, n_hidden: int = 128, n_labels: int = 5, n_layers: int = 1, dropout_rate: float = 0.1, logits: bool = False, use_batch_norm: bool = True, use_layer_norm: bool = False, activation_fn: nn.Module = nn.ReLU, ): super().__init__() self.logits = logits layers = [ FCLayers( n_in=n_input, n_out=n_hidden, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm, activation_fn=activation_fn, ), nn.Linear(n_hidden, n_labels), ] if not logits: layers.append(nn.Softmax(dim=-1)) self.classifier = nn.Sequential(*layers)
def __init__( self, n_input: int, n_cat_list: Iterable[int] = None, n_layers: int = 2, n_hidden: int = 128, use_batch_norm: bool = False, use_layer_norm: bool = True, deep_inject_covariates: bool = False, ): super().__init__() self.px_decoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0, activation_fn=torch.nn.LeakyReLU, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm, inject_covariates=deep_inject_covariates, ) self.output = torch.nn.Sequential(torch.nn.Linear(n_hidden, 1), torch.nn.LeakyReLU())
def __init__( self, n_input: int, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 5, n_layers: int = 2, dropout_rate: float = 0.1, log_variational: bool = True, ct_weight: np.ndarray = None, **module_kwargs, ): super().__init__() self.dispersion = "gene" self.n_latent = n_latent self.n_layers = n_layers self.n_hidden = n_hidden self.log_variational = log_variational self.gene_likelihood = "nb" self.latent_distribution = "normal" # Automatically deactivate if useless self.n_batch = 0 self.n_labels = n_labels # gene dispersion self.px_r = torch.nn.Parameter(torch.randn(n_input)) # z encoder goes from the n_input-dimensional data to an n_latent-d self.z_encoder = Encoder( n_input, n_latent, n_cat_list=[n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, inject_covariates=True, use_batch_norm=False, use_layer_norm=True, ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = FCLayers( n_in=n_latent, n_out=n_hidden, n_cat_list=[n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0, inject_covariates=True, use_batch_norm=False, use_layer_norm=True, ) self.px_decoder = torch.nn.Sequential( torch.nn.Linear(n_hidden, n_input), torch.nn.Softplus()) if ct_weight is not None: ct_weight = torch.tensor(ct_weight, dtype=torch.float32) else: ct_weight = torch.ones((self.n_labels, ), dtype=torch.float32) self.register_buffer("ct_weight", ct_weight)
def __init__( self, n_input: int, n_output: int, n_cat_list: Iterable[int] = None, n_layers: int = 1, n_hidden: int = 128, dropout_rate: float = 0.2, **kwargs, ): super().__init__() self.decoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, **kwargs, ) self.linear_out = nn.Linear(n_hidden, n_output)
def __init__( self, n_spots: int, n_labels: int, n_hidden: int, n_layers: int, n_latent: int, n_genes: int, decoder_state_dict: OrderedDict, px_decoder_state_dict: OrderedDict, px_r: np.ndarray, mean_vprior: np.ndarray = None, var_vprior: np.ndarray = None, amortization: Literal["none", "latent", "proportion", "both"] = "latent", ): super().__init__() self.n_spots = n_spots self.n_labels = n_labels self.n_hidden = n_hidden self.n_latent = n_latent self.n_genes = n_genes self.amortization = amortization # unpack and copy parameters self.decoder = FCLayers( n_in=n_latent, n_out=n_hidden, n_cat_list=[n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0, use_layer_norm=True, use_batch_norm=False, ) self.px_decoder = torch.nn.Sequential( torch.nn.Linear(n_hidden, n_genes), torch.nn.Softplus()) # don't compute gradient for those parameters self.decoder.load_state_dict(decoder_state_dict) for param in self.decoder.parameters(): param.requires_grad = False self.px_decoder.load_state_dict(px_decoder_state_dict) for param in self.px_decoder.parameters(): param.requires_grad = False self.register_buffer("px_o", torch.tensor(px_r)) # cell_type specific factor loadings self.V = torch.nn.Parameter( torch.randn(self.n_labels + 1, self.n_spots)) # within cell_type factor loadings self.gamma = torch.nn.Parameter( torch.randn(n_latent, self.n_labels, self.n_spots)) if mean_vprior is not None: self.p = mean_vprior.shape[1] self.register_buffer("mean_vprior", torch.tensor(mean_vprior)) self.register_buffer("var_vprior", torch.tensor(var_vprior)) else: self.mean_vprior = None self.var_vprior = None # noise from data self.eta = torch.nn.Parameter(torch.randn(self.n_genes)) # additive gene bias self.beta = torch.nn.Parameter(0.01 * torch.randn(self.n_genes)) # create additional neural nets for amortization # within cell_type factor loadings self.gamma_encoder = torch.nn.Sequential( FCLayers( n_in=self.n_genes, n_out=n_hidden, n_cat_list=None, n_layers=2, n_hidden=n_hidden, dropout_rate=0.1, ), torch.nn.Linear(n_hidden, n_latent * n_labels), ) # cell type loadings self.V_encoder = FCLayers( n_in=self.n_genes, n_out=self.n_labels + 1, n_layers=2, n_hidden=n_hidden, dropout_rate=0.1, )
class MRDeconv(BaseModuleClass): """ Model for multi-resolution deconvolution of spatial transriptomics. Parameters ---------- n_spots Number of input spots n_labels Number of cell types n_hidden Number of neurons in the hidden layers n_layers Number of layers used in the encoder networks n_latent Number of dimensions used in the latent variables n_genes Number of genes used in the decoder decoder_state_dict state_dict from the decoder of the CondSCVI model px_decoder_state_dict state_dict from the px_decoder of the CondSCVI model px_r parameters for the px_r tensor in the CondSCVI model mean_vprior Mean parameter for each component in the empirical prior over the latent space var_vprior Diagonal variance parameter for each component in the empirical prior over the latent space amortization which of the latent variables to amortize inference over (gamma, proportions, both or none) """ def __init__( self, n_spots: int, n_labels: int, n_hidden: int, n_layers: int, n_latent: int, n_genes: int, decoder_state_dict: OrderedDict, px_decoder_state_dict: OrderedDict, px_r: np.ndarray, mean_vprior: np.ndarray = None, var_vprior: np.ndarray = None, amortization: Literal["none", "latent", "proportion", "both"] = "latent", ): super().__init__() self.n_spots = n_spots self.n_labels = n_labels self.n_hidden = n_hidden self.n_latent = n_latent self.n_genes = n_genes self.amortization = amortization # unpack and copy parameters self.decoder = FCLayers( n_in=n_latent, n_out=n_hidden, n_cat_list=[n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0, use_layer_norm=True, use_batch_norm=False, ) self.px_decoder = torch.nn.Sequential( torch.nn.Linear(n_hidden, n_genes), torch.nn.Softplus()) # don't compute gradient for those parameters self.decoder.load_state_dict(decoder_state_dict) for param in self.decoder.parameters(): param.requires_grad = False self.px_decoder.load_state_dict(px_decoder_state_dict) for param in self.px_decoder.parameters(): param.requires_grad = False self.register_buffer("px_o", torch.tensor(px_r)) # cell_type specific factor loadings self.V = torch.nn.Parameter( torch.randn(self.n_labels + 1, self.n_spots)) # within cell_type factor loadings self.gamma = torch.nn.Parameter( torch.randn(n_latent, self.n_labels, self.n_spots)) if mean_vprior is not None: self.p = mean_vprior.shape[1] self.register_buffer("mean_vprior", torch.tensor(mean_vprior)) self.register_buffer("var_vprior", torch.tensor(var_vprior)) else: self.mean_vprior = None self.var_vprior = None # noise from data self.eta = torch.nn.Parameter(torch.randn(self.n_genes)) # additive gene bias self.beta = torch.nn.Parameter(0.01 * torch.randn(self.n_genes)) # create additional neural nets for amortization # within cell_type factor loadings self.gamma_encoder = torch.nn.Sequential( FCLayers( n_in=self.n_genes, n_out=n_hidden, n_cat_list=None, n_layers=2, n_hidden=n_hidden, dropout_rate=0.1, ), torch.nn.Linear(n_hidden, n_latent * n_labels), ) # cell type loadings self.V_encoder = FCLayers( n_in=self.n_genes, n_out=self.n_labels + 1, n_layers=2, n_hidden=n_hidden, dropout_rate=0.1, ) def _get_inference_input(self, tensors): # we perform MAP here, so we just need to subsample the variables return {} def _get_generative_input(self, tensors, inference_outputs): x = tensors[REGISTRY_KEYS.X_KEY] ind_x = tensors[REGISTRY_KEYS.X_KEY].long() input_dict = dict(x=x, ind_x=ind_x) return input_dict @auto_move_data def inference(self): return {} @auto_move_data def generative(self, x, ind_x): """Build the deconvolution model for every cell in the minibatch.""" m = x.shape[0] library = torch.sum(x, dim=1, keepdim=True) # setup all non-linearities beta = torch.nn.functional.softplus(self.beta) # n_genes eps = torch.nn.functional.softplus(self.eta) # n_genes x_ = torch.log(1 + x) # subsample parameters if self.amortization in ["both", "latent"]: gamma_ind = torch.transpose(self.gamma_encoder(x_), 0, 1).reshape( (self.n_latent, self.n_labels, -1)) else: gamma_ind = self.gamma[:, :, ind_x[:, 0]] # n_latent, n_labels, minibatch_size if self.amortization in ["both", "proportion"]: v_ind = self.V_encoder(x_) else: v_ind = self.V[:, ind_x[:, 0]].T # minibatch_size, labels + 1 v_ind = torch.nn.functional.softplus(v_ind) # reshape and get gene expression value for all minibatch gamma_ind = torch.transpose(gamma_ind, 2, 0) # minibatch_size, n_labels, n_latent gamma_reshape = gamma_ind.reshape( (-1, self.n_latent)) # minibatch_size * n_labels, n_latent enum_label = (torch.arange(0, self.n_labels).repeat((m)).view( (-1, 1))) # minibatch_size * n_labels, 1 h = self.decoder(gamma_reshape, enum_label.to(x.device)) px_rate = self.px_decoder(h).reshape( (m, self.n_labels, -1)) # (minibatch, n_labels, n_genes) # add the dummy cell type eps = eps.repeat( (m, 1)).view(m, 1, -1) # (M, 1, n_genes) <- this is the dummy cell type # account for gene specific bias and add noise r_hat = torch.cat([beta.unsqueeze(0).unsqueeze(1) * px_rate, eps], dim=1) # M, n_labels + 1, n_genes # now combine them for convolution px_scale = torch.sum(v_ind.unsqueeze(2) * r_hat, dim=1) # batch_size, n_genes px_rate = library * px_scale return dict(px_o=self.px_o, px_rate=px_rate, px_scale=px_scale, gamma=gamma_ind, v=v_ind) def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, n_obs: int = 1.0, ): x = tensors[REGISTRY_KEYS.X_KEY] px_rate = generative_outputs["px_rate"] px_o = generative_outputs["px_o"] gamma = generative_outputs["gamma"] reconst_loss = -NegativeBinomial(px_rate, logits=px_o).log_prob(x).sum(-1) # eta prior likelihood mean = torch.zeros_like(self.eta) scale = torch.ones_like(self.eta) glo_neg_log_likelihood_prior = -Normal(mean, scale).log_prob( self.eta).sum() glo_neg_log_likelihood_prior += torch.var(self.beta) # gamma prior likelihood if self.mean_vprior is None: # isotropic normal prior mean = torch.zeros_like(gamma) scale = torch.ones_like(gamma) neg_log_likelihood_prior = ( -Normal(mean, scale).log_prob(gamma).sum(2).sum(1)) else: # vampprior # gamma is of shape n_latent, n_labels, minibatch_size gamma = gamma.unsqueeze(1) # minibatch_size, 1, n_labels, n_latent mean_vprior = torch.transpose(self.mean_vprior, 0, 1).unsqueeze( 0) # 1, p, n_labels, n_latent var_vprior = torch.transpose(self.var_vprior, 0, 1).unsqueeze( 0) # 1, p, n_labels, n_latent pre_lse = (Normal(mean_vprior, torch.sqrt(var_vprior)).log_prob(gamma).sum(-1) ) # minibatch, p, n_labels log_likelihood_prior = torch.logsumexp(pre_lse, 1) - np.log( self.p) # minibatch, n_labels neg_log_likelihood_prior = -log_likelihood_prior.sum( 1) # minibatch # mean_vprior is of shape n_labels, p, n_latent loss = ( n_obs * torch.mean(reconst_loss + kl_weight * neg_log_likelihood_prior) + glo_neg_log_likelihood_prior) return LossRecorder(loss, reconst_loss, neg_log_likelihood_prior, glo_neg_log_likelihood_prior) @torch.no_grad() def sample( self, tensors, n_samples=1, library_size=1, ): raise NotImplementedError("No sampling method for DestVI") @torch.no_grad() @auto_move_data def get_proportions(self, x=None, keep_noise=False) -> np.ndarray: """Returns the loadings.""" if self.amortization in ["both", "proportion"]: # get estimated unadjusted proportions x_ = torch.log(1 + x) res = torch.nn.functional.softplus(self.V_encoder(x_)) else: res = (torch.nn.functional.softplus(self.V).cpu().numpy().T ) # n_spots, n_labels + 1 # remove dummy cell type proportion values if not keep_noise: res = res[:, :-1] # normalize to obtain adjusted proportions res = res / res.sum(axis=1).reshape(-1, 1) return res @torch.no_grad() @auto_move_data def get_gamma(self, x: torch.Tensor = None) -> torch.Tensor: """ Returns the loadings. Returns ------- type tensor """ # get estimated unadjusted proportions if self.amortization in ["latent", "both"]: x_ = torch.log(1 + x) gamma = self.gamma_encoder(x_) return torch.transpose(gamma, 0, 1).reshape( (self.n_latent, self.n_labels, -1)) # n_latent, n_labels, minibatch else: return self.gamma.cpu().numpy() # (n_latent, n_labels, n_spots) @torch.no_grad() @auto_move_data def get_ct_specific_expression(self, x: torch.Tensor = None, ind_x: torch.Tensor = None, y: int = None): """ Returns cell type specific gene expression at the queried spots. Parameters ---------- x tensor of data ind_x tensor of indices y integer for cell types """ # cell-type specific gene expression, shape (minibatch, celltype, gene). beta = torch.nn.functional.softplus(self.beta) # n_genes y_torch = y * torch.ones_like(ind_x) # obtain the relevant gammas if self.amortization in ["both", "latent"]: x_ = torch.log(1 + x) gamma_ind = torch.transpose(self.gamma_encoder(x_), 0, 1).reshape( (self.n_latent, self.n_labels, -1)) else: gamma_ind = self.gamma[:, :, ind_x[:, 0]] # n_latent, n_labels, minibatch_size # calculate cell type specific expression gamma_select = gamma_ind[:, y_torch[:, 0], torch.arange(ind_x.shape[0] )].T # minibatch_size, n_latent h = self.decoder(gamma_select, y_torch) px_scale = self.px_decoder(h) # (minibatch, n_genes) px_ct = torch.exp( self.px_o).unsqueeze(0) * beta.unsqueeze(0) * px_scale return px_ct # shape (minibatch, genes)
def __init__(self, adata, gmt_paths=None, add_nodes=1, min_genes=0, max_genes=5000, positive_decoder=True, encode_covariates=False, regularizer='mask', reg_kwargs=None, **kwargs): """ Constructor for class VEGA (VAE Enhanced by Gene Annotations). Parameters ---------- adata scanpy single-cell object. Please run setup_anndata() before passing to VEGA. gmt_paths one or more paths to .gmt files for GMVs initialization. add_nodes additional fully-connected nodes in the mask. min_genes minimum gene size for GMVs. max_genes maximum gene size for GMVs. positive_decoder whether to constrain decoder to positive weights encode_covariates whether to encode covariates along gene expression regularizer which regularization strategy to use (l1, gelnet, mask). Default: mask. reg_kwargs parameters for regularizer. **kwargs use_cuda using CPU (False) or CUDA (True). beta weight for KL-divergence. dropout dropout rate in model. z_dropout dropout rate for the latent space (for correlation). """ super(VEGA, self).__init__() self.adata = adata self.add_nodes_ = add_nodes self.min_genes_ = min_genes self.max_genes_ = max_genes # Check for setup and mask existence if '_vega' not in self.adata.uns.keys(): raise ValueError( 'Please run vega.utils.setup_anndata(adata) before initializing VEGA.' ) if 'mask' not in self.adata.uns['_vega'].keys() and not gmt_paths: raise ValueError( 'No existing mask found in Anndata object and no .gmt files passed to VEGA. Please provide .gmt file paths to initialize a new mask or use an Anndata object used for training of a previous VEGA model.' ) elif gmt_paths: create_mask(self.adata, gmt_paths, add_nodes, self.min_genes_, self.max_genes_) self.gmv_mask = adata.uns['_vega']['mask'] self.n_gmvs = self.gmv_mask.shape[1] self.n_genes = self.gmv_mask.shape[0] self.use_cuda = kwargs.get('use_cuda', False) self.beta_ = kwargs.get('beta', 0.0001) self.dropout_ = kwargs.get('dropout', 0.1) self.z_dropout_ = kwargs.get('z_dropout', 0.3) self.pos_dec_ = positive_decoder self.regularizer_ = regularizer self.encode_covariates = encode_covariates self.epoch_history = {} # Categorical covariates n_cats_per_cov = ( adata.uns['_scvi']['extra_categoricals']['n_cats_per_key'] if 'extra_categoricals' in adata.uns['_scvi'] else None) n_batch = adata.uns['_scvi']['summary_stats']['n_batch'] cat_list = [n_batch ] + list([] if n_cats_per_cov is None else n_cats_per_cov) # Model architecture self.encoder = FCLayers( n_in=self.n_genes, n_out=800, n_cat_list=cat_list if encode_covariates else None, n_layers=2, n_hidden=800, dropout_rate=self.dropout_) self.mean = nn.Sequential(nn.Linear(800, self.n_gmvs), nn.Dropout(self.z_dropout_)) self.logvar = nn.Sequential(nn.Linear(800, self.n_gmvs), nn.Dropout(self.z_dropout_)) #self.decoder = SparseLayer(self.gmv_mask.T, #n_cat_list=cat_list, #use_batch_norm=False, #use_layer_norm=False, #bias=True, #dropout_rate=0) # Setup decoder self.decoder = DecoderVEGA(mask=self.gmv_mask.T, n_cat_list=cat_list, regularizer=self.regularizer_, positive_decoder=self.pos_dec_, reg_kwargs=reg_kwargs) # Other hyperparams self.is_trained_ = kwargs.get('is_trained', False) # Constraining decoder to positive weights or not if self.pos_dec_: print('Constraining decoder to positive weights', flush=True) #self.decoder.sparse_layer[0].reset_params_pos() #self.decoder.sparse_layer[0].weight.data *= self.decoder.sparse_layer[0].mask self.decoder._positive_weights()