Beispiel #1
0
    def compute_expected_per_cell_type(self, samples, adata, ind_x=None):
        r"""
        Compute expected expression of each gene in each location for each cell type.

        Parameters
        ----------
        samples
            Posterior distribution summary self.samples[f"post_sample_q05}"]
            (or 'means', 'stds', 'q05', 'q95') produced by export_posterior().
        adata
            Registered anndata object (self.adata).
        ind_x
            Location/observation indices for which to compute expected count
            (if None all locations are used).

        Returns
        -------
        dictionary with:
        1) list with expected expression counts (sparse, shape=(N locations, N genes)
        for each cell type in the same order as mod.factor_names_;
        2) np.array with location indices
        """
        if ind_x is None:
            ind_x = np.arange(adata.n_obs).astype(int)
        else:
            ind_x = ind_x.astype(int)

        # fetch data
        x_data = get_from_registry(adata, _CONSTANTS.X_KEY)[ind_x, :]
        x_data = csr_matrix(x_data)

        # compute total expected expression
        obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY)
        obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :]
        mu = np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) * samples["m_g"] + np.dot(
            obs2sample, samples["s_g_gene_add"]
        )

        # compute conditional expected expression per cell type
        mu_ct = [
            x_data.multiply(
                (
                    np.dot(
                        samples["w_sf"][ind_x, i, np.newaxis],
                        self.cell_state_mat.T[np.newaxis, i, :],
                    )
                    * samples["m_g"]
                )
                / mu
            )
            for i in range(self.n_factors)
        ]
        mu_ct = [csr_matrix(x) for x in mu_ct]

        return {"mu": mu_ct, "ind_x": ind_x}
Beispiel #2
0
    def plot_QC(self, summary_name: str = "means", use_n_obs: int = 1000):
        """
        Show quality control plots:
        1. Reconstruction accuracy to assess if there are any issues with model training.
            The plot should be roughly diagonal, strong deviations signal problems that need to be investigated.
            Plotting is slow because expected value of mRNA count needs to be computed from model parameters. Random
            observations are used to speed up computation.

        Parameters
        ----------
        summary_name
            posterior distribution summary to use ('means', 'stds', 'q05', 'q95')

        Returns
        -------

        """

        if getattr(self, "samples", False) is False:
            raise RuntimeError("self.samples is missing, please run self.export_posterior() first")
        if use_n_obs is not None:
            ind_x = np.random.choice(self.adata.n_obs, np.min((use_n_obs, self.adata.n_obs)), replace=False)
        else:
            ind_x = None

        self.expected_nb_param = self.module.model.compute_expected(
            self.samples[f"post_sample_{summary_name}"], self.adata, ind_x=ind_x
        )
        x_data = get_from_registry(self.adata, _CONSTANTS.X_KEY)[ind_x, :]
        if issparse(x_data):
            x_data = np.asarray(x_data.toarray())
        self.plot_posterior_mu_vs_data(self.expected_nb_param["mu"], x_data)
    def compute_expected_subset(self, samples, adata, fact_ind, cell_ind):
        r"""Compute expected expression of each gene in each cell that comes from
        a subset of factors (cell types) or cells.

        Useful for evaluating how well the model learned expression pattern of all genes in the data.

        Parameters
        ----------
        samples
            dictionary with values of the posterior
        adata
            registered anndata
        fact_ind
            indices of factors/cell types to use
        cell_ind
            indices of cells to use
        """
        obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY)
        obs2sample = pd.get_dummies(obs2sample.flatten())
        obs2label = get_from_registry(adata, _CONSTANTS.LABELS_KEY)
        obs2label = pd.get_dummies(obs2label.flatten())
        if self.n_extra_categoricals is not None:
            extra_categoricals = get_from_registry(adata,
                                                   _CONSTANTS.CAT_COVS_KEY)
            obs2extra_categoricals = np.concatenate(
                [
                    pd.get_dummies(extra_categoricals.iloc[:, i])
                    for i, n_cat in enumerate(self.n_extra_categoricals)
                ],
                axis=1,
            )

        alpha = 1 / np.power(samples["alpha_g_inverse"], 2)

        mu = (
            np.dot(obs2label[cell_ind, fact_ind],
                   samples["per_cluster_mu_fg"][fact_ind, :]) +
            np.dot(obs2sample[cell_ind, :], samples["s_g_gene_add"])) * np.dot(
                obs2sample,
                samples["detection_mean_y_e"])  # samples["detection_y_c"]
        if self.n_extra_categoricals is not None:
            mu = mu * np.dot(obs2extra_categoricals[cell_ind, :],
                             samples["detection_tech_gene_tg"])

        return {"mu": mu, "alpha": alpha}
