def test_pyro_bayesian_regression_jit(): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(adata.shape[1], 1) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) trainer = Trainer(gpus=use_gpu, max_epochs=2, callbacks=[PyroJitGuideWarmup(train_dl)]) trainer.fit(plan, train_dl) # 100 features, 1 for sigma, 1 for bias assert list(model.guide.parameters())[0].shape[0] == 102 if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" }
def test_pyro_bayesian_regression(save_path): use_gpu = 0 adata = synthetic_iid() train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(adata.shape[1], 1) plan = PyroTrainingPlan(model) trainer = Trainer( gpus=use_gpu, max_epochs=2, ) trainer.fit(plan, train_dl) # test save and load post_dl = AnnDataLoader(adata, shuffle=False, batch_size=128) mean1 = [] with torch.no_grad(): for tensors in post_dl: args, kwargs = model._get_fn_args_from_batch(tensors) mean1.append(model(*args, **kwargs).cpu().numpy()) mean1 = np.concatenate(mean1) model_save_path = os.path.join(save_path, "model_params.pt") torch.save(model.state_dict(), model_save_path) pyro.clear_param_store() new_model = BayesianRegressionModule(adata.shape[1], 1) # run model one step to get autoguide params try: new_model.load_state_dict(torch.load(model_save_path)) except RuntimeError as err: if isinstance(new_model, PyroBaseModuleClass): plan = PyroTrainingPlan(new_model) trainer = Trainer( gpus=use_gpu, max_steps=1, ) trainer.fit(plan, train_dl) new_model.load_state_dict(torch.load(model_save_path)) else: raise err mean2 = [] with torch.no_grad(): for tensors in post_dl: args, kwargs = new_model._get_fn_args_from_batch(tensors) mean2.append(new_model(*args, **kwargs).cpu().numpy()) mean2 = np.concatenate(mean2) np.testing.assert_array_equal(mean1, mean2)
def test_pyro_bayesian_regression(save_path): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(adata.shape[1], 1) plan = PyroTrainingPlan(model) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( gpus=use_gpu, max_epochs=2, ) trainer.fit(plan, train_dl) if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" } # test save and load # cpu/gpu has minor difference model.cpu() quants = model.guide.quantiles([0.5]) sigma_median = quants["sigma"][0].detach().cpu().numpy() linear_median = quants["linear.weight"][0].detach().cpu().numpy() model_save_path = os.path.join(save_path, "model_params.pt") torch.save(model.state_dict(), model_save_path) pyro.clear_param_store() new_model = BayesianRegressionModule(adata.shape[1], 1) # run model one step to get autoguide params try: new_model.load_state_dict(torch.load(model_save_path)) except RuntimeError as err: if isinstance(new_model, PyroBaseModuleClass): plan = PyroTrainingPlan(new_model) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( gpus=use_gpu, max_steps=1, ) trainer.fit(plan, train_dl) new_model.load_state_dict(torch.load(model_save_path)) else: raise err quants = new_model.guide.quantiles([0.5]) sigma_median_new = quants["sigma"][0].detach().cpu().numpy() linear_median_new = quants["linear.weight"][0].detach().cpu().numpy() np.testing.assert_array_equal(sigma_median_new, sigma_median) np.testing.assert_array_equal(linear_median_new, linear_median)
def _posterior_quantile(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = True): """ Compute median of the posterior distribution of each parameter pyro models trained without amortised inference. Parameters ---------- q quantile to compute use_gpu Bool, use gpu? Returns ------- dictionary {variable_name: posterior median} """ self.module.eval() gpus, device = parse_use_gpu_arg(use_gpu) train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size) # sample global parameters tensor_dict = next(iter(train_dl)) args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) means = self.module.guide.quantiles([q], *args, **kwargs) means = {k: means[k].cpu().detach().numpy() for k in means.keys()} return means
def generative(self, adata=None, indices=None, use_mean=True): """ Generate new samples from input data (encode-decode). Parameters ---------- adata scanpy single-cell dataset indices indices of the subset of cells to be encoded use_mean whether to use the mean of the multivariate gaussian or samples """ if self.is_trained_ is False: raise RuntimeError("Please train the model first.") if not adata: adata = self.adata sc_dl = AnnDataLoader(adata, indices=indices, batch_size=128) samples = [] for tensors in sc_dl: input_encode = self._get_inference_input(tensors) z, mu, logvar = self.encode(**input_encode) gen_input = mu if use_mean else z input_decode = self._get_generative_input(tensors, gen_input) x_rec = self.decode(**input_decode) samples += [x_rec.cpu()] return np.array(torch.cat(samples))
def to_latent(self, adata=None, indices=None, return_mean=False): """ Project data into latent space. Inspired by SCVI. Parameters ---------- adata scanpy single-cell dataset indices indices of the subset of cells to be encoded return_mean whether to use the mean of the multivariate gaussian or samples """ if self.is_trained_ is False: raise RuntimeError("Please train the model first.") if not adata: adata = self.adata sc_dl = AnnDataLoader(adata, indices=indices, batch_size=128) latent = [] for tensors in sc_dl: input_encode = self._get_inference_input(tensors) z, mu, logvar = self.encode(**input_encode) if return_mean: latent += [mu.cpu()] else: latent += [z.cpu()] return np.array(torch.cat(latent))
def test_ann_dataloader(): a = scvi.data.synthetic_iid() # test that batch sampler drops the last batch if it has less than 3 cells assert a.n_obs == 400 adl = AnnDataLoader(a, batch_size=397, drop_last=3) assert len(adl) == 2 for i, x in enumerate(adl): pass assert i == 1 adl = AnnDataLoader(a, batch_size=398, drop_last=3) assert len(adl) == 1 for i, x in enumerate(adl): pass assert i == 0 with pytest.raises(ValueError): AnnDataLoader(a, batch_size=1, drop_last=2)
def test_pyro_bayesian_regression_jit(): use_gpu = 0 adata = synthetic_iid() train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(adata.shape[1], 1) # warmup guide for JIT for tensors in train_dl: args, kwargs = model._get_fn_args_from_batch(tensors) model.guide(*args, **kwargs) break train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) trainer = Trainer( gpus=use_gpu, max_epochs=2, ) trainer.fit(plan, train_dl)
def test_pyro_bayesian_regression_jit(): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() # add index for each cell (provided to pyro plate for correct minibatching) adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") register_tensor_from_anndata( adata, registry_key="ind_x", adata_attr_name="obs", adata_key_name="_indices", ) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) plan.n_obs_training = len(train_dl.indices) trainer = Trainer(gpus=use_gpu, max_epochs=2, callbacks=[PyroJitGuideWarmup(train_dl)]) trainer.fit(plan, train_dl) # 100 features assert list(model.guide.state_dict() ["locs.linear.weight_unconstrained"].shape) == [ 1, 100, ] # 1 bias assert list( model.guide.state_dict()["locs.linear.bias_unconstrained"].shape) == [ 1, ] if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" }
def _posterior_quantile(self, q: float = 0.5, batch_size: int = None, use_gpu: bool = None, use_median: bool = False): """ Compute median of the posterior distribution of each parameter pyro models trained without amortised inference. Parameters ---------- q Quantile to compute use_gpu Bool, use gpu? use_median Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- dictionary {variable_name: posterior quantile} """ self.module.eval() gpus, device = parse_use_gpu_arg(use_gpu) if batch_size is None: batch_size = self.adata_manager.adata.n_obs train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size) # sample global parameters tensor_dict = next(iter(train_dl)) args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if use_median and q == 0.5: means = self.module.guide.median(*args, **kwargs) else: means = self.module.guide.quantiles([q], *args, **kwargs) means = {k: means[k].cpu().detach().numpy() for k in means.keys()} return means
def test_pyro_bayesian_regression_jit(): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() adata_manager = _create_indices_adata_manager(adata) train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) plan.n_obs_training = len(train_dl.indices) trainer = Trainer(gpus=use_gpu, max_epochs=2, callbacks=[PyroJitGuideWarmup(train_dl)]) trainer.fit(plan, train_dl) # 100 features assert list(model.guide.state_dict() ["locs.linear.weight_unconstrained"].shape) == [ 1, 100, ] # 1 bias assert list( model.guide.state_dict()["locs.linear.bias_unconstrained"].shape) == [ 1, ] if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" }
def _posterior_quantile_minibatch(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None, use_median: bool = False): """ Compute median of the posterior distribution of each parameter, separating local (minibatch) variable and global variables, which is necessary when performing amortised inference. Note for developers: requires model class method which lists observation/minibatch plate variables (self.module.model.list_obs_plate_vars()). Parameters ---------- q quantile to compute batch_size number of observations per batch use_gpu Bool, use gpu? use_median Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- dictionary {variable_name: posterior quantile} """ gpus, device = parse_use_gpu_arg(use_gpu) self.module.eval() train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size) # sample local parameters i = 0 for tensor_dict in train_dl: args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if i == 0: # find plate sites obs_plate_sites = self._get_obs_plate_sites( args, kwargs, return_observed=True) if len(obs_plate_sites) == 0: # if no local variables - don't sample break # find plate dimension obs_plate_dim = list(obs_plate_sites.values())[0] if use_median and q == 0.5: means = self.module.guide.median(*args, **kwargs) else: means = self.module.guide.quantiles([q], *args, **kwargs) means = { k: means[k].cpu().numpy() for k in means.keys() if k in obs_plate_sites } else: if use_median and q == 0.5: means_ = self.module.guide.median(*args, **kwargs) else: means_ = self.module.guide.quantiles([q], *args, **kwargs) means_ = { k: means_[k].cpu().numpy() for k in means_.keys() if k in obs_plate_sites } means = { k: np.concatenate([means[k], means_[k]], axis=obs_plate_dim) for k in means.keys() } i += 1 # sample global parameters tensor_dict = next(iter(train_dl)) args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if use_median and q == 0.5: global_means = self.module.guide.median(*args, **kwargs) else: global_means = self.module.guide.quantiles([q], *args, **kwargs) global_means = { k: global_means[k].cpu().numpy() for k in global_means.keys() if k not in obs_plate_sites } for k in global_means.keys(): means[k] = global_means[k] self.module.to(device) return means
def _posterior_samples_minibatch( self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs ): """ Generate samples of the posterior distribution in minibatches. Generate samples of the posterior distribution of each parameter, separating local (minibatch) variables and global variables, which is necessary when performing minibatch inference. Parameters ---------- use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. Returns ------- dictionary {variable_name: [array with samples in 0 dimension]} """ samples = dict() _, device = parse_use_gpu_arg(use_gpu) batch_size = batch_size if batch_size is not None else settings.batch_size train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size) # sample local parameters i = 0 for tensor_dict in track( train_dl, style="tqdm", description="Sampling local variables, batch: ", ): args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if i == 0: return_observed = getattr(sample_kwargs, "return_observed", False) obs_plate_sites = self._get_obs_plate_sites( args, kwargs, return_observed=return_observed ) if len(obs_plate_sites) == 0: # if no local variables - don't sample break obs_plate_dim = list(obs_plate_sites.values())[0] sample_kwargs_obs_plate = sample_kwargs.copy() sample_kwargs_obs_plate[ "return_sites" ] = self._get_obs_plate_return_sites( sample_kwargs["return_sites"], list(obs_plate_sites.keys()) ) sample_kwargs_obs_plate["show_progress"] = False samples = self._get_posterior_samples( args, kwargs, **sample_kwargs_obs_plate ) else: samples_ = self._get_posterior_samples( args, kwargs, **sample_kwargs_obs_plate ) samples = { k: np.array( [ np.concatenate( [samples[k][j], samples_[k][j]], axis=obs_plate_dim, ) for j in range( len(samples[k]) ) # for each sample (in 0 dimension ] ) for k in samples.keys() # for each variable } i += 1 # sample global parameters global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) global_samples = { k: v for k, v in global_samples.items() if k not in list(obs_plate_sites.keys()) } for k in global_samples.keys(): samples[k] = global_samples[k] self.module.to(device) return samples
def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = True): """ Compute median of the posterior distribution of each parameter, separating local (minibatch) variable and global variables, which is necessary when performing amortised inference. Note for developers: requires model class method which lists observation/minibatch plate variables (self.module.model.list_obs_plate_vars()). Parameters ---------- q quantile to compute batch_size number of observations per batch use_gpu Bool, use gpu? Returns ------- dictionary {variable_name: posterior median} """ gpus, device = parse_use_gpu_arg(use_gpu) self.module.eval() train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size) # sample local parameters i = 0 for tensor_dict in train_dl: args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if i == 0: means = self.module.guide.quantiles([q], *args, **kwargs) means = { k: means[k].cpu().numpy() for k in means.keys() if k in self.module.model.list_obs_plate_vars()["sites"] } # find plate dimension trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) # print(trace.nodes[self.module.model.list_obs_plate_vars()['name']]) obs_plate = { name: site["cond_indep_stack"][0].dim for name, site in trace.nodes.items() if site["type"] == "sample" if any(f.name == self.module.model.list_obs_plate_vars()["name"] for f in site["cond_indep_stack"]) } else: means_ = self.module.guide.quantiles([q], *args, **kwargs) means_ = { k: means_[k].cpu().numpy() for k in means_.keys() if k in list(self.module.model.list_obs_plate_vars()["sites"].keys()) } means = { k: np.concatenate([means[k], means_[k]], axis=list(obs_plate.values())[0]) for k in means.keys() } i += 1 # sample global parameters tensor_dict = next(iter(train_dl)) args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) global_means = self.module.guide.quantiles([q], *args, **kwargs) global_means = { k: global_means[k].cpu().numpy() for k in global_means.keys() if k not in list(self.module.model.list_obs_plate_vars()["sites"].keys()) } for k in global_means.keys(): means[k] = global_means[k] self.module.to(device) return means
def test_cell2location(): save_path = "./cell2location_model_test" if torch.cuda.is_available(): use_gpu = int(torch.cuda.is_available()) else: use_gpu = False dataset = synthetic_iid(n_labels=5) RegressionModel.setup_anndata(dataset, labels_key="labels", batch_key="batch") # train regression model to get signatures of cell types sc_model = RegressionModel(dataset) # test full data training sc_model.train(max_epochs=1, use_gpu=use_gpu) # test minibatch training sc_model.train(max_epochs=1, batch_size=1000, use_gpu=use_gpu) # export the estimated cell abundance (summary of the posterior distribution) dataset = sc_model.export_posterior(dataset, sample_kwargs={"num_samples": 10}) # test plot_QC sc_model.plot_QC() # test save/load sc_model.save(save_path, overwrite=True, save_anndata=True) sc_model = RegressionModel.load(save_path) # export estimated expression in each cluster if "means_per_cluster_mu_fg" in dataset.varm.keys(): inf_aver = dataset.varm["means_per_cluster_mu_fg"][[ f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"] ]].copy() else: inf_aver = dataset.var[[ f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"] ]].copy() inf_aver.columns = dataset.uns["mod"]["factor_names"] ### test default cell2location model ### Cell2location.setup_anndata(dataset, batch_key="batch") ## full data ## st_model = Cell2location(dataset, cell_state_df=inf_aver, N_cells_per_location=30, detection_alpha=200) # test full data training st_model.train(max_epochs=1, use_gpu=use_gpu) # export the estimated cell abundance (summary of the posterior distribution) # full data dataset = st_model.export_posterior(dataset, sample_kwargs={ "num_samples": 10, "batch_size": st_model.adata.n_obs }) ## minibatches of locations ## Cell2location.setup_anndata(dataset, batch_key="batch") st_model = Cell2location(dataset, cell_state_df=inf_aver, N_cells_per_location=30, detection_alpha=200) # test minibatch training st_model.train(max_epochs=1, batch_size=50, use_gpu=use_gpu) # export the estimated cell abundance (summary of the posterior distribution) # minibatches of locations dataset = st_model.export_posterior(dataset, sample_kwargs={ "num_samples": 10, "batch_size": 50 }) # test plot_QC st_model.plot_QC() # test save/load st_model.save(save_path, overwrite=True, save_anndata=True) st_model = Cell2location.load(save_path) # export the estimated cell abundance (summary of the posterior distribution) # minibatches of locations dataset = st_model.export_posterior(dataset, sample_kwargs={ "num_samples": 10, "batch_size": 50 }) # test computing any quantile of the posterior distribution if not isinstance(st_model.module.guide, poutine.messenger.Messenger): st_model.posterior_quantile(q=0.5, use_gpu=use_gpu) # test computing median if True: if use_gpu: device = f"cuda:{use_gpu}" else: device = "cpu" train_dl = AnnDataLoader(st_model.adata_manager, shuffle=False, batch_size=50) for batch in train_dl: batch = {k: v.to(device) for k, v in batch.items()} args, kwargs = st_model.module._get_fn_args_from_batch(batch) break st_model.module.guide.median(*args, **kwargs) # test computing expected expression per cell type st_model.module.model.compute_expected_per_cell_type( st_model.samples["post_sample_q05"], st_model.adata_manager) ### test amortised inference with default cell2location model ### ## full data ## Cell2location.setup_anndata(dataset, batch_key="batch") st_model = Cell2location( dataset, cell_state_df=inf_aver, N_cells_per_location=30, detection_alpha=200, amortised=True, encoder_mode="multiple", ) # test minibatch training st_model.train(max_epochs=1, batch_size=20, use_gpu=use_gpu) st_model.train_aggressive(max_epochs=3, batch_size=20, plan_kwargs={ "n_aggressive_epochs": 1, "n_aggressive_steps": 5 }, use_gpu=use_gpu) # test computing median if True: if use_gpu: device = f"cuda:{use_gpu}" else: device = "cpu" train_dl = AnnDataLoader(st_model.adata_manager, shuffle=False, batch_size=50) for batch in train_dl: batch = {k: v.to(device) for k, v in batch.items()} args, kwargs = st_model.module._get_fn_args_from_batch(batch) break st_model.module.guide.median(*args, **kwargs) st_model.module.guide.quantiles([0.5], *args, **kwargs) st_model.module.guide.mutual_information(*args, **kwargs) # export the estimated cell abundance (summary of the posterior distribution) # minibatches of locations dataset = st_model.export_posterior(dataset, sample_kwargs={ "num_samples": 10, "batch_size": 50 }) ### test downstream analysis ### _, _ = run_colocation( dataset, model_name="CoLocatedGroupsSklearnNMF", train_args={ "n_fact": np.arange( 3, 4 ), # IMPORTANT: use a wider range of the number of factors (5-30) "sample_name_col": "batch", # columns in adata_vis.obs that identifies sample "n_restarts": 2, # number of training restarts }, export_args={"path": f"{save_path}/CoLocatedComb/"}, ) ### test simplified cell2location models ### ## no m_g ## Cell2location.setup_anndata(dataset, batch_key="batch") st_model = Cell2location( dataset, cell_state_df=inf_aver, N_cells_per_location=30, detection_alpha=200, model_class= LocationModelMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel, ) # test full data training st_model.train(max_epochs=1, use_gpu=use_gpu) # export the estimated cell abundance (summary of the posterior distribution) # full data dataset = st_model.export_posterior(dataset, sample_kwargs={ "num_samples": 10, "batch_size": st_model.adata.n_obs }) ## no w_sf factorisation ## Cell2location.setup_anndata(dataset, batch_key="batch") st_model = Cell2location( dataset, cell_state_df=inf_aver, N_cells_per_location=30, detection_alpha=200, model_class= LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelNoMGPyroModel, ) # test full data training st_model.train(max_epochs=1, use_gpu=use_gpu) # export the estimated cell abundance (summary of the posterior distribution) # full data st_model.export_posterior(dataset, sample_kwargs={ "num_samples": 10, "batch_size": st_model.adata.n_obs })