Пример #1
0
    def __init__(
        self,
        model: Union[SCVI, SCANVI, TOTALVI],
        adata: anndata.AnnData,
        trainer: Optional['Trainer'] = None,
        cell_type_key: str = None,
        batch_key: str = None,
    ):
        self.outer_model = model
        self.model = model.model
        self.model.eval()

        if trainer is None:
            self.trainer = model.trainer
        else:
            self.trainer = trainer

        self.adata = adata
        self.modified = getattr(model.model, 'encode_covariates', True)
        self.annotated = type(model) is SCANVI
        self.predictions = None
        self.certainty = None
        self.prediction_names = None
        self.class_check = None
        self.post_adata_2 = None

        if trainer is not None:
            if self.trainer.use_cuda:
                self.device = torch.device('cuda')
            else:
                self.device = torch.device('cpu')
        else:
            self.device = next(self.model.parameters()).get_device()

        if issparse(self.adata.X):
            X = self.adata.X.toarray()
        else:
            X = self.adata.X
        self.x_tensor = torch.tensor(X, device=self.device)

        self.labels = None
        self.label_tensor = None
        if self.annotated:
            self.labels = get_from_registry(self.adata,
                                            "labels").astype(np.int8)
            self.label_tensor = torch.tensor(self.labels, device=self.device)
        self.cell_types = self.adata.obs[cell_type_key].tolist()

        self.batch_indices = get_from_registry(self.adata,
                                               "batch_indices").astype(np.int8)
        self.batch_tensor = torch.tensor(self.batch_indices,
                                         device=self.device)
        self.batch_names = self.adata.obs[batch_key].tolist()
        self.celltype_enc = [0] * len(
            self.adata.obs[cell_type_key].unique().tolist())
        for i, cell_type in enumerate(
                self.adata.obs[cell_type_key].unique().tolist()):
            label = self.adata.obs['_scvi_labels'].unique().tolist()[i]
            self.celltype_enc[label] = cell_type
        self.post_adata = self.latent_as_anndata()
Пример #2
0
    def _validate_anndata(self,
                          adata: Optional[AnnData] = None,
                          copy_if_view: bool = True):
        """Validate anndata has been properly registered, transfer if necessary."""
        if adata is None:
            adata = self.adata
        if adata.is_view:
            if copy_if_view:
                logger.info("Received view of anndata, making copy.")
                adata = adata.copy()
            else:
                raise ValueError("Please run `adata = adata.copy()`")

        if "_scvi" not in adata.uns_keys():
            logger.info("Input adata not setup with scvi. " +
                        "attempting to transfer anndata setup")
            transfer_anndata_setup(self.scvi_setup_dict_, adata)
        is_nonneg_int = _check_nonnegative_integers(
            get_from_registry(adata, _CONSTANTS.X_KEY))
        if not is_nonneg_int:
            logger.warning(
                "Make sure the registered X field in anndata contains unnormalized count data."
            )

        _check_anndata_setup_equivalence(self.scvi_setup_dict_, adata)

        return adata
Пример #3
0
def scatac_raw_counts_properties(
    adata: anndata.AnnData,
    idx1: Union[List[int], np.ndarray],
    idx2: Union[List[int], np.ndarray],
) -> Dict[str, np.ndarray]:
    """
    Computes and returns some statistics on the raw counts of two sub-populations.

    Parameters
    ----------
    adata
        AnnData object setup with `scvi`.
    idx1
        subset of indices describing the first population.
    idx2
        subset of indices describing the second population.

    Returns
    -------
    type
        Dict of ``np.ndarray`` containing, by pair (one for each sub-population).
    """
    data = get_from_registry(adata, _CONSTANTS.X_KEY)
    data1 = data[idx1]
    data2 = data[idx2]
    mean1 = np.asarray((data1 > 0).mean(axis=0)).ravel()
    mean2 = np.asarray((data2 > 0).mean(axis=0)).ravel()
    properties = dict(emp_mean1=mean1,
                      emp_mean2=mean2,
                      emp_effect=(mean1 - mean2))
    return properties