Beispiel #4
0
    def setup_data_attr(self):
        """
        Sets data attribute.

        Reduces number of times anndata needs to be accessed
        """
        self.data = {
            key: get_from_registry(self.adata, key)
            for key, _ in self.attributes_and_types.items()
        }
    def compute_expected(self, samples, adata, ind_x=None):
        r"""Compute expected expression of each gene in each cell. Useful for evaluating how well
        the model learned expression pattern of all genes in the data.

        Parameters
        ----------
        samples
            dictionary with values of the posterior
        adata
            registered anndata
        ind_x
            indices of cells to use (to reduce data size)
        """
        if ind_x is None:
            ind_x = np.arange(adata.n_obs).astype(int)
        else:
            ind_x = ind_x.astype(int)
        obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY)
        obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :]
        obs2label = get_from_registry(adata, _CONSTANTS.LABELS_KEY)
        obs2label = pd.get_dummies(obs2label.flatten()).values[ind_x, :]
        if self.n_extra_categoricals is not None:
            extra_categoricals = get_from_registry(adata,
                                                   _CONSTANTS.CAT_COVS_KEY)
            obs2extra_categoricals = np.concatenate(
                [
                    pd.get_dummies(extra_categoricals.iloc[ind_x, i])
                    for i, n_cat in enumerate(self.n_extra_categoricals)
                ],
                axis=1,
            )

        alpha = 1 / np.power(samples["alpha_g_inverse"], 2)

        mu = (np.dot(obs2label, samples["per_cluster_mu_fg"]) +
              np.dot(obs2sample, samples["s_g_gene_add"])) * np.dot(
                  obs2sample, samples["detection_mean_y_e"]
              )  # samples["detection_y_c"][ind_x, :]
        if self.n_extra_categoricals is not None:
            mu = mu * np.dot(obs2extra_categoricals,
                             samples["detection_tech_gene_tg"])

        return {"mu": mu, "alpha": alpha}
Beispiel #6
0
def test_data_format():
    # if data was dense np array, check after setup_anndata, data is C_CONTIGUOUS
    adata = synthetic_iid(run_setup_anndata=False)

    old_x = adata.X
    old_pro = adata.obsm["protein_expression"]
    old_obs = adata.obs
    adata.X = np.asfortranarray(old_x)
    adata.obsm["protein_expression"] = np.asfortranarray(old_pro)
    assert adata.X.flags["C_CONTIGUOUS"] is False
    assert adata.obsm["protein_expression"].flags["C_CONTIGUOUS"] is False

    setup_anndata(adata, protein_expression_obsm_key="protein_expression")
    assert adata.X.flags["C_CONTIGUOUS"] is True
    assert adata.obsm["protein_expression"].flags["C_CONTIGUOUS"] is True

    assert np.array_equal(old_x, adata.X)
    assert np.array_equal(old_pro, adata.obsm["protein_expression"])
    assert np.array_equal(old_obs, adata.obs)

    assert np.array_equal(adata.X, get_from_registry(adata, _CONSTANTS.X_KEY))
    assert np.array_equal(
        adata.obsm["protein_expression"],
        get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY),
    )

    # if obsm is dataframe, make it C_CONTIGUOUS if it isnt
    adata = synthetic_iid()
    pe = np.asfortranarray(adata.obsm["protein_expression"])
    adata.obsm["protein_expression"] = pd.DataFrame(pe, index=adata.obs_names)
    assert adata.obsm["protein_expression"].to_numpy(
    ).flags["C_CONTIGUOUS"] is False
    setup_anndata(adata, protein_expression_obsm_key="protein_expression")
    new_pe = get_from_registry(adata, "protein_expression")
    assert new_pe.to_numpy().flags["C_CONTIGUOUS"] is True
    assert np.array_equal(pe, new_pe)
    assert np.array_equal(adata.X, get_from_registry(adata, _CONSTANTS.X_KEY))
    assert np.array_equal(
        adata.obsm["protein_expression"],
        get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY),
    )
    def normalise(self, samples, adata):
        r"""Normalise expression data by estimated technical variables.

        Parameters
        ----------
        samples
            dictionary with values of the posterior
        adata
            registered anndata

        """
        obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY)
        obs2sample = pd.get_dummies(obs2sample.flatten())
        if self.n_extra_categoricals is not None:
            extra_categoricals = get_from_registry(adata,
                                                   _CONSTANTS.CAT_COVS_KEY)
            obs2extra_categoricals = np.concatenate(
                [
                    pd.get_dummies(extra_categoricals.iloc[:, i])
                    for i, n_cat in enumerate(self.n_extra_categoricals)
                ],
                axis=1,
            )
        # get counts matrix
        corrected = get_from_registry(adata, _CONSTANTS.X_KEY)
        # normalise per-sample scaling
        corrected = corrected / np.dot(obs2sample,
                                       samples["detection_mean_y_e"])
        # normalise per gene effects
        if self.n_extra_categoricals is not None:
            corrected = corrected / np.dot(obs2extra_categoricals,
                                           samples["detection_tech_gene_tg"])

        # remove additive sample effects
        corrected = corrected - np.dot(obs2sample, samples["s_g_gene_add"])

        # set minimum value to 0 for each gene (a hack to avoid negative values)
        corrected = corrected - corrected.min()

        return corrected
