def _calc_sig_scores(data: UnimodalData,
                     signatures: Dict[str, List[str]],
                     show_omitted_genes: bool = False,
                     skip_threshold: int = 1) -> None:
    for key, gene_list in signatures.items():
        genes = pd.Index(gene_list)
        idx = data.var_names.isin(genes)

        omit_string = ""
        nvalid = idx.sum()
        if nvalid < genes.size and show_omitted_genes:
            omitted = ~genes.isin(data.var_names)
            omit_string = f" Genes {str(list(genes[omitted]))[1:-1]} are not in the data and thus omitted."
        logger.info(
            f"Signature {key}: {nvalid} out of {genes.size} genes are used in signature score calculation.{omit_string}"
        )

        if nvalid < skip_threshold:
            logger.warning(
                f"Signature {key} has less than {skip_threshold} genes kept and thus its score calculation is skipped!"
            )
        else:
            if key in data.obs:
                logger.warning(
                    f"Signature key {key} exists in data.obs, the existing content will be overwritten!"
                )

            data.obs[key] = (
                (data.X[:, idx].toarray() - data.var.loc[idx, "mean"].values -
                 data.obsm["sig_bkg_mean"][:, data.var["bins"].cat.codes[idx]])
                / data.obsm["sig_bkg_std"][:, data.var["bins"].cat.codes[idx]]
            ).mean(axis=1).astype(np.float32)
            data.register_attr(key, "signature")
Example #2
0
def down_sampling(rna_gt: UnimodalData,
                  hto_gt: UnimodalData,
                  probs: List[float],
                  n_threads: int = 1):
    f = np.vectorize(
        lambda x, p: np.random.binomial(int(x + 1e-4), p, size=1)[0])

    nsample = rna_gt.shape[0]
    nhto = hto_gt.X.sum()

    fracs = []
    accuracy = []
    for p in probs:
        rna_data = rna_gt.copy()
        hto_data = hto_gt.copy()

        hto_data.X.data = f(hto_data.X.data, p)
        idx = hto_data.X.sum(axis=1).A1 > 0
        hto_data = hto_data[idx, ].copy(deep=False)
        fracs.append(hto_data.X.sum() / nhto)

        estimate_background_probs(hto_data)
        demultiplex(rna_data, hto_data, n_threads=n_threads)
        accuracy.append(
            sum(rna_data.obs["assignment"].values.astype("str") ==
                rna_gt.obs["assignment"].values.astype("str")) / nsample)

    return fracs, accuracy
Example #3
0
def calc_demux(rna_data: UnimodalData,
               hashing_data: UnimodalData,
               nsample: int,
               min_signal: float,
               probs: str = "raw_probs") -> None:
    demux_type = np.full(rna_data.shape[0], "unknown", dtype="object")
    assignments = np.full(rna_data.shape[0], "", dtype="object")

    signals = hashing_data.obs["counts"].reindex(
        rna_data.obs_names,
        fill_value=0.0).values * (1.0 - rna_data.obsm[probs][:, nsample])
    idx = signals >= min_signal

    tmp = rna_data.obsm[probs][idx, ]
    norm_probs = tmp[:, 0:nsample] / (1.0 - tmp[:, nsample])[:, None]

    values1 = []
    values2 = []
    for i in range(norm_probs.shape[0]):
        droplet_type, droplet_id = get_droplet_info(norm_probs[i, ],
                                                    hashing_data.var_names)
        values1.append(droplet_type)
        values2.append(droplet_id)

    demux_type[idx] = values1
    rna_data.obs["demux_type"] = pd.Categorical(
        demux_type, categories=["singlet", "doublet", "unknown"])
    assignments[idx] = values2
    rna_data.obs["assignment"] = pd.Categorical(assignments,
                                                categories=natsorted(
                                                    np.unique(assignments)))
Example #4
0
def _check_and_calc_sig_background(data: UnimodalData, n_bins: int) -> bool:
    if "mean" not in data.var:
        from pegasus.cylib.fast_utils import calc_mean
        data.var["mean"] = calc_mean(data.X, axis=0)

    if data.uns.get("sig_n_bins", 0) != n_bins:
        mean_vec = data.var["mean"].values
        if mean_vec.size <= n_bins:
            logger.error(
                f"Number of bins {n_bins} is larger or equal to the total number of genes {mean_vec.size}! Please adjust n_bins and rerun this function!"
            )
            return False
        data.uns["sig_n_bins"] = n_bins
        try:
            bins = pd.qcut(mean_vec, n_bins)
        except ValueError:
            logger.warning("Detected and dropped duplicate bin edges!")
            bins = pd.qcut(mean_vec, n_bins, duplicates="drop")
        if bins.value_counts().min() == 1:
            logger.warning("Detected bins with only 1 gene!")
        bins.categories = bins.categories.astype(str)
        data.var["bins"] = bins

        # calculate background expectations
        from pegasus.cylib.fast_utils import calc_sig_background
        data.obsm["sig_background"] = calc_sig_background(
            data.X, bins, mean_vec)

    return True
def apply_qc_filters(unidata: UnimodalData):
    """ Apply QC filters to filter out low quality cells """
    if "passed_qc" in unidata.obs:
        prior_n = unidata.shape[0]
        unidata._inplace_subset_obs(unidata.obs["passed_qc"])

        cols = ["passed_qc"]
        if unidata.uns.get("__del_demux_type", False):
            cols.append("demux_type")
            if "assignment" in unidata.obs:
                # remove categories that contain no elements
                series = unidata.obs["assignment"].value_counts(sort=False)
                unidata.obs["assignment"] = pd.Categorical(
                    unidata.obs["assignment"],
                    categories=series[series > 0].index.astype(str))
            # del unidata.uns["__del_demux_type"]

        unidata.obs.drop(columns=cols, inplace=True)
        if len(unidata.obsm) > 0:
            unidata.obsm.clear()
        if len(unidata.varm) > 0:
            unidata.varm.clear()
        for key in list(unidata.uns):
            if key not in {'genome', 'modality', 'norm_count', 'df_qcplot'}:
                del unidata.uns[key]
        logger.info(
            f"After filtration, {unidata.shape[0]} out of {prior_n} cell barcodes are kept in UnimodalData object {unidata.get_uid()}."
        )
Example #6
0
    def concat_data(self, modality: str = "rna"):
        """ Used for raw data, Ignore multiarrays and only consider one matrix per unidata """
        genomes = []
        unidata_arr = []

        for key in list(self.data):
            unidata = self.data.pop(key)
            assert unidata.get_modality() == modality
            genomes.append(unidata.get_genome())
            unidata_arr.append(unidata)

        unikey = None
        if len(genomes) == 1:
            unikey = unidata_arr[0].get_uid()
            self.data[unikey] = unidata_arr[0]
        else:
            genome = ",".join(genomes)
            feature_metadata = pd.concat(
                [unidata.feature_metadata for unidata in unidata_arr], axis=0)
            feature_metadata.reset_index(inplace=True)
            feature_metadata.fillna(value="N/A", inplace=True)
            X = hstack([unidata.matrices["X"] for unidata in unidata_arr],
                       format="csr")
            unidata = UnimodalData(unidata_arr[0].barcode_metadata,
                                   feature_metadata, {"X": X}, {
                                       "genome": genome,
                                       "modality": "rna"
                                   })
            unikey = unidata.get_uid()
            self.data[unikey] = unidata
            del unidata_arr
            gc.collect()

        self._selected = unikey
        self._unidata = self.data[unikey]
Example #7
0
def apply_qc_filters(unidata: UnimodalData):
    """ Apply QC filters to filter out low quality cells """
    if "passed_qc" in unidata.obs:
        prior_n = unidata.shape[0]
        unidata._inplace_subset_obs(unidata.obs["passed_qc"])

        cols = ["passed_qc"]
        if unidata.uns.get("__del_demux_type", False):
            cols.append("demux_type")
            del unidata.uns["__del_demux_type"]

        unidata.obs.drop(columns=cols, inplace=True)
        logger.info(
            f"After filtration, {unidata.shape[0]} out of {prior_n} cell barcodes are kept in UnimodalData object {unidata.get_uid()}."
        )
Example #8
0
def load_10x_h5_file_v2(h5_in: h5py.Group) -> MultimodalData:
    """Load 10x v2 format matrix from hdf5 file

    Parameters
    ----------

    h5_in : h5py.Group
        An instance of h5py.Group class that is connected to a 10x v2 formatted hdf5 file.

    Returns
    -------

    A MultimodalData object containing (genome, UnimodalData) pair per genome.

    Examples
    --------
    >>> io.load_10x_h5_file_v2(h5_in)
    """
    data = MultimodalData()
    for genome in h5_in.keys():
        group = h5_in[genome]

        M, N = group["shape"][...]
        mat = csr_matrix(
            (
                group["data"][...],
                group["indices"][...],
                group["indptr"][...],
            ),
            shape=(N, M),
        )

        barcodes = group["barcodes"][...].astype(str)
        ids = group["genes"][...].astype(str)
        names = group["gene_names"][...].astype(str)

        unidata = UnimodalData({"barcodekey": barcodes}, {
            "featurekey": names,
            "featureid": ids
        }, {"X": mat}, {
            "modality": "rna",
            "genome": genome
        })
        unidata.separate_channels()

        data.add_data(unidata)

    return data
def _generate_filter_plots(unidata: UnimodalData,
                           plot_filt: str,
                           plot_filt_figsize: str = None,
                           min_genes_before_filt: int = 100) -> None:
    """ This function generates filtration plots, only used in command line.
    """
    group_key = unidata.get_uid()

    from pegasus.plotting import qcviolin

    kwargs = {"return_fig": True, "dpi": 500}
    if plot_filt_figsize is not None:
        width, height = plot_filt_figsize.split(",")
        kwargs["panel_size"] = (int(width), int(height))

    fig = qcviolin(unidata, "count", **kwargs)
    fig.savefig(f"{plot_filt}.{group_key}.filt.UMI.pdf")

    fig = qcviolin(unidata, "gene", **kwargs)
    fig.savefig(f"{plot_filt}.{group_key}.filt.gene.pdf")

    fig = qcviolin(unidata, "mito", **kwargs)
    if fig is not None:
        fig.savefig(f"{plot_filt}.{group_key}.filt.mito.pdf")

    logger.info("Filtration plots are generated.")
def _write_mtx(unidata: UnimodalData, output_dir: str, precision: int):
    """ Write Unimodal data to mtx
    """
    try:
        from pegasusio.cylib.io import write_mtx
    except ModuleNotFoundError:
        print("No module named 'pegasusio.cylib.io'")

    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    for key in unidata.list_keys():
        matrix = unidata.matrices[key]
        mtx_file = os.path.join(output_dir, ("matrix" if key == "X" else key) + ".mtx.gz")
        fifo_file = mtx_file + ".fifo"
        if os.path.exists(fifo_file):
            os.unlink(fifo_file)
        os.mkfifo(fifo_file)
        pobj = subprocess.Popen(f"gzip < {shlex.quote(fifo_file)} > {shlex.quote(mtx_file)}", shell = True)
        write_mtx(fifo_file, matrix.data, matrix.indices, matrix.indptr, matrix.shape[0], matrix.shape[1], precision = precision) # matrix is cell x gene csr_matrix, will write as gene x cell
        assert pobj.wait() == 0
        os.unlink(fifo_file)
        logger.info(f"{mtx_file} is written.")

    unidata.barcode_metadata.to_csv(os.path.join(output_dir, "barcodes.tsv.gz"), sep = '\t')
    logger.info("barcodes.tsv.gz is written.")

    unidata.feature_metadata.to_csv(os.path.join(output_dir, "features.tsv.gz"), sep = '\t')
    logger.info("features.tsv.gz is written.")

    logger.info(f"Mtx for {unidata.get_uid()} is written.")
