Esempio n. 1
0
def test_pyro_bayesian_regression_jit():
    use_gpu = int(torch.cuda.is_available())
    adata = synthetic_iid()
    train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128)
    pyro.clear_param_store()
    model = BayesianRegressionModule(adata.shape[1], 1)
    train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128)
    plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO())
    trainer = Trainer(gpus=use_gpu,
                      max_epochs=2,
                      callbacks=[PyroJitGuideWarmup(train_dl)])
    trainer.fit(plan, train_dl)

    # 100 features, 1 for sigma, 1 for bias
    assert list(model.guide.parameters())[0].shape[0] == 102

    if use_gpu == 1:
        model.cuda()

    # test Predictive
    num_samples = 5
    predictive = model.create_predictive(num_samples=num_samples)
    for tensor_dict in train_dl:
        args, kwargs = model._get_fn_args_from_batch(tensor_dict)
        _ = {
            k: v.detach().cpu().numpy()
            for k, v in predictive(*args, **kwargs).items() if k != "obs"
        }
Esempio n. 2
0
def test_pyro_bayesian_regression(save_path):
    use_gpu = 0
    adata = synthetic_iid()
    train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128)
    pyro.clear_param_store()
    model = BayesianRegressionModule(adata.shape[1], 1)
    plan = PyroTrainingPlan(model)
    trainer = Trainer(
        gpus=use_gpu,
        max_epochs=2,
    )
    trainer.fit(plan, train_dl)

    # test save and load
    post_dl = AnnDataLoader(adata, shuffle=False, batch_size=128)
    mean1 = []
    with torch.no_grad():
        for tensors in post_dl:
            args, kwargs = model._get_fn_args_from_batch(tensors)
            mean1.append(model(*args, **kwargs).cpu().numpy())
    mean1 = np.concatenate(mean1)

    model_save_path = os.path.join(save_path, "model_params.pt")
    torch.save(model.state_dict(), model_save_path)

    pyro.clear_param_store()
    new_model = BayesianRegressionModule(adata.shape[1], 1)
    # run model one step to get autoguide params
    try:
        new_model.load_state_dict(torch.load(model_save_path))
    except RuntimeError as err:
        if isinstance(new_model, PyroBaseModuleClass):
            plan = PyroTrainingPlan(new_model)
            trainer = Trainer(
                gpus=use_gpu,
                max_steps=1,
            )
            trainer.fit(plan, train_dl)
            new_model.load_state_dict(torch.load(model_save_path))
        else:
            raise err

    mean2 = []
    with torch.no_grad():
        for tensors in post_dl:
            args, kwargs = new_model._get_fn_args_from_batch(tensors)
            mean2.append(new_model(*args, **kwargs).cpu().numpy())
    mean2 = np.concatenate(mean2)

    np.testing.assert_array_equal(mean1, mean2)
Esempio n. 3
0
def test_pyro_bayesian_regression(save_path):
    use_gpu = int(torch.cuda.is_available())
    adata = synthetic_iid()
    train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128)
    pyro.clear_param_store()
    model = BayesianRegressionModule(adata.shape[1], 1)
    plan = PyroTrainingPlan(model)
    plan.n_obs_training = len(train_dl.indices)
    trainer = Trainer(
        gpus=use_gpu,
        max_epochs=2,
    )
    trainer.fit(plan, train_dl)
    if use_gpu == 1:
        model.cuda()

    # test Predictive
    num_samples = 5
    predictive = model.create_predictive(num_samples=num_samples)
    for tensor_dict in train_dl:
        args, kwargs = model._get_fn_args_from_batch(tensor_dict)
        _ = {
            k: v.detach().cpu().numpy()
            for k, v in predictive(*args, **kwargs).items()
            if k != "obs"
        }
    # test save and load
    # cpu/gpu has minor difference
    model.cpu()
    quants = model.guide.quantiles([0.5])
    sigma_median = quants["sigma"][0].detach().cpu().numpy()
    linear_median = quants["linear.weight"][0].detach().cpu().numpy()

    model_save_path = os.path.join(save_path, "model_params.pt")
    torch.save(model.state_dict(), model_save_path)

    pyro.clear_param_store()
    new_model = BayesianRegressionModule(adata.shape[1], 1)
    # run model one step to get autoguide params
    try:
        new_model.load_state_dict(torch.load(model_save_path))
    except RuntimeError as err:
        if isinstance(new_model, PyroBaseModuleClass):
            plan = PyroTrainingPlan(new_model)
            plan.n_obs_training = len(train_dl.indices)
            trainer = Trainer(
                gpus=use_gpu,
                max_steps=1,
            )
            trainer.fit(plan, train_dl)
            new_model.load_state_dict(torch.load(model_save_path))
        else:
            raise err

    quants = new_model.guide.quantiles([0.5])
    sigma_median_new = quants["sigma"][0].detach().cpu().numpy()
    linear_median_new = quants["linear.weight"][0].detach().cpu().numpy()

    np.testing.assert_array_equal(sigma_median_new, sigma_median)
    np.testing.assert_array_equal(linear_median_new, linear_median)