Beispiel #8
0
    def __init__(
        self,
        model,
        adata,
        n_labelled_samples_per_class=50,
        n_epochs_classifier=1,
        lr_classification=5 * 1e-3,
        classification_ratio=50,
        seed=0,
        **kwargs
    ):
        super().__init__(model, adata, **kwargs)
        self.model = model
        self.adata = adata
        self.n_epochs_classifier = n_epochs_classifier
        self.lr_classification = lr_classification
        self.classification_ratio = classification_ratio
        n_labelled_samples_per_class_array = [
            n_labelled_samples_per_class
        ] * self.adata.uns["_scvi"]["summary_stats"]["n_labels"]
        labels = np.array(get_from_registry(self.adata, _CONSTANTS.LABELS_KEY)).ravel()
        np.random.seed(seed=seed)
        permutation_idx = np.random.permutation(len(labels))
        labels = labels[permutation_idx]
        indices = []
        current_nbrs = np.zeros(len(n_labelled_samples_per_class_array))
        for idx, (label) in enumerate(labels):
            label = int(label)
            if current_nbrs[label] < n_labelled_samples_per_class_array[label]:
                indices.insert(0, idx)
                current_nbrs[label] += 1
            else:
                indices.append(idx)
        indices = np.array(indices)
        total_labelled = sum(n_labelled_samples_per_class_array)
        indices_labelled = permutation_idx[indices[:total_labelled]]
        indices_unlabelled = permutation_idx[indices[total_labelled:]]

        self.classifier_trainer = ClassifierTrainer(
            model.classifier,
            self.adata,
            metrics_to_monitor=[],
            silent=True,
            frequency=0,
            sampling_model=self.model,
        )
        self.full_dataset = self.create_scvi_dl(shuffle=True)
        self.labelled_set = self.create_scvi_dl(indices=indices_labelled)
        self.unlabelled_set = self.create_scvi_dl(indices=indices_unlabelled)

        for scdl in [self.labelled_set, self.unlabelled_set]:
            scdl.to_monitor = ["reconstruction_error", "accuracy"]
    def compute_expected(self, samples, adata, ind_x=None):
        r"""Compute expected expression of each gene in each location. Useful for evaluating how well
        the model learned expression pattern of all genes in the data.
        """
        if ind_x is None:
            ind_x = np.arange(adata.n_obs).astype(int)
        else:
            ind_x = ind_x.astype(int)
        obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY)
        obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :]
        mu = (np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) +
              np.dot(obs2sample, samples["s_g_gene_add"])
              ) * samples["detection_y_s"][ind_x, :]
        alpha = np.dot(obs2sample, 1 / np.power(samples["alpha_g_inverse"], 2))

        return {"mu": mu, "alpha": alpha, "ind_x": ind_x}