def _write_scp_expression(unidata: UnimodalData, output_name: str, is_sparse: bool, precision: int = 2) -> None:
    """ Only write the main matrix X
    """
    try:
        from pegasusio.cylib.io import write_mtx, write_dense
    except ModuleNotFoundError:
        print("No module named 'pegasusio.cylib.io'")

    matrix = unidata.get_matrix("X")
    if is_sparse:
        barcode_file = f"{output_name}.scp.barcodes.tsv"
        with open(barcode_file, "w") as fout:
            fout.write("\n".join(unidata.obs_names) + "\n")
        logger.info(f"Barcode file {barcode_file} is written.")

        feature_file = f"{output_name}.scp.features.tsv"

        gene_names = unidata.var_names.values
        gene_ids = unidata.var["featureid"].values if "featureid" in unidata.var else (unidata.var["gene_ids"] if "gene_ids" in unidata.var else gene_names)

        df = pd.DataFrame({"gene_names": gene_names, "gene_ids": gene_ids})[["gene_ids", "gene_names"]]
        df.to_csv(feature_file, sep="\t", header=False, index=False)
        logger.info(f"Feature file {feature_file} is written.")

        mtx_file = f"{output_name}.scp.matrix.mtx"
        write_mtx(mtx_file, matrix.data, matrix.indices, matrix.indptr, matrix.shape[0], matrix.shape[1], precision = precision) # matrix is cell x gene csr_matrix, will write as gene x cell
        logger.info(f"Matrix file {mtx_file} is written.")
    else:
        expr_file = f"{output_name}.scp.expr.txt"
        matrix = matrix.T.tocsr() # convert to gene x cell
        write_dense(expr_file, unidata.obs_names.values, unidata.var_names.values, matrix.data, matrix.indices, matrix.indptr, matrix.shape[0], matrix.shape[1], precision = precision)
        logger.info(f"Dense expression file {expr_file} is written.")
Example #12
0
def estimate_background_probs(hashing_data: UnimodalData,
                              random_state: int = 0) -> None:
    """For cell-hashing data, estimate antibody background probability using KMeans algorithm.

    Parameters
    ----------
    hashing_data: ``UnimodalData``
        Annotated data matrix for antibody.

    random_state: ``int``, optional, default: ``0``
        Random seed set for reproducing results.

    Returns
    -------
    ``None``

    Update ``hashing_data.uns``:
        * ``hashing_data.uns["background_probs"]``: estimated antibody background probability.

    Example
    -------
    >>> demuxEM.estimate_background_probs(hashing_data)
    """
    hashing_data.obs["counts"] = hashing_data.X.sum(axis=1).A1
    counts_log10 = np.log10(hashing_data.obs["counts"].values.reshape(-1, 1))
    kmeans = KMeans(n_clusters=2, random_state=random_state).fit(counts_log10)
    signal = 0 if kmeans.cluster_centers_[0] > kmeans.cluster_centers_[1] else 1
    hashing_data.obs["hto_type"] = "background"
    hashing_data.obs.loc[kmeans.labels_ == signal, "hto_type"] = "signal"

    idx = np.isin(hashing_data.obs["hto_type"], "background")
    pvec = hashing_data.X[idx, ].sum(axis=0).A1
    back_probs = pvec / pvec.sum()

    idx = back_probs <= 0.0
    if idx.sum() > 0:
        logger.warning(
            f"Detected {idx.sum()} antibody barcodes {','.join(hashing_data.var_names[idx])} with 0 counts in the background! These barcodes are likely not in the experiment and thus removed."
        )
        hashing_data._inplace_subset_var(~idx)
        back_probs = back_probs[~idx]

    hashing_data.uns["background_probs"] = back_probs
    logger.info("Background probability distribution is estimated.")
Example #13
0
    def __init__(self, unidata: Union[UnimodalData, anndata.AnnData, MultiDataDict] = None, genome: str = None, modality: str = None):
        self._selected = self._unidata = self._zarrobj = None

        if isinstance(unidata, MultiDataDict):
            self.data = unidata
        else:
            self.data = MultiDataDict()
            if unidata is not None:
                if isinstance(unidata, anndata.AnnData):
                    unidata = UnimodalData(unidata, genome = genome, modality = modality)
                self.add_data(unidata)
Example #14
0
 def add_data(self, unidata: UnimodalData) -> None:
     """ Add data, if _selected is not set, set as the first added dataset
     """
     key = unidata.get_uid()
     assert key is not None
     if key in self.data:
         raise ValueError(f"Key '{key}' already exists!")
     self.data[key] = unidata
     if self._selected is None:
         self._selected = key
         self._unidata = unidata
def load_one_mtx_file(path: str, file_name: str, genome: str, modality: str) -> UnimodalData:
    """Load one gene-count matrix in mtx format into a UnimodalData object
    """
    try:
        from pegasusio.cylib.io import read_mtx
    except ModuleNotFoundError:
        print("No module named 'pegasusio.cylib.io'")

    fname = re.sub('(.mtx|.mtx.gz)$', '', file_name)
    barcode_file, feature_file = _locate_barcode_and_feature_files(path, fname)

    barcode_metadata, format_type = _load_barcode_metadata(barcode_file)
    feature_metadata, format_type = _load_feature_metadata(feature_file, format_type)
    logger.info(f"Detected mtx file in {format_type} format.")

    mtx_file = os.path.join(path, file_name)
    if file_name.endswith(".gz"):
        mtx_fifo = os.path.join(tempfile.gettempdir(), file_name + ".fifo")
        if os.path.exists(mtx_fifo):
            os.unlink(mtx_fifo)
        os.mkfifo(mtx_fifo)
        subprocess.Popen(f"gunzip -c {shlex.quote(mtx_file)} > {shlex.quote(mtx_fifo)}", shell = True)
        row_ind, col_ind, data, shape = read_mtx(mtx_fifo)
        os.unlink(mtx_fifo)
    else:
        row_ind, col_ind, data, shape = read_mtx(mtx_file)

    if shape[1] == barcode_metadata.shape[0]: # Column is barcode, swap the coordinates
        row_ind, col_ind = col_ind, row_ind
        shape = (shape[1], shape[0])

    mat = csr_matrix((data, (row_ind, col_ind)), shape = shape)
    mat.eliminate_zeros()

    unidata = UnimodalData(barcode_metadata, feature_metadata, {"X": mat}, {"genome": genome, "modality": modality})
    if format_type == "10x v3" or format_type == "10x v2":
        unidata.separate_channels()

    return unidata
    def write_unimodal_data(self, group: zarr.Group, name: str, data: UnimodalData, overwrite: bool = True) -> None:
        """ Write UnimodalData
        """
        sub_group = group.require_group(name, overwrite = overwrite)
        attrs_dict = {'data_type': 'UnimodalData', '_cur_matrix': data.current_matrix()}
        sub_group.attrs.update(**attrs_dict)

        self.write_dataframe(sub_group, 'barcode_metadata', data.barcode_metadata)
        self.write_dataframe(sub_group, 'feature_metadata', data.feature_metadata)

        if overwrite or data.matrices.is_dirty():
            self.write_mapping(sub_group, 'matrices', data.matrices, overwrite = overwrite)
        if overwrite or data.metadata.is_dirty():
            self.write_mapping(sub_group, 'metadata', data.metadata, overwrite = overwrite)
        if overwrite or data.barcode_multiarrays.is_dirty():
            self.write_mapping(sub_group, 'barcode_multiarrays', data.barcode_multiarrays, overwrite = overwrite)
        if overwrite or data.feature_multiarrays.is_dirty():
            self.write_mapping(sub_group, 'feature_multiarrays', data.feature_multiarrays, overwrite = overwrite)
def pseudobulk(
    data: MultimodalData,
    sample: str,
    attrs: Optional[Union[List[str], str]] = None,
    mat_key: Optional[str] = "counts",
    cluster: Optional[str] = None,
) -> UnimodalData:
    """Generate Pseudo-bulk count matrices.

    Parameters
    -----------
    data: ``MultimodalData`` or ``UnimodalData`` object
        Annotated data matrix with rows for cells and columns for genes.

    sample: ``str``
        Specify the cell attribute used for aggregating pseudo-bulk data.
        Key must exist in ``data.obs``.

    attrs: ``str`` or ``List[str]``, optional, default: ``None``
        Specify additional cell attributes to remain in the pseudo bulk data.
        If set, all attributes' keys must exist in ``data.obs``.
        Notice that for a categorical attribute, each pseudo-bulk's value is the one of highest frequency among its cells,
        and for a numeric attribute, each pseudo-bulk's value is the mean among its cells.

    mat_key: ``str``, optional, default: ``counts``
        Specify the single-cell count matrix used for aggregating pseudo-bulk counts:
        If specified, use the count matrix with key ``mat_key`` from matrices of ``data``; otherwise, default is ``counts``.

    cluster: ``str``, optional, default: ``None``
        If set, additionally generate pseudo-bulk matrices per cluster specified in ``data.obs[cluster]``.

    Returns
    -------
    A UnimodalData object ``udata`` containing pseudo-bulk information:
        * It has the following count matrices:

          * ``X``: The pseudo-bulk count matrix over all cells.
          * If ``cluster`` is set, a number of pseudo-bulk count matrices of cells belonging to the clusters, respectively.
        * ``udata.obs``: It contains pseudo-bulk attributes aggregated from the corresponding single-cell attributes.
        * ``udata.var``: Gene names and Ensembl IDs are maintained.

    Update ``data``:
        * Add the returned UnimodalData object above to ``data`` with key ``<sample>-pseudobulk``, where ``<sample>`` is replaced by the actual value of ``sample`` argument.

    Examples
    --------
    >>> pg.pseudobulk(data, sample="Channel")
    """
    X = data.get_matrix(mat_key)

    assert sample in data.obs.columns, f"Sample key '{sample}' must exist in data.obs!"

    sample_vec = (data.obs[sample] if is_categorical_dtype(data.obs[sample])
                  else data.obs[sample].astype("category"))
    bulk_list = sample_vec.cat.categories

    df_barcode = data.obs.reset_index()

    mat_dict = {
        "counts": get_pseudobulk_count(X, df_barcode, sample, bulk_list)
    }

    # Generate pseudo-bulk attributes if specified
    bulk_attr_list = []

    if attrs is not None:
        if isinstance(attrs, str):
            attrs = [attrs]
        for attr in attrs:
            assert (attr in data.obs.columns
                    ), f"Cell attribute key '{attr}' must exist in data.obs!"

    for bulk in bulk_list:
        df_bulk = df_barcode.loc[df_barcode[sample] == bulk]
        if attrs is not None:
            bulk_attr = df_bulk[attrs].apply(set_bulk_value, axis=0)
            bulk_attr["barcodekey"] = bulk
        else:
            bulk_attr = pd.Series({"barcodekey": bulk})
        bulk_attr_list.append(bulk_attr)

    df_pseudobulk = pd.DataFrame(bulk_attr_list)

    df_feature = pd.DataFrame(index=data.var_names)
    if "featureid" in data.var.columns:
        df_feature["featureid"] = data.var["featureid"]

    if cluster is not None:
        assert (cluster in data.obs.columns
                ), f"Cluster key '{attr}' must exist in data.obs!"

        cluster_list = data.obs[cluster].astype("category").cat.categories
        for cls in cluster_list:
            mat_dict[f"{cluster}_{cls}.X"] = get_pseudobulk_count(
                X, df_barcode.loc[df_barcode[cluster] == cls], sample,
                bulk_list)

    udata = UnimodalData(
        barcode_metadata=df_pseudobulk,
        feature_metadata=df_feature,
        matrices=mat_dict,
        genome=sample,
        modality="pseudobulk",
        cur_matrix="counts",
    )

    data.add_data(udata)

    return udata