Esempio n. 4
0
    def _posterior_quantile(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = True):
        """
        Compute median of the posterior distribution of each parameter pyro models trained without amortised inference.

        Parameters
        ----------
        q
            quantile to compute
        use_gpu
            Bool, use gpu?

        Returns
        -------
        dictionary {variable_name: posterior median}

        """

        self.module.eval()
        gpus, device = parse_use_gpu_arg(use_gpu)

        train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size)
        # sample global parameters
        tensor_dict = next(iter(train_dl))
        args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
        args = [a.to(device) for a in args]
        kwargs = {k: v.to(device) for k, v in kwargs.items()}
        self.to_device(device)

        means = self.module.guide.quantiles([q], *args, **kwargs)
        means = {k: means[k].cpu().detach().numpy() for k in means.keys()}

        return means
Esempio n. 5
0
    def generative(self, adata=None, indices=None, use_mean=True):
        """
        Generate new samples from input data (encode-decode).

        Parameters
        ----------
        adata
            scanpy single-cell dataset
        indices
            indices of the subset of cells to be encoded
        use_mean
            whether to use the mean of the multivariate gaussian or samples
        """
        if self.is_trained_ is False:
            raise RuntimeError("Please train the model first.")
        if not adata:
            adata = self.adata
        sc_dl = AnnDataLoader(adata, indices=indices, batch_size=128)
        samples = []
        for tensors in sc_dl:
            input_encode = self._get_inference_input(tensors)
            z, mu, logvar = self.encode(**input_encode)
            gen_input = mu if use_mean else z
            input_decode = self._get_generative_input(tensors, gen_input)
            x_rec = self.decode(**input_decode)
            samples += [x_rec.cpu()]
        return np.array(torch.cat(samples))
Esempio n. 6
0
    def to_latent(self, adata=None, indices=None, return_mean=False):
        """
        Project data into latent space. Inspired by SCVI.
        
        Parameters
        ----------
        adata
            scanpy single-cell dataset
        indices
            indices of the subset of cells to be encoded
        return_mean
            whether to use the mean of the multivariate gaussian or samples
        """
        if self.is_trained_ is False:
            raise RuntimeError("Please train the model first.")
        if not adata:
            adata = self.adata
        sc_dl = AnnDataLoader(adata, indices=indices, batch_size=128)
        latent = []
        for tensors in sc_dl:
            input_encode = self._get_inference_input(tensors)
            z, mu, logvar = self.encode(**input_encode)
            if return_mean:
                latent += [mu.cpu()]
            else:
                latent += [z.cpu()]

        return np.array(torch.cat(latent))
Esempio n. 7
0
def test_ann_dataloader():
    a = scvi.data.synthetic_iid()

    # test that batch sampler drops the last batch if it has less than 3 cells
    assert a.n_obs == 400
    adl = AnnDataLoader(a, batch_size=397, drop_last=3)
    assert len(adl) == 2
    for i, x in enumerate(adl):
        pass
    assert i == 1
    adl = AnnDataLoader(a, batch_size=398, drop_last=3)
    assert len(adl) == 1
    for i, x in enumerate(adl):
        pass
    assert i == 0
    with pytest.raises(ValueError):
        AnnDataLoader(a, batch_size=1, drop_last=2)