Beispiel #10
0
    def create_doublets(
        adata: AnnData,
        doublet_ratio: int,
        indices: Optional[Sequence[int]] = None,
        seed: int = 1,
    ) -> AnnData:
        """Simulate doublets.

        Parameters
        ----------
        adata
            AnnData object setup with setup_anndata.
        doublet_ratio
            Ratio of generated doublets to produce relative to number of
            cells in adata or length of indices, if not `None`.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        seed
            Seed for reproducibility
        """
        n_obs = adata.n_obs if indices is None else len(indices)
        num_doublets = doublet_ratio * n_obs

        # counts can be in many locations, this uses where it was registered in setup
        x = get_from_registry(adata, _CONSTANTS.X_KEY)
        if indices is not None:
            x = x[indices]

        random_state = np.random.RandomState(seed=seed)
        parent_inds = random_state.choice(n_obs, size=(num_doublets, 2))
        doublets = x[parent_inds[:, 0]] + x[parent_inds[:, 1]]

        doublets_ad = AnnData(doublets)
        doublets_ad.var_names = adata.var_names
        doublets_ad.obs_names = [
            "sim_doublet_{}".format(i) for i in range(num_doublets)
        ]

        # if adata setup with a layer, need to add layer to doublets adata
        data_registry = adata.uns["_scvi"]["data_registry"]
        x_loc = data_registry[_CONSTANTS.X_KEY]["attr_name"]
        layer = (data_registry[_CONSTANTS.X_KEY]["attr_key"]
                 if x_loc == "layers" else None)
        if layer is not None:
            doublets_ad.layers[layer] = doublets

        return doublets_ad
Beispiel #11
0
def test_setup_anndata():
    # test regular setup
    adata = synthetic_iid(run_setup_anndata=False)
    setup_anndata(
        adata,
        batch_key="batch",
        labels_key="labels",
        protein_expression_obsm_key="protein_expression",
        protein_names_uns_key="protein_names",
    )
    np.testing.assert_array_equal(
        get_from_registry(adata, "batch_indices"),
        np.array(adata.obs["_scvi_batch"]).reshape((-1, 1)),
    )
    np.testing.assert_array_equal(
        get_from_registry(adata, "labels"),
        np.array(adata.obs["labels"].cat.codes).reshape((-1, 1)),
    )
    np.testing.assert_array_equal(get_from_registry(adata, "X"), adata.X)
    np.testing.assert_array_equal(
        get_from_registry(adata, "protein_expression"),
        adata.obsm["protein_expression"],
    )
    np.testing.assert_array_equal(adata.uns["_scvi"]["protein_names"],
                                  adata.uns["protein_names"])

    # test that error is thrown if its a view:
    adata = synthetic_iid()
    with pytest.raises(ValueError):
        setup_anndata(adata[1])

    # If obsm is a df and protein_names_uns_key is None, protein names should be grabbed from column of df
    adata = synthetic_iid()
    new_protein_names = np.array(random.sample(range(100), 100)).astype("str")
    df = pd.DataFrame(
        adata.obsm["protein_expression"],
        index=adata.obs_names,
        columns=new_protein_names,
    )
    adata.obsm["protein_expression"] = df
    setup_anndata(adata, protein_expression_obsm_key="protein_expression")
    np.testing.assert_array_equal(adata.uns["_scvi"]["protein_names"],
                                  new_protein_names)

    # test that layer is working properly
    adata = synthetic_iid()
    true_x = adata.X
    adata.layers["X"] = true_x
    adata.X = np.ones_like(adata.X)
    setup_anndata(adata, layer="X")
    np.testing.assert_array_equal(get_from_registry(adata, "X"), true_x)

    # test that it creates layers and batch if no layers_key is passed
    adata = synthetic_iid()
    setup_anndata(
        adata,
        protein_expression_obsm_key="protein_expression",
        protein_names_uns_key="protein_names",
    )
    np.testing.assert_array_equal(get_from_registry(adata, "batch_indices"),
                                  np.zeros((adata.shape[0], 1)))
    np.testing.assert_array_equal(get_from_registry(adata, "labels"),
                                  np.zeros((adata.shape[0], 1)))