def load_csv_file(
    input_csv: str,
    sep: str = ",",
    genome: str = None,
    modality: str = None,
) -> MultimodalData:
    """Load count matrix from a CSV-style file, such as CSV file or DGE style tsv file.

    Parameters
    ----------

    input_csv : `str`
        The CSV file, gzipped or not, containing the count matrix.
    sep: `str`, optional (default: ',')
        Separator between fields, either ',' or '\t'.
    genome : `str`, optional (default None)
        The genome reference. If None, use "unknown" instead.
    modality: `str`, optional (default None)
        Modality. If None, use "rna" instead.

    Returns
    -------

    A MultimodalData object containing a (genome, UnimodalData) pair.

    Examples
    --------
    >>> io.load_csv_file('example_ADT.csv')
    >>> io.load_csv_file('example.umi.dge.txt.gz', genome = 'GRCh38', sep = '\t')
    """
    try:
        from pegasusio.cylib.io import read_csv
    except ModuleNotFoundError:
        print("No module named 'pegasusio.cylib.io'")

    if not os.path.exists(input_csv):
        raise FileNotFoundError(f"File {input_csv} does not exist!")

    barcode_metadata = feature_metadata = None

    input_csv = os.path.abspath(input_csv)
    path = os.path.dirname(input_csv)
    fname = os.path.basename(input_csv)

    barcode_file = os.path.join(path, "cells.csv")
    if not os.path.isfile(barcode_file):
        barcode_file += ".gz"
    feature_file = os.path.join(path, "genes.csv")
    if not os.path.isfile(feature_file):
        feature_file += ".gz"

    if os.path.isfile(barcode_file) and os.path.isfile(feature_file):
        barcode_metadata, format_type = _load_barcode_metadata(barcode_file, sep = sep)
        feature_metadata, format_type = _load_feature_metadata(feature_file, format_type, sep = sep)
        assert format_type == "HCA DCP"

    if input_csv.endswith(".gz"):
        csv_fifo = os.path.join(tempfile.gettempdir(), fname + ".fifo")
        if os.path.exists(csv_fifo):
            os.unlink(csv_fifo)
        os.mkfifo(csv_fifo)
        subprocess.Popen(f"gunzip -c {shlex.quote(input_csv)} > {shlex.quote(csv_fifo)}", shell = True)
        row_ind, col_ind, data, shape, rowkey, rownames, colnames = read_csv(csv_fifo, sep)
        os.unlink(csv_fifo)
    else:
        row_ind, col_ind, data, shape, rowkey, rownames, colnames = read_csv(input_csv, sep)

    if rowkey == "cellkey":
        # HCA format
        assert (barcode_metadata is not None) and (feature_metadata is not None) and (barcode_metadata.shape[0] == shape[0]) and (feature_metadata.shape[0] == shape[1]) and \
               ((barcode_metadata["barcodekey"].values != np.array(rownames)).sum() == 0) and ((feature_metadata["featureid"].values != np.array(colnames)).sum() == 0)
        mat = csr_matrix((data, (row_ind, col_ind)), shape = shape)
    else:
        mat = csr_matrix((data, (col_ind, row_ind)), shape = (shape[1], shape[0]))
        if barcode_metadata is None:
            barcode_metadata = {"barcodekey": colnames}
        else:
            assert (barcode_metadata.shape[0] == shape[1]) and ((barcode_metadata["barcodekey"].values != np.array(colnames)).sum() == 0)
        if feature_metadata is None:
            feature_metadata = {"featurekey": rownames}
        else:
            assert (feature_metadata.shape[0] == shape[0]) and ((feature_metadata["featurekey"].values != np.array(rownames)).sum() == 0)

    genome = genome if genome is not None else "unknown"
    modality = modality if modality is not None else "rna"

    if modality == "citeseq":
        unidata = CITESeqData(barcode_metadata, feature_metadata, {"raw.count": mat}, {"genome": genome, "modality": modality})
    else:
        unidata = UnimodalData(barcode_metadata, feature_metadata, {"X": mat}, {"genome": genome, "modality": modality})

    data = MultimodalData(unidata)

    return data
Example #19
0
    def _aggregate_unidata(self, unilist: List[UnimodalData]) -> UnimodalData:
        if len(unilist) == 1:
            del unilist[0].metadata["_sample"]
            return unilist[0]

        modality = unilist[0].get_modality()

        barcode_metadata_dfs = [
            unidata.barcode_metadata for unidata in unilist
        ]
        barcode_metadata = pd.concat(barcode_metadata_dfs,
                                     axis=0,
                                     sort=False,
                                     copy=False)
        fillna_dict = _get_fillna_dict(barcode_metadata)
        barcode_metadata.fillna(value=fillna_dict, inplace=True)

        var_dict = {}
        for unidata in unilist:
            if modality == "citeseq":
                unidata.feature_metadata.drop(
                    columns=CITESeqData._var_keywords, inplace=True)
            elif modality == "cyto":
                unidata.feature_metadata.drop(columns=CytoData._var_keywords,
                                              inplace=True)

            idx = unidata.feature_metadata.columns.difference(["featureid"])
            if idx.size > 0:
                var_dict[unidata.
                         metadata["_sample"]] = unidata.feature_metadata[idx]
                unidata.feature_metadata.drop(columns=idx, inplace=True)

        feature_metadata = unilist[0].feature_metadata
        for other in unilist[1:]:
            keys = ["featurekey"] + feature_metadata.columns.intersection(
                other.feature_metadata.columns).values.tolist()
            feature_metadata = feature_metadata.merge(
                other.feature_metadata,
                on=keys,
                how="outer",
                sort=False,
                copy=False
            )  # If sort is True, feature keys will be changed even if all channels share the same feature keys.
        fillna_dict = _get_fillna_dict(feature_metadata)
        feature_metadata.fillna(value=fillna_dict, inplace=True)

        matrices = self._merge_matrices(feature_metadata, unilist, modality)

        uns_dict = {}
        metadata = {
            "genome": unilist[0].metadata["genome"],
            "modality": unilist[0].metadata["modality"]
        }
        for unidata in unilist:
            assert unidata.metadata.pop("genome") == metadata["genome"]
            assert unidata.metadata.pop("modality") == metadata["modality"]
            if modality == "citeseq":
                for key in CITESeqData._uns_keywords:
                    unidata.metadata.pop(key, None)
                del unidata.metadata["_obs_keys"]
            elif modality == "cyto":
                for key in CytoData._uns_keywords:
                    unidata.metadata.pop(key, None)
            sample_name = unidata.metadata.pop("_sample")
            if len(unidata.metadata) > 0:
                uns_dict[sample_name] = unidata.metadata.mapping

        if len(var_dict) > 0:
            metadata["var_dict"] = var_dict
        if len(uns_dict) > 0:
            metadata["uns_dict"] = uns_dict

        unidata = None
        if isinstance(unilist[0], CITESeqData):
            unidata = CITESeqData(barcode_metadata, feature_metadata, matrices,
                                  metadata)
        elif isinstance(unilist[0], CytoData):
            unidata = CytoData(barcode_metadata, feature_metadata, matrices,
                               metadata)
        elif isinstance(unilist[0], VDJData):
            self._vdj_update_metadata_matrices(metadata, matrices, unilist)
            unidata = VDJData(barcode_metadata, feature_metadata, matrices,
                              metadata)
        else:
            unidata = UnimodalData(barcode_metadata, feature_metadata,
                                   matrices, metadata)

        return unidata