Esempio n. 8
0
def test_pyro_bayesian_regression_jit():
    use_gpu = 0
    adata = synthetic_iid()
    train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128)
    pyro.clear_param_store()
    model = BayesianRegressionModule(adata.shape[1], 1)
    # warmup guide for JIT
    for tensors in train_dl:
        args, kwargs = model._get_fn_args_from_batch(tensors)
        model.guide(*args, **kwargs)
        break
    train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128)
    plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO())
    trainer = Trainer(
        gpus=use_gpu,
        max_epochs=2,
    )
    trainer.fit(plan, train_dl)
Esempio n. 9
0
def test_pyro_bayesian_regression_jit():
    use_gpu = int(torch.cuda.is_available())
    adata = synthetic_iid()
    # add index for each cell (provided to pyro plate for correct minibatching)
    adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64")
    register_tensor_from_anndata(
        adata,
        registry_key="ind_x",
        adata_attr_name="obs",
        adata_key_name="_indices",
    )
    train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128)
    pyro.clear_param_store()
    model = BayesianRegressionModule(in_features=adata.shape[1],
                                     out_features=1)
    plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO())
    plan.n_obs_training = len(train_dl.indices)
    trainer = Trainer(gpus=use_gpu,
                      max_epochs=2,
                      callbacks=[PyroJitGuideWarmup(train_dl)])
    trainer.fit(plan, train_dl)

    # 100 features
    assert list(model.guide.state_dict()
                ["locs.linear.weight_unconstrained"].shape) == [
                    1,
                    100,
                ]
    # 1 bias
    assert list(
        model.guide.state_dict()["locs.linear.bias_unconstrained"].shape) == [
            1,
        ]

    if use_gpu == 1:
        model.cuda()

    # test Predictive
    num_samples = 5
    predictive = model.create_predictive(num_samples=num_samples)
    for tensor_dict in train_dl:
        args, kwargs = model._get_fn_args_from_batch(tensor_dict)
        _ = {
            k: v.detach().cpu().numpy()
            for k, v in predictive(*args, **kwargs).items() if k != "obs"
        }
Esempio n. 10
0
    def _posterior_quantile(self,
                            q: float = 0.5,
                            batch_size: int = None,
                            use_gpu: bool = None,
                            use_median: bool = False):
        """
        Compute median of the posterior distribution of each parameter pyro models trained without amortised inference.

        Parameters
        ----------
        q
            Quantile to compute
        use_gpu
            Bool, use gpu?
        use_median
            Bool, when q=0.5 use median rather than quantile method of the guide

        Returns
        -------
        dictionary {variable_name: posterior quantile}

        """

        self.module.eval()
        gpus, device = parse_use_gpu_arg(use_gpu)
        if batch_size is None:
            batch_size = self.adata_manager.adata.n_obs
        train_dl = AnnDataLoader(self.adata_manager,
                                 shuffle=False,
                                 batch_size=batch_size)
        # sample global parameters
        tensor_dict = next(iter(train_dl))
        args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
        args = [a.to(device) for a in args]
        kwargs = {k: v.to(device) for k, v in kwargs.items()}
        self.to_device(device)

        if use_median and q == 0.5:
            means = self.module.guide.median(*args, **kwargs)
        else:
            means = self.module.guide.quantiles([q], *args, **kwargs)
        means = {k: means[k].cpu().detach().numpy() for k in means.keys()}

        return means