Beispiel #12
0
def custom_objective_hyperopt(
    space, is_best_training=False, dataset=None, n_epochs=None
):
    """Custom objective function for advanced autotune tutorial."""
    space = defaultdict(dict, space)
    model_tunable_kwargs = space["model_tunable_kwargs"]
    trainer_tunable_kwargs = space["trainer_tunable_kwargs"]
    train_func_tunable_kwargs = space["train_func_tunable_kwargs"]

    trainer_specific_kwargs = {}
    model_specific_kwargs = {}
    train_func_specific_kwargs = {}
    trainer_specific_kwargs["use_cuda"] = bool(torch.cuda.device_count())
    train_func_specific_kwargs["n_epochs"] = n_epochs

    # add hardcoded parameters
    # disable scVI progbar
    trainer_specific_kwargs["silent"] = True
    trainer_specific_kwargs["frequency"] = 1

    # merge params with fixed param precedence
    model_tunable_kwargs.update(model_specific_kwargs)
    trainer_tunable_kwargs.update(trainer_specific_kwargs)
    train_func_tunable_kwargs.update(train_func_specific_kwargs)

    scanvi = SCANVAE(
        dataset.uns["_scvi"]["summary_stats"]["n_vars"],
        dataset.uns["_scvi"]["summary_stats"]["n_batch"],
        dataset.uns["_scvi"]["summary_stats"]["n_labels"],
        **model_tunable_kwargs
    )
    trainer_scanvi = SemiSupervisedTrainer(scanvi, dataset, **trainer_tunable_kwargs)
    batch_indices = get_from_registry(dataset, _CONSTANTS.BATCH_KEY)
    trainer_scanvi.unlabelled_set = trainer_scanvi.create_scvi_dl(
        indices=(batch_indices == 1)
    )
    trainer_scanvi.unlabelled_set.to_monitor = ["reconstruction_error", "accuracy"]
    indices_labelled = batch_indices == 0

    if not is_best_training:
        # compute k-fold accuracy on a 20% validation set
        k = 5
        accuracies = np.zeros(k)
        indices_labelled = batch_indices == 0
        for i in range(k):
            indices_labelled_train, indices_labelled_val = train_test_split(
                indices_labelled.nonzero()[0], test_size=0.2
            )
            trainer_scanvi.labelled_set = trainer_scanvi.create_scvi_dl(
                indices=indices_labelled_train
            )
            trainer_scanvi.labelled_set.to_monitor = [
                "reconstruction_error",
                "accuracy",
            ]
            trainer_scanvi.validation_set = trainer_scanvi.create_scvi_dl(
                indices=indices_labelled_val
            )
            trainer_scanvi.validation_set.to_monitor = ["accuracy"]
            trainer_scanvi.train(**train_func_tunable_kwargs)
            accuracies[i] = trainer_scanvi.history["accuracy_unlabelled_set"][-1]
        return {"loss": -accuracies.mean(), "space": space, "status": STATUS_OK}
    else:
        trainer_scanvi.labelled_set = trainer_scanvi.create_scvi_dl(
            indices=indices_labelled
        )
        trainer_scanvi.labelled_set.to_monitor = ["reconstruction_error", "accuracy"]
        trainer_scanvi.train(**train_func_tunable_kwargs)
        return trainer_scanvi
Beispiel #13
0
    def __init__(
        self,
        model,
        adata,
        n_labelled_samples_per_class=50,
        indices_labelled=None,
        indices_unlabelled=None,
        n_epochs_classifier=1,
        lr_classification=5 * 1e-3,
        classification_ratio=50,
        seed=0,
        scheme: Literal["joint", "alternate", "both"] = "both",
        **kwargs,
    ):
        super().__init__(model, adata, **kwargs)
        self.model = model
        self.adata = adata
        self.n_epochs_classifier = n_epochs_classifier
        self.lr_classification = lr_classification
        self.classification_ratio = classification_ratio
        self.scheme = scheme

        if scheme == "joint":
            self.n_epochs_classifier = 0

        if indices_labelled is None and indices_unlabelled is None:
            n_labelled_samples_per_class_array = [
                n_labelled_samples_per_class
            ] * self.adata.uns["_scvi"]["summary_stats"]["n_labels"]
            labels = np.array(
                get_from_registry(self.adata, _CONSTANTS.LABELS_KEY)).ravel()
            np.random.seed(seed=seed)
            permutation_idx = np.random.permutation(len(labels))
            labels = labels[permutation_idx]
            indices = []
            current_nbrs = np.zeros(len(n_labelled_samples_per_class_array))
            for idx, (label) in enumerate(labels):
                label = int(label)
                if current_nbrs[label] < n_labelled_samples_per_class_array[
                        label]:
                    indices.insert(0, idx)
                    current_nbrs[label] += 1
                else:
                    indices.append(idx)
            indices = np.array(indices)
            total_labelled = sum(n_labelled_samples_per_class_array)
            indices_labelled = permutation_idx[indices[:total_labelled]]
            indices_unlabelled = permutation_idx[indices[total_labelled:]]

        class_kwargs = {}
        if "weight_decay" in kwargs.keys():
            class_kwargs["weight_decay"] = kwargs["weight_decay"]
        self.classifier_trainer = ClassifierTrainer(
            model.classifier,
            self.adata,
            metrics_to_monitor=[],
            silent=True,
            frequency=0,
            sampling_model=self.model,
            **class_kwargs,
        )
        self.full_dataset = self.create_scvi_dl(shuffle=True)
        self.labelled_set = self.create_scvi_dl(indices=indices_labelled)
        self.unlabelled_set = self.create_scvi_dl(indices=indices_unlabelled)

        for scdl in [self.labelled_set, self.unlabelled_set]:
            scdl.to_monitor = ["elbo", "reconstruction_error", "accuracy"]
        # allow to track ELBO
        self.unlabelled_set.unlabeled = True
        self.full_dataset.unlabeled = True