Пример #4
0
def _init_library_size(
    adata: anndata.AnnData, n_batch: dict
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Computes and returns library size.

    Parameters
    ----------
    adata
        AnnData object setup with `scvi`.
    n_batch
        Number of batches.

    Returns
    -------
    type
        Tuple of two 1 x n_batch ``np.ndarray`` containing the means and variances
        of library size in each batch in adata.

        If a certain batch is not present in the adata, the mean defaults to 0,
        and the variance defaults to 1. These defaults are arbitrary placeholders which
        should not be used in any downstream computation.
    """
    data = get_from_registry(adata, _CONSTANTS.X_KEY)
    batch_indices = get_from_registry(adata, _CONSTANTS.BATCH_KEY)

    library_log_means = np.zeros(n_batch)
    library_log_vars = np.ones(n_batch)

    for i_batch in np.unique(batch_indices):
        idx_batch = np.squeeze(batch_indices == i_batch)
        batch_data = data[
            idx_batch.nonzero()[0]
        ]  # h5ad requires integer indexing arrays.
        sum_counts = batch_data.sum(axis=1)
        masked_log_sum = np.ma.log(sum_counts)
        if np.ma.is_masked(masked_log_sum):
            warnings.warn(
                "This dataset has some empty cells, this might fail inference."
                "Data should be filtered with `scanpy.pp.filter_cells()`"
            )

        log_counts = masked_log_sum.filled(0)
        library_log_means[i_batch] = np.mean(log_counts).astype(np.float32)
        library_log_vars[i_batch] = np.var(log_counts).astype(np.float32)

    return library_log_means.reshape(1, -1), library_log_vars.reshape(1, -1)
Пример #5
0
def scrna_raw_counts_properties(
    adata: anndata.AnnData,
    idx1: Union[List[int], np.ndarray],
    idx2: Union[List[int], np.ndarray],
) -> Dict[str, np.ndarray]:
    """
    Computes and returns some statistics on the raw counts of two sub-populations.

    Parameters
    ----------
    adata
        AnnData object setup with `scvi`.
    idx1
        subset of indices describing the first population.
    idx2
        subset of indices describing the second population.

    Returns
    -------
    type
        Dict of ``np.ndarray`` containing, by pair (one for each sub-population),
        mean expression per gene, proportion of non-zero expression per gene, mean of normalized expression.
    """
    data = get_from_registry(adata, _CONSTANTS.X_KEY)
    data1 = data[idx1]
    data2 = data[idx2]
    mean1 = np.asarray((data1).mean(axis=0)).ravel()
    mean2 = np.asarray((data2).mean(axis=0)).ravel()
    nonz1 = np.asarray((data1 != 0).mean(axis=0)).ravel()
    nonz2 = np.asarray((data2 != 0).mean(axis=0)).ravel()

    key = "_scvi_raw_norm_scaling"
    if key not in adata.obs.keys():
        scaling_factor = 1 / np.asarray(data.sum(axis=1)).ravel().reshape(
            -1, 1)
        scaling_factor *= 1e4
        adata.obs[key] = scaling_factor.ravel()
    else:
        scaling_factor = adata.obs[key].to_numpy().ravel().reshape(-1, 1)

    if issubclass(type(data), sp_sparse.spmatrix):
        norm_data1 = data1.multiply(scaling_factor[idx1])
        norm_data2 = data2.multiply(scaling_factor[idx2])
    else:
        norm_data1 = data1 * scaling_factor[idx1]
        norm_data2 = data2 * scaling_factor[idx2]

    norm_mean1 = np.asarray(norm_data1.mean(axis=0)).ravel()
    norm_mean2 = np.asarray(norm_data2.mean(axis=0)).ravel()

    properties = dict(
        raw_mean1=mean1,
        raw_mean2=mean2,
        non_zeros_proportion1=nonz1,
        non_zeros_proportion2=nonz2,
        raw_normalized_mean1=norm_mean1,
        raw_normalized_mean2=norm_mean2,
    )
    return properties
Пример #6
0
def cite_seq_raw_counts_properties(
    adata: anndata.AnnData,
    idx1: Union[List[int], np.ndarray],
    idx2: Union[List[int], np.ndarray],
) -> Dict[str, np.ndarray]:
    """
    Computes and returns some statistics on the raw counts of two sub-populations.

    Parameters
    ----------
    adata
        AnnData object setup with `scvi`.
    idx1
        subset of indices describing the first population.
    idx2
        subset of indices describing the second population.

    Returns
    -------
    type
        Dict of ``np.ndarray`` containing, by pair (one for each sub-population),
        mean expression per gene, proportion of non-zero expression per gene, mean of normalized expression.
    """
    gp = scrna_raw_counts_properties(adata, idx1, idx2)
    protein_exp = get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY)

    nan = np.array([np.nan] * len(adata.uns["_scvi"]["protein_names"]))
    protein_exp = get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY)
    mean1_pro = np.asarray(protein_exp[idx1].mean(0))
    mean2_pro = np.asarray(protein_exp[idx2].mean(0))
    nonz1_pro = np.asarray((protein_exp[idx1] > 0).mean(0))
    nonz2_pro = np.asarray((protein_exp[idx2] > 0).mean(0))
    properties = dict(
        raw_mean1=np.concatenate([gp["raw_mean1"], mean1_pro]),
        raw_mean2=np.concatenate([gp["raw_mean2"], mean2_pro]),
        non_zeros_proportion1=np.concatenate(
            [gp["non_zeros_proportion1"], nonz1_pro]),
        non_zeros_proportion2=np.concatenate(
            [gp["non_zeros_proportion2"], nonz2_pro]),
        raw_normalized_mean1=np.concatenate([gp["raw_normalized_mean1"], nan]),
        raw_normalized_mean2=np.concatenate([gp["raw_normalized_mean2"], nan]),
    )

    return properties
Пример #7
0
def test_data_format():
    # if data was dense np array, check after setup_anndata, data is C_CONTIGUOUS
    adata = synthetic_iid(run_setup_anndata=False)

    old_x = adata.X
    old_pro = adata.obsm["protein_expression"]
    old_obs = adata.obs
    adata.X = np.asfortranarray(old_x)
    adata.obsm["protein_expression"] = np.asfortranarray(old_pro)
    assert adata.X.flags["C_CONTIGUOUS"] is False
    assert adata.obsm["protein_expression"].flags["C_CONTIGUOUS"] is False

    _setup_anndata(adata, protein_expression_obsm_key="protein_expression")
    assert adata.X.flags["C_CONTIGUOUS"] is True
    assert adata.obsm["protein_expression"].flags["C_CONTIGUOUS"] is True

    assert np.array_equal(old_x, adata.X)
    assert np.array_equal(old_pro, adata.obsm["protein_expression"])
    assert np.array_equal(old_obs, adata.obs)

    assert np.array_equal(adata.X, get_from_registry(adata, _CONSTANTS.X_KEY))
    assert np.array_equal(
        adata.obsm["protein_expression"],
        get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY),
    )

    # if obsm is dataframe, make it C_CONTIGUOUS if it isnt
    adata = synthetic_iid()
    pe = np.asfortranarray(adata.obsm["protein_expression"])
    adata.obsm["protein_expression"] = pd.DataFrame(pe, index=adata.obs_names)
    assert adata.obsm["protein_expression"].to_numpy(
    ).flags["C_CONTIGUOUS"] is False
    _setup_anndata(adata, protein_expression_obsm_key="protein_expression")
    new_pe = get_from_registry(adata, "protein_expression")
    assert new_pe.to_numpy().flags["C_CONTIGUOUS"] is True
    assert np.array_equal(pe, new_pe)
    assert np.array_equal(adata.X, get_from_registry(adata, _CONSTANTS.X_KEY))
    assert np.array_equal(
        adata.obsm["protein_expression"],
        get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY),
    )
Пример #8
0
    def _validate_anndata(
        self, adata: Optional[AnnData] = None, copy_if_view: bool = True
    ):
        adata = super()._validate_anndata(adata, copy_if_view)
        error_msg = "Number of {} in anndata different from when setup_anndata was run. Please rerun setup_anndata."
        if _CONSTANTS.PROTEIN_EXP_KEY in adata.uns["_scvi"]["data_registry"].keys():
            if (
                self.summary_stats["n_proteins"]
                != get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY).shape[1]
            ):
                raise ValueError(error_msg.format("proteins"))
        else:
            raise ValueError("No protein data found, please setup or transfer anndata")

        return adata
Пример #9
0
    def create_doublets(
        adata: AnnData,
        doublet_ratio: int,
        indices: Optional[Sequence[int]] = None,
        seed: int = 1,
    ) -> AnnData:
        """Simulate doublets.

        Parameters
        ----------
        adata
            AnnData object setup with :func:`~scvi.data.setup_anndata`.
        doublet_ratio
            Ratio of generated doublets to produce relative to number of
            cells in adata or length of indices, if not `None`.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        seed
            Seed for reproducibility
        """
        n_obs = adata.n_obs if indices is None else len(indices)
        num_doublets = doublet_ratio * n_obs

        # counts can be in many locations, this uses where it was registered in setup
        x = get_from_registry(adata, _CONSTANTS.X_KEY)
        if indices is not None:
            x = x[indices]

        random_state = np.random.RandomState(seed=seed)
        parent_inds = random_state.choice(n_obs, size=(num_doublets, 2))
        doublets = x[parent_inds[:, 0]] + x[parent_inds[:, 1]]

        doublets_ad = AnnData(doublets)
        doublets_ad.var_names = adata.var_names
        doublets_ad.obs_names = [
            "sim_doublet_{}".format(i) for i in range(num_doublets)
        ]

        # if adata setup with a layer, need to add layer to doublets adata
        data_registry = adata.uns["_scvi"]["data_registry"]
        x_loc = data_registry[_CONSTANTS.X_KEY]["attr_name"]
        layer = (data_registry[_CONSTANTS.X_KEY]["attr_key"]
                 if x_loc == "layers" else None)
        if layer is not None:
            doublets_ad.layers[layer] = doublets

        return doublets_ad
Пример #10
0
    def _validate_anndata(
        self, adata: Optional[AnnData] = None, copy_if_view: bool = True
    ):
        adata = super()._validate_anndata(adata, copy_if_view)
        error_msg = "Number of {} in anndata different from when setup_anndata was run. Please rerun setup_anndata."
        if _CONSTANTS.PROTEIN_EXP_KEY in adata.uns["_scvi"]["data_registry"].keys():
            pro_exp = get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY)
            if self.summary_stats["n_proteins"] != pro_exp.shape[1]:
                raise ValueError(error_msg.format("proteins"))
            is_nonneg_int = _check_nonnegative_integers(pro_exp)
            if not is_nonneg_int:
                warnings.warn(
                    "Make sure the registered protein expression in anndata contains unnormalized count data."
                )
        else:
            raise ValueError("No protein data found, please setup or transfer anndata")

        return adata
Пример #11
0
    def create_doublets(adata: AnnData,
                        seed: int = 1,
                        doublet_ratio: int = 2) -> AnnData:
        """Simulate doublets."""
        num_doublets = doublet_ratio * adata.n_obs

        # counts can be in many locations, this uses where it was registered in setup
        x = get_from_registry(adata, _CONSTANTS.X_KEY)

        # TODO: needs a random state so it's reproducible
        parent_inds = np.random.choice(adata.n_obs, size=(num_doublets, 2))
        doublets = x[parent_inds[:, 0]] + x[parent_inds[:, 1]]

        doublets_ad = AnnData(doublets)
        doublets_ad.var_names = adata.var_names
        doublets_ad.obs_names = [
            "sim_doublet_{}".format(i) for i in range(num_doublets)
        ]

        return doublets_ad
Пример #12
0
    def scale_sampler(
        self,
        selection: Union[List[bool], np.ndarray],
        n_samples: Optional[int] = 5000,
        n_samples_per_cell: Optional[int] = None,
        batchid: Optional[Union[List[int], np.ndarray]] = None,
        use_observed_batches: Optional[bool] = False,
        give_mean: Optional[bool] = False,
    ) -> dict:
        """
        Samples the posterior scale using the variational posterior distribution.

        Parameters
        ----------
        selection
            Mask or list of cell ids to select
        n_samples
            Number of samples in total per batch (fill either `n_samples_total`
            or `n_samples_per_cell`)
        n_samples_per_cell
            Number of time we sample from each observation per batch
            (fill either `n_samples_total` or `n_samples_per_cell`)
        batchid
            Biological batch for which to sample from.
            Default (None) sample from all batches
        use_observed_batches
            Whether normalized means are conditioned on observed
            batches or if observed batches are to be used
        give_mean
            Return mean of values


        Returns
        -------
        type
            Dictionary containing:
            `scale`
            Posterior aggregated scale samples of shape (n_samples, n_genes)
            where n_samples correspond to either:
            - n_bio_batches * n_cells * n_samples_per_cell
            or
            - n_samples_total
            `batch`
            associated batch ids

        """
        # Get overall number of desired samples and desired batches
        if batchid is None and not use_observed_batches:
            # TODO determine if we iterate over all categorical batches from train dataset
            # or just the batches in adata
            batchid = np.unique(
                get_from_registry(self.adata, key=_CONSTANTS.BATCH_KEY))
        if use_observed_batches:
            if batchid is not None:
                raise ValueError("Unconsistent batch policy")
            batchid = [None]
        if n_samples is None and n_samples_per_cell is None:
            n_samples = 5000
        elif n_samples_per_cell is not None and n_samples is None:
            n_samples = n_samples_per_cell * len(selection)
        if (n_samples_per_cell is not None) and (n_samples is not None):
            warnings.warn(
                "n_samples and n_samples_per_cell were provided. Ignoring n_samples_per_cell"
            )
        n_samples = int(n_samples / len(batchid))
        if n_samples == 0:
            warnings.warn(
                "very small sample size, please consider increasing `n_samples`"
            )
            n_samples = 2

        # Selection of desired cells for sampling
        if selection is None:
            raise ValueError(
                "selections should be a list of cell subsets indices")
        selection = np.asarray(selection)
        if selection.dtype is np.dtype("bool"):
            if len(selection) < self.adata.shape[0]:
                raise ValueError("Mask must be same length as adata.")
            selection = np.asarray(np.where(selection)[0].ravel())

        # Sampling loop
        px_scales = []
        batch_ids = []
        for batch_idx in batchid:
            idx = np.random.choice(
                np.arange(self.adata.shape[0])[selection], n_samples)
            px_scales.append(
                self.model_fn(self.adata,
                              indices=idx,
                              transform_batch=batch_idx))
            batch_idx = batch_idx if batch_idx is not None else np.nan
            batch_ids.append([batch_idx] * px_scales[-1].shape[0])
        px_scales = np.concatenate(px_scales)
        batch_ids = np.concatenate(batch_ids).reshape(-1)
        if px_scales.shape[0] != batch_ids.shape[0]:
            raise ValueError(
                "sampled scales and batches have inconsistent shapes")
        if give_mean:
            px_scales = px_scales.mean(0)
        return dict(scale=px_scales, batch=batch_ids)
Пример #13
0
def _get_totalvi_protein_priors(adata, n_cells=100):
    """Compute an empirical prior for protein background."""
    import warnings

    from sklearn.exceptions import ConvergenceWarning
    from sklearn.mixture import GaussianMixture

    warnings.filterwarnings("error")

    batch = get_from_registry(adata, _CONSTANTS.BATCH_KEY).ravel()
    cats = adata.uns["_scvi"]["categorical_mappings"]["_scvi_batch"]["mapping"]
    codes = np.arange(len(cats))

    batch_avg_mus, batch_avg_scales = [], []
    for b in np.unique(codes):
        # can happen during online updates
        # the values of these batches will not be used
        num_in_batch = np.sum(batch == b)
        if num_in_batch == 0:
            batch_avg_mus.append(0)
            batch_avg_scales.append(1)
            continue
        pro_exp = get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY)[batch == b]

        # for missing batches, put dummy values -- scarches case, will be replaced anyway
        if pro_exp.shape[0] == 0:
            batch_avg_mus.append(0.0)
            batch_avg_scales.append(0.05)

        cells = np.random.choice(np.arange(pro_exp.shape[0]), size=n_cells)
        if isinstance(pro_exp, pd.DataFrame):
            pro_exp = pro_exp.to_numpy()
        pro_exp = pro_exp[cells]
        gmm = GaussianMixture(n_components=2)
        mus, scales = [], []
        # fit per cell GMM
        for c in pro_exp:
            try:
                gmm.fit(np.log1p(c.reshape(-1, 1)))
            # when cell is all 0
            except ConvergenceWarning:
                mus.append(0)
                scales.append(0.05)
                continue

            means = gmm.means_.ravel()
            sorted_fg_bg = np.argsort(means)
            mu = means[sorted_fg_bg].ravel()[0]
            covariances = gmm.covariances_[sorted_fg_bg].ravel()[0]
            scale = np.sqrt(covariances)
            mus.append(mu)
            scales.append(scale)

        # average distribution over cells
        batch_avg_mu = np.mean(mus)
        batch_avg_scale = np.sqrt(np.sum(np.square(scales)) / (n_cells ** 2))

        batch_avg_mus.append(batch_avg_mu)
        batch_avg_scales.append(batch_avg_scale)

    # repeat prior for each protein
    batch_avg_mus = np.array(batch_avg_mus, dtype=np.float32).reshape(1, -1)
    batch_avg_scales = np.array(batch_avg_scales, dtype=np.float32).reshape(1, -1)
    batch_avg_mus = np.tile(batch_avg_mus, (pro_exp.shape[1], 1))
    batch_avg_scales = np.tile(batch_avg_scales, (pro_exp.shape[1], 1))

    warnings.resetwarnings()

    return batch_avg_mus, batch_avg_scales
Пример #14
0
def test_setup_anndata():
    # test regular setup
    adata = synthetic_iid(run_setup_anndata=False)
    _setup_anndata(
        adata,
        batch_key="batch",
        labels_key="labels",
        protein_expression_obsm_key="protein_expression",
        protein_names_uns_key="protein_names",
    )
    np.testing.assert_array_equal(
        get_from_registry(adata, "batch_indices"),
        np.array(adata.obs["_scvi_batch"]).reshape((-1, 1)),
    )
    np.testing.assert_array_equal(
        get_from_registry(adata, "labels"),
        np.array(adata.obs["labels"].cat.codes).reshape((-1, 1)),
    )
    np.testing.assert_array_equal(get_from_registry(adata, "X"), adata.X)
    np.testing.assert_array_equal(
        get_from_registry(adata, "protein_expression"),
        adata.obsm["protein_expression"],
    )
    np.testing.assert_array_equal(adata.uns["_scvi"]["protein_names"],
                                  adata.uns["protein_names"])

    # test that error is thrown if its a view:
    adata = synthetic_iid()
    with pytest.raises(ValueError):
        _setup_anndata(adata[1])

    # If obsm is a df and protein_names_uns_key is None, protein names should be grabbed from column of df
    adata = synthetic_iid()
    new_protein_names = np.array(random.sample(range(100), 100)).astype("str")
    df = pd.DataFrame(
        adata.obsm["protein_expression"],
        index=adata.obs_names,
        columns=new_protein_names,
    )
    adata.obsm["protein_expression"] = df
    _setup_anndata(adata, protein_expression_obsm_key="protein_expression")
    np.testing.assert_array_equal(adata.uns["_scvi"]["protein_names"],
                                  new_protein_names)

    # test that layer is working properly
    adata = synthetic_iid()
    true_x = adata.X
    adata.layers["X"] = true_x
    adata.X = np.ones_like(adata.X)
    _setup_anndata(adata, layer="X")
    np.testing.assert_array_equal(get_from_registry(adata, "X"), true_x)

    # test that it creates layers and batch if no layers_key is passed
    adata = synthetic_iid()
    _setup_anndata(
        adata,
        protein_expression_obsm_key="protein_expression",
        protein_names_uns_key="protein_names",
    )
    np.testing.assert_array_equal(get_from_registry(adata, "batch_indices"),
                                  np.zeros((adata.shape[0], 1)))
    np.testing.assert_array_equal(get_from_registry(adata, "labels"),
                                  np.zeros((adata.shape[0], 1)))
Пример #15
0
    def get_bayes_factors(
        self,
        idx1: Union[List[bool], np.ndarray],
        idx2: Union[List[bool], np.ndarray],
        mode: Literal["vanilla", "change"] = "vanilla",
        batchid1: Optional[Sequence[Union[Number, str]]] = None,
        batchid2: Optional[Sequence[Union[Number, str]]] = None,
        use_observed_batches: Optional[bool] = False,
        n_samples: int = 5000,
        use_permutation: bool = False,
        m_permutation: int = 10000,
        change_fn: Optional[Union[str, Callable]] = None,
        m1_domain_fn: Optional[Callable] = None,
        delta: Optional[float] = 0.5,
        pseudocounts: Union[float, None] = 0.0,
        cred_interval_lvls: Optional[Union[List[float], np.ndarray]] = None,
    ) -> Dict[str, np.ndarray]:
        r"""
        A unified method for differential expression inference.

        Two modes coexist:

        - the `"vanilla"` mode follows protocol described in [Lopez18]_ and [Xu21]_
        In this case, we perform hypothesis testing based on the hypotheses

        .. math::
            M_1: h_1 > h_2 ~\text{and}~ M_2: h_1 \leq h_2.

        DE can then be based on the study of the Bayes factors

        .. math::
            \log p(M_1 | x_1, x_2) / p(M_2 | x_1, x_2).

        - the `"change"` mode (described in [Boyeau19]_)
        This mode consists of estimating an effect size random variable (e.g., log fold-change) and
        performing Bayesian hypothesis testing on this variable.
        The `change_fn` function computes the effect size variable :math:`r` based on two inputs
        corresponding to the posterior quantities (e.g., normalized expression) in both populations.

        Hypotheses:

        .. math::
            M_1: r \in R_1 ~\text{(effect size r in region inducing differential expression)}

        .. math::
            M_2: r  \notin R_1 ~\text{(no differential expression)}

        To characterize the region :math:`R_1`, which induces DE, the user has two choices.

        1. A common case is when the region :math:`[-\delta, \delta]` does not induce differential
           expression. If the user specifies a threshold delta, we suppose that :math:`R_1 = \mathbb{R} \setminus [-\delta, \delta]`
        2. Specify an specific indicator function:

        .. math::
            f: \mathbb{R} \mapsto \{0, 1\} ~\text{s.t.}~ r \in R_1 ~\text{iff.}~ f(r) = 1.

        Decision-making can then be based on the estimates of

        .. math::
            p(M_1 \mid x_1, x_2).

        Both modes require to sample the posterior distributions.
        To that purpose, we sample the posterior in the following way:

        1. The posterior is sampled `n_samples` times for each subpopulation.
        2. For computational efficiency (posterior sampling is quite expensive), instead of
           comparing the obtained samples element-wise, we can permute posterior samples.
           Remember that computing the Bayes Factor requires sampling :math:`q(z_A \mid x_A)` and :math:`q(z_B \mid x_B)`.

        Currently, the code covers several batch handling configurations:

        1. If ``use_observed_batches=True``, then batch are considered as observations
           and cells' normalized means are conditioned on real batch observations.
        2. If case (cell group 1) and control (cell group 2) are conditioned on the same
           batch ids. This requires ``set(batchid1) == set(batchid2)`` or ``batchid1 == batchid2 === None``.
        3. If case and control are conditioned on different batch ids that do not intersect
           i.e., ``set(batchid1) != set(batchid2)`` and ``len(set(batchid1).intersection(set(batchid2))) == 0``.

        This function does not cover other cases yet and will warn users in such cases.

        Parameters
        ----------
        mode
            one of ["vanilla", "change"]
        idx1
            bool array masking subpopulation cells 1. Should be True where cell is
            from associated population
        idx2
            bool array masking subpopulation cells 2. Should be True where cell is
            from associated population
        batchid1
            List of batch ids for which you want to perform DE Analysis for
            subpopulation 1. By default, all ids are taken into account
        batchid2
            List of batch ids for which you want to perform DE Analysis for
            subpopulation 2. By default, all ids are taken into account
        use_observed_batches
            Whether posterior values are conditioned on observed
            batches
        n_samples
            Number of posterior samples
        use_permutation
            Activates step 2 described above.
            Simply formulated, pairs obtained from posterior sampling
            will be randomly permuted so that the number of pairs used
            to compute Bayes Factors becomes `m_permutation`.
        m_permutation
            Number of times we will "mix" posterior samples in step 2.
            Only makes sense when `use_permutation=True`
        change_fn
            function computing effect size based on both posterior values
        m1_domain_fn
            custom indicator function of effect size regions
            inducing differential expression
        delta
            specific case of region inducing differential expression.
            In this case, we suppose that :math:`R \setminus [-\delta, \delta]` does not induce differential expression
            (LFC case). If the provided value is `None`, then a proper threshold is determined
            from the distribution of LFCs accross genes.
        pseudocounts
            pseudocount offset used for the mode `change`.
            When None, observations from non-expressed genes are used to estimate its value.
        cred_interval_lvls
            List of credible interval levels to compute for the posterior
            LFC distribution

        Returns
        -------
        Differential expression properties

        """
        # if not np.array_equal(self.indices, np.arange(len(self.dataset))):
        #     warnings.warn(
        #         "Differential expression requires a Posterior object created with all indices."
        #     )
        eps = 1e-8
        # Normalized means sampling for both populations
        scales_batches_1 = self.scale_sampler(
            selection=idx1,
            batchid=batchid1,
            use_observed_batches=use_observed_batches,
            n_samples=n_samples,
        )
        scales_batches_2 = self.scale_sampler(
            selection=idx2,
            batchid=batchid2,
            use_observed_batches=use_observed_batches,
            n_samples=n_samples,
        )

        px_scale_mean1 = scales_batches_1["scale"].mean(axis=0)
        px_scale_mean2 = scales_batches_2["scale"].mean(axis=0)

        # Sampling pairs
        # The objective of code section below is to ensure than the samples of normalized
        # means we consider are conditioned on the same batch id
        batchid1_vals = np.unique(scales_batches_1["batch"])
        batchid2_vals = np.unique(scales_batches_2["batch"])

        create_pairs_from_same_batches = (
            set(batchid1_vals)
            == set(batchid2_vals)) and not use_observed_batches
        if create_pairs_from_same_batches:
            # First case: same batch normalization in two groups
            logger.debug("Same batches in both cell groups")
            n_batches = len(set(batchid1_vals))
            n_samples_per_batch = (m_permutation // n_batches
                                   if m_permutation is not None else None)
            logger.debug("Using {} samples per batch for pair matching".format(
                n_samples_per_batch))
            scales_1 = []
            scales_2 = []
            for batch_val in set(batchid1_vals):
                # Select scale samples that originate from the same batch id
                scales_1_batch = scales_batches_1["scale"][
                    scales_batches_1["batch"] == batch_val]
                scales_2_batch = scales_batches_2["scale"][
                    scales_batches_2["batch"] == batch_val]

                # Create more pairs
                scales_1_local, scales_2_local = pairs_sampler(
                    scales_1_batch,
                    scales_2_batch,
                    use_permutation=use_permutation,
                    m_permutation=n_samples_per_batch,
                )
                scales_1.append(scales_1_local)
                scales_2.append(scales_2_local)
            scales_1 = np.concatenate(scales_1, axis=0)
            scales_2 = np.concatenate(scales_2, axis=0)
        else:
            logger.debug("Ignoring batch conditionings to compare means")
            if len(set(batchid1_vals).intersection(set(batchid2_vals))) >= 1:
                warnings.warn(
                    "Batchids of cells groups 1 and 2 are different but have an non-null "
                    "intersection. Specific handling of such situations is not implemented "
                    "yet and batch correction is not trustworthy.")
            scales_1, scales_2 = pairs_sampler(
                scales_batches_1["scale"],
                scales_batches_2["scale"],
                use_permutation=use_permutation,
                m_permutation=m_permutation,
            )

        # Adding pseudocounts to the scales
        if pseudocounts is None:
            logger.debug("Estimating pseudocounts offet from the data")
            x = get_from_registry(self.adata, _CONSTANTS.X_KEY)
            where_zero_a = densify(np.max(x[idx1], 0)) == 0
            where_zero_b = densify(np.max(x[idx2], 0)) == 0
            pseudocounts = estimate_pseudocounts_offset(
                scales_a=scales_1,
                scales_b=scales_2,
                where_zero_a=where_zero_a,
                where_zero_b=where_zero_b,
            )
        logger.debug("Using pseudocounts ~ {}".format(pseudocounts))
        # Core of function: hypotheses testing based on the posterior samples we obtained above
        if mode == "vanilla":
            logger.debug("Differential expression using vanilla mode")
            proba_m1 = np.mean(scales_1 > scales_2, 0)
            proba_m2 = 1.0 - proba_m1
            res = dict(
                proba_m1=proba_m1,
                proba_m2=proba_m2,
                bayes_factor=np.log(proba_m1 + eps) - np.log(proba_m2 + eps),
                scale1=px_scale_mean1,
                scale2=px_scale_mean2,
            )

        elif mode == "change":
            logger.debug("Differential expression using change mode")

            # step 1: Construct the change function
            def lfc(x, y):
                return np.log2(x + pseudocounts) - np.log2(y + pseudocounts)

            if change_fn == "log-fold" or change_fn is None:
                change_fn = lfc
            elif not callable(change_fn):
                raise ValueError("'change_fn' attribute not understood")

            # step2: Construct the DE area function
            if m1_domain_fn is None:

                def m1_domain_fn(samples):
                    delta_ = (delta if delta is not None else estimate_delta(
                        lfc_means=samples.mean(0)))
                    logger.debug("Using delta ~ {:.2f}".format(delta_))
                    return np.abs(samples) >= delta_

            change_fn_specs = inspect.getfullargspec(change_fn)
            domain_fn_specs = inspect.getfullargspec(m1_domain_fn)
            if (len(change_fn_specs.args) != 2) | (len(domain_fn_specs.args) !=
                                                   1):
                raise ValueError(
                    "change_fn should take exactly two parameters as inputs; m1_domain_fn one parameter."
                )
            try:
                change_distribution = change_fn(scales_1, scales_2)
                is_de = m1_domain_fn(change_distribution)
                delta_ = (estimate_delta(lfc_means=change_distribution.mean(0))
                          if delta is None else delta)
            except TypeError:
                raise TypeError(
                    "change_fn or m1_domain_fn have has wrong properties."
                    "Please ensure that these functions have the right signatures and"
                    "outputs and that they can process numpy arrays")
            proba_m1 = np.mean(is_de, 0)
            change_distribution_props = describe_continuous_distrib(
                samples=change_distribution,
                credible_intervals_levels=cred_interval_lvls,
            )
            change_distribution_props = {
                "lfc_" + key: val
                for (key, val) in change_distribution_props.items()
            }

            res = dict(
                proba_de=proba_m1,
                proba_not_de=1.0 - proba_m1,
                bayes_factor=np.log(proba_m1 + eps) -
                np.log(1.0 - proba_m1 + eps),
                scale1=px_scale_mean1,
                scale2=px_scale_mean2,
                pseudocounts=pseudocounts,
                delta=delta_,
                **change_distribution_props,
            )
        else:
            raise NotImplementedError(
                "Mode {mode} not recognized".format(mode=mode))

        return res