Esempio n. 11
0
def test_pyro_bayesian_regression_jit():
    use_gpu = int(torch.cuda.is_available())
    adata = synthetic_iid()
    adata_manager = _create_indices_adata_manager(adata)
    train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128)
    pyro.clear_param_store()
    model = BayesianRegressionModule(in_features=adata.shape[1],
                                     out_features=1)
    plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO())
    plan.n_obs_training = len(train_dl.indices)
    trainer = Trainer(gpus=use_gpu,
                      max_epochs=2,
                      callbacks=[PyroJitGuideWarmup(train_dl)])
    trainer.fit(plan, train_dl)

    # 100 features
    assert list(model.guide.state_dict()
                ["locs.linear.weight_unconstrained"].shape) == [
                    1,
                    100,
                ]
    # 1 bias
    assert list(
        model.guide.state_dict()["locs.linear.bias_unconstrained"].shape) == [
            1,
        ]

    if use_gpu == 1:
        model.cuda()

    # test Predictive
    num_samples = 5
    predictive = model.create_predictive(num_samples=num_samples)
    for tensor_dict in train_dl:
        args, kwargs = model._get_fn_args_from_batch(tensor_dict)
        _ = {
            k: v.detach().cpu().numpy()
            for k, v in predictive(*args, **kwargs).items() if k != "obs"
        }
Esempio n. 12
0
    def _posterior_quantile_minibatch(self,
                                      q: float = 0.5,
                                      batch_size: int = 2048,
                                      use_gpu: bool = None,
                                      use_median: bool = False):
        """
        Compute median of the posterior distribution of each parameter, separating local (minibatch) variable
        and global variables, which is necessary when performing amortised inference.

        Note for developers: requires model class method which lists observation/minibatch plate
        variables (self.module.model.list_obs_plate_vars()).

        Parameters
        ----------
        q
            quantile to compute
        batch_size
            number of observations per batch
        use_gpu
            Bool, use gpu?
        use_median
            Bool, when q=0.5 use median rather than quantile method of the guide

        Returns
        -------
        dictionary {variable_name: posterior quantile}

        """

        gpus, device = parse_use_gpu_arg(use_gpu)

        self.module.eval()

        train_dl = AnnDataLoader(self.adata_manager,
                                 shuffle=False,
                                 batch_size=batch_size)

        # sample local parameters
        i = 0
        for tensor_dict in train_dl:

            args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
            args = [a.to(device) for a in args]
            kwargs = {k: v.to(device) for k, v in kwargs.items()}
            self.to_device(device)

            if i == 0:
                # find plate sites
                obs_plate_sites = self._get_obs_plate_sites(
                    args, kwargs, return_observed=True)
                if len(obs_plate_sites) == 0:
                    # if no local variables - don't sample
                    break
                # find plate dimension
                obs_plate_dim = list(obs_plate_sites.values())[0]
                if use_median and q == 0.5:
                    means = self.module.guide.median(*args, **kwargs)
                else:
                    means = self.module.guide.quantiles([q], *args, **kwargs)
                means = {
                    k: means[k].cpu().numpy()
                    for k in means.keys() if k in obs_plate_sites
                }

            else:
                if use_median and q == 0.5:
                    means_ = self.module.guide.median(*args, **kwargs)
                else:
                    means_ = self.module.guide.quantiles([q], *args, **kwargs)

                means_ = {
                    k: means_[k].cpu().numpy()
                    for k in means_.keys() if k in obs_plate_sites
                }
                means = {
                    k: np.concatenate([means[k], means_[k]],
                                      axis=obs_plate_dim)
                    for k in means.keys()
                }
            i += 1

        # sample global parameters
        tensor_dict = next(iter(train_dl))
        args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
        args = [a.to(device) for a in args]
        kwargs = {k: v.to(device) for k, v in kwargs.items()}
        self.to_device(device)

        if use_median and q == 0.5:
            global_means = self.module.guide.median(*args, **kwargs)
        else:
            global_means = self.module.guide.quantiles([q], *args, **kwargs)
        global_means = {
            k: global_means[k].cpu().numpy()
            for k in global_means.keys() if k not in obs_plate_sites
        }

        for k in global_means.keys():
            means[k] = global_means[k]

        self.module.to(device)

        return means