Example #20
0
def infer_doublets(
    data: MultimodalData,
    channel_attr: Optional[str] = None,
    clust_attr: Optional[str] = None,
    min_cell: Optional[int] = 100,
    expected_doublet_rate: Optional[float] = None,
    sim_doublet_ratio: Optional[float] = 2.0,
    n_prin_comps: Optional[int] = 30,
    robust: Optional[bool] = False,
    k: Optional[int] = None,
    n_jobs: Optional[int] = -1,
    alpha: Optional[float] = 0.05,
    random_state: Optional[int] = 0,
    plot_hist: Optional[str] = "dbl",
) -> None:
    """Infer doublets using a Scrublet-like strategy. [Li20-2]_

    This function must be called after clustering. 

    Parameters
    ----------
    data: ``pegasusio.MultimodalData``
        Annotated data matrix with rows for cells and columns for genes.

    channel_attr: ``str``, optional, default: None
        Attribute indicating sample channels. If set, calculate scrublet-like doublet scores per channel.

    clust_attr: ``str``, optional, default: None
        Attribute indicating cluster labels. If set, estimate proportion of doublets in each cluster and statistical significance.

    min_cell: ``int``, optional, default: 100
        Minimum number of cells per sample to calculate doublet scores. For samples having less than 'min_cell' cells, doublet score calculation will be skipped.

    expected_doublet_rate: ``float``, optional, default: ``None``
        The expected doublet rate for the experiment. By default, calculate the expected rate based on number of cells from the 10x multiplet rate table

    sim_doublet_ratio: ``float``, optional, default: ``2.0``
        The ratio between synthetic doublets and observed cells.

    n_prin_comps: ``int``, optional, default: ``30``
        Number of principal components.

    robust: ``bool``, optional, default: ``False``.
        If true, use 'arpack' instead of 'randomized' for large matrices (i.e. max(X.shape) > 500 and n_components < 0.8 * min(X.shape))

    k: ``int``, optional, default: ``None``
        Number of observed cell neighbors. If None, k = round(0.5 * sqrt(number of observed cells)). Total neighbors k_adj = round(k * (1.0 + sim_doublet_ratio)).

    n_job: ``int``, optional, default: ``-``
        Number of threads to use. If ``-1``, use all available threads.

    alpha: ``float``, optional, default: ``0.05``
        FDR significant level for cluster-level fisher exact test.

    random_state: ``int``, optional, default: ``0``
        Random seed for reproducing results.

    plot_hist: ``str``, optional, default: ``dbl``
        If not None, plot diagnostic histograms using ``plot_hist`` as the prefix. If `channel_attr` is None, ``plot_hist.png`` is generated; Otherwise, ``plot_hist.channel_name.png`` files are generated.

    Returns
    -------
    ``None``

    Update ``data.obs``:
        * ``data.obs['pred_dbl_type']``: Predicted singlet/doublet types.

        * ``data.uns['pred_dbl_cluster']``: Only generated if 'clust_attr' is not None. This is a dataframe with two columns, 'Cluster' and 'Qval'. Only clusters with significantly more doublets than expected will be recorded here.

    Examples
    --------
    >>> pg.infer_doublets(data, channel_attr = 'Channel', clust_attr = 'Annotation')
    """
    assert data.get_modality() == "rna"
    try:
        rawX = data.get_matrix("raw.X")
    except ValueError:
        raise ValueError(
            "Cannot detect the raw count matrix raw.X; stop inferring doublets!"
        )

    if_plot = plot_hist is not None

    if channel_attr is None:
        if data.shape[0] >= min_cell:
            fig = _run_scrublet(data, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
                                n_prin_comps = n_prin_comps, robust = robust, k = k, n_jobs = n_jobs, random_state = random_state, \
                                plot_hist = if_plot)
            if if_plot:
                fig.savefig(f"{plot_hist}.png")
        else:
            logger.warning(
                f"Data has {data.shape[0]} < {min_cell} cells and thus doublet score calculation is skipped!"
            )
            data.obs["doublet_score"] = 0.0
            data.obs["pred_dbl"] = False
    else:
        from pandas.api.types import is_categorical_dtype
        from pegasus.tools import identify_robust_genes, log_norm, highly_variable_features

        assert is_categorical_dtype(data.obs[channel_attr])
        genome = data.get_genome()
        modality = data.get_modality()
        channels = data.obs[channel_attr].cat.categories

        dbl_score = np.zeros(data.shape[0], dtype=np.float32)
        pred_dbl = np.zeros(data.shape[0], dtype=np.bool_)
        thresholds = {}
        for channel in channels:
            # Generate a new unidata object for the channel
            idx = np.where(data.obs[channel_attr] == channel)[0]
            if idx.size >= min_cell:
                unidata = UnimodalData({"barcodekey": data.obs_names[idx]},
                                       {"featurekey": data.var_names},
                                       {"X": rawX[idx]}, {
                                           "genome": genome,
                                           "modality": modality
                                       })
                # Identify robust genes, count and log normalized and select top 2,000 highly variable features
                identify_robust_genes(unidata)
                log_norm(unidata)
                highly_variable_features(unidata)
                # Run _run_scrublet
                fig = _run_scrublet(unidata, name = channel, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
                                    n_prin_comps = n_prin_comps, robust = robust, k = k, n_jobs = n_jobs, random_state = random_state, \
                                    plot_hist = if_plot)
                if if_plot:
                    fig.savefig(f"{plot_hist}.{channel}.png")

                dbl_score[idx] = unidata.obs["doublet_score"].values
                pred_dbl[idx] = unidata.obs["pred_dbl"].values
                thresholds[channel] = unidata.uns["doublet_threshold"]
            else:
                logger.warning(
                    f"Channel {channel} has {idx.size} < {min_cell} cells and thus doublet score calculation is skipped!"
                )

        data.obs["doublet_score"] = dbl_score
        data.obs["pred_dbl"] = pred_dbl
        data.uns["doublet_thresholds"] = thresholds

    if clust_attr is not None:
        data.uns["pred_dbl_cluster"] = _identify_doublets_fisher(
            data.obs[clust_attr].values,
            data.obs["pred_dbl"].values,
            alpha=alpha)

    logger.info('Doublets are predicted!')
