def run_ag(self, gen=5000, markers=400, incr=0): cm = pd.DataFrame(index=self.scc.var_names, columns=self.clusters) for i in self.clusters: cells = [x for x in self.scc.obs_names if x.startswith(i)] cm[i] = pd.DataFrame(np.mean(self.scc[cells, :].X, axis=0), index=self.scc.var_names) ag = AutoGenes(cm.T) if (type(markers) == list): markers = [i for i in range(markers[0], markers[1] + incr, incr)] for i in markers: ag.run(ngen=gen, seed=0, nfeatures=i, mode='fixed') self.ag.append(cm[ag.pareto[len(ag.pareto) - 1]]) else: ag.run(ngen=gen, seed=0, nfeatures=markers, mode='fixed') self.ag.append(cm[ag.pareto[len(ag.pareto) - 1]])
def run_cell2location( sc_data, sp_data, model_name=None, verbose=True, show_locations=False, return_all=True, summ_sc_data_args={ 'cluster_col': "annotation_1", 'selection': "cluster_specificity", 'selection_specificity': 0.07 }, train_args={ 'n_iter': 20000, 'learning_rate': 0.005, 'sample_prior': False, 'readable_var_name_col': None, 'sample_name_col': None }, model_kwargs={}, posterior_args={'n_samples': 1000}, export_args={ 'path': "./results", 'save_model': False, 'save_spatial_plots': True, 'run_name_suffix': '', 'scanpy_coords_name': 'coords' }): r"""Run cell2location model pipeline: train the model, sample prior and posterior, export results and save diagnostic plots The model decomposes multi-cell spatial transcriptomics data into cell type abundance estimates in a spatially resolved manner. The model uses a hierarchical non-negative decomposition of the gene expression profiles at spatial locations (each with multiple cells) into the reference signatures. The reference signatures that are estimated from scRNA-seq profiles using estimation of average gene expression profilves for each cell cluster. The cell2location model parameters are estimated using Varitional Bayesian Inference (ADVI) implemented in the pymc3 framework. This is done by optimisation using ADAM algorithm and taking advantage of the GPU acceleration. Anndata object with exported results, W weights representing cell state densities, and the trained model object are saved to `export_args['path']` .. note:: This pipeline is not specific to a particular model so please look up the relevant parameters in model documentation (the default is CoLocationModelNB4V2). Parameters ---------- sc_data : anndata object with single cell / nucleus data, or pd.DataFrame with genes in rows and cell type signatures (factors) in columns. sp_data : anndata object with spatial data, variable names should match sc_data model_name : model class object or None. If none, the default single-sample or multi-sample model is selected depending on the number of samples detected in in the sp_data.obs columns specified by `train_args['sample_name_col']`. verbose : boolean, print diagnostic messages? (Default value = True) show_locations : boolean, print the plot of factor locations? If False the plots are saved but not shown. This reduces notebook size. (Default value = False) return_all : boolean, return the model and annotated `sp_data`. If True, both are saved but not returned? (Default value = True) summ_sc_data_args : arguments for summarising single cell data * **cluster_col** - name of sc_data.obs column containing cluster annotations * **which_genes** - select intersect or union genes between single cell and spatial? (Default: intersect) * **selection** - select highly variable genes in sc_data? (Default: None - do not select), available options: * None * **cluster_specificity** use marker specificity to clusters to select most informative genes (see **selection_specificity** below) * **high_cv** use rank by coefficient of variation (variance / mean) * **cluster_markers** use cluster markers (derived using sc.tl.rank_genes_groups) * **AutoGeneS** use https://github.com/theislab/AutoGeneS * **select_n** - how many variable genes to select? Used only when selection is not None (Default: 5000). * **select_n_AutoGeneS** - how many variable genes tor select with AutoGeneS (lower than select_n)? (Default: 1000). * **selection_specificity** expression specificity cut-off (default=0.1), expression signatures of cell types are normalised to sum to 1 gene-wise, then a cutoff is applied. Decrease cutoff to include more genes. train_args : arguments for training methods. See help(c2l.LocationModel) for more details * **mode** - "normal" or "tracking" parameters? (Default: normal) * **use_raw** - extract data from RAW slot of anndata? Applies only to spatial data, not single cell reference. For reference sc_data.raw is always used. (Default: True) * **data_type** - use arrays with which precision? (Default: 'float32'). * **n_iter** - number of training iterations (Default: 20000). * **learning_rate** - ADAM optimiser learning rate (Default: 0.005). If the training is very unstable try reducing the learning rate. * **total_grad_norm_constraint** - ADAM optimiser total_grad_norm_constraint (Default: 200). Prevents exploding gradient problem, of the training is very unstable (jumps to Inf/NaN loss) try reducing this parameter. * **method** - which method to use to find posterior. Use default 'advi' unless you know what you are doing. * **sample_prior** - use sampling from the prior to evaluate how reasonable priors are for your data. (Default: False) This is not essential for 10X Visium data but can be very useful for troubleshooting when the model fails training with NaN or Inf likelihood / ELBO loss. It is essential to do this first to choose priors appropriate for any other technology. Sample from the prior should be in the same range as data, span similar number of orders of magnitude, and ideally be weakly informative: results/plots/evaluating_prior.png plot should looks like a very wide diagonal. Caution ------- Sampling the prior takes a lot of RAM (10s-100s of GB) depending on data size so if needed for troubleshooting - reduce the number of locations in needed. * **n_prior_samples** - number of prior samples. The more the better but also takes a lot of RAM. (Default: 10) * **n_restarts** - number of training restarts to evaluate stability (Default: 2) * **n_type** - type of restart training: 'restart' 'cv' 'bootstrap' (Default: restart) (see help(cell2location.LocationModel.fit_advi_iterative) for details) * **tracking_every**, **tracking_n_samples** - parameters for "tracking" mode: Posterior samples are saved after 'tracking_every' iterations (Default: 1000), the process is repeated 'tracking_n_samples' times (Default: 50) * **readable_var_name_col** - column in sp_data.var that contains readable gene ID (e.g. HGNC symbol) (Default: None) * **sample_name_col** - column in sp_data.obs that contains sample / slide ID (e.g. Visium section), for plotting W cell locations from a multi-sample object correctly (Default: None) * **fact_names** - optional list of factor names, by default taken from cluster names in sc_data. posterior_args : arguments for sampling posterior The model generates samples from the posterior distribution of each parameter to compute mean, SD, 5% and 95% quantiles - stored in `mod.samples` and exported in `adata.uns['mod']`. * **n_samples** - number of samples to generate (the more the better but you run out of the GPU memory if too many). (Default: 1000) * **evaluate_stability_align** - when generating the model stability plots, align different restarts? (Default: False) * **mean_field_slot** - Posterior of all parameters is sampled only from one training restart due to stability of the model. - which training restart to use? (Default: 'init_1') export_args : arguments for exporting results * **path** - file path where to save results (Default: "./results") * **plot_extension** - file extension of saved plots (Default: "png") * **scanpy_plot_vmax** - 'p99.2', 'scanpy_plot_size': 1.3 - scanpy.pl.spatial plottin settings * **save_model** - boolean, save trained model? Could be useful but also takes up 10s of GB of disk space. (Default: False) * **run_name_suffix** - optinal suffix to modify the name the run. (Default: '') * **export_q05** - boolean, save plots of 5% quantile of parameters. (Default: True) * **scanpy_coords_name** - sp_data.obsm entry that stores X and Y coordinates of each location. If None - no spatial plot is produced. * **img_key** - which image to use for scanpy plotting ('hires', 'lowres', None) model_kwargs : Keyword arguments for the model class. See the list of relevant arguments for CoLocationModelNB4V2. (Default value = {}) Returns ------- dict Results as a dictionary, use dict.keys() to find the elements. Results are saved to `export_args['path']`. """ # set default parameters d_summ_sc_data_args = { 'cluster_col': "annotation_1", 'which_genes': "intersect", 'selection': None, 'select_n': 5000, 'select_n_AutoGeneS': 1000, 'selection_specificity': 0.1, 'cluster_markers_kwargs': {} } d_train_args = { 'mode': "normal", 'use_raw': True, 'data_type': "float32", 'n_iter': 20000, 'learning_rate': 0.005, 'total_grad_norm_constraint': 200, 'method': 'advi', 'sample_prior': False, 'n_prior_samples': 10, 'n_restarts': 2, 'n_type': "restart", 'tracking_every': 1000, 'tracking_n_samples': 50, 'readable_var_name_col': None, 'sample_name_col': None, 'fact_names': None, 'minibatch_size': None } d_posterior_args = { 'n_samples': 1000, 'evaluate_stability_align': False, 'mean_field_slot': "init_1" } d_export_args = { 'path': "./results", 'plot_extension': "png", 'scanpy_plot_vmax': 'p99.2', 'scanpy_plot_size': 1.3, 'save_model': False, 'save_spatial_plots': True, 'run_name_suffix': '', 'export_q05': True, 'export_mean': False, 'scanpy_coords_name': 'spatial', 'img_key': 'hires' } # replace defaults with parameters supplied for k in summ_sc_data_args.keys(): d_summ_sc_data_args[k] = summ_sc_data_args[k] summ_sc_data_args = d_summ_sc_data_args for k in train_args.keys(): d_train_args[k] = train_args[k] train_args = d_train_args for k in posterior_args.keys(): d_posterior_args[k] = posterior_args[k] posterior_args = d_posterior_args for k in export_args.keys(): d_export_args[k] = export_args[k] export_args = d_export_args theano.config.allow_gc = True # start timing start = time.time() if np.sum(sp_data.obs_names.duplicated()) > 0: raise ValueError( 'Make sure location names are unique (`sp_data.obs_names`)') sp_data = sp_data.copy() ####### Summarising single cell clusters ####### if verbose: print('### Summarising single cell clusters ###') if not isinstance(sc_data, pd.DataFrame): # if scanpy compute cluster averages cell_state_df = get_cluster_averages( sc_data, cluster_col=summ_sc_data_args['cluster_col']) obs = sc_data.obs else: # if dataframe assume signature with matching .index to sp_data.var_names cell_state_df = sc_data.copy() obs = cell_state_df # select features if requested if summ_sc_data_args['selection'] == 'high_cv': cv = cell_state_df.var(1) / cell_state_df.mean(1) cv = cv.sort_values(ascending=False) cell_state_df = cell_state_df.loc[ cv[np.arange(summ_sc_data_args['select_n'])].index, :] elif summ_sc_data_args['selection'] == 'cluster_markers': if isinstance(sc_data, pd.DataFrame): raise ValueError( 'summ_sc_data_args["selection"] = "cluster_markers" can only be used with type(sc_data) Anndata' ) sel_feat = select_features( sc_data, summ_sc_data_args['cluster_col'], n_features=summ_sc_data_args['select_n'], use_raw=train_args['use_raw'], sc_kwargs=summ_sc_data_args['cluster_markers_kwargs']) cell_state_df = cell_state_df.loc[ cell_state_df.index.isin(sel_feat), :] elif summ_sc_data_args['selection'] == 'cluster_specificity': # selecting most informative genes based on specificity cell_state_df_norm = (cell_state_df.T / cell_state_df.sum(1)).T cell_state_df_norm = (cell_state_df_norm > summ_sc_data_args['selection_specificity']) sel_feat = cell_state_df_norm.index[cell_state_df_norm.sum(1) > 0] cell_state_df = cell_state_df.loc[ cell_state_df.index.isin(sel_feat), :] elif summ_sc_data_args['selection'] == 'AutoGeneS': cv = cell_state_df.var(1) / cell_state_df.mean(1) cv = cv.sort_values(ascending=False) cell_state_df = cell_state_df.loc[ cv[np.arange(summ_sc_data_args['select_n'])].index, :] from autogenes import AutoGenes ag = AutoGenes(cell_state_df.T) ag.run(ngen=5000, seed=0, nfeatures=summ_sc_data_args['select_n_AutoGeneS'], mode='fixed') pareto = ag.pareto print(ag.plot(size='large', weights=(1, -1))) # We then pick one solution and filter its corresponding marker genes: cell_state_df = cell_state_df[pareto[ len(pareto) - 1]] # select the solution with min correlation elif summ_sc_data_args['selection'] is not None: raise ValueError( "summ_sc_data_args['selection'] can be only None, high_cv, cluster_markers or AutoGeneS" ) # extract data as a dense matrix if train_args['use_raw']: if scipy.sparse.issparse(sp_data.raw.X): X_data = np.array(sp_data.raw.X.toarray()) else: X_data = np.array(sp_data.raw.X) else: if scipy.sparse.issparse(sp_data.raw.X): X_data = np.array(sp_data.X.toarray()) else: X_data = np.array(sp_data.X) # Filter cell states and X_data to common genes sp_ind = sp_data.var_names.isin(cell_state_df.index) if np.sum(sp_ind) == 0: raise ValueError( 'No overlapping genes found, check that `sc_data` and `sp_data` use the same variable names' ) X_data = X_data[:, sp_ind] cell_state_df = cell_state_df.loc[sp_data.var_names[sp_ind], :] # prepare cell state matrix cell_state_mat = cell_state_df.values # if factor names not provided use cluster names if train_args['fact_names'] is None: fact_names = cell_state_df.columns else: # useful for unobserved factor models fact_names = train_args['fact_names'] if train_args['sample_name_col'] is None: sp_data.obs['sample'] = 'sample' train_args['sample_name_col'] = 'sample' if train_args['readable_var_name_col'] is not None: readable_var_name_col = sp_data.var[ train_args['readable_var_name_col']][sp_ind] else: readable_var_name_col = None if train_args['minibatch_size'] is not None: model_kwargs['minibatch_size'] = train_args['minibatch_size'] ####### Creating model ####### # choose model when not specified if model_name is None: sample_name = sp_data.obs[train_args['sample_name_col']] sample_n = len(np.unique(sample_name)) if sample_n < sp_data.shape[0]: Model = models.LocationModelLinearDependentWMultiExperiment model_name = 'LocationModelLinearDependentWMultiExperiment' else: ValueError( "train_args['sample_name_col'] points to non-existing column or the number of samples(batches) is equal to the number of observations `adata.n_obs`" ) else: # use supplied class Model = model_name n_exper = len(sp_data.obs[train_args['sample_name_col']].unique()) if n_exper == X_data.shape[0]: raise ValueError( "The number of samples is equal to the number of locations, aborting... (check 'sample_name_col')" ) if verbose: print('### Creating model ### - time ' + str(np.around((time.time() - start) / 60, 2)) + ' min') mod = Model( cell_state_mat, X_data, data_type=train_args['data_type'], n_iter=train_args['n_iter'], learning_rate=train_args['learning_rate'], total_grad_norm_constraint=train_args['total_grad_norm_constraint'], verbose=verbose, var_names=sp_data.var_names[sp_ind], var_names_read=readable_var_name_col, obs_names=sp_data.obs_names, fact_names=fact_names, sample_id=sp_data.obs[train_args['sample_name_col']], **model_kwargs) ####### Print run name ####### run_name = f"{mod.__class__.__name__}_{n_exper}experiments_{mod.n_fact}clusters_{mod.n_obs}locations_{mod.n_var}genes{export_args['run_name_suffix']}" print('### Analysis name: ' + run_name) # analysis name is always printed # create the export directory path = export_args['path'] + run_name + '/' if not os.path.exists(path): mkdir(path) fig_path = path + 'plots/' if not os.path.exists(fig_path): mkdir(fig_path) ####### Sampling prior ####### if train_args['sample_prior']: if verbose: print('### Sampling prior ###') mod.sample_prior(samples=train_args['n_prior_samples']) # plot & save plot mod.plot_prior_vs_data() plt.tight_layout() save_plot(fig_path, filename='evaluating_prior', extension=export_args['plot_extension']) plt.close() ####### Training model ####### if verbose: print('### Training model ###') if train_args['mode'] == 'normal': mod.fit_advi_iterative(n=train_args['n_restarts'], method=train_args['method'], n_type=train_args['n_type'], progressbar=verbose) elif train_args['mode'] == 'tracking': mod.verbose = False if isinstance(mod, Pymc3LocModel): mod.track_parameters(n=train_args['n_restarts'], every=train_args['tracking_every'], n_samples=train_args['tracking_n_samples'], n_type=train_args['n_type'], df_node_name1='nUMI_factors', df_node_df_name1='spot_factors_df', df_prior_node_name1='spot_fact_mu_hyp', df_prior_node_name2='spot_fact_sd_hyp', mu_node_name='mu', data_node='X_data', extra_df_parameters=['spot_add'], sample_type='post_sample_means') else: # TODO add isinstance(mod, PyroModel) mod.fit_advi_iterative(n=train_args['n_restarts'], method=train_args['method'], n_type=train_args['n_type'], tracking=True) else: raise ValueError( "train_args['mode'] can be only 'normal' or 'tracking'") theano.config.compute_test_value = 'ignore' ####### Evaluate stability of training ####### if train_args['n_restarts'] > 1: n_plots = train_args['n_restarts'] - 1 ncol = int(np.min((n_plots, 3))) nrow = np.ceil(n_plots / ncol) plt.figure(figsize=(5 * nrow, 5 * ncol)) mod.evaluate_stability( n_samples=posterior_args['n_samples'], align=posterior_args['evaluate_stability_align']) plt.tight_layout() save_plot(fig_path, filename='evaluate_stability', extension=export_args['plot_extension']) plt.close() ####### Sampling posterior ####### if verbose: print('### Sampling posterior ### - time ' + str(np.around((time.time() - start) / 60, 2)) + ' min') mod.sample_posterior(node='all', n_samples=posterior_args['n_samples'], save_samples=False, mean_field_slot=posterior_args['mean_field_slot']) # evaluate predictive accuracy mod.compute_expected() ####### Export summarised posterior & Saving results ####### if verbose: print('### Saving results ###') # save W cell locations (cell density) # convert cell location parameters (cell density) to a dataframe mod.sample2df(node_name='spot_factors') # add cell location parameters to `sp_data.obs` sp_data = mod.annotate_spot_adata(sp_data) # save mod.spot_factors_df.to_csv(path + 'W_cell_density.csv') if export_args['export_q05']: mod.spot_factors_q05.to_csv(path + 'W_cell_density_q05.csv') # convert cell location parameters (mRNA count) to a dataframe, see help(mod.sample2df) for details mod.sample2df(node_name='nUMI_factors') # add cell location parameters to `sp_data.obs` sp_data = mod.annotate_spot_adata(sp_data) # save W cell locations (mRNA count) mod.spot_factors_df.to_csv(path + 'W_mRNA_count.csv') if export_args['export_q05']: mod.spot_factors_q05.to_csv(path + 'W_mRNA_count_q05.csv') # add posterior of all parameters to `sp_data.uns['mod']` mod.fact_filt = None sp_data = mod.export2adata(adata=sp_data, slot_name='mod') sp_data.uns['mod']['fact_names'] = list(sp_data.uns['mod']['fact_names']) sp_data.uns['mod']['var_names'] = list(sp_data.uns['mod']['var_names']) sp_data.uns['mod']['obs_names'] = list(sp_data.uns['mod']['obs_names']) # save spatial anndata with exported posterior sp_data.write(filename=path + 'sp.h5ad', compression='gzip') # save model object and related annotations if export_args['save_model']: # save the model and other objects res_dict = { 'mod': mod, 'sp_data': sp_data, 'sc_obs': obs, 'model_name': model_name, 'summ_sc_data_args': summ_sc_data_args, 'train_args': train_args, 'posterior_args': posterior_args, 'export_args': export_args, 'run_name': run_name, 'run_time': str(np.around((time.time() - start) / 60, 2)) + ' min' } pickle.dump(res_dict, file=open(path + 'model_.p', "wb")) else: # just save the settings res_dict = { 'sc_obs': obs, 'model_name': model_name, 'summ_sc_data_args': summ_sc_data_args, 'train_args': train_args, 'posterior_args': posterior_args, 'export_args': export_args, 'run_name': run_name, 'run_time': str(np.around((time.time() - start) / 60, 2)) + ' min' } pickle.dump(res_dict, file=open(path + 'model_.p', "wb")) # add model for returning res_dict['mod'] = mod ####### Plotting ####### if verbose: print('### Ploting results ###') # Show training history # if verbose: mod.plot_history(0) else: mod.plot_history(0) plt.tight_layout() save_plot(fig_path, filename='training_history_all', extension=export_args['plot_extension']) plt.close() if verbose: mod.plot_history(int(np.ceil(train_args['n_iter'] * 0.2))) else: mod.plot_history(int(np.ceil(train_args['n_iter'] * 0.2))) plt.tight_layout() save_plot(fig_path, filename='training_history_without_first_20perc', extension=export_args['plot_extension']) plt.close() # Predictive accuracy try: mod.plot_posterior_mu_vs_data() plt.tight_layout() save_plot(fig_path, filename='data_vs_posterior_mean', extension=export_args['plot_extension']) plt.close() except Exception as e: print('Some error in plotting `mod.plot_posterior_mu_vs_data()`\n ' + str(e)) ####### Plotting posterior of W / cell locations ####### if verbose: print('### Plotting posterior of W / cell locations ###') data_samples = sp_data.obs[train_args['sample_name_col']].unique() cluster_plot_names = pd.Series( [i[17:] for i in mod.spot_factors_df.columns]) if export_args['save_spatial_plots']: try: for s in data_samples: # if slots needed to generate scanpy plots are present, scanpy: sc_spatial_present = np.any( np.isin(list(sp_data.uns.keys()), ['spatial'])) if sc_spatial_present: sc.settings.figdir = fig_path + 'spatial/' os.makedirs(fig_path + 'spatial/', exist_ok=True) facecolor = rcParams["axes.facecolor"] rcParams["axes.facecolor"] = "black" s_ind = sp_data.obs[train_args['sample_name_col']] == s s_keys = list(sp_data.uns['spatial'].keys()) s_spatial = np.array(s_keys)[[s in i for i in s_keys]][0] if export_args['export_mean']: # Visualize cell type locations - mRNA_count = nUMI ##### # making copy to transform to log & assign nice names clust_names_orig = [ 'mean_nUMI_factors' + i for i in sp_data.uns['mod']['fact_names'] ] clust_names = sp_data.uns['mod']['fact_names'] sp_data.obs[clust_names] = sp_data.obs[ clust_names_orig] fig = sc.pl.spatial( sp_data[s_ind, :], cmap='magma', color=clust_names, ncols=5, library_id=s_spatial, size=export_args['scanpy_plot_size'], img_key=export_args['img_key'], alpha_img=0, vmin=0, vmax=export_args['scanpy_plot_vmax'], return_fig=True, show=False) fig.savefig( f"{fig_path}/spatial/W_mRNA_count_mean_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches='tight') fig.clear() plt.close(fig) if show_locations: plt.show() fig = sc.pl.spatial( sp_data[s_ind, :], cmap='magma', color=clust_names, ncols=5, library_id=s_spatial, size=export_args['scanpy_plot_size'], img_key=export_args['img_key'], alpha_img=1, vmin=0, vmax=export_args['scanpy_plot_vmax'], show=False, return_fig=True) fig.savefig( f"{fig_path}/spatial/histo_W_mRNA_count_mean_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches='tight') fig.clear() plt.close(fig) # Visualize cell type locations ##### # making copy to transform to log & assign nice names clust_names_orig = [ 'mean_spot_factors' + i for i in sp_data.uns['mod']['fact_names'] ] clust_names = sp_data.uns['mod']['fact_names'] sp_data.obs[clust_names] = ( sp_data.obs[clust_names_orig]) fig = sc.pl.spatial( sp_data[s_ind, :], cmap='magma', color=clust_names, ncols=5, library_id=s_spatial, size=export_args['scanpy_plot_size'], img_key=export_args['img_key'], alpha_img=0, vmin=0, vmax=export_args['scanpy_plot_vmax'], show=False, return_fig=True) fig.savefig( f"{fig_path}/spatial/W_cell_density_mean_{s}_{export_args['scanpy_plot_vmax']}." f"{export_args['plot_extension']}", bbox_inches='tight') fig.clear() plt.close(fig) fig = sc.pl.spatial( sp_data[s_ind, :], cmap='magma', color=clust_names, ncols=5, library_id=s_spatial, size=export_args['scanpy_plot_size'], img_key=export_args['img_key'], alpha_img=1, vmin=0, vmax=export_args['scanpy_plot_vmax'], show=False, return_fig=True) fig.savefig( f"{fig_path}/spatial/histo_W_cell_density_mean_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches='tight') fig.clear() plt.close(fig) if export_args['export_q05']: # Visualize cell type locations - mRNA_count = nUMI ##### # making copy to transform to log & assign nice names clust_names_orig = [ 'q05_nUMI_factors' + i for i in sp_data.uns['mod']['fact_names'] ] clust_names = sp_data.uns['mod']['fact_names'] sp_data.obs[clust_names] = ( sp_data.obs[clust_names_orig]) fig = sc.pl.spatial( sp_data[s_ind, :], cmap='magma', color=clust_names, ncols=5, library_id=s_spatial, size=export_args['scanpy_plot_size'], img_key=export_args['img_key'], alpha_img=0, vmin=0, vmax=export_args['scanpy_plot_vmax'], show=False, return_fig=True) fig.savefig( f"{fig_path}/spatial/W_mRNA_count_q05_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches='tight') fig.clear() plt.close(fig) fig = sc.pl.spatial( sp_data[s_ind, :], cmap='magma', color=clust_names, ncols=5, library_id=s_spatial, size=export_args['scanpy_plot_size'], img_key=export_args['img_key'], alpha_img=1, vmin=0, vmax=export_args['scanpy_plot_vmax'], show=False, return_fig=True) fig.savefig( f"{fig_path}/spatial/histo_W_mRNA_count_q05_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches='tight') fig.clear() plt.close(fig) # Visualize cell type locations ##### # making copy to transform to log & assign nice names clust_names_orig = [ 'q05_spot_factors' + i for i in sp_data.uns['mod']['fact_names'] ] clust_names = sp_data.uns['mod']['fact_names'] sp_data.obs[clust_names] = ( sp_data.obs[clust_names_orig]) fig = sc.pl.spatial( sp_data[s_ind, :], cmap='magma', color=clust_names, ncols=5, library_id=s_spatial, size=export_args['scanpy_plot_size'], img_key=export_args['img_key'], alpha_img=0, vmin=0, vmax=export_args['scanpy_plot_vmax'], show=False, return_fig=True) fig.savefig( f"{fig_path}/spatial/W_cell_density_q05_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches='tight') fig.clear() plt.close(fig) fig = sc.pl.spatial( sp_data[s_ind, :], cmap='magma', color=clust_names, ncols=5, library_id=s_spatial, size=export_args['scanpy_plot_size'], img_key=export_args['img_key'], alpha_img=1, vmin=0, vmax=export_args['scanpy_plot_vmax'], show=False, return_fig=True) fig.savefig( f"{fig_path}/spatial/histo_W_cell_density_q05_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches='tight') fig.clear() plt.close(fig) rcParams["axes.facecolor"] = facecolor else: # if coordinates exist plot if export_args['scanpy_coords_name'] is not None: # move spatial coordinates to obs for compatibility with our plotter sp_data.obs['imagecol'] = sp_data.obsm[ export_args['scanpy_coords_name']][:, 0] sp_data.obs['imagerow'] = sp_data.obsm[ export_args['scanpy_coords_name']][:, 1] if export_args['export_mean']: p = c2lpl.plot_factor_spatial( adata=sp_data, fact_ind=np.arange( mod.spot_factors_df.shape[1]), fact=mod.spot_factors_df, cluster_names=cluster_plot_names, n_columns=6, trans='identity', sample_name=s, samples_col=train_args['sample_name_col'], obs_x='imagecol', obs_y='imagerow') p.save(filename=fig_path + 'cell_locations_W_mRNA_count_' + str(s) + '.' \ + export_args['plot_extension']) if export_args['export_q05']: p = c2lpl.plot_factor_spatial( adata=sp_data, fact_ind=np.arange( mod.spot_factors_q05.shape[1]), fact=mod.spot_factors_q05, cluster_names=cluster_plot_names, n_columns=6, trans='identity', sample_name=s, samples_col=train_args['sample_name_col'], obs_x='imagecol', obs_y='imagerow') p.save(filename=fig_path + 'cell_locations_W_mRNA_count_q05_' + str(s) + '.' \ + export_args['plot_extension']) if show_locations: print(p) except Exception as e: print( 'Some error in plotting with scanpy or `cell2location.plt.plot_factor_spatial()`\n ' + repr(e)) if verbose: print('### Done ### - time ' + str(np.around((time.time() - start) / 60, 2)) + ' min') if return_all: return res_dict else: del res_dict del mod gc.collect() return str((time.time() - start) / 60) + ' min'
def run_regression( sc_data, model_name='RegressionNBV4Torch', verbose=True, return_all=True, train_args={ 'covariate_col_names': None, 'sample_name_col': "sample", 'tech_name_col': None, 'n_iter': 15000, 'learning_rate': 0.005, 'sample_prior': False, 'readable_var_name_col': "SYMBOL" }, model_kwargs={}, posterior_args={'n_samples': 1000}, export_args={ 'path': "./results", 'save_model': True, 'run_name_suffix': '', 'scanpy_coords_name': 'coords' }): r""" Run cell2location model: train, sample prior and posterior, save, export results and save diagnostic plots Spatial expression of cell types (proxy for density) is exported as columns of adata.obs, named `'mean_nUMI_factors' + 'cluster name'` for posterior mean or `'q05_nUMI_factors' + 'cluster name'` for 5% quantile of the posterior and as np.ndarray in `adata.uns['mod']['post_sample_means']['nUMI_factors']` (also post_sample_q05, post_sample_q95, post_sample_sd) Anndata object with exported results, W weights representing cell state densities, and the trained model object are saved to `export_args['path']` :param sc_data: anndata object with single cell / nucleus data :param model: model name as a string :param train_args: arguments for training methods. See help(c2l.LocationModel) for more details 'covariate_col_names' - a list of columns in sp_data.obs that contains covariates of interest 'sample_name_col' - column in sp_data.obs that contains sample / slide ID (e.g. Visium section) - for plotting W cell locations 'mode' - "normal" or "tracking" parameters? 'use_raw' - extract data from RAW slot of anndata? Applies only to spatial, not single cell reference. 'sample_prior' - use sampling from the prior to evaluate how reasonable priors are for your data. This is not essential for Visium data but can be very useful for troubleshooting. It is essential for choose priors appropriate for any other spatial technology. Caution: takes a lot of RAM (10th-100s of GB). 'n_prior_samples' - Number of prior sample. The more the better but also takes a lot of RAM. 'n_restarts' - number of training restarts to evaluate stability 'n_type' - type of restart training: 'restart' 'cv' 'bootstrap'(see help(c2l.LocationModel.fit_advi_iterative) for details) 'method' - which method to use to find posterior. Use default 'advi' unless you know what you are doing. 'readable_var_name_col' - column in sp_data.var that contains readable gene ID (e.g. HGNC symbol) 'fact_names' - optional list of factor names, by default taken from sc_data. :param posterior_args: arguments for sampling posterior, We get 1000 samples from the posterior distribution of each parameter to compute means, SD, 5% and 95% quantiles - stored in `mod.samples` and exported in `adata.uns['mod']`. 'n_samples' - number of samples to get from posterior approximation to compute average and SD of that distribution :param export_args: arguments for exporting results 'path' - path where to save results 'plot_extension' - file extention of saved plots 'save_model' - boolean, save trained model? Could be useful but also takes up GBs of disk. 'export_q05' - boolean, save plots of 5% quantile of parameters. :param return_all: return the model and annotated `sp_data` or just save both? :return: results as a dictionary {'mod','sp_data','sc_obs','model_name', 'summ_sc_data_args', 'train_args','posterior_args', 'export_args', 'run_name', 'run_time'} """ # set default parameters d_train_args = { 'mode': "normal", 'use_raw': True, 'data_type': "float32", 'n_iter': 100000, 'learning_rate': 0.001, 'total_grad_norm_constraint': 200, 'method': 'advi', 'sample_prior': True, 'n_prior_samples': 10, 'n_restarts': 2, 'n_type': "restart", 'tracking_every': 1000, 'tracking_n_samples': 50, 'readable_var_name_col': None, 'sample_name_col': None, 'fact_names': None } d_posterior_args = { 'n_samples': 1000, 'evaluate_stability_align': False, 'mean_field_slot': "init_1" } d_export_args = { 'path': "./results", 'plot_extension': "png", 'save_model': True, 'run_name_suffix': '', 'export_q05': True, 'scanpy_coords_name': 'spatial' } # replace defaults with parameters supplied for k in summ_sc_data_args.keys(): d_summ_sc_data_args[k] = summ_sc_data_args[k] summ_sc_data_args = d_summ_sc_data_args for k in train_args.keys(): d_train_args[k] = train_args[k] train_args = d_train_args for k in posterior_args.keys(): d_posterior_args[k] = posterior_args[k] posterior_args = d_posterior_args for k in export_args.keys(): d_export_args[k] = export_args[k] export_args = d_export_args theano.config.allow_gc = True # start timing start = time.time() sp_data = sp_data.copy() # move spatial coordinates to obs for compatibility with our plotter sp_data.obs['imagecol'] = sp_data.obsm[ export_args['scanpy_coords_name']][:, 0] sp_data.obs['imagerow'] = sp_data.obsm[ export_args['scanpy_coords_name']][:, 1] # import the specied version of the model if type(model_name) is str: import cell2location.models as models Model = getattr(models, model_name) else: Model = model_name ####### Summarising single cell clusters ####### if verbose: print('### Summarising single cell clusters ###') if not isinstance(sc_data, pd.DataFrame): # if scanpy compute cluster averages cell_state_df = get_cluster_averages( sc_data, cluster_col=summ_sc_data_args['cluster_col']) obs = sc_data.obs else: # if dataframe assume signature with matching .index to sp_data.var_names cell_state_df = sc_data.copy() obs = cell_state_df # select features if requested if summ_sc_data_args['selection'] == 'high_cv': cv = cell_state_df.var(1) / cell_state_df.mean(1) cv = cv.sort_values(ascending=False) cell_state_df = cell_state_df.loc[ cv[np.arange(summ_sc_data_args['select_n'])].index, :] elif summ_sc_data_args['selection'] == 'cluster_markers': if isinstance(sc_data, pd.DataFrame): raise ValueError( 'summ_sc_data_args["selection"] = "cluster_markers" can only be used with type(sc_data) Anndata' ) sel_feat = select_features(sc_data, summ_sc_data_args['cluster_col'], n_features=summ_sc_data_args['select_n'], use_raw=train_args['use_raw']) cell_state_df = cell_state_df.loc[ cell_state_df.index.isin(sel_feat), :] elif summ_sc_data_args['selection'] == 'AutoGeneS': cv = cell_state_df.var(1) / cell_state_df.mean(1) cv = cv.sort_values(ascending=False) cell_state_df = cell_state_df.loc[ cv[np.arange(summ_sc_data_args['select_n'])].index, :] from autogenes import AutoGenes ag = AutoGenes(cell_state_df.T) ag.run(ngen=5000, seed=0, nfeatures=summ_sc_data_args['select_n_AutoGeneS'], mode='fixed') pareto = ag.pareto print(ag.plot(size='large', weights=(1, -1))) # We then pick one solution and filter its corresponding marker genes: cell_state_df = cell_state_df[pareto[ len(pareto) - 1]] # select the solution with min correlation elif summ_sc_data_args['selection'] is not None: raise ValueError( "summ_sc_data_args['selection'] can be only None, high_cv, cluster_markers or AutoGeneS" ) # extract data as a dense matrix if train_args['use_raw']: X_data = sp_data.raw.X.toarray() else: X_data = sp_data.X.toarray() # Filter cell states and X_data to common genes sp_ind = sp_data.var_names.isin(cell_state_df.index) if np.sum(sp_ind) == 0: raise ValueError( 'No overlapping genes found, check that `sc_data` and `sp_data` use the same variable names' ) X_data = X_data[:, sp_ind] cell_state_df = cell_state_df.loc[sp_data.var_names[sp_ind], :] # prepare cell state matrix cell_state_mat = cell_state_df.values # if factor names not provided use cluster names if train_args['fact_names'] is None: fact_names = cell_state_df.columns else: # useful for unobserved factor models fact_names = train_args['fact_names'] if train_args['sample_name_col'] is None: sp_data.obs['sample'] = 'sample' train_args['sample_name_col'] = 'sample' if train_args['readable_var_name_col'] is not None: readable_var_name_col = sp_data.var[ train_args['readable_var_name_col']][sp_ind] else: readable_var_name_col = None ####### Creating model ####### if verbose: print('### Creating model ### - time ' + str(np.around((time.time() - start) / 60, 2)) + ' min') mod = Model( cell_state_mat, X_data, data_type=train_args['data_type'], n_iter=train_args['n_iter'], learning_rate=train_args['learning_rate'], total_grad_norm_constraint=train_args['total_grad_norm_constraint'], verbose=verbose, var_names=sp_data.var_names[sp_ind], var_names_read=readable_var_name_col, obs_names=sp_data.obs_names, fact_names=fact_names, sample_id=sp_data.obs[train_args['sample_name_col']], **model_kwargs) ####### Print run name ####### run_name = str(mod.__class__.__name__) + '_' + str(mod.n_fact) + 'clusters_' \ + str(mod.n_cells) + 'locations_' + str(mod.n_genes) + 'genes' \ + export_args['run_name_suffix'] print('### Analysis name: ' + run_name) # analysis name is always printed # create the export directory path = export_args['path'] + run_name + '/' if not os.path.exists(path): mkdir(path) fig_path = path + 'plots/' if not os.path.exists(fig_path): mkdir(fig_path) ####### Sampling prior ####### if train_args['sample_prior']: if verbose: print('### Sampling prior ###') mod.sample_prior(samples=train_args['n_prior_samples']) # plot & save plot mod.plot_prior_vs_data() save_plot(fig_path, filename='evaluating_prior', extension=export_args['plot_extension']) plt.close() ####### Training model ####### if verbose: print('### Training model ###') if train_args['mode'] == 'normal': mod.fit_advi_iterative(n=train_args['n_restarts'], method=train_args['method'], n_type=train_args['n_type'], progressbar=verbose) elif train_args['mode'] == 'tracking': mod.verbose = False mod.track_parameters(n=train_args['n_restarts'], every=train_args['tracking_every'], n_samples=train_args['tracking_n_samples'], n_type=train_args['n_type'], df_node_name1='nUMI_factors', df_node_df_name1='spot_factors_df', df_prior_node_name1='spot_fact_mu_hyp', df_prior_node_name2='spot_fact_sd_hyp', mu_node_name='mu', data_node='X_data', extra_df_parameters=['spot_add'], sample_type='post_sample_means') else: raise ValueError( "train_args['mode'] can be only 'normal' or 'tracking'") theano.config.compute_test_value = 'ignore' ####### Evaluate stability of training ####### if train_args['n_restarts'] > 1: mod.evaluate_stability( n_samples=posterior_args['n_samples'], align=posterior_args['evaluate_stability_align']) save_plot(fig_path, filename='evaluate_stability', extension=export_args['plot_extension']) plt.close() ####### Sampling posterior ####### if verbose: print('### Sampling posterior ### - time ' + str(np.around((time.time() - start) / 60, 2)) + ' min') mod.sample_posterior(node='all', n_samples=posterior_args['n_samples'], save_samples=False, mean_field_slot=posterior_args['mean_field_slot']) # evaluate predictive accuracy mod.compute_expected() ####### Export summarised posterior & Saving results ####### if verbose: print('### Saving results ###') # save W cell locations (cell density) # convert cell location parameters (cell density) to a dataframe mod.sample2df(node_name='spot_factors') # add cell location parameters to `sp_data.obs` sp_data = mod.annotate_spot_adata(sp_data) # save mod.spot_factors_df.to_csv(path + 'W_cell_density.csv') if export_args['export_q05']: mod.spot_factors_q05.to_csv(path + 'W_cell_density_q05.csv') # convert cell location parameters (mRNA count) to a dataframe, see help(mod.sample2df) for details mod.sample2df(node_name='nUMI_factors') # add cell location parameters to `sp_data.obs` sp_data = mod.annotate_spot_adata(sp_data) # save W cell locations (mRNA count) mod.spot_factors_df.to_csv(path + 'W_mRNA_count.csv') if export_args['export_q05']: mod.spot_factors_q05.to_csv(path + 'W_mRNA_count_q05.csv') # add posterior of all parameters to `sp_data.uns['mod']` mod.fact_filt = None sp_data = mod.export2adata(adata=sp_data, slot_name='mod') sp_data.uns['mod']['fact_names'] = list(sp_data.uns['mod']['fact_names']) sp_data.uns['mod']['var_names'] = list(sp_data.uns['mod']['var_names']) sp_data.uns['mod']['obs_names'] = list(sp_data.uns['mod']['obs_names']) # save spatial anndata with exported posterior sp_data.write(filename=path + 'sp.h5ad', compression='gzip') # save model object and related annotations if export_args['save_model']: # save the model and other objects res_dict = { 'mod': mod, 'sp_data': sp_data, 'sc_obs': obs, 'model_name': model_name, 'summ_sc_data_args': summ_sc_data_args, 'train_args': train_args, 'posterior_args': posterior_args, 'export_args': export_args, 'run_name': run_name, 'run_time': str(np.around((time.time() - start) / 60, 2)) + ' min' } pickle.dump(res_dict, file=open(path + 'model_.p', "wb")) else: # just save the settings res_dict = { 'sc_obs': obs, 'model_name': model_name, 'summ_sc_data_args': summ_sc_data_args, 'train_args': train_args, 'posterior_args': posterior_args, 'export_args': export_args, 'run_name': run_name, 'run_time': str(np.around((time.time() - start) / 60, 2)) + ' min' } pickle.dump(res_dict, file=open(path + 'model_.p', "wb")) ####### Plotting ####### if verbose: print('### Ploting results ###') # Show training history # if verbose: print(mod.plot_history(0)) else: mod.plot_history(0) save_plot(fig_path, filename='training_history_all', extension=export_args['plot_extension']) plt.close() if verbose: print(mod.plot_history(int(np.ceil(train_args['n_iter'] * 0.2)))) else: mod.plot_history(int(np.ceil(train_args['n_iter'] * 0.2))) save_plot(fig_path, filename='training_history_without_first_20perc', extension=export_args['plot_extension']) plt.close() # Predictive accuracy try: mod.plot_posterior_mu_vs_data() save_plot(fig_path, filename='data_vs_posterior_mean_Poisson_rate', extension=export_args['plot_extension']) plt.close() except Exception as e: print('Some error in plotting `mod.plot_posterior_mu_vs_data()`\n ' + str(e)) ####### Ploting posterior of W / cell locations ####### if verbose: print('### Ploting posterior of W / cell locations ###') data_samples = sp_data.obs[train_args['sample_name_col']].unique() cluster_plot_names = mod.spot_factors_df.columns cluster_plot_names = pd.Series( [i[17:] for i in mod.spot_factors_df.columns]) try: for i in data_samples: p = c2lpl.plot_factor_spatial( adata=sp_data, fact_ind=np.arange(mod.spot_factors_df.shape[1]), fact=mod.spot_factors_df, cluster_names=cluster_plot_names, n_columns=6, trans='log', sample_name=i, samples_col=train_args['sample_name_col'], obs_x='imagecol', obs_y='imagerow') p.save(filename=fig_path + 'cell_locations_W_mRNA_count_' + str(i) + '.' + export_args['plot_extension']) if export_args['export_q05']: p = c2lpl.plot_factor_spatial( adata=sp_data, fact_ind=np.arange(mod.spot_factors_q05.shape[1]), fact=mod.spot_factors_q05, cluster_names=cluster_plot_names, n_columns=6, trans='log', sample_name=i, samples_col=train_args['sample_name_col'], obs_x='imagecol', obs_y='imagerow') p.save(filename=fig_path + 'cell_locations_W_mRNA_count_q05_' + str(i) + '.' \ + export_args['plot_extension']) if verbose: print(p) except Exception as e: print('Some error in plotting `mod.plot_factor_spatial()`\n ' + str(e)) if verbose: print('### Done ### - time ' + res_dict['run_time']) if return_all: return res_dict else: del res_dict del mod gc.collect() return str((time.time() - start) / 60) + ' min'
def run_cell2location( sc_data, sp_data, model_name=None, verbose=True, show_locations=False, return_all=True, summ_sc_data_args={ "cluster_col": "annotation_1", "selection": "cluster_specificity", "selection_specificity": 0.07, }, train_args={ "n_iter": 20000, "learning_rate": 0.005, "sample_prior": False, "readable_var_name_col": None, "sample_name_col": None, }, model_kwargs={}, posterior_args={"n_samples": 1000}, export_args={ "path": "./results", "save_model": False, "save_spatial_plots": True, "run_name_suffix": "", "scanpy_coords_name": "coords", }, ignore_deprecation_exception: bool = False, ): r"""Run cell2location model pipeline: train the model, sample prior and posterior, export results and save diagnostic plots Briefly, cell2location is a Bayesian model, which estimates absolute cell density of cell types by decomposing mRNA counts :math:`d_{s,g}` of each gene :math:`g = {1, .., G}` at locations :math:`s = {1, .., S}` into a set of predefined reference signatures of cell types :math:`g_{f,g}`. The cell2location software comes with two implementations for this estimation step: 1) a statistical method based on Negative Binomial regression (see `cell2location.run_regression`); 2) hard-coded computation of per-cluster average mRNA counts for individual genes (provided anndata object to this function, see below). Approximate Variational Inference is used to estimate all model parameters, implemented in the pymc3 framework, which supports GPU acceleration. Anndata object with exported results, W weights representing cell abundance for each cell type, and the trained model object are saved to `export_args['path']` .. note:: This pipeline is not specific to a particular model so please look up the relevant parameters in model documentation (the default is LocationModelLinearDependentWMultiExperiment). Parameters ---------- sc_data : anndata object with single cell / nucleus data, or pd.DataFrame with genes in rows and cell type signatures (factors) in columns. sp_data : anndata object with spatial data, variable names should match sc_data model_name : model class object or None. If none, the default single-sample or multi-sample model is selected depending on the number of samples detected in in the sp_data.obs columns specified by `train_args['sample_name_col']`. verbose : boolean, print diagnostic messages? (Default value = True) show_locations : boolean, print the plot of factor locations? If False the plots are saved but not shown. This reduces notebook size. (Default value = False) return_all : boolean, return the model and annotated `sp_data`. If True, both are saved but not returned? (Default value = True) summ_sc_data_args : arguments for summarising single cell data * **cluster_col** - name of sc_data.obs column containing cluster annotations * **which_genes** - select intersect or union genes between single cell and spatial? (Default: intersect) * **selection** - select highly variable genes in sc_data? (Default: None - do not select), available options: * None * **cluster_specificity** use marker specificity to clusters to select most informative genes (see **selection_specificity** below) * **high_cv** use rank by coefficient of variation (variance / mean) * **cluster_markers** use cluster markers (derived using sc.tl.rank_genes_groups) * **AutoGeneS** use https://github.com/theislab/AutoGeneS * **select_n** - how many variable genes to select? Used only when selection is not None (Default: 5000). * **select_n_AutoGeneS** - how many variable genes tor select with AutoGeneS (lower than select_n)? (Default: 1000). * **selection_specificity** expression specificity cut-off (default=0.1), expression signatures of cell types are normalised to sum to 1 gene-wise, then a cutoff is applied. Decrease cutoff to include more genes. train_args : arguments for training methods. See help(c2l.LocationModelLinearDependentWMultiExperiment) for more details * **mode** - "normal" or "tracking" parameters? (Default: normal) * **use_raw** - extract data from RAW slot of anndata? Applies only to spatial data, not single cell reference. For reference sc_data.raw is always used. (Default: True) * **data_type** - use arrays with which precision? (Default: 'float32'). * **n_iter** - number of training iterations (Default: 20000). * **learning_rate** - ADAM optimiser learning rate (Default: 0.005). If the training is very unstable try reducing the learning rate. * **total_grad_norm_constraint** - ADAM optimiser total_grad_norm_constraint (Default: 200). Prevents exploding gradient problem, of the training is very unstable (jumps to Inf/NaN loss) try reducing this parameter. * **method** - which method to use to find posterior. Use default 'advi' unless you know what you are doing. * **sample_prior** - use sampling from the prior to evaluate how reasonable priors are for your data. (Default: False) This is not essential for 10X Visium data but can be very useful for troubleshooting when the model fails training with NaN or Inf likelihood / ELBO loss. It is essential to do this first to choose priors appropriate for any other technology. Sample from the prior should be in the same range as data, span similar number of orders of magnitude, and ideally be weakly informative: results/plots/evaluating_prior.png plot should looks like a very wide diagonal. Caution ------- Sampling the prior takes a lot of RAM (10s-100s of GB) depending on data size so if needed for troubleshooting - reduce the number of locations in needed. * **n_prior_samples** - number of prior samples. The more the better but also takes a lot of RAM. (Default: 10) * **n_restarts** - number of training restarts to evaluate stability (Default: 2) * **n_type** - type of restart training: 'restart' 'cv' 'bootstrap' (Default: restart) (see help(cell2location.LocationModel.fit_advi_iterative) for details) * **tracking_every**, **tracking_n_samples** - parameters for "tracking" mode: Posterior samples are saved after 'tracking_every' iterations (Default: 1000), the process is repeated 'tracking_n_samples' times (Default: 50) * **readable_var_name_col** - column in sp_data.var that contains readable gene ID (e.g. HGNC symbol) (Default: None) * **sample_name_col** - column in sp_data.obs that contains sample / slide ID (e.g. Visium section), for plotting W cell locations from a multi-sample object correctly (Default: None) * **fact_names** - optional list of factor names, by default taken from cluster names in sc_data. posterior_args : arguments for sampling posterior The model generates samples from the posterior distribution of each parameter to compute mean, SD, 5% and 95% quantiles - stored in `mod.samples` and exported in `adata.uns['mod']`. * **n_samples** - number of samples to generate (the more the better but you run out of the GPU memory if too many). (Default: 1000) * **evaluate_stability_align** - when generating the model stability plots, align different restarts? (Default: False) * **mean_field_slot** - Posterior of all parameters is sampled only from one training restart due to stability of the model. - which training restart to use? (Default: 'init_1') export_args : arguments for exporting results * **path** - file path where to save results (Default: "./results") * **plot_extension** - file extension of saved plots (Default: "png") * **scanpy_plot_vmax** - 'p99.2', 'scanpy_plot_size': 1.3 - scanpy.pl.spatial plottin settings * **save_model** - boolean, save trained model? Could be useful but also takes up 10s of GB of disk space. (Default: False) * **run_name_suffix** - optinal suffix to modify the name the run. (Default: '') * **export_q05** - boolean, save plots of 5% quantile of parameters. (Default: True) * **scanpy_coords_name** - sp_data.obsm entry that stores X and Y coordinates of each location. If None - no spatial plot is produced. * **img_key** - which image to use for scanpy plotting ('hires', 'lowres', None) model_kwargs : Keyword arguments for the model class. See the list of relevant arguments for CoLocationModelNB4V2. (Default value = {}) Returns ------- dict Results as a dictionary, use dict.keys() to find the elements. Results are saved to `export_args['path']`. """ if ignore_deprecation_exception is not True: raise Exception( "This workflow is deprecated. Please use pyro/scvi-tools interface. \n See tutorial: https://cell2location.readthedocs.io/en/latest/notebooks/cell2location_tutorial.html \n If you are absolutely sure that you need to use the pymc3 version set `ignore_deprecation_exception=True`" ) # set default parameters d_summ_sc_data_args = { "cluster_col": "annotation_1", "which_genes": "intersect", "selection": None, "select_n": 5000, "select_n_AutoGeneS": 1000, "selection_specificity": 0.1, "cluster_markers_kwargs": {}, } d_train_args = { "mode": "normal", "use_raw": True, "data_type": "float32", "n_iter": 20000, "learning_rate": 0.005, "total_grad_norm_constraint": 200, "method": "advi", "sample_prior": False, "n_prior_samples": 10, "n_restarts": 2, "n_type": "restart", "tracking_every": 1000, "tracking_n_samples": 50, "readable_var_name_col": None, "sample_name_col": None, "fact_names": None, "minibatch_size": None, } d_posterior_args = { "n_samples": 1000, "evaluate_stability_align": False, "mean_field_slot": "init_1" } d_export_args = { "path": "./results", "plot_extension": "png", "scanpy_plot_vmax": "p99.2", "scanpy_plot_size": 1.3, "save_model": False, "save_spatial_plots": True, "run_name_suffix": "", "export_q05": True, "export_mean": False, "scanpy_coords_name": "spatial", "img_key": "hires", } # replace defaults with parameters supplied for k in summ_sc_data_args.keys(): d_summ_sc_data_args[k] = summ_sc_data_args[k] summ_sc_data_args = d_summ_sc_data_args for k in train_args.keys(): d_train_args[k] = train_args[k] train_args = d_train_args for k in posterior_args.keys(): d_posterior_args[k] = posterior_args[k] posterior_args = d_posterior_args for k in export_args.keys(): d_export_args[k] = export_args[k] export_args = d_export_args theano.config.allow_gc = True # start timing start = time.time() if np.sum(sp_data.obs_names.duplicated()) > 0: raise ValueError( "Make sure location names are unique (`sp_data.obs_names`)") sp_data = sp_data.copy() ####### Summarising single cell clusters ####### if verbose: print("### Summarising single cell clusters ###") if not isinstance(sc_data, pd.DataFrame): # if scanpy compute cluster averages cell_state_df = compute_cluster_averages( sc_data, labels=summ_sc_data_args["cluster_col"]) obs = sc_data.obs else: # if dataframe assume signature with matching .index to sp_data.var_names cell_state_df = sc_data.copy() obs = cell_state_df # select features if requested if summ_sc_data_args["selection"] == "high_cv": cv = cell_state_df.var(1) / cell_state_df.mean(1) cv = cv.sort_values(ascending=False) cell_state_df = cell_state_df.loc[ cv[np.arange(summ_sc_data_args["select_n"])].index, :] elif summ_sc_data_args["selection"] == "cluster_markers": if isinstance(sc_data, pd.DataFrame): raise ValueError( 'summ_sc_data_args["selection"] = "cluster_markers" can only be used with type(sc_data) Anndata' ) sel_feat = select_features( sc_data, summ_sc_data_args["cluster_col"], n_features=summ_sc_data_args["select_n"], use_raw=train_args["use_raw"], sc_kwargs=summ_sc_data_args["cluster_markers_kwargs"], ) cell_state_df = cell_state_df.loc[ cell_state_df.index.isin(sel_feat), :] elif summ_sc_data_args["selection"] == "cluster_specificity": # selecting most informative genes based on specificity cell_state_df_norm = (cell_state_df.T / cell_state_df.sum(1)).T cell_state_df_norm = cell_state_df_norm > summ_sc_data_args[ "selection_specificity"] sel_feat = cell_state_df_norm.index[cell_state_df_norm.sum(1) > 0] cell_state_df = cell_state_df.loc[ cell_state_df.index.isin(sel_feat), :] elif summ_sc_data_args["selection"] == "AutoGeneS": cv = cell_state_df.var(1) / cell_state_df.mean(1) cv = cv.sort_values(ascending=False) cell_state_df = cell_state_df.loc[ cv[np.arange(summ_sc_data_args["select_n"])].index, :] from autogenes import AutoGenes ag = AutoGenes(cell_state_df.T) ag.run(ngen=5000, seed=0, nfeatures=summ_sc_data_args["select_n_AutoGeneS"], mode="fixed") pareto = ag.pareto print(ag.plot(size="large", weights=(1, -1))) # We then pick one solution and filter its corresponding marker genes: cell_state_df = cell_state_df[pareto[ len(pareto) - 1]] # select the solution with min correlation elif summ_sc_data_args["selection"] is not None: raise ValueError( "summ_sc_data_args['selection'] can be only None, high_cv, cluster_markers or AutoGeneS" ) # extract data as a dense matrix if train_args["use_raw"]: if scipy.sparse.issparse(sp_data.raw.X): X_data = np.array(sp_data.raw.X.toarray()) else: X_data = np.array(sp_data.raw.X) else: if scipy.sparse.issparse(sp_data.raw.X): X_data = np.array(sp_data.X.toarray()) else: X_data = np.array(sp_data.X) # Filter cell states and X_data to common genes sp_ind = sp_data.var_names.isin(cell_state_df.index) if np.sum(sp_ind) == 0: raise ValueError( "No overlapping genes found, check that `sc_data` and `sp_data` use the same variable names" ) X_data = X_data[:, sp_ind] cell_state_df = cell_state_df.loc[sp_data.var_names[sp_ind], :] # prepare cell state matrix cell_state_mat = cell_state_df.values # if factor names not provided use cluster names if train_args["fact_names"] is None: fact_names = cell_state_df.columns else: # useful for unobserved factor models fact_names = train_args["fact_names"] if train_args["sample_name_col"] is None: sp_data.obs["sample"] = "sample" train_args["sample_name_col"] = "sample" if train_args["readable_var_name_col"] is not None: readable_var_name_col = sp_data.var[ train_args["readable_var_name_col"]][sp_ind] else: readable_var_name_col = None if train_args["minibatch_size"] is not None: model_kwargs["minibatch_size"] = train_args["minibatch_size"] ####### Creating model ####### # choose model when not specified if model_name is None: sample_name = sp_data.obs[train_args["sample_name_col"]] sample_n = len(np.unique(sample_name)) if sample_n < sp_data.shape[0]: Model = models.pymc3.LocationModelLinearDependentWMultiExperimentLocationBackgroundNormGeneAlpha model_name = "LocationModelLinearDependentWMultiExperimentLocationBackgroundNormGeneAlpha" else: ValueError( "train_args['sample_name_col'] points to non-existing column or the number of samples(batches) is equal to the number of observations `adata.n_obs`" ) else: # use supplied class Model = model_name n_exper = len(sp_data.obs[train_args["sample_name_col"]].unique()) if n_exper == X_data.shape[0]: raise ValueError( "The number of samples is equal to the number of locations, aborting... (check 'sample_name_col')" ) if verbose: print("### Creating model ### - time " + str(np.around((time.time() - start) / 60, 2)) + " min") mod = Model( cell_state_mat, X_data, data_type=train_args["data_type"], n_iter=train_args["n_iter"], learning_rate=train_args["learning_rate"], total_grad_norm_constraint=train_args["total_grad_norm_constraint"], verbose=verbose, var_names=sp_data.var_names[sp_ind], var_names_read=readable_var_name_col, obs_names=sp_data.obs_names, fact_names=fact_names, sample_id=sp_data.obs[train_args["sample_name_col"]], **model_kwargs, ) ####### Print run name ####### run_name = f"{mod.__class__.__name__}_{n_exper}experiments_{mod.n_fact}clusters_{mod.n_obs}locations_{mod.n_var}genes{export_args['run_name_suffix']}" print("### Analysis name: " + run_name) # analysis name is always printed # create the export directory path = export_args["path"] + run_name + "/" if not os.path.exists(path): mkdir(path) fig_path = path + "plots/" if not os.path.exists(fig_path): mkdir(fig_path) ####### Sampling prior ####### if train_args["sample_prior"]: if verbose: print("### Sampling prior ###") mod.sample_prior(samples=train_args["n_prior_samples"]) # plot & save plot mod.plot_prior_vs_data() plt.tight_layout() save_plot(fig_path, filename="evaluating_prior", extension=export_args["plot_extension"]) plt.close() ####### Training model ####### if verbose: print("### Training model ###") if train_args["mode"] == "normal": mod.train(n=train_args["n_restarts"], method=train_args["method"], n_type=train_args["n_type"], progressbar=verbose) elif train_args["mode"] == "tracking": mod.verbose = False if isinstance(mod, Pymc3LocModel): mod.track_parameters( n=train_args["n_restarts"], every=train_args["tracking_every"], n_samples=train_args["tracking_n_samples"], n_type=train_args["n_type"], df_node_name1="nUMI_factors", df_node_df_name1="spot_factors_df", df_prior_node_name1="spot_fact_mu_hyp", df_prior_node_name2="spot_fact_sd_hyp", mu_node_name="mu", data_node="X_data", extra_df_parameters=["spot_add"], sample_type="post_sample_means", ) else: # TODO add isinstance(mod, PyroModel) mod.train(n=train_args["n_restarts"], method=train_args["method"], n_type=train_args["n_type"], tracking=True) else: raise ValueError( "train_args['mode'] can be only 'normal' or 'tracking'") theano.config.compute_test_value = "ignore" ####### Evaluate stability of training ####### if train_args["n_restarts"] > 1: n_plots = train_args["n_restarts"] - 1 ncol = int(np.min((n_plots, 3))) nrow = np.ceil(n_plots / ncol) plt.figure(figsize=(5 * nrow, 5 * ncol)) mod.evaluate_stability( n_samples=posterior_args["n_samples"], align=posterior_args["evaluate_stability_align"]) plt.tight_layout() save_plot(fig_path, filename="evaluate_stability", extension=export_args["plot_extension"]) plt.close() ####### Sampling posterior ####### if verbose: print("### Sampling posterior ### - time " + str(np.around((time.time() - start) / 60, 2)) + " min") mod.sample_posterior( node="all", n_samples=posterior_args["n_samples"], save_samples=False, mean_field_slot=posterior_args["mean_field_slot"], ) # evaluate predictive accuracy mod.compute_expected() ####### Export summarised posterior & Saving results ####### if verbose: print("### Saving results ###") # save W cell locations (cell density) # convert cell location parameters (cell density) to a dataframe mod.sample2df(node_name="spot_factors") # add cell location parameters to `sp_data.obs` sp_data = mod.annotate_spot_adata(sp_data) # save mod.spot_factors_df.to_csv(path + "W_cell_density.csv") if export_args["export_q05"]: mod.spot_factors_q05.to_csv(path + "W_cell_density_q05.csv") # convert cell location parameters (mRNA count) to a dataframe, see help(mod.sample2df) for details mod.sample2df(node_name="nUMI_factors") # add cell location parameters to `sp_data.obs` sp_data = mod.annotate_spot_adata(sp_data) # save W cell locations (mRNA count) mod.spot_factors_df.to_csv(path + "W_mRNA_count.csv") if export_args["export_q05"]: mod.spot_factors_q05.to_csv(path + "W_mRNA_count_q05.csv") # add posterior of all parameters to `sp_data.uns['mod']` mod.fact_filt = None sp_data = mod.export2adata(adata=sp_data, slot_name="mod") sp_data.uns["mod"]["fact_names"] = list(sp_data.uns["mod"]["fact_names"]) sp_data.uns["mod"]["var_names"] = list(sp_data.uns["mod"]["var_names"]) sp_data.uns["mod"]["obs_names"] = list(sp_data.uns["mod"]["obs_names"]) # save spatial anndata with exported posterior sp_data.write(filename=path + "sp.h5ad", compression="gzip") # save model object and related annotations if export_args["save_model"]: # save the model and other objects res_dict = { "mod": mod, "sp_data": sp_data, "sc_obs": obs, "model_name": model_name, "summ_sc_data_args": summ_sc_data_args, "train_args": train_args, "posterior_args": posterior_args, "export_args": export_args, "run_name": run_name, "run_time": str(np.around((time.time() - start) / 60, 2)) + " min", } pickle.dump(res_dict, file=open(path + "model_.p", "wb")) else: # just save the settings res_dict = { "sc_obs": obs, "model_name": model_name, "summ_sc_data_args": summ_sc_data_args, "train_args": train_args, "posterior_args": posterior_args, "export_args": export_args, "run_name": run_name, "run_time": str(np.around((time.time() - start) / 60, 2)) + " min", } pickle.dump(res_dict, file=open(path + "model_.p", "wb")) # add model for returning res_dict["mod"] = mod ####### Plotting ####### if verbose: print("### Ploting results ###") # Show training history # if verbose: mod.plot_history(0) else: mod.plot_history(0) plt.tight_layout() save_plot(fig_path, filename="training_history_all", extension=export_args["plot_extension"]) plt.close() if verbose: mod.plot_history(int(np.ceil(train_args["n_iter"] * 0.2))) else: mod.plot_history(int(np.ceil(train_args["n_iter"] * 0.2))) plt.tight_layout() save_plot(fig_path, filename="training_history_without_first_20perc", extension=export_args["plot_extension"]) plt.close() # Predictive accuracy try: mod.plot_posterior_mu_vs_data() plt.tight_layout() save_plot(fig_path, filename="data_vs_posterior_mean", extension=export_args["plot_extension"]) plt.close() except Exception as e: print("Some error in plotting `mod.plot_posterior_mu_vs_data()`\n " + str(e)) ####### Plotting posterior of W / cell locations ####### if verbose: print("### Plotting posterior of W / cell locations ###") data_samples = sp_data.obs[train_args["sample_name_col"]].unique() cluster_plot_names = pd.Series( [i[17:] for i in mod.spot_factors_df.columns]) if export_args["save_spatial_plots"]: try: for s in data_samples: # if slots needed to generate scanpy plots are present, scanpy: sc_spatial_present = np.any( np.isin(list(sp_data.uns.keys()), ["spatial"])) if sc_spatial_present: sc.settings.figdir = fig_path + "spatial/" os.makedirs(fig_path + "spatial/", exist_ok=True) with matplotlib.rc_context({"axes.facecolor": "black"}): s_ind = sp_data.obs[train_args["sample_name_col"]] == s s_keys = list(sp_data.uns["spatial"].keys()) s_spatial = np.array(s_keys)[[s in i for i in s_keys]][0] if export_args["export_mean"]: # Visualize cell type locations - mRNA_count = nUMI ##### # making copy to transform to log & assign nice names clust_names_orig = [ "mean_nUMI_factors" + i for i in sp_data.uns["mod"]["fact_names"] ] clust_names = sp_data.uns["mod"]["fact_names"] sp_data.obs[clust_names] = sp_data.obs[ clust_names_orig] fig = sc.pl.spatial( sp_data[s_ind, :], cmap="magma", color=clust_names, ncols=5, library_id=s_spatial, size=export_args["scanpy_plot_size"], img_key=export_args["img_key"], alpha_img=0, vmin=0, vmax=export_args["scanpy_plot_vmax"], return_fig=True, show=False, ) fig.savefig( f"{fig_path}/spatial/W_mRNA_count_mean_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches="tight", ) # fig.clear() plt.close() if show_locations: plt.show() fig = sc.pl.spatial( sp_data[s_ind, :], cmap="magma", color=clust_names, ncols=5, library_id=s_spatial, size=export_args["scanpy_plot_size"], img_key=export_args["img_key"], alpha_img=1, vmin=0, vmax=export_args["scanpy_plot_vmax"], show=False, return_fig=True, ) fig.savefig( f"{fig_path}/spatial/histo_W_mRNA_count_mean_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches="tight", ) # fig.clear() plt.close() # Visualize cell type locations ##### # making copy to transform to log & assign nice names clust_names_orig = [ "mean_spot_factors" + i for i in sp_data.uns["mod"]["fact_names"] ] clust_names = sp_data.uns["mod"]["fact_names"] sp_data.obs[clust_names] = sp_data.obs[ clust_names_orig] fig = sc.pl.spatial( sp_data[s_ind, :], cmap="magma", color=clust_names, ncols=5, library_id=s_spatial, size=export_args["scanpy_plot_size"], img_key=export_args["img_key"], alpha_img=0, vmin=0, vmax=export_args["scanpy_plot_vmax"], show=False, return_fig=True, ) fig.savefig( f"{fig_path}/spatial/W_cell_density_mean_{s}_{export_args['scanpy_plot_vmax']}." f"{export_args['plot_extension']}", bbox_inches="tight", ) # fig.clear() plt.close() fig = sc.pl.spatial( sp_data[s_ind, :], cmap="magma", color=clust_names, ncols=5, library_id=s_spatial, size=export_args["scanpy_plot_size"], img_key=export_args["img_key"], alpha_img=1, vmin=0, vmax=export_args["scanpy_plot_vmax"], show=False, return_fig=True, ) plt.savefig( f"{fig_path}/spatial/histo_W_cell_density_mean_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches="tight", ) # fig.clear() plt.close() if export_args["export_q05"]: # Visualize cell type locations - mRNA_count = nUMI ##### # making copy to transform to log & assign nice names clust_names_orig = [ "q05_nUMI_factors" + i for i in sp_data.uns["mod"]["fact_names"] ] clust_names = sp_data.uns["mod"]["fact_names"] sp_data.obs[clust_names] = sp_data.obs[ clust_names_orig] fig = sc.pl.spatial( sp_data[s_ind, :], cmap="magma", color=clust_names, ncols=5, library_id=s_spatial, size=export_args["scanpy_plot_size"], img_key=export_args["img_key"], alpha_img=0, vmin=0, vmax=export_args["scanpy_plot_vmax"], show=False, return_fig=True, ) plt.savefig( f"{fig_path}/spatial/W_mRNA_count_q05_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches="tight", ) # fig.clear() plt.close() fig = sc.pl.spatial( sp_data[s_ind, :], cmap="magma", color=clust_names, ncols=5, library_id=s_spatial, size=export_args["scanpy_plot_size"], img_key=export_args["img_key"], alpha_img=1, vmin=0, vmax=export_args["scanpy_plot_vmax"], show=False, return_fig=True, ) plt.savefig( f"{fig_path}/spatial/histo_W_mRNA_count_q05_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches="tight", ) # fig.clear() plt.close() # Visualize cell type locations ##### # making copy to transform to log & assign nice names clust_names_orig = [ "q05_spot_factors" + i for i in sp_data.uns["mod"]["fact_names"] ] clust_names = sp_data.uns["mod"]["fact_names"] sp_data.obs[clust_names] = sp_data.obs[ clust_names_orig] fig = sc.pl.spatial( sp_data[s_ind, :], cmap="magma", color=clust_names, ncols=5, library_id=s_spatial, size=export_args["scanpy_plot_size"], img_key=export_args["img_key"], alpha_img=0, vmin=0, vmax=export_args["scanpy_plot_vmax"], show=False, return_fig=True, ) plt.savefig( f"{fig_path}/spatial/W_cell_density_q05_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches="tight", ) # fig.clear() plt.close() fig = sc.pl.spatial( sp_data[s_ind, :], cmap="magma", color=clust_names, ncols=5, library_id=s_spatial, size=export_args["scanpy_plot_size"], img_key=export_args["img_key"], alpha_img=1, vmin=0, vmax=export_args["scanpy_plot_vmax"], show=False, return_fig=True, ) plt.savefig( f"{fig_path}/spatial/histo_W_cell_density_q05_{s}_{export_args['scanpy_plot_vmax']}" f".{export_args['plot_extension']}", bbox_inches="tight", ) # fig.clear() plt.close() else: # if coordinates exist plot if export_args["scanpy_coords_name"] is not None: # move spatial coordinates to obs for compatibility with our plotter sp_data.obs["imagecol"] = sp_data.obsm[ export_args["scanpy_coords_name"]][:, 0] sp_data.obs["imagerow"] = sp_data.obsm[ export_args["scanpy_coords_name"]][:, 1] if export_args["export_mean"]: p = c2lpl.plot_factor_spatial( adata=sp_data, fact_ind=np.arange( mod.spot_factors_df.shape[1]), fact=mod.spot_factors_df, cluster_names=cluster_plot_names, n_columns=6, trans="identity", sample_name=s, samples_col=train_args["sample_name_col"], obs_x="imagecol", obs_y="imagerow", ) p.save(filename=fig_path + "cell_locations_W_mRNA_count_" + str(s) + "." + export_args["plot_extension"]) if export_args["export_q05"]: p = c2lpl.plot_factor_spatial( adata=sp_data, fact_ind=np.arange( mod.spot_factors_q05.shape[1]), fact=mod.spot_factors_q05, cluster_names=cluster_plot_names, n_columns=6, trans="identity", sample_name=s, samples_col=train_args["sample_name_col"], obs_x="imagecol", obs_y="imagerow", ) p.save(filename=fig_path + "cell_locations_W_mRNA_count_q05_" + str(s) + "." + export_args["plot_extension"]) if show_locations: print(p) except Exception as e: print( "Some error in plotting with scanpy or `cell2location.plt.plot_factor_spatial()`\n " + repr(e)) if verbose: print("### Done ### - time " + str(np.around((time.time() - start) / 60, 2)) + " min") if return_all: return res_dict else: del res_dict del mod gc.collect() return str((time.time() - start) / 60) + " min"