Esempio n. 13
0
    def _posterior_samples_minibatch(
        self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs
    ):
        """
        Generate samples of the posterior distribution in minibatches.

        Generate samples of the posterior distribution of each parameter, separating local (minibatch) variables
        and global variables, which is necessary when performing minibatch inference.

        Parameters
        ----------
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

        Returns
        -------
        dictionary {variable_name: [array with samples in 0 dimension]}
        """
        samples = dict()

        _, device = parse_use_gpu_arg(use_gpu)

        batch_size = batch_size if batch_size is not None else settings.batch_size

        train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size)
        # sample local parameters
        i = 0
        for tensor_dict in track(
            train_dl,
            style="tqdm",
            description="Sampling local variables, batch: ",
        ):
            args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
            args = [a.to(device) for a in args]
            kwargs = {k: v.to(device) for k, v in kwargs.items()}
            self.to_device(device)

            if i == 0:
                return_observed = getattr(sample_kwargs, "return_observed", False)
                obs_plate_sites = self._get_obs_plate_sites(
                    args, kwargs, return_observed=return_observed
                )
                if len(obs_plate_sites) == 0:
                    # if no local variables - don't sample
                    break
                obs_plate_dim = list(obs_plate_sites.values())[0]

                sample_kwargs_obs_plate = sample_kwargs.copy()
                sample_kwargs_obs_plate[
                    "return_sites"
                ] = self._get_obs_plate_return_sites(
                    sample_kwargs["return_sites"], list(obs_plate_sites.keys())
                )
                sample_kwargs_obs_plate["show_progress"] = False

                samples = self._get_posterior_samples(
                    args, kwargs, **sample_kwargs_obs_plate
                )
            else:
                samples_ = self._get_posterior_samples(
                    args, kwargs, **sample_kwargs_obs_plate
                )

                samples = {
                    k: np.array(
                        [
                            np.concatenate(
                                [samples[k][j], samples_[k][j]],
                                axis=obs_plate_dim,
                            )
                            for j in range(
                                len(samples[k])
                            )  # for each sample (in 0 dimension
                        ]
                    )
                    for k in samples.keys()  # for each variable
                }
            i += 1

        # sample global parameters
        global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs)
        global_samples = {
            k: v
            for k, v in global_samples.items()
            if k not in list(obs_plate_sites.keys())
        }

        for k in global_samples.keys():
            samples[k] = global_samples[k]

        self.module.to(device)

        return samples
Esempio n. 14
0
    def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = True):
        """
        Compute median of the posterior distribution of each parameter, separating local (minibatch) variable
        and global variables, which is necessary when performing amortised inference.

        Note for developers: requires model class method which lists observation/minibatch plate
        variables (self.module.model.list_obs_plate_vars()).

        Parameters
        ----------
        q
            quantile to compute
        batch_size
            number of observations per batch
        use_gpu
            Bool, use gpu?

        Returns
        -------
        dictionary {variable_name: posterior median}

        """

        gpus, device = parse_use_gpu_arg(use_gpu)

        self.module.eval()

        train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size)

        # sample local parameters
        i = 0
        for tensor_dict in train_dl:

            args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
            args = [a.to(device) for a in args]
            kwargs = {k: v.to(device) for k, v in kwargs.items()}
            self.to_device(device)

            if i == 0:

                means = self.module.guide.quantiles([q], *args, **kwargs)
                means = {
                    k: means[k].cpu().numpy()
                    for k in means.keys()
                    if k in self.module.model.list_obs_plate_vars()["sites"]
                }

                # find plate dimension
                trace = poutine.trace(self.module.model).get_trace(*args, **kwargs)
                # print(trace.nodes[self.module.model.list_obs_plate_vars()['name']])
                obs_plate = {
                    name: site["cond_indep_stack"][0].dim
                    for name, site in trace.nodes.items()
                    if site["type"] == "sample"
                    if any(f.name == self.module.model.list_obs_plate_vars()["name"] for f in site["cond_indep_stack"])
                }

            else:

                means_ = self.module.guide.quantiles([q], *args, **kwargs)
                means_ = {
                    k: means_[k].cpu().numpy()
                    for k in means_.keys()
                    if k in list(self.module.model.list_obs_plate_vars()["sites"].keys())
                }
                means = {
                    k: np.concatenate([means[k], means_[k]], axis=list(obs_plate.values())[0]) for k in means.keys()
                }
            i += 1

        # sample global parameters
        tensor_dict = next(iter(train_dl))
        args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
        args = [a.to(device) for a in args]
        kwargs = {k: v.to(device) for k, v in kwargs.items()}
        self.to_device(device)

        global_means = self.module.guide.quantiles([q], *args, **kwargs)
        global_means = {
            k: global_means[k].cpu().numpy()
            for k in global_means.keys()
            if k not in list(self.module.model.list_obs_plate_vars()["sites"].keys())
        }

        for k in global_means.keys():
            means[k] = global_means[k]

        self.module.to(device)

        return means