def analyze_one_modality(unidata: UnimodalData, output_name: str, is_raw: bool,
                         append_data: UnimodalData, **kwargs) -> None:
    print()
    logger.info(f"Begin to analyze UnimodalData {unidata.get_uid()}.")

    if is_raw:
        # normailize counts and then transform to log space
        tools.log_norm(unidata, kwargs["norm_count"])

        # select highly variable features
        standardize = False  # if no select HVF, False
        if kwargs["select_hvf"]:
            if unidata.shape[1] <= kwargs["hvf_ngenes"]:
                logger.warning(
                    f"Number of genes {unidata.shape[1]} is no greater than the target number of highly variable features {kwargs['hvf_ngenes']}. HVF selection is omitted."
                )
            else:
                standardize = True
                tools.highly_variable_features(
                    unidata,
                    kwargs["batch_attr"]
                    if kwargs["batch_correction"] else None,
                    flavor=kwargs["hvf_flavor"],
                    n_top=kwargs["hvf_ngenes"],
                    n_jobs=kwargs["n_jobs"],
                )
                if kwargs["hvf_flavor"] == "pegasus":
                    if kwargs["plot_hvf"] is not None:
                        from pegasus.plotting import hvfplot
                        fig = hvfplot(unidata, return_fig=True)
                        fig.savefig(f"{kwargs['plot_hvf']}.hvf.pdf")

        n_pc = min(kwargs["pca_n"], unidata.shape[0], unidata.shape[1])
        if n_pc < kwargs["pca_n"]:
            logger.warning(
                f"UnimodalData {unidata.get_uid()} has either dimension ({unidata.shape[0]}, {unidata.shape[1]}) less than the specified number of PCs {kwargs['pca_n']}. Reduce the number of PCs to {n_pc}."
            )

        # Run PCA irrespect of which batch correction method would apply
        tools.pca(
            unidata,
            n_components=n_pc,
            features="highly_variable_features",
            standardize=standardize,
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
        )
        dim_key = "pca"

        if kwargs["nmf"] or (kwargs["batch_correction"]
                             and kwargs["correction_method"] == "inmf"):
            n_nmf = min(kwargs["nmf_n"], unidata.shape[0], unidata.shape[1])
            if n_nmf < kwargs["nmf_n"]:
                logger.warning(
                    f"UnimodalData {unidata.get_uid()} has either dimension ({unidata.shape[0]}, {unidata.shape[1]}) less than the specified number of NMF components {kwargs['nmf_n']}. Reduce the number of NMF components to {n_nmf}."
                )

        if kwargs["nmf"]:
            if kwargs["batch_correction"] and kwargs[
                    "correction_method"] == "inmf":
                logger.warning(
                    "NMF is skipped because integrative NMF is run instead.")
            else:
                tools.nmf(
                    unidata,
                    n_components=n_nmf,
                    features="highly_variable_features",
                    n_jobs=kwargs["n_jobs"],
                    random_state=kwargs["random_state"],
                )

        if kwargs["batch_correction"]:
            if kwargs["correction_method"] == "harmony":
                dim_key = tools.run_harmony(
                    unidata,
                    batch=kwargs["batch_attr"],
                    rep="pca",
                    n_jobs=kwargs["n_jobs"],
                    n_clusters=kwargs["harmony_nclusters"],
                    random_state=kwargs["random_state"])
            elif kwargs["correction_method"] == "inmf":
                dim_key = tools.integrative_nmf(
                    unidata,
                    batch=kwargs["batch_attr"],
                    n_components=n_nmf,
                    features="highly_variable_features",
                    lam=kwargs["inmf_lambda"],
                    n_jobs=kwargs["n_jobs"],
                    random_state=kwargs["random_state"])
            elif kwargs["correction_method"] == "scanorama":
                dim_key = tools.run_scanorama(
                    unidata,
                    batch=kwargs["batch_attr"],
                    n_components=n_pc,
                    features="highly_variable_features",
                    standardize=standardize,
                    random_state=kwargs["random_state"])
            else:
                raise ValueError(
                    f"Unknown batch correction method {kwargs['correction_method']}!"
                )

        # Find K neighbors
        tools.neighbors(
            unidata,
            K=kwargs["K"],
            rep=dim_key,
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            full_speed=kwargs["full_speed"],
        )

    if kwargs["calc_sigscore"] is not None:
        sig_files = kwargs["calc_sigscore"].split(",")
        for sig_file in sig_files:
            tools.calc_signature_score(unidata, sig_file)

    # calculate diffmap
    if (kwargs["fle"] or kwargs["net_fle"]):
        if not kwargs["diffmap"]:
            print("Turn on --diffmap option!")
        kwargs["diffmap"] = True

    if kwargs["diffmap"]:
        tools.diffmap(
            unidata,
            n_components=kwargs["diffmap_ndc"],
            rep=dim_key,
            solver=kwargs["diffmap_solver"],
            max_t=kwargs["diffmap_maxt"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
        )

    # calculate kBET
    if ("kBET" in kwargs) and kwargs["kBET"]:
        stat_mean, pvalue_mean, accept_rate = tools.calc_kBET(
            unidata,
            kwargs["kBET_batch"],
            rep=dim_key,
            K=kwargs["kBET_K"],
            alpha=kwargs["kBET_alpha"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"])
        print(
            "kBET stat_mean = {:.2f}, pvalue_mean = {:.4f}, accept_rate = {:.2%}."
            .format(stat_mean, pvalue_mean, accept_rate))

    # clustering
    if kwargs["spectral_louvain"]:
        tools.cluster(
            unidata,
            algo="spectral_louvain",
            rep=dim_key,
            resolution=kwargs["spectral_louvain_resolution"],
            rep_kmeans=kwargs["spectral_louvain_basis"],
            n_clusters=kwargs["spectral_louvain_nclusters"],
            n_clusters2=kwargs["spectral_louvain_nclusters2"],
            n_init=kwargs["spectral_louvain_ninit"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            class_label="spectral_louvain_labels",
        )

    if kwargs["spectral_leiden"]:
        tools.cluster(
            unidata,
            algo="spectral_leiden",
            rep=dim_key,
            resolution=kwargs["spectral_leiden_resolution"],
            rep_kmeans=kwargs["spectral_leiden_basis"],
            n_clusters=kwargs["spectral_leiden_nclusters"],
            n_clusters2=kwargs["spectral_leiden_nclusters2"],
            n_init=kwargs["spectral_leiden_ninit"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            class_label="spectral_leiden_labels",
        )

    if kwargs["louvain"]:
        tools.cluster(
            unidata,
            algo="louvain",
            rep=dim_key,
            resolution=kwargs["louvain_resolution"],
            random_state=kwargs["random_state"],
            class_label=kwargs["louvain_class_label"],
        )

    if kwargs["leiden"]:
        tools.cluster(
            unidata,
            algo="leiden",
            rep=dim_key,
            resolution=kwargs["leiden_resolution"],
            n_iter=kwargs["leiden_niter"],
            random_state=kwargs["random_state"],
            class_label=kwargs["leiden_class_label"],
        )

    # visualization
    if kwargs["net_umap"]:
        tools.net_umap(
            unidata,
            rep=dim_key,
            n_jobs=kwargs["n_jobs"],
            n_neighbors=kwargs["umap_K"],
            min_dist=kwargs["umap_min_dist"],
            spread=kwargs["umap_spread"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            full_speed=kwargs["full_speed"],
            net_alpha=kwargs["net_l2"],
            polish_learning_rate=kwargs["net_umap_polish_learing_rate"],
            polish_n_epochs=kwargs["net_umap_polish_nepochs"],
            out_basis=kwargs["net_umap_basis"],
        )

    if kwargs["net_fle"]:
        tools.net_fle(
            unidata,
            output_name,
            n_jobs=kwargs["n_jobs"],
            K=kwargs["fle_K"],
            full_speed=kwargs["full_speed"],
            target_change_per_node=kwargs["fle_target_change_per_node"],
            target_steps=kwargs["fle_target_steps"],
            is3d=False,
            memory=kwargs["fle_memory"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            net_alpha=kwargs["net_l2"],
            polish_target_steps=kwargs["net_fle_polish_target_steps"],
            out_basis=kwargs["net_fle_basis"],
        )

    if kwargs["tsne"]:
        tools.tsne(
            unidata,
            rep=dim_key,
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
            initialization=kwargs["tsne_init"],
        )

    if kwargs["umap"]:
        tools.umap(
            unidata,
            rep=dim_key,
            n_neighbors=kwargs["umap_K"],
            min_dist=kwargs["umap_min_dist"],
            spread=kwargs["umap_spread"],
            n_jobs=kwargs["n_jobs"],
            full_speed=kwargs["full_speed"],
            random_state=kwargs["random_state"],
        )

    if kwargs["fle"]:
        tools.fle(
            unidata,
            output_name,
            n_jobs=kwargs["n_jobs"],
            K=kwargs["fle_K"],
            full_speed=kwargs["full_speed"],
            target_change_per_node=kwargs["fle_target_change_per_node"],
            target_steps=kwargs["fle_target_steps"],
            is3d=False,
            memory=kwargs["fle_memory"],
            random_state=kwargs["random_state"],
        )

    if kwargs["infer_doublets"]:
        channel_attr = "Channel"
        if (channel_attr not in unidata.obs) or (
                unidata.obs["Channel"].cat.categories.size == 1):
            channel_attr = None
        clust_attr = kwargs["dbl_cluster_attr"]
        if (clust_attr is None) or (clust_attr not in unidata.obs):
            clust_attr = None
            for value in [
                    "leiden_labels", "louvain_labels",
                    "spectral_leiden_labels", "spectral_louvain_labels"
            ]:
                if value in unidata.obs:
                    clust_attr = value
                    break

        if channel_attr is not None:
            logger.info(f"For doublet inference, channel_attr={channel_attr}.")
        if clust_attr is not None:
            logger.info(f"For doublet inference, clust_attr={clust_attr}.")

        tools.infer_doublets(
            unidata,
            channel_attr=channel_attr,
            clust_attr=clust_attr,
            expected_doublet_rate=kwargs["expected_doublet_rate"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            plot_hist=output_name)

        dbl_clusts = None
        if clust_attr is not None:
            clusts = []
            for idx, row in unidata.uns["pred_dbl_cluster"].iterrows():
                if row["percentage"] >= 50.0:
                    logger.info(
                        f"Cluster {row['cluster']} (percentage={row['percentage']:.2f}%, q-value={row['qval']:.6g}) is identified as a doublet cluster."
                    )
                    clusts.append(row["cluster"])
            if len(clusts) > 0:
                dbl_clusts = f"{clust_attr}:{','.join(clusts)}"

        tools.mark_doublets(unidata, dbl_clusts=dbl_clusts)

    # calculate diffusion-based pseudotime from roots
    if len(kwargs["pseudotime"]) > 0:
        tools.calc_pseudotime(unidata, kwargs["pseudotime"])

    genome = unidata.uns["genome"]

    if append_data is not None:
        locs = unidata.obs_names.get_indexer(append_data.obs_names)
        idx = locs >= 0
        locs = locs[idx]
        Y = append_data.X[idx, :].tocoo(copy=False)
        Z = coo_matrix((Y.data, (locs[Y.row], Y.col)),
                       shape=(unidata.shape[0], append_data.shape[1])).tocsr()

        idy = Z.getnnz(axis=0) > 0
        n_nonzero = idy.sum()
        if n_nonzero > 0:
            if n_nonzero < append_data.shape[1]:
                Z = Z[:, idy]
                append_df = append_data.feature_metadata.loc[idy, :]
            else:
                append_df = append_data.feature_metadata

            if kwargs["citeseq"]:
                append_df = append_df.copy()
                append_df.index = append_df.index.map(lambda x: f"Ab-{x}")

            rawX = hstack([unidata.get_matrix("counts"), Z], format="csr")

            Zt = Z.astype(np.float32)
            if not kwargs["citeseq"]:
                Zt.data *= np.repeat(unidata.obs["scale"].values,
                                     np.diff(Zt.indptr))
                Zt.data = np.log1p(Zt.data)
            else:
                Zt.data = np.arcsinh(Zt.data / 5.0, dtype=np.float32)

            X = hstack([unidata.get_matrix(unidata.current_matrix()), Zt],
                       format="csr")

            new_genome = unidata.get_genome()
            if new_genome != append_data.get_genome():
                new_genome = f"{new_genome}_and_{append_data.get_genome()}"

            feature_metadata = pd.concat([unidata.feature_metadata, append_df],
                                         axis=0)
            feature_metadata.reset_index(inplace=True)
            _fillna(feature_metadata)
            unidata = UnimodalData(
                unidata.barcode_metadata, feature_metadata, {
                    unidata.current_matrix(): X,
                    "counts": rawX
                }, unidata.uns.mapping, unidata.obsm.mapping,
                unidata.varm.mapping
            )  # uns.mapping, obsm.mapping and varm.mapping are passed by reference
            unidata.uns["genome"] = new_genome

            if kwargs["citeseq"] and kwargs["citeseq_umap"]:
                umap_index = append_df.index.difference(
                    [f"Ab-{x}" for x in kwargs["citeseq_umap_exclude"]])
                unidata.obsm["X_citeseq"] = unidata.X[:,
                                                      unidata.var_names.
                                                      isin(umap_index
                                                           )].toarray()
                tools.umap(
                    unidata,
                    rep="citeseq",
                    n_neighbors=kwargs["umap_K"],
                    min_dist=kwargs["umap_min_dist"],
                    spread=kwargs["umap_spread"],
                    n_jobs=kwargs["n_jobs"],
                    full_speed=kwargs["full_speed"],
                    random_state=kwargs["random_state"],
                    out_basis="citeseq_umap",
                )

    if kwargs["output_h5ad"]:
        import time
        start_time = time.perf_counter()
        adata = unidata.to_anndata()
        if "_tmp_fmat_highly_variable_features" in adata.uns:
            adata.uns["scale.data"] = adata.uns.pop(
                "_tmp_fmat_highly_variable_features")  # assign by reference
            adata.uns["scale.data.rownames"] = unidata.var_names[
                unidata.var["highly_variable_features"] == True].values
        adata.write(f"{output_name}.h5ad", compression="gzip")
        del adata
        end_time = time.perf_counter()
        logger.info(
            f"H5AD file {output_name}.h5ad is written. Time spent = {end_time - start_time:.2f}s."
        )

    # write out results
    if kwargs["output_loom"]:
        write_output(unidata, f"{output_name}.loom")

    # Change genome name back if append_data is True
    if unidata.uns["genome"] != genome:
        unidata.uns["genome"] = genome
    # Eliminate objects starting with _tmp from uns
    unidata.uns.pop("_tmp_fmat_highly_variable_features", None)
Example #22
0
def load_10x_h5_file_v3(h5_in: h5py.Group) -> MultimodalData:
    """Load 10x v3 format matrix from hdf5 file, allowing detection of crispr and citeseq libraries

    Parameters
    ----------

    h5_in : h5py.Group
        An instance of h5py.Group class that is connected to a 10x v3 formatted hdf5 file.

    Returns
    -------

    A MultimodalData object containing (genome, UnimodalData) pair per genome.

    Examples
    --------
    >>> io.load_10x_h5_file_v3(h5_in)
    """
    M, N = h5_in["matrix/shape"][...]
    bigmat = csr_matrix(
        (
            h5_in["matrix/data"][...],
            h5_in["matrix/indices"][...],
            h5_in["matrix/indptr"][...],
        ),
        shape=(N, M),
    )
    barcodes = h5_in["matrix/barcodes"][...].astype(str)
    df = pd.DataFrame(
        data={
            "genome": h5_in["matrix/features/genome"][...].astype(str),
            "feature_type": h5_in["matrix/features/feature_type"][...].astype(
                str),
            "id": h5_in["matrix/features/id"][...].astype(str),
            "name": h5_in["matrix/features/name"][...].astype(str)
        })

    genomes = list(df["genome"].unique())
    if "" in genomes:
        genomes.remove("")
    default_genome = genomes[0] if len(genomes) == 1 else None

    data = MultimodalData()
    gb = df.groupby(by=["genome", "feature_type"])
    for name, group in gb:
        barcode_metadata = {"barcodekey": barcodes}
        feature_metadata = {
            "featurekey": group["name"].values,
            "featureid": group["id"].values
        }
        mat = bigmat[:, gb.groups[name]]

        genome = name[0] if (name[0] != ""
                             or default_genome is None) else default_genome
        modality = "custom"
        if name[1] == "Gene Expression":
            modality = "rna"
        elif name[1] == "CRISPR Guide Capture":
            modality = "crispr"
        elif name[1] == "Antibody Capture":
            modality = "citeseq"

        if modality == "citeseq":
            unidata = CITESeqData(barcode_metadata, feature_metadata,
                                  {"raw.count": mat}, {
                                      "genome": genome,
                                      "modality": modality
                                  })
        else:
            unidata = UnimodalData(barcode_metadata, feature_metadata,
                                   {"X": mat}, {
                                       "genome": genome,
                                       "modality": modality
                                   })
        unidata.separate_channels()

        data.add_data(unidata)

    return data
Example #23
0
def load_loom_file(input_loom: str,
                   genome: str = None,
                   modality: str = None) -> MultimodalData:
    """Load count matrix from a LOOM file.

    Parameters
    ----------

    input_loom : `str`
        The LOOM file, containing the count matrix.
    genome : `str`, optional (default None)
        The genome reference. If None, use "unknown" instead. If not None and input loom contains genome attribute, the attribute will be overwritten.
    modality: `str`, optional (default None)
        Modality. If None, use "rna" instead. If not None and input loom contains modality attribute, the attribute will be overwritten.

    Returns
    -------

    A MultimodalData object containing a (genome, UmimodalData) pair.

    Examples
    --------
    >>> io.load_loom_file('example.loom', genome = 'GRCh38')
    """
    col_trans = {"CellID": "barcodekey", "obs_names": "barcodekey"}
    row_trans = {
        "Gene": "featurekey",
        "var_names": "featurekey",
        "Accession": "featureid",
        "gene_ids": "featureid"
    }

    import loompy
    with loompy.connect(input_loom) as ds:
        barcode_metadata = {}
        barcode_multiarrays = {}
        for key, arr in ds.col_attrs.items():
            key = col_trans.get(key, key)
            if arr.ndim == 1:
                barcode_metadata[key] = arr
            elif arr.ndim > 1:
                barcode_multiarrays[key] = arr
            else:
                raise ValueError(
                    f"Detected column attribute '{key}' has ndim = {arr.ndim}!"
                )

        feature_metadata = {}
        feature_multiarrays = {}
        for key, arr in ds.row_attrs.items():
            key = row_trans.get(key, key)
            if arr.ndim == 1:
                feature_metadata[key] = arr
            elif arr.ndim > 1:
                feature_multiarrays[key] = arr
            else:
                raise ValueError(
                    f"Detected row attribute '{key}' has ndim = {arr.ndim}!")

        matrices = {}
        for key, mat in ds.layers.items():
            key = "X" if key == "" else key
            matrices[key] = mat.sparse().T.tocsr()

        metadata = dict(ds.attrs)
        if genome is not None:
            metadata["genome"] = genome
        elif "genome" not in metadata:
            metadata["genome"] = "unknown"

        if modality is not None:
            metadata["modality"] = modality
        elif "modality" not in metadata:
            if metadata.get("experiment_type", "none") in modalities:
                metadata["modality"] = metadata.pop("experiment_type")
            else:
                metadata["modality"] = "rna"

        unidata = UnimodalData(barcode_metadata, feature_metadata, matrices,
                               metadata, barcode_multiarrays,
                               feature_multiarrays)

    data = MultimodalData(unidata)
    return data
Example #24
0
def demultiplex(
    rna_data: UnimodalData,
    hashing_data: UnimodalData,
    min_signal: float = 10.0,
    alpha: float = 0.0,
    alpha_noise: float = 1.0,
    tol: float = 1e-6,
    n_threads: int = 1,
):
    """Demultiplexing cell/nucleus-hashing data, using the estimated antibody background probability calculated in ``demuxEM.estimate_background_probs``.

    Parameters
    ----------
    rna_data: ``UnimodalData``
        Data matrix for gene expression matrix.

    hashing_data: ``UnimodalData``
        Data matrix for HTO count matrix.

    min_signal: ``float``, optional, default: ``10.0``
        Any cell/nucleus with less than ``min_signal`` hashtags from the signal will be marked as ``unknown``.

    alpha: ``float``, optional, default: ``0.0``
        The Dirichlet prior concentration parameter (alpha) on samples. An alpha value < 1.0 will make the prior sparse.

    alpha_noise: ``float``, optional, default: ``1.0``
        The Dirichlet prior concenration parameter on the background noise.

    tol: ``float``, optional, default: ``1e-6``
        Threshold used for the EM convergence.

    n_threads: ``int``, optional, default: ``1``
        Number of threads to use. Must be a positive integer.

    Returns
    -------
    ``None``

    Update ``data.obs``:
        * ``data.obs["demux_type"]``: Demultiplexed types of the cells. Either ``singlet``, ``doublet``, or ``unknown``.
        * ``data.obs["assignment"]``: Assigned samples of origin for each cell barcode.
        * ``data.obs["assignment.dedup"]``: Only exist if one sample name can correspond to multiple feature barcodes. In this case, each feature barcode is assigned a unique sample name.

    Examples
    --------
    >>> demuxEM.demultiplex(rna_data, hashing_data)
    """
    nsample = hashing_data.shape[1]
    rna_data.uns["background_probs"] = hashing_data.uns["background_probs"]
    assert (rna_data.uns["background_probs"] <= 0.0).sum() == 0

    idx_df = rna_data.obs_names.isin(hashing_data.obs_names)
    hashing_data.obs["rna_type"] = "background"
    hashing_data.obs.loc[rna_data.obs_names[idx_df], "rna_type"] = "signal"

    if nsample == 1:
        logger.warning("Detected only one barcode, no need to demultiplex!")
        rna_data.obsm["raw_probs"] = np.zeros((rna_data.shape[0], nsample + 1))
        rna_data.obsm["raw_probs"][:, 0] = 1.0
        rna_data.obsm["raw_probs"][:, 1] = 0.0
        rna_data.obs["demux_type"] = "singlet"
        rna_data.obs["assignment"] = hashing_data.var_names[0]
    else:
        if nsample == 2:
            logger.warning(
                "Detected only two barcodes, demultiplexing accuracy might be affected!"
            )

        ncalc = idx_df.sum()
        if ncalc < rna_data.shape[0]:
            nzero = rna_data.shape[0] - ncalc
            logger.warning(
                "Warning: {} cells do not have ADTs, percentage = {:.2f}%.".
                format(nzero, nzero * 100.0 / rna_data.shape[0]))
        hto_small = hashing_data[rna_data.obs_names[idx_df], ].X.toarray()

        rna_data.obsm["raw_probs"] = np.zeros((rna_data.shape[0], nsample + 1))
        rna_data.obsm["raw_probs"][:, nsample] = 1.0

        iter_array = [(hto_small[i, ], hashing_data.uns["background_probs"],
                       alpha, alpha_noise, tol) for i in range(ncalc)]
        with multiprocessing.Pool(n_threads) as pool:
            rna_data.obsm["raw_probs"][idx_df, :] = pool.starmap(
                estimate_probs, iter_array)

        calc_demux(rna_data, hashing_data, nsample, min_signal)

        if has_duplicate_names(hashing_data.var_names):
            rna_data.obs["assignment.dedup"] = rna_data.obs["assignment"]
            rna_data.obs["assignment"] = remove_suffix(
                rna_data.obs["assignment"].values)

    logger.info("Demultiplexing is done.")
def deseq2(
    pseudobulk: UnimodalData,
    design: str,
    contrast: Tuple[str, str, str],
    de_key: str = "deseq2",
    replaceOutliers: bool = True,
) -> None:
    """Perform Differential Expression (DE) Analysis using DESeq2 on pseduobulk data. This function calls R package DESeq2, requiring DESeq2 in R installed.

    DE analysis will be performed on all pseudo-bulk matrices in pseudobulk.

    Parameters
    ----------
    pseudobulk: ``UnimodalData``
        Pseudobulk data with rows for samples and columns for genes. If pseudobulk contains multiple matrices, DESeq2 will apply to all matrices.

    design: ``str``
        Design formula that will be passed to DESeq2

    contrast: ``Tuple[str, str, str]``
        A tuple of three elements passing to DESeq2: a factor in design formula, a level in the factor as numeritor of fold change, and a level as denominator of fold change.
    
    de_key: ``str``, optional, default: ``"deseq2"``
        Key name of DE analysis results stored. For cluster.X, stored key will be cluster.de_key

    replaceOutliers: ``bool``, optional, default: ``True``
        If execute DESeq2's replaceOutliers step. If set to ``False``, we will set minReplicatesForReplace=Inf in ``DESeq`` function and set cooksCutoff=False in ``results`` function.

    Returns
    -------
    ``None``

    Update ``pseudobulk.varm``:
        ``pseudobulk.varm[de_key]``: DE analysis result for pseudo-bulk count matrix.
        ``pseudobulk.varm[cluster.de_key]``: DE results for cluster-specific pseudo-bulk count matrices.

    Examples
    --------
    >>> pg.deseq2(pseudobulk, '~gender', ('gender', 'female', 'male'))
    """
    try:
        import rpy2.robjects as ro
        from rpy2.robjects import pandas2ri, numpy2ri, Formula
        from rpy2.robjects.packages import importr
        from rpy2.robjects.conversion import localconverter
    except ModuleNotFoundError as e:
        import sys
        logger.error(f"{e}\nNeed rpy2! Try 'pip install rpy2'.")
        sys.exit(-1)

    try:
        deseq2 = importr('DESeq2')
    except ModuleNotFoundError:
        import sys
        text = """Please install DESeq2 in order to run this function.\n
                To install this package, start R and enter:\n
                if (!require("BiocManager", quietly = TRUE))
                    install.packages("BiocManager")
                BiocManager::install("DESeq2")"""

        logger.error(text)
        sys.exit(-1)

    import math
    to_dataframe = ro.r('function(x) data.frame(x)')

    for mat_key in pseudobulk.list_keys():
        with localconverter(ro.default_converter + numpy2ri.converter +
                            pandas2ri.converter):
            dds = deseq2.DESeqDataSetFromMatrix(
                countData=pseudobulk.get_matrix(mat_key).T,
                colData=pseudobulk.obs,
                design=Formula(design))

        if replaceOutliers:
            dds = deseq2.DESeq(dds)
            res = deseq2.results(dds, contrast=ro.StrVector(contrast))
        else:
            dds = deseq2.DESeq(dds, minReplicatesForReplace=math.inf)
            res = deseq2.results(dds,
                                 contrast=ro.StrVector(contrast),
                                 cooksCutoff=False)
        with localconverter(ro.default_converter + pandas2ri.converter):
            res_df = ro.conversion.rpy2py(to_dataframe(res))
            res_df.fillna(
                {
                    'log2FoldChange': 0.0,
                    'lfcSE': 0.0,
                    'stat': 0.0,
                    'pvalue': 1.0,
                    'padj': 1.0
                },
                inplace=True)

        de_res_key = de_key if mat_key.find(
            '.') < 0 else f"{mat_key.partition('.')[0]}.{de_key}"
        pseudobulk.varm[de_res_key] = res_df.to_records(index=False)
Example #26
0
def infer_doublets(
    data: MultimodalData,
    channel_attr: Optional[str] = None,
    clust_attr: Optional[str] = None,
    raw_mat_key: Optional[str] = 'counts',
    min_cell: Optional[int] = 100,
    expected_doublet_rate: Optional[float] = None,
    sim_doublet_ratio: Optional[float] = 2.0,
    n_prin_comps: Optional[int] = 30,
    k: Optional[int] = None,
    n_jobs: Optional[int] = -1,
    alpha: Optional[float] = 0.05,
    random_state: Optional[int] = 0,
    plot_hist: Optional[str] = "sample",
    manual_correction: Optional[str] = None,
) -> None:
    """Infer doublets by first calculating Scrublet-like [Wolock18]_ doublet scores and then smartly determining an appropriate doublet score cutoff [Li20-2]_ .

    This function should be called after clustering if clust_attr is not None. In this case, we will test if each cluster is significantly enriched for doublets using Fisher's exact test.

    Parameters
    ----------
    data: ``pegasusio.MultimodalData``
        Annotated data matrix with rows for cells and columns for genes.

    channel_attr: ``str``, optional, default: None
        Attribute indicating sample channels. If set, calculate scrublet-like doublet scores per channel.

    clust_attr: ``str``, optional, default: None
        Attribute indicating cluster labels. If set, estimate proportion of doublets in each cluster and statistical significance.

    min_cell: ``int``, optional, default: 100
        Minimum number of cells per sample to calculate doublet scores. For samples having less than 'min_cell' cells, doublet score calculation will be skipped.

    expected_doublet_rate: ``float``, optional, default: ``None``
        The expected doublet rate for the experiment. By default, calculate the expected rate based on number of cells from the 10x multiplet rate table

    sim_doublet_ratio: ``float``, optional, default: ``2.0``
        The ratio between synthetic doublets and observed cells.

    n_prin_comps: ``int``, optional, default: ``30``
        Number of principal components.

    k: ``int``, optional, default: ``None``
        Number of observed cell neighbors. If None, k = round(0.5 * sqrt(number of observed cells)). Total neighbors k_adj = round(k * (1.0 + sim_doublet_ratio)).

    n_jobs: ``int``, optional, default: ``-1``
        Number of threads to use. If ``-1``, use all physical CPU cores.

    alpha: ``float``, optional, default: ``0.05``
        FDR significant level for cluster-level fisher exact test.

    random_state: ``int``, optional, default: ``0``
        Random seed for reproducing results.

    plot_hist: ``str``, optional, default: ``sample``
        If not None, plot diagnostic histograms using ``plot_hist`` as the prefix. If `channel_attr` is None, ``plot_hist.dbl.png`` is generated; Otherwise, ``plot_hist.channel_name.dbl.png`` files are generated. Each figure consists of 4 panels showing histograms of doublet scores for observed cells (panel 1, density in log scale), simulated doublets (panel 2, density in log scale), KDE plot (panel 3) and signed curvature plot (panel 4) of log doublet scores for simulated doublets. Each plot contains two dashed lines. The red dashed line represents the theoretical cutoff (calucalted based on number of cells and 10x doublet table) and the black dashed line represents the cutof inferred from the data.
    
    manual_correction: ``str``, optional, default: ``None``
        Use human guide to correct doublet threshold for certain channels. This is string representing a comma-separately list. Each item in the list represent one sample and the sample name and correction guide are separated using ':'. The only correction guide supported is 'peak', which means cut at the center of the peak. If only one sample available, use '' as the sample name.

    Returns
    -------
    ``None``

    Update ``data.obs``:
        * ``data.obs['pred_dbl']``: Predicted singlet/doublet types.

        * ``data.uns['pred_dbl_cluster']``: Only generated if 'clust_attr' is not None. This is a dataframe with two columns, 'Cluster' and 'Qval'. Only clusters with significantly more doublets than expected will be recorded here.

    Examples
    --------
    >>> pg.infer_doublets(data, channel_attr = 'Channel', clust_attr = 'Annotation')
    """
    assert data.get_modality() == "rna"
    try:
        rawX = data.get_matrix(raw_mat_key)
    except ValueError:
        raise ValueError(
            f"Cannot detect the raw count matrix {raw_mat_key}; stop inferring doublets!"
        )

    if_plot = plot_hist is not None

    mancor = {}
    if manual_correction is not None:
        for item in manual_correction.split(','):
            name, action = item.split(':')
            mancor[name] = action

    if channel_attr is None:
        if data.shape[0] >= min_cell:
            fig = _run_scrublet(data, raw_mat_key, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
                                n_prin_comps = n_prin_comps, k = k, n_jobs = n_jobs, random_state = random_state, plot_hist = if_plot, manual_correction = mancor.get('', None))
            if if_plot:
                fig.savefig(f"{plot_hist}.dbl.png")
        else:
            logger.warning(
                f"Data has {data.shape[0]} < {min_cell} cells and thus doublet score calculation is skipped!"
            )
            data.obs["doublet_score"] = 0.0
            data.obs["pred_dbl"] = False
    else:
        from pandas.api.types import is_categorical_dtype
        from pegasus.tools import identify_robust_genes, log_norm, highly_variable_features

        assert is_categorical_dtype(data.obs[channel_attr])
        genome = data.get_genome()
        modality = data.get_modality()
        channels = data.obs[channel_attr].cat.categories

        dbl_score = np.zeros(data.shape[0], dtype=np.float32)
        pred_dbl = np.zeros(data.shape[0], dtype=np.bool_)
        thresholds = {}
        for channel in channels:
            # Generate a new unidata object for the channel
            idx = np.where(data.obs[channel_attr] == channel)[0]
            if idx.size >= min_cell:
                unidata = UnimodalData({"barcodekey": data.obs_names[idx]},
                                       {"featurekey": data.var_names},
                                       {"counts": rawX[idx]}, {
                                           "genome": genome,
                                           "modality": modality
                                       },
                                       cur_matrix="counts")
                # Identify robust genes, count and log normalized and select top 2,000 highly variable features
                identify_robust_genes(unidata)
                log_norm(unidata)
                highly_variable_features(unidata)
                # Run _run_scrublet
                fig = _run_scrublet(unidata, raw_mat_key, name = channel, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
                                    n_prin_comps = n_prin_comps, k = k, n_jobs = n_jobs, random_state = random_state, plot_hist = if_plot, manual_correction = mancor.get(channel, None))
                if if_plot:
                    fig.savefig(f"{plot_hist}.{channel}.dbl.png")

                dbl_score[idx] = unidata.obs["doublet_score"].values
                pred_dbl[idx] = unidata.obs["pred_dbl"].values
                thresholds[channel] = unidata.uns["doublet_threshold"]
            else:
                logger.warning(
                    f"Channel {channel} has {idx.size} < {min_cell} cells and thus doublet score calculation is skipped!"
                )

        data.obs["doublet_score"] = dbl_score
        data.obs["pred_dbl"] = pred_dbl
        data.uns["doublet_thresholds"] = thresholds

    if clust_attr is not None:
        data.uns["pred_dbl_cluster"] = _identify_doublets_fisher(
            data.obs[clust_attr].values,
            data.obs["pred_dbl"].values,
            alpha=alpha)

    logger.info('Doublets are predicted!')
Example #27
0
def analyze_one_modality(unidata: UnimodalData, output_name: str, is_raw: bool,
                         append_data: UnimodalData, **kwargs) -> None:
    print()
    logger.info(f"Begin to analyze UnimodalData {unidata.get_uid()}.")
    if kwargs["channel_attr"] is not None:
        unidata.obs["Channel"] = unidata.obs[kwargs["channel_attr"]]

    if is_raw:
        # normailize counts and then transform to log space
        tools.log_norm(unidata, kwargs["norm_count"])
        # set group attribute
        if kwargs["batch_correction"] and kwargs["group_attribute"] is not None:
            tools.set_group_attribute(unidata, kwargs["group_attribute"])

    # select highly variable features
    standardize = False  # if no select HVF, False
    if kwargs["select_hvf"]:
        if unidata.shape[1] <= kwargs["hvf_ngenes"]:
            logger.warning(
                f"Number of genes {unidata.shape[1]} is no greater than the target number of highly variable features {kwargs['hvf_ngenes']}. HVF selection is omitted."
            )
        else:
            standardize = True
            tools.highly_variable_features(
                unidata,
                kwargs["batch_correction"],
                flavor=kwargs["hvf_flavor"],
                n_top=kwargs["hvf_ngenes"],
                n_jobs=kwargs["n_jobs"],
            )
            if kwargs["hvf_flavor"] == "pegasus":
                if kwargs["plot_hvf"] is not None:
                    from pegasus.plotting import hvfplot
                    fig = hvfplot(unidata, return_fig=True)
                    fig.savefig(f"{kwargs['plot_hvf']}.hvf.pdf")

    # batch correction: L/S
    if kwargs["batch_correction"] and kwargs["correction_method"] == "L/S":
        tools.correct_batch(unidata, features="highly_variable_features")

    if kwargs["calc_sigscore"] is not None:
        sig_files = kwargs["calc_sigscore"].split(",")
        for sig_file in sig_files:
            tools.calc_signature_score(unidata, sig_file)

    n_pc = min(kwargs["pca_n"], unidata.shape[0], unidata.shape[1])
    if n_pc < kwargs["pca_n"]:
        logger.warning(
            f"UnimodalData {unidata.get_uid()} has either dimension ({unidata.shape[0]}, {unidata.shape[1]}) less than the specified number of PCs {kwargs['pca_n']}. Reduce the number of PCs to {n_pc}."
        )

    if kwargs["batch_correction"] and kwargs[
            "correction_method"] == "scanorama":
        pca_key = tools.run_scanorama(unidata,
                                      n_components=n_pc,
                                      features="highly_variable_features",
                                      standardize=standardize,
                                      random_state=kwargs["random_state"])
    else:
        # PCA
        tools.pca(
            unidata,
            n_components=n_pc,
            features="highly_variable_features",
            standardize=standardize,
            robust=kwargs["pca_robust"],
            random_state=kwargs["random_state"],
        )
        pca_key = "pca"

    # batch correction: Harmony
    if kwargs["batch_correction"] and kwargs["correction_method"] == "harmony":
        pca_key = tools.run_harmony(unidata,
                                    rep="pca",
                                    n_jobs=kwargs["n_jobs"],
                                    n_clusters=kwargs["harmony_nclusters"],
                                    random_state=kwargs["random_state"])

    # Find K neighbors
    tools.neighbors(
        unidata,
        K=kwargs["K"],
        rep=pca_key,
        n_jobs=kwargs["n_jobs"],
        random_state=kwargs["random_state"],
        full_speed=kwargs["full_speed"],
    )

    # calculate diffmap
    if (kwargs["fle"] or kwargs["net_fle"]):
        if not kwargs["diffmap"]:
            print("Turn on --diffmap option!")
        kwargs["diffmap"] = True

    if kwargs["diffmap"]:
        tools.diffmap(
            unidata,
            n_components=kwargs["diffmap_ndc"],
            rep=pca_key,
            solver=kwargs["diffmap_solver"],
            random_state=kwargs["random_state"],
            max_t=kwargs["diffmap_maxt"],
        )
        if kwargs["diffmap_to_3d"]:
            tools.reduce_diffmap_to_3d(unidata,
                                       random_state=kwargs["random_state"])

    # calculate kBET
    if ("kBET" in kwargs) and kwargs["kBET"]:
        stat_mean, pvalue_mean, accept_rate = tools.calc_kBET(
            unidata,
            kwargs["kBET_batch"],
            rep=pca_key,
            K=kwargs["kBET_K"],
            alpha=kwargs["kBET_alpha"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"])
        print(
            "kBET stat_mean = {:.2f}, pvalue_mean = {:.4f}, accept_rate = {:.2%}."
            .format(stat_mean, pvalue_mean, accept_rate))

    # clustering
    if kwargs["spectral_louvain"]:
        tools.cluster(
            unidata,
            algo="spectral_louvain",
            rep=pca_key,
            resolution=kwargs["spectral_louvain_resolution"],
            rep_kmeans=kwargs["spectral_louvain_basis"],
            n_clusters=kwargs["spectral_louvain_nclusters"],
            n_clusters2=kwargs["spectral_louvain_nclusters2"],
            n_init=kwargs["spectral_louvain_ninit"],
            random_state=kwargs["random_state"],
            class_label="spectral_louvain_labels",
        )

    if kwargs["spectral_leiden"]:
        tools.cluster(
            unidata,
            algo="spectral_leiden",
            rep=pca_key,
            resolution=kwargs["spectral_leiden_resolution"],
            rep_kmeans=kwargs["spectral_leiden_basis"],
            n_clusters=kwargs["spectral_leiden_nclusters"],
            n_clusters2=kwargs["spectral_leiden_nclusters2"],
            n_init=kwargs["spectral_leiden_ninit"],
            random_state=kwargs["random_state"],
            class_label="spectral_leiden_labels",
        )

    if kwargs["louvain"]:
        tools.cluster(
            unidata,
            algo="louvain",
            rep=pca_key,
            resolution=kwargs["louvain_resolution"],
            random_state=kwargs["random_state"],
            class_label=kwargs["louvain_class_label"],
        )

    if kwargs["leiden"]:
        tools.cluster(
            unidata,
            algo="leiden",
            rep=pca_key,
            resolution=kwargs["leiden_resolution"],
            n_iter=kwargs["leiden_niter"],
            random_state=kwargs["random_state"],
            class_label=kwargs["leiden_class_label"],
        )

    # visualization
    if kwargs["net_tsne"]:
        tools.net_tsne(
            unidata,
            rep=pca_key,
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            net_alpha=kwargs["net_l2"],
            polish_learning_frac=kwargs["net_tsne_polish_learing_frac"],
            polish_n_iter=kwargs["net_tsne_polish_niter"],
            out_basis=kwargs["net_tsne_basis"],
        )

    if kwargs["net_umap"]:
        tools.net_umap(
            unidata,
            rep=pca_key,
            n_jobs=kwargs["n_jobs"],
            n_neighbors=kwargs["umap_K"],
            min_dist=kwargs["umap_min_dist"],
            spread=kwargs["umap_spread"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            full_speed=kwargs["full_speed"],
            net_alpha=kwargs["net_l2"],
            polish_learning_rate=kwargs["net_umap_polish_learing_rate"],
            polish_n_epochs=kwargs["net_umap_polish_nepochs"],
            out_basis=kwargs["net_umap_basis"],
        )

    if kwargs["net_fle"]:
        tools.net_fle(
            unidata,
            output_name,
            n_jobs=kwargs["n_jobs"],
            K=kwargs["fle_K"],
            full_speed=kwargs["full_speed"],
            target_change_per_node=kwargs["fle_target_change_per_node"],
            target_steps=kwargs["fle_target_steps"],
            is3d=False,
            memory=kwargs["fle_memory"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            net_alpha=kwargs["net_l2"],
            polish_target_steps=kwargs["net_fle_polish_target_steps"],
            out_basis=kwargs["net_fle_basis"],
        )

    if kwargs["tsne"]:
        tools.tsne(
            unidata,
            rep=pca_key,
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
        )

    if kwargs["fitsne"]:
        tools.fitsne(
            unidata,
            rep=pca_key,
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
        )

    if kwargs["umap"]:
        tools.umap(
            unidata,
            rep=pca_key,
            n_neighbors=kwargs["umap_K"],
            min_dist=kwargs["umap_min_dist"],
            spread=kwargs["umap_spread"],
            random_state=kwargs["random_state"],
        )

    if kwargs["fle"]:
        tools.fle(
            unidata,
            output_name,
            n_jobs=kwargs["n_jobs"],
            K=kwargs["fle_K"],
            full_speed=kwargs["full_speed"],
            target_change_per_node=kwargs["fle_target_change_per_node"],
            target_steps=kwargs["fle_target_steps"],
            is3d=False,
            memory=kwargs["fle_memory"],
            random_state=kwargs["random_state"],
        )

    # calculate diffusion-based pseudotime from roots
    if len(kwargs["pseudotime"]) > 0:
        tools.calc_pseudotime(unidata, kwargs["pseudotime"])

    genome = unidata.uns["genome"]

    if append_data is not None:
        locs = unidata.obs_names.get_indexer(append_data.obs_names)
        idx = locs >= 0
        locs = locs[idx]
        Y = append_data.X[idx, :].tocoo(copy=False)
        Z = coo_matrix((Y.data, (locs[Y.row], Y.col)),
                       shape=(unidata.shape[0], append_data.shape[1])).tocsr()

        idy = Z.getnnz(axis=0) > 0
        n_nonzero = idy.sum()
        if n_nonzero > 0:
            if n_nonzero < append_data.shape[1]:
                Z = Z[:, idy]
                append_df = append_data.feature_metadata.loc[idy, :]
            else:
                append_df = append_data.feature_metadata

            rawX = hstack([unidata.get_matrix("raw.X"), Z], format="csr")

            Zt = Z.astype(np.float32)
            Zt.data *= np.repeat(unidata.obs["scale"].values,
                                 np.diff(Zt.indptr))
            Zt.data = np.log1p(Zt.data)

            X = hstack([unidata.get_matrix("X"), Zt], format="csr")

            new_genome = unidata.get_genome(
            ) + "_and_" + append_data.get_genome()

            feature_metadata = pd.concat([unidata.feature_metadata, append_df],
                                         axis=0)
            feature_metadata.reset_index(inplace=True)
            feature_metadata.fillna(value=_get_fillna_dict(
                unidata.feature_metadata),
                                    inplace=True)

            unidata = UnimodalData(
                unidata.barcode_metadata, feature_metadata, {
                    "X": X,
                    "raw.X": rawX
                }, unidata.uns.mapping, unidata.obsm.mapping,
                unidata.varm.mapping
            )  # uns.mapping, obsm.mapping and varm.mapping are passed by reference
            unidata.uns["genome"] = new_genome

    if kwargs["output_h5ad"]:
        adata = unidata.to_anndata()
        adata.uns["scale.data"] = adata.uns.pop(
            "_tmp_fmat_highly_variable_features")  # assign by reference
        adata.uns["scale.data.rownames"] = unidata.var_names[
            unidata.var["highly_variable_features"]].values
        adata.write(f"{output_name}.h5ad", compression="gzip")
        del adata

    # write out results
    if kwargs["output_loom"]:
        write_output(unidata, f"{output_name}.loom")

    # Change genome name back if append_data is True
    if unidata.uns["genome"] != genome:
        unidata.uns["genome"] = genome
    # Eliminate objects starting with fmat_ from uns
    unidata.uns.pop("_tmp_fmat_highly_variable_features", None)
def calc_qc_filters(unidata: UnimodalData,
                    select_singlets: bool = False,
                    remap_string: str = None,
                    subset_string: str = None,
                    min_genes: int = None,
                    max_genes: int = None,
                    min_umis: int = None,
                    max_umis: int = None,
                    mito_prefix: str = None,
                    percent_mito: float = None) -> None:
    """Calculate Quality Control (QC) metrics and mark barcodes based on the combination of QC metrics.

    Parameters
    ----------
    unidata: ``UnimodalData``
       Unimodal data matrix with rows for cells and columns for genes.
    select_singlets: ``bool``, optional, default ``False``
        If select only singlets.
    remap_string: ``str``, optional, default ``None``
        Remap singlet names using <remap_string>, where <remap_string> takes the format "new_name_i:old_name_1,old_name_2;new_name_ii:old_name_3;...". For example, if we hashed 5 libraries from 3 samples sample1_lib1, sample1_lib2, sample2_lib1, sample2_lib2 and sample3, we can remap them to 3 samples using this string: "sample1:sample1_lib1,sample1_lib2;sample2:sample2_lib1,sample2_lib2". In this way, the new singlet names will be in metadata field with key 'assignment', while the old names will be kept in metadata field with key 'assignment.orig'.
    subset_string: ``str``, optional, default ``None``
        If select singlets, only select singlets in the <subset_string>, which takes the format "name1,name2,...". Note that if --remap-singlets is specified, subsetting happens after remapping. For example, we can only select singlets from sampe 1 and 3 using "sample1,sample3".
    min_genes: ``int``, optional, default: None
       Only keep cells with at least ``min_genes`` genes.
    max_genes: ``int``, optional, default: None
       Only keep cells with less than ``max_genes`` genes.
    min_umis: ``int``, optional, default: None
       Only keep cells with at least ``min_umis`` UMIs.
    max_umis: ``int``, optional, default: None
       Only keep cells with less than ``max_umis`` UMIs.
    mito_prefix: ``str``, optional, default: None
       Prefix for mitochondrial genes.
    percent_mito: ``float``, optional, default: None
       Only keep cells with percent mitochondrial genes less than ``percent_mito`` % of total counts. Only when both mito_prefix and percent_mito set, the mitochondrial filter will be triggered.

    Returns
    -------
    ``None``

    Update ``unidata.obs``:

        * ``n_genes``: Total number of genes for each cell.
        * ``n_counts``: Total number of counts for each cell.
        * ``percent_mito``: Percent of mitochondrial genes for each cell.
        * ``passed_qc``: Boolean type indicating if a cell passes the QC process based on the QC metrics.
        * ``demux_type``: this column might be deleted if select_singlets is on.

    Examples
    --------
    >>> calc_qc_filters(unidata, min_umis = 500, select_singlets = True)
    """
    assert unidata.uns["modality"] == "rna"

    filters = []

    if select_singlets and ("demux_type" in unidata.obs):
        if remap_string is not None:
            if "assignment" not in unidata.obs:
                raise ValueError("No assignment field detected!")
            unidata.obs["assignment.orig"] = unidata.obs["assignment"]

            remap = {}
            tokens = remap_string.split(";")
            for token in tokens:
                new_key, old_str = token.split(":")
                old_keys = old_str.split(",")
                for key in old_keys:
                    remap[key] = new_key

            unidata.obs["assignment"] = pd.Categorical(
                unidata.obs["assignment"].apply(lambda x: remap[x]
                                                if x in remap else x))
            logger.info("Singlets are remapped.")

        if subset_string is None:
            filters.append(unidata.obs["demux_type"] == "singlet")
        else:
            if "assignment" not in unidata.obs:
                raise ValueError("No assignment field detected!")

            subset = np.array(subset_string.split(","))
            filters.append(np.isin(unidata.obs["assignment"], subset))

        unidata.uns["__del_demux_type"] = True

    if "n_genes" not in unidata.obs:
        unidata.obs["n_genes"] = unidata.X.getnnz(axis=1)

    if "n_counts" not in unidata.obs:
        unidata.obs["n_counts"] = unidata.X.sum(axis=1).A1

    if min_genes is not None:
        filters.append(unidata.obs["n_genes"] >= min_genes)
    if max_genes is not None:
        filters.append(unidata.obs["n_genes"] < max_genes)
    if min_umis is not None:
        filters.append(unidata.obs["n_counts"] >= min_umis)
    if max_umis is not None:
        filters.append(unidata.obs["n_counts"] < max_umis)

    if (mito_prefix is not None) and (percent_mito is not None):
        mito_genes = unidata.var_names.map(
            lambda x: x.startswith(mito_prefix)).values.nonzero()[0]
        unidata.obs["percent_mito"] = (unidata.X[:, mito_genes].sum(
            axis=1).A1 / np.maximum(unidata.obs["n_counts"].values, 1.0)) * 100
        filters.append(unidata.obs["percent_mito"] < percent_mito)

    if len(filters) > 0:
        selected = np.logical_and.reduce(filters)
        unidata.obs["passed_qc"] = selected
    else:
        unidata.obs["passed_qc"] = True