Beispiel #14
0
    def __init__(
        self,
        adata: AnnData,
        cell_state_df: pd.DataFrame,
        model_class: Optional[PyroModule] = None,
        detection_mean_per_sample: bool = False,
        detection_mean_correction: float = 1.0,
        **model_kwargs,
    ):
        # in case any other model was created before that shares the same parameter names.
        clear_param_store()

        if not np.all(adata.var_names == cell_state_df.index):
            raise ValueError(
                "adata.var_names should match cell_state_df.index, find interecting variables/genes first"
            )

        # add index for each cell (provided to pyro plate for correct minibatching)
        adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64")
        scvi.data.register_tensor_from_anndata(
            adata,
            registry_key="ind_x",
            adata_attr_name="obs",
            adata_key_name="_indices",
        )

        super().__init__(adata)

        if model_class is None:
            model_class = LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel

        self.cell_state_df_ = cell_state_df
        self.n_factors_ = cell_state_df.shape[1]
        self.factor_names_ = cell_state_df.columns.values

        if not detection_mean_per_sample:
            # compute expected change in sensitivity (m_g in V1 or y_s in V2)
            sc_total = cell_state_df.sum(0).mean()
            sp_total = get_from_registry(self.adata,
                                         _CONSTANTS.X_KEY).sum(1).mean()
            get_from_registry(adata, _CONSTANTS.BATCH_KEY)
            self.detection_mean_ = (sp_total / model_kwargs.get(
                "N_cells_per_location", 1)) / sc_total
            self.detection_mean_ = self.detection_mean_ * detection_mean_correction
            model_kwargs["detection_mean"] = self.detection_mean_
        else:
            # compute expected change in sensitivity (m_g in V1 and y_s in V2)
            sc_total = cell_state_df.sum(0).mean()
            sp_total = get_from_registry(self.adata, _CONSTANTS.X_KEY).sum(1)
            batch = get_from_registry(self.adata,
                                      _CONSTANTS.BATCH_KEY).flatten()
            sp_total = np.array([
                sp_total[batch == b].mean()
                for b in range(self.summary_stats["n_batch"])
            ])
            self.detection_mean_ = (sp_total / model_kwargs.get(
                "N_cells_per_location", 1)) / sc_total
            self.detection_mean_ = self.detection_mean_ * detection_mean_correction
            model_kwargs["detection_mean"] = self.detection_mean_.reshape(
                (self.summary_stats["n_batch"], 1)).astype("float32")

        detection_alpha = model_kwargs.get("detection_alpha", None)
        if detection_alpha is not None:
            if type(detection_alpha) is dict:
                batch_mapping = self.adata.uns["_scvi"][
                    "categorical_mappings"]["_scvi_batch"]["mapping"]
                self.detection_alpha_ = pd.Series(
                    detection_alpha)[batch_mapping]
                model_kwargs[
                    "detection_alpha"] = self.detection_alpha_.values.reshape(
                        (self.summary_stats["n_batch"], 1)).astype("float32")

        self.module = Cell2locationBaseModule(
            model=model_class,
            n_obs=self.summary_stats["n_cells"],
            n_vars=self.summary_stats["n_vars"],
            n_factors=self.n_factors_,
            n_batch=self.summary_stats["n_batch"],
            cell_state_mat=self.cell_state_df_.values.astype("float32"),
            **model_kwargs,
        )
        self._model_summary_string = f'cell2location model with the following params: \nn_factors: {self.n_factors_} \nn_batch: {self.summary_stats["n_batch"]} '
        self.init_params_ = self._get_init_params(locals())