def test_cell2location():
    save_path = "./cell2location_model_test"
    if torch.cuda.is_available():
        use_gpu = int(torch.cuda.is_available())
    else:
        use_gpu = False
    dataset = synthetic_iid(n_labels=5)
    RegressionModel.setup_anndata(dataset,
                                  labels_key="labels",
                                  batch_key="batch")

    # train regression model to get signatures of cell types
    sc_model = RegressionModel(dataset)
    # test full data training
    sc_model.train(max_epochs=1, use_gpu=use_gpu)
    # test minibatch training
    sc_model.train(max_epochs=1, batch_size=1000, use_gpu=use_gpu)
    # export the estimated cell abundance (summary of the posterior distribution)
    dataset = sc_model.export_posterior(dataset,
                                        sample_kwargs={"num_samples": 10})
    # test plot_QC
    sc_model.plot_QC()
    # test save/load
    sc_model.save(save_path, overwrite=True, save_anndata=True)
    sc_model = RegressionModel.load(save_path)
    # export estimated expression in each cluster
    if "means_per_cluster_mu_fg" in dataset.varm.keys():
        inf_aver = dataset.varm["means_per_cluster_mu_fg"][[
            f"means_per_cluster_mu_fg_{i}"
            for i in dataset.uns["mod"]["factor_names"]
        ]].copy()
    else:
        inf_aver = dataset.var[[
            f"means_per_cluster_mu_fg_{i}"
            for i in dataset.uns["mod"]["factor_names"]
        ]].copy()
    inf_aver.columns = dataset.uns["mod"]["factor_names"]

    ### test default cell2location model ###
    Cell2location.setup_anndata(dataset, batch_key="batch")
    ##  full data  ##
    st_model = Cell2location(dataset,
                             cell_state_df=inf_aver,
                             N_cells_per_location=30,
                             detection_alpha=200)
    # test full data training
    st_model.train(max_epochs=1, use_gpu=use_gpu)
    # export the estimated cell abundance (summary of the posterior distribution)
    # full data
    dataset = st_model.export_posterior(dataset,
                                        sample_kwargs={
                                            "num_samples": 10,
                                            "batch_size": st_model.adata.n_obs
                                        })
    ##  minibatches of locations  ##
    Cell2location.setup_anndata(dataset, batch_key="batch")
    st_model = Cell2location(dataset,
                             cell_state_df=inf_aver,
                             N_cells_per_location=30,
                             detection_alpha=200)
    # test minibatch training
    st_model.train(max_epochs=1, batch_size=50, use_gpu=use_gpu)
    # export the estimated cell abundance (summary of the posterior distribution)
    # minibatches of locations
    dataset = st_model.export_posterior(dataset,
                                        sample_kwargs={
                                            "num_samples": 10,
                                            "batch_size": 50
                                        })
    # test plot_QC
    st_model.plot_QC()
    # test save/load
    st_model.save(save_path, overwrite=True, save_anndata=True)
    st_model = Cell2location.load(save_path)
    # export the estimated cell abundance (summary of the posterior distribution)
    # minibatches of locations
    dataset = st_model.export_posterior(dataset,
                                        sample_kwargs={
                                            "num_samples": 10,
                                            "batch_size": 50
                                        })
    # test computing any quantile of the posterior distribution
    if not isinstance(st_model.module.guide, poutine.messenger.Messenger):
        st_model.posterior_quantile(q=0.5, use_gpu=use_gpu)
    # test computing median
    if True:
        if use_gpu:
            device = f"cuda:{use_gpu}"
        else:
            device = "cpu"
        train_dl = AnnDataLoader(st_model.adata_manager,
                                 shuffle=False,
                                 batch_size=50)
        for batch in train_dl:
            batch = {k: v.to(device) for k, v in batch.items()}
            args, kwargs = st_model.module._get_fn_args_from_batch(batch)
            break
        st_model.module.guide.median(*args, **kwargs)
    # test computing expected expression per cell type
    st_model.module.model.compute_expected_per_cell_type(
        st_model.samples["post_sample_q05"], st_model.adata_manager)
    ### test amortised inference with default cell2location model ###
    ##  full data  ##
    Cell2location.setup_anndata(dataset, batch_key="batch")
    st_model = Cell2location(
        dataset,
        cell_state_df=inf_aver,
        N_cells_per_location=30,
        detection_alpha=200,
        amortised=True,
        encoder_mode="multiple",
    )
    # test minibatch training
    st_model.train(max_epochs=1, batch_size=20, use_gpu=use_gpu)
    st_model.train_aggressive(max_epochs=3,
                              batch_size=20,
                              plan_kwargs={
                                  "n_aggressive_epochs": 1,
                                  "n_aggressive_steps": 5
                              },
                              use_gpu=use_gpu)
    # test computing median
    if True:
        if use_gpu:
            device = f"cuda:{use_gpu}"
        else:
            device = "cpu"
        train_dl = AnnDataLoader(st_model.adata_manager,
                                 shuffle=False,
                                 batch_size=50)
        for batch in train_dl:
            batch = {k: v.to(device) for k, v in batch.items()}
            args, kwargs = st_model.module._get_fn_args_from_batch(batch)
            break
        st_model.module.guide.median(*args, **kwargs)
        st_model.module.guide.quantiles([0.5], *args, **kwargs)
        st_model.module.guide.mutual_information(*args, **kwargs)
    # export the estimated cell abundance (summary of the posterior distribution)
    # minibatches of locations
    dataset = st_model.export_posterior(dataset,
                                        sample_kwargs={
                                            "num_samples": 10,
                                            "batch_size": 50
                                        })

    ### test downstream analysis ###
    _, _ = run_colocation(
        dataset,
        model_name="CoLocatedGroupsSklearnNMF",
        train_args={
            "n_fact": np.arange(
                3, 4
            ),  # IMPORTANT: use a wider range of the number of factors (5-30)
            "sample_name_col":
            "batch",  # columns in adata_vis.obs that identifies sample
            "n_restarts": 2,  # number of training restarts
        },
        export_args={"path": f"{save_path}/CoLocatedComb/"},
    )

    ### test simplified cell2location models ###
    ##  no m_g  ##
    Cell2location.setup_anndata(dataset, batch_key="batch")
    st_model = Cell2location(
        dataset,
        cell_state_df=inf_aver,
        N_cells_per_location=30,
        detection_alpha=200,
        model_class=
        LocationModelMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel,
    )
    # test full data training
    st_model.train(max_epochs=1, use_gpu=use_gpu)
    # export the estimated cell abundance (summary of the posterior distribution)
    # full data
    dataset = st_model.export_posterior(dataset,
                                        sample_kwargs={
                                            "num_samples": 10,
                                            "batch_size": st_model.adata.n_obs
                                        })
    ##  no w_sf factorisation  ##
    Cell2location.setup_anndata(dataset, batch_key="batch")
    st_model = Cell2location(
        dataset,
        cell_state_df=inf_aver,
        N_cells_per_location=30,
        detection_alpha=200,
        model_class=
        LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelNoMGPyroModel,
    )
    # test full data training
    st_model.train(max_epochs=1, use_gpu=use_gpu)
    # export the estimated cell abundance (summary of the posterior distribution)
    # full data
    st_model.export_posterior(dataset,
                              sample_kwargs={
                                  "num_samples": 10,
                                  "batch_size": st_model.adata.n_obs
                              })