Exemple #1
0
def run_pipeline(input_file, output_name, **kwargs):
    is_raw = not kwargs["processed"]

    if "seurat_compatible" not in kwargs:
        kwargs["seurat_compatible"] = False

    # load input data
    adata = io.read_input(
        input_file,
        genome=kwargs["genome"],
        concat_matrices=False if kwargs["cite_seq"] else True,
        h5ad_mode=("a" if (is_raw or kwargs["subcluster"]) else "r+"),
        select_singlets=kwargs["select_singlets"],
        channel_attr=kwargs["channel_attr"],
        black_list=(
            kwargs["black_list"].split(",") if kwargs["black_list"] is not None else []
        ),
    )

    if not kwargs["cite_seq"]:
        if is_raw:
            values = adata.X.getnnz(axis=1)
            if values.min() == 0:  # 10x raw data
                adata._inplace_subset_obs(values >= kwargs["min_genes_on_raw"])
    else:
        data_list = adata
        assert len(data_list) == 2
        adata = cdata = None
        for i in range(len(data_list)):
            if data_list[i].uns["genome"].startswith("CITE_Seq"):
                cdata = data_list[i]
            else:
                adata = data_list[i]
        assert adata is not None and cdata is not None
    print("Inputs are loaded.")

    if kwargs["seurat_compatible"]:
        assert is_raw and kwargs["select_hvf"]

    if kwargs["subcluster"]:
        adata = tools.get_anndata_for_subclustering(adata, kwargs["subset_selections"])
        is_raw = True  # get submat and then set is_raw to True

    if is_raw:
        if not kwargs["subcluster"]:
            # filter out low quality cells/genes
            tools.run_filter_data(
                adata,
                output_filt=kwargs["output_filt"],
                plot_filt=kwargs["plot_filt"],
                plot_filt_figsize=kwargs["plot_filt_figsize"],
                mito_prefix=kwargs["mito_prefix"],
                min_genes=kwargs["min_genes"],
                max_genes=kwargs["max_genes"],
                min_umis=kwargs["min_umis"],
                max_umis=kwargs["max_umis"],
                percent_mito=kwargs["percent_mito"],
                percent_cells=kwargs["percent_cells"],
            )

            if kwargs["seurat_compatible"]:
                raw_data = adata.copy()  # raw as count

            # normailize counts and then transform to log space
            tools.log_norm(adata, kwargs["norm_count"])

            # set group attribute
            if kwargs["batch_correction"] and kwargs["group_attribute"] is not None:
                tools.set_group_attribute(adata, kwargs["group_attribute"])

        # select highly variable features
        if kwargs["select_hvf"]:
            tools.highly_variable_features(
                adata,
                kwargs["batch_correction"],
                flavor=kwargs["hvf_flavor"],
                n_top=kwargs["hvf_ngenes"],
                n_jobs=kwargs["n_jobs"],
            )
            if kwargs["hvf_flavor"] == "pegasus":
                if kwargs["plot_hvf"] is not None:
                    from pegasus.plotting import plot_hvf

                    robust_idx = adata.var["robust"].values
                    plot_hvf(
                        adata.var.loc[robust_idx, "mean"],
                        adata.var.loc[robust_idx, "var"],
                        adata.var.loc[robust_idx, "hvf_loess"],
                        adata.var.loc[robust_idx, "highly_variable_features"],
                        kwargs["plot_hvf"] + ".hvf.pdf",
                    )

        # batch correction
        if kwargs["batch_correction"]:
            tools.correct_batch(adata, features="highly_variable_features")

        # PCA
        tools.pca(
            adata,
            n_components=kwargs["nPC"],
            features="highly_variable_features",
            random_state=kwargs["random_state"],
        )

        # Find K neighbors
        tools.neighbors(
            adata,
            K=kwargs["K"],
            rep="pca",
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            full_speed=kwargs["full_speed"],
        )

        # calculate diffmap
        if (
            kwargs["fle"]
            or kwargs["net_fle"]
        ):
            if not kwargs["diffmap"]:
                print("Turn on --diffmap option!")
            kwargs["diffmap"] = True

        if kwargs["diffmap"]:
            tools.diffmap(
                adata,
                n_components=kwargs["diffmap_ndc"],
                rep="pca",
                solver=kwargs["diffmap_solver"],
                random_state=kwargs["random_state"],
                max_t=kwargs["diffmap_maxt"],
            )
            if kwargs["diffmap_to_3d"]:
                tools.reduce_diffmap_to_3d(adata, random_state=kwargs["random_state"])

    # calculate kBET
    if ("kBET" in kwargs) and kwargs["kBET"]:
        stat_mean, pvalue_mean, accept_rate = tools.calc_kBET(
            adata,
            kwargs["kBET_batch"],
            K=kwargs["kBET_K"],
            alpha=kwargs["kBET_alpha"],
            n_jobs=kwargs["n_jobs"],
        )
        print(
            "kBET stat_mean = {:.2f}, pvalue_mean = {:.4f}, accept_rate = {:.2%}.".format(
                stat_mean, pvalue_mean, accept_rate
            )
        )

    # clustering
    if kwargs["spectral_louvain"]:
        tools.cluster(
            adata,
            algo="spectral_louvain",
            rep="pca",
            resolution=kwargs["spectral_louvain_resolution"],
            rep_kmeans=kwargs["spectral_louvain_basis"],
            n_clusters=kwargs["spectral_louvain_nclusters"],
            n_clusters2=kwargs["spectral_louvain_nclusters2"],
            n_init=kwargs["spectral_louvain_ninit"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            class_label="spectral_louvain_labels",
        )

    if kwargs["spectral_leiden"]:
        tools.cluster(
            adata,
            algo="spectral_leiden",
            rep="pca",
            resolution=kwargs["spectral_leiden_resolution"],
            rep_kmeans=kwargs["spectral_leiden_basis"],
            n_clusters=kwargs["spectral_leiden_nclusters"],
            n_clusters2=kwargs["spectral_leiden_nclusters2"],
            n_init=kwargs["spectral_leiden_ninit"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            class_label="spectral_leiden_labels",
        )

    if kwargs["louvain"]:
        tools.cluster(
            adata,
            algo="louvain",
            rep="pca",
            resolution=kwargs["louvain_resolution"],
            random_state=kwargs["random_state"],
            class_label=kwargs["louvain_class_label"],
        )

    if kwargs["leiden"]:
        tools.cluster(
            adata,
            algo="leiden",
            rep="pca",
            resolution=kwargs["leiden_resolution"],
            n_iter=kwargs["leiden_niter"],
            random_state=kwargs["random_state"],
            class_label=kwargs["leiden_class_label"],
        )

    # visualization
    if kwargs["net_tsne"]:
        tools.net_tsne(
            adata,
            rep="pca",
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            net_alpha=kwargs["net_l2"],
            polish_learning_frac=kwargs["net_tsne_polish_learing_frac"],
            polish_n_iter=kwargs["net_tsne_polish_niter"],
            out_basis=kwargs["net_tsne_basis"],
        )

    if kwargs["net_umap"]:
        tools.net_umap(
            adata,
            rep="pca",
            n_jobs=kwargs["n_jobs"],
            n_neighbors=kwargs["umap_K"],
            min_dist=kwargs["umap_min_dist"],
            spread=kwargs["umap_spread"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            full_speed=kwargs["full_speed"],
            net_alpha=kwargs["net_l2"],
            polish_learning_rate=kwargs["net_umap_polish_learing_rate"],
            polish_n_epochs=kwargs["net_umap_polish_nepochs"],
            out_basis=kwargs["net_umap_basis"],
        )

    if kwargs["net_fle"]:
        tools.net_fle(
            adata,
            output_name,
            n_jobs=kwargs["n_jobs"],
            K=kwargs["fle_K"],
            full_speed=kwargs["full_speed"],
            target_change_per_node=kwargs["fle_target_change_per_node"],
            target_steps=kwargs["fle_target_steps"],
            is3d=False,
            memory=kwargs["fle_memory"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            net_alpha=kwargs["net_l2"],
            polish_target_steps=kwargs["net_fle_polish_target_steps"],
            out_basis=kwargs["net_fle_basis"],
        )

    if kwargs["tsne"]:
        tools.tsne(
            adata,
            rep="pca",
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
        )

    if kwargs["fitsne"]:
        tools.fitsne(
            adata,
            rep="pca",
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
        )

    if kwargs["umap"]:
        tools.umap(
            adata,
            rep="pca",
            n_neighbors=kwargs["umap_K"],
            min_dist=kwargs["umap_min_dist"],
            spread=kwargs["umap_spread"],
            random_state=kwargs["random_state"],
        )

    if kwargs["fle"]:
        tools.fle(
            adata,
            output_name,
            n_jobs=kwargs["n_jobs"],
            K=kwargs["fle_K"],
            full_speed=kwargs["full_speed"],
            target_change_per_node=kwargs["fle_target_change_per_node"],
            target_steps=kwargs["fle_target_steps"],
            is3d=False,
            memory=kwargs["fle_memory"],
            random_state=kwargs["random_state"],
        )

    # calculate diffusion-based pseudotime from roots
    if len(kwargs["pseudotime"]) > 0:
        tools.calc_pseudotime(adata, kwargs["pseudotime"])

    # merge cite-seq data and run t-SNE
    if kwargs["cite_seq"]:
        adt_matrix = np.zeros((adata.shape[0], cdata.shape[1]), dtype="float32")
        idx = adata.obs_names.isin(cdata.obs_names)
        adt_matrix[idx, :] = cdata[adata.obs_names[idx],].X.toarray()
        if abs(100.0 - kwargs["cite_seq_capping"]) > 1e-4:
            cite_seq.capping(adt_matrix, kwargs["cite_seq_capping"])

        var_names = np.concatenate(
            [adata.var_names, ["AD-" + x for x in cdata.var_names]]
        )

        new_data = anndata.AnnData(
            X=hstack([adata.X, csr_matrix(adt_matrix)], format="csr"),
            obs=adata.obs,
            obsm=adata.obsm,
            uns=adata.uns,
            var={
                "var_names": var_names,
                "gene_ids": var_names,
                "n_cells": np.concatenate(
                    [adata.var["n_cells"].values, [0] * cdata.shape[1]]
                ),
                "percent_cells": np.concatenate(
                    [adata.var["percent_cells"].values, [0.0] * cdata.shape[1]]
                ),
                "robust": np.concatenate(
                    [adata.var["robust"].values, [False] * cdata.shape[1]]
                ),
                "highly_variable_features": np.concatenate(
                    [
                        adata.var["highly_variable_features"].values,
                        [False] * cdata.shape[1],
                    ]
                ),
            },
        )
        new_data.obsm["X_CITE-Seq"] = adt_matrix
        adata = new_data
        print("ADT count matrix is attached.")

        tools.fitsne(
            adata,
            rep="CITE-Seq",
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
            out_basis="citeseq_fitsne",
        )
        print("Antibody embedding is done.")

    if kwargs["seurat_compatible"]:
        seurat_data = adata.copy()
        seurat_data.raw = raw_data
        seurat_data.uns["scale.data"] = adata.uns["fmat_highly_variable_features"] # assign by reference
        seurat_data.uns["scale.data.rownames"] = adata.var_names[
            adata.var["highly_variable_features"]
        ].values
        io.write_output(seurat_data, output_name + ".seurat.h5ad")

    # write out results
    io.write_output(adata, output_name + ".h5ad")

    if kwargs["output_loom"]:
        io.write_output(adata, output_name + ".loom")

    print("Results are written.")
Exemple #2
0
def infer_doublets(
    data: MultimodalData,
    channel_attr: Optional[str] = None,
    clust_attr: Optional[str] = None,
    min_cell: Optional[int] = 100,
    expected_doublet_rate: Optional[float] = None,
    sim_doublet_ratio: Optional[float] = 2.0,
    n_prin_comps: Optional[int] = 30,
    robust: Optional[bool] = False,
    k: Optional[int] = None,
    n_jobs: Optional[int] = -1,
    alpha: Optional[float] = 0.05,
    random_state: Optional[int] = 0,
    plot_hist: Optional[str] = "dbl",
) -> None:
    """Infer doublets using a Scrublet-like strategy. [Li20-2]_

    This function must be called after clustering. 

    Parameters
    ----------
    data: ``pegasusio.MultimodalData``
        Annotated data matrix with rows for cells and columns for genes.

    channel_attr: ``str``, optional, default: None
        Attribute indicating sample channels. If set, calculate scrublet-like doublet scores per channel.

    clust_attr: ``str``, optional, default: None
        Attribute indicating cluster labels. If set, estimate proportion of doublets in each cluster and statistical significance.

    min_cell: ``int``, optional, default: 100
        Minimum number of cells per sample to calculate doublet scores. For samples having less than 'min_cell' cells, doublet score calculation will be skipped.

    expected_doublet_rate: ``float``, optional, default: ``None``
        The expected doublet rate for the experiment. By default, calculate the expected rate based on number of cells from the 10x multiplet rate table

    sim_doublet_ratio: ``float``, optional, default: ``2.0``
        The ratio between synthetic doublets and observed cells.

    n_prin_comps: ``int``, optional, default: ``30``
        Number of principal components.

    robust: ``bool``, optional, default: ``False``.
        If true, use 'arpack' instead of 'randomized' for large matrices (i.e. max(X.shape) > 500 and n_components < 0.8 * min(X.shape))

    k: ``int``, optional, default: ``None``
        Number of observed cell neighbors. If None, k = round(0.5 * sqrt(number of observed cells)). Total neighbors k_adj = round(k * (1.0 + sim_doublet_ratio)).

    n_job: ``int``, optional, default: ``-``
        Number of threads to use. If ``-1``, use all available threads.

    alpha: ``float``, optional, default: ``0.05``
        FDR significant level for cluster-level fisher exact test.

    random_state: ``int``, optional, default: ``0``
        Random seed for reproducing results.

    plot_hist: ``str``, optional, default: ``dbl``
        If not None, plot diagnostic histograms using ``plot_hist`` as the prefix. If `channel_attr` is None, ``plot_hist.png`` is generated; Otherwise, ``plot_hist.channel_name.png`` files are generated.

    Returns
    -------
    ``None``

    Update ``data.obs``:
        * ``data.obs['pred_dbl_type']``: Predicted singlet/doublet types.

        * ``data.uns['pred_dbl_cluster']``: Only generated if 'clust_attr' is not None. This is a dataframe with two columns, 'Cluster' and 'Qval'. Only clusters with significantly more doublets than expected will be recorded here.

    Examples
    --------
    >>> pg.infer_doublets(data, channel_attr = 'Channel', clust_attr = 'Annotation')
    """
    assert data.get_modality() == "rna"
    try:
        rawX = data.get_matrix("raw.X")
    except ValueError:
        raise ValueError(
            "Cannot detect the raw count matrix raw.X; stop inferring doublets!"
        )

    if_plot = plot_hist is not None

    if channel_attr is None:
        if data.shape[0] >= min_cell:
            fig = _run_scrublet(data, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
                                n_prin_comps = n_prin_comps, robust = robust, k = k, n_jobs = n_jobs, random_state = random_state, \
                                plot_hist = if_plot)
            if if_plot:
                fig.savefig(f"{plot_hist}.png")
        else:
            logger.warning(
                f"Data has {data.shape[0]} < {min_cell} cells and thus doublet score calculation is skipped!"
            )
            data.obs["doublet_score"] = 0.0
            data.obs["pred_dbl"] = False
    else:
        from pandas.api.types import is_categorical_dtype
        from pegasus.tools import identify_robust_genes, log_norm, highly_variable_features

        assert is_categorical_dtype(data.obs[channel_attr])
        genome = data.get_genome()
        modality = data.get_modality()
        channels = data.obs[channel_attr].cat.categories

        dbl_score = np.zeros(data.shape[0], dtype=np.float32)
        pred_dbl = np.zeros(data.shape[0], dtype=np.bool_)
        thresholds = {}
        for channel in channels:
            # Generate a new unidata object for the channel
            idx = np.where(data.obs[channel_attr] == channel)[0]
            if idx.size >= min_cell:
                unidata = UnimodalData({"barcodekey": data.obs_names[idx]},
                                       {"featurekey": data.var_names},
                                       {"X": rawX[idx]}, {
                                           "genome": genome,
                                           "modality": modality
                                       })
                # Identify robust genes, count and log normalized and select top 2,000 highly variable features
                identify_robust_genes(unidata)
                log_norm(unidata)
                highly_variable_features(unidata)
                # Run _run_scrublet
                fig = _run_scrublet(unidata, name = channel, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
                                    n_prin_comps = n_prin_comps, robust = robust, k = k, n_jobs = n_jobs, random_state = random_state, \
                                    plot_hist = if_plot)
                if if_plot:
                    fig.savefig(f"{plot_hist}.{channel}.png")

                dbl_score[idx] = unidata.obs["doublet_score"].values
                pred_dbl[idx] = unidata.obs["pred_dbl"].values
                thresholds[channel] = unidata.uns["doublet_threshold"]
            else:
                logger.warning(
                    f"Channel {channel} has {idx.size} < {min_cell} cells and thus doublet score calculation is skipped!"
                )

        data.obs["doublet_score"] = dbl_score
        data.obs["pred_dbl"] = pred_dbl
        data.uns["doublet_thresholds"] = thresholds

    if clust_attr is not None:
        data.uns["pred_dbl_cluster"] = _identify_doublets_fisher(
            data.obs[clust_attr].values,
            data.obs["pred_dbl"].values,
            alpha=alpha)

    logger.info('Doublets are predicted!')
def run_demuxEM_pipeline(input_adt_file, input_rna_file, output_name, **kwargs):
    # load input data
    adt = io.read_input(input_adt_file, genome="_ADT_")
    print("ADT file is loaded.")
    data = io.read_input(input_rna_file, genome=kwargs["genome"], concat_matrices=True)
    print("RNA file is loaded.")

    # Filter the RNA matrix
    data.obs["n_genes"] = data.X.getnnz(axis=1)
    data.obs["n_counts"] = data.X.sum(axis=1).A1
    obs_index = np.logical_and.reduce(
        (
            data.obs["n_genes"] >= kwargs["min_num_genes"],
            data.obs["n_counts"] >= kwargs["min_num_umis"],
        )
    )
    data._inplace_subset_obs(obs_index)
    data.var["robust"] = True

    # run demuxEM
    demuxEM.estimate_background_probs(adt, random_state=kwargs["random_state"])
    print("Background probability distribution is estimated.")
    demuxEM.demultiplex(
        data,
        adt,
        min_signal=kwargs["min_signal"],
        alpha=kwargs["alpha"],
        n_threads=kwargs["n_jobs"],
    )
    print("Demultiplexing is done.")

    # annotate raw matrix with demuxEM results
    genome_indexed_raw_data = io.read_input(
        input_rna_file, return_type="MemData", concat_matrices=False
    )
    for keyword in genome_indexed_raw_data.listKeys():
        array2d = genome_indexed_raw_data.getData(keyword)
        barcodes = array2d.barcode_metadata.index
        idx = barcodes.isin(data.obs_names)
        selected = barcodes[idx]

        demux_type = np.empty(barcodes.size, dtype="object")
        demux_type[:] = ""
        demux_type[idx] = data.obs.loc[selected, "demux_type"]
        array2d.barcode_metadata["demux_type"] = demux_type

        assignment = np.empty(barcodes.size, dtype="object")
        assignment[:] = ""
        assignment[idx] = data.obs.loc[selected, "assignment"]
        array2d.barcode_metadata["assignment"] = assignment

        if "assignment.dedup" in data.obs:
            assignment_dedup = np.empty(barcodes.size, dtype="object")
            assignment_dedup[:] = ""
            assignment_dedup[idx] = data.obs.loc[selected, "assignment.dedup"]
            array2d.barcode_metadata["assignment.dedup"] = assignment_dedup

    print("Demultiplexing results are added to raw expression matrices.")

    # generate plots
    if kwargs["gen_plots"]:
        demuxEM.plot_adt_hist(
            adt, "hto_type", output_name + ".ambient_hashtag.hist.pdf", alpha=1.0
        )
        demuxEM.plot_bar(
            adt.uns["background_probs"],
            adt.var_names,
            "Sample ID",
            "Background probability",
            output_name + ".background_probabilities.bar.pdf",
        )
        demuxEM.plot_adt_hist(
            adt, "rna_type", output_name + ".real_content.hist.pdf", alpha=0.5
        )
        demuxEM.plot_rna_hist(data, output_name + ".rna_demux.hist.pdf")
        print("Diagnostic plots are generated.")

    if len(kwargs["gen_gender_plot"]) > 0:
        tools.log_norm(data, 1e5)
        for gene_name in kwargs["gen_gender_plot"]:
            demuxEM.plot_violin(
                data,
                {"gene": gene_name},
                "{output_name}.{gene_name}.violin.pdf".format(
                    output_name=output_name, gene_name=gene_name
                ),
                title="{gene_name}: a gender-specific gene".format(gene_name=gene_name),
            )
        print("Gender-specific gene expression violin plots are generated.")

    # output results
    io.write_output(adt, output_name + "_ADTs.h5ad")
    print(
        "Hashtag count information is written to {output_name}_ADTs.h5ad .".format(
            output_name=output_name
        )
    )
    io.write_output(data, output_name + "_demux.h5ad")
    print(
        "Demutiplexed RNA expression information is written to {output_name}_demux.h5ad .".format(
            output_name=output_name
        )
    )
    io.write_output(genome_indexed_raw_data, output_name + "_demux")
    print(
        "Raw pegasus-format hdf5 file with demultiplexing results is written to {output_name}_demux.h5sc .".format(
            output_name=output_name
        )
    )

    # output summary statistics
    print("\nSummary statistics:")
    print("total\t{}".format(data.shape[0]))
    for name, value in data.obs["demux_type"].value_counts().iteritems():
        print("{}\t{}".format(name, value))
def analyze_one_modality(unidata: UnimodalData, output_name: str, is_raw: bool,
                         append_data: UnimodalData, **kwargs) -> None:
    print()
    logger.info(f"Begin to analyze UnimodalData {unidata.get_uid()}.")

    if is_raw:
        # normailize counts and then transform to log space
        tools.log_norm(unidata, kwargs["norm_count"])

        # select highly variable features
        standardize = False  # if no select HVF, False
        if kwargs["select_hvf"]:
            if unidata.shape[1] <= kwargs["hvf_ngenes"]:
                logger.warning(
                    f"Number of genes {unidata.shape[1]} is no greater than the target number of highly variable features {kwargs['hvf_ngenes']}. HVF selection is omitted."
                )
            else:
                standardize = True
                tools.highly_variable_features(
                    unidata,
                    kwargs["batch_attr"]
                    if kwargs["batch_correction"] else None,
                    flavor=kwargs["hvf_flavor"],
                    n_top=kwargs["hvf_ngenes"],
                    n_jobs=kwargs["n_jobs"],
                )
                if kwargs["hvf_flavor"] == "pegasus":
                    if kwargs["plot_hvf"] is not None:
                        from pegasus.plotting import hvfplot
                        fig = hvfplot(unidata, return_fig=True)
                        fig.savefig(f"{kwargs['plot_hvf']}.hvf.pdf")

        n_pc = min(kwargs["pca_n"], unidata.shape[0], unidata.shape[1])
        if n_pc < kwargs["pca_n"]:
            logger.warning(
                f"UnimodalData {unidata.get_uid()} has either dimension ({unidata.shape[0]}, {unidata.shape[1]}) less than the specified number of PCs {kwargs['pca_n']}. Reduce the number of PCs to {n_pc}."
            )

        # Run PCA irrespect of which batch correction method would apply
        tools.pca(
            unidata,
            n_components=n_pc,
            features="highly_variable_features",
            standardize=standardize,
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
        )
        dim_key = "pca"

        if kwargs["nmf"] or (kwargs["batch_correction"]
                             and kwargs["correction_method"] == "inmf"):
            n_nmf = min(kwargs["nmf_n"], unidata.shape[0], unidata.shape[1])
            if n_nmf < kwargs["nmf_n"]:
                logger.warning(
                    f"UnimodalData {unidata.get_uid()} has either dimension ({unidata.shape[0]}, {unidata.shape[1]}) less than the specified number of NMF components {kwargs['nmf_n']}. Reduce the number of NMF components to {n_nmf}."
                )

        if kwargs["nmf"]:
            if kwargs["batch_correction"] and kwargs[
                    "correction_method"] == "inmf":
                logger.warning(
                    "NMF is skipped because integrative NMF is run instead.")
            else:
                tools.nmf(
                    unidata,
                    n_components=n_nmf,
                    features="highly_variable_features",
                    n_jobs=kwargs["n_jobs"],
                    random_state=kwargs["random_state"],
                )

        if kwargs["batch_correction"]:
            if kwargs["correction_method"] == "harmony":
                dim_key = tools.run_harmony(
                    unidata,
                    batch=kwargs["batch_attr"],
                    rep="pca",
                    n_jobs=kwargs["n_jobs"],
                    n_clusters=kwargs["harmony_nclusters"],
                    random_state=kwargs["random_state"])
            elif kwargs["correction_method"] == "inmf":
                dim_key = tools.integrative_nmf(
                    unidata,
                    batch=kwargs["batch_attr"],
                    n_components=n_nmf,
                    features="highly_variable_features",
                    lam=kwargs["inmf_lambda"],
                    n_jobs=kwargs["n_jobs"],
                    random_state=kwargs["random_state"])
            elif kwargs["correction_method"] == "scanorama":
                dim_key = tools.run_scanorama(
                    unidata,
                    batch=kwargs["batch_attr"],
                    n_components=n_pc,
                    features="highly_variable_features",
                    standardize=standardize,
                    random_state=kwargs["random_state"])
            else:
                raise ValueError(
                    f"Unknown batch correction method {kwargs['correction_method']}!"
                )

        # Find K neighbors
        tools.neighbors(
            unidata,
            K=kwargs["K"],
            rep=dim_key,
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            full_speed=kwargs["full_speed"],
        )

    if kwargs["calc_sigscore"] is not None:
        sig_files = kwargs["calc_sigscore"].split(",")
        for sig_file in sig_files:
            tools.calc_signature_score(unidata, sig_file)

    # calculate diffmap
    if (kwargs["fle"] or kwargs["net_fle"]):
        if not kwargs["diffmap"]:
            print("Turn on --diffmap option!")
        kwargs["diffmap"] = True

    if kwargs["diffmap"]:
        tools.diffmap(
            unidata,
            n_components=kwargs["diffmap_ndc"],
            rep=dim_key,
            solver=kwargs["diffmap_solver"],
            max_t=kwargs["diffmap_maxt"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
        )

    # calculate kBET
    if ("kBET" in kwargs) and kwargs["kBET"]:
        stat_mean, pvalue_mean, accept_rate = tools.calc_kBET(
            unidata,
            kwargs["kBET_batch"],
            rep=dim_key,
            K=kwargs["kBET_K"],
            alpha=kwargs["kBET_alpha"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"])
        print(
            "kBET stat_mean = {:.2f}, pvalue_mean = {:.4f}, accept_rate = {:.2%}."
            .format(stat_mean, pvalue_mean, accept_rate))

    # clustering
    if kwargs["spectral_louvain"]:
        tools.cluster(
            unidata,
            algo="spectral_louvain",
            rep=dim_key,
            resolution=kwargs["spectral_louvain_resolution"],
            rep_kmeans=kwargs["spectral_louvain_basis"],
            n_clusters=kwargs["spectral_louvain_nclusters"],
            n_clusters2=kwargs["spectral_louvain_nclusters2"],
            n_init=kwargs["spectral_louvain_ninit"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            class_label="spectral_louvain_labels",
        )

    if kwargs["spectral_leiden"]:
        tools.cluster(
            unidata,
            algo="spectral_leiden",
            rep=dim_key,
            resolution=kwargs["spectral_leiden_resolution"],
            rep_kmeans=kwargs["spectral_leiden_basis"],
            n_clusters=kwargs["spectral_leiden_nclusters"],
            n_clusters2=kwargs["spectral_leiden_nclusters2"],
            n_init=kwargs["spectral_leiden_ninit"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            class_label="spectral_leiden_labels",
        )

    if kwargs["louvain"]:
        tools.cluster(
            unidata,
            algo="louvain",
            rep=dim_key,
            resolution=kwargs["louvain_resolution"],
            random_state=kwargs["random_state"],
            class_label=kwargs["louvain_class_label"],
        )

    if kwargs["leiden"]:
        tools.cluster(
            unidata,
            algo="leiden",
            rep=dim_key,
            resolution=kwargs["leiden_resolution"],
            n_iter=kwargs["leiden_niter"],
            random_state=kwargs["random_state"],
            class_label=kwargs["leiden_class_label"],
        )

    # visualization
    if kwargs["net_umap"]:
        tools.net_umap(
            unidata,
            rep=dim_key,
            n_jobs=kwargs["n_jobs"],
            n_neighbors=kwargs["umap_K"],
            min_dist=kwargs["umap_min_dist"],
            spread=kwargs["umap_spread"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            full_speed=kwargs["full_speed"],
            net_alpha=kwargs["net_l2"],
            polish_learning_rate=kwargs["net_umap_polish_learing_rate"],
            polish_n_epochs=kwargs["net_umap_polish_nepochs"],
            out_basis=kwargs["net_umap_basis"],
        )

    if kwargs["net_fle"]:
        tools.net_fle(
            unidata,
            output_name,
            n_jobs=kwargs["n_jobs"],
            K=kwargs["fle_K"],
            full_speed=kwargs["full_speed"],
            target_change_per_node=kwargs["fle_target_change_per_node"],
            target_steps=kwargs["fle_target_steps"],
            is3d=False,
            memory=kwargs["fle_memory"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            net_alpha=kwargs["net_l2"],
            polish_target_steps=kwargs["net_fle_polish_target_steps"],
            out_basis=kwargs["net_fle_basis"],
        )

    if kwargs["tsne"]:
        tools.tsne(
            unidata,
            rep=dim_key,
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
            initialization=kwargs["tsne_init"],
        )

    if kwargs["umap"]:
        tools.umap(
            unidata,
            rep=dim_key,
            n_neighbors=kwargs["umap_K"],
            min_dist=kwargs["umap_min_dist"],
            spread=kwargs["umap_spread"],
            n_jobs=kwargs["n_jobs"],
            full_speed=kwargs["full_speed"],
            random_state=kwargs["random_state"],
        )

    if kwargs["fle"]:
        tools.fle(
            unidata,
            output_name,
            n_jobs=kwargs["n_jobs"],
            K=kwargs["fle_K"],
            full_speed=kwargs["full_speed"],
            target_change_per_node=kwargs["fle_target_change_per_node"],
            target_steps=kwargs["fle_target_steps"],
            is3d=False,
            memory=kwargs["fle_memory"],
            random_state=kwargs["random_state"],
        )

    if kwargs["infer_doublets"]:
        channel_attr = "Channel"
        if (channel_attr not in unidata.obs) or (
                unidata.obs["Channel"].cat.categories.size == 1):
            channel_attr = None
        clust_attr = kwargs["dbl_cluster_attr"]
        if (clust_attr is None) or (clust_attr not in unidata.obs):
            clust_attr = None
            for value in [
                    "leiden_labels", "louvain_labels",
                    "spectral_leiden_labels", "spectral_louvain_labels"
            ]:
                if value in unidata.obs:
                    clust_attr = value
                    break

        if channel_attr is not None:
            logger.info(f"For doublet inference, channel_attr={channel_attr}.")
        if clust_attr is not None:
            logger.info(f"For doublet inference, clust_attr={clust_attr}.")

        tools.infer_doublets(
            unidata,
            channel_attr=channel_attr,
            clust_attr=clust_attr,
            expected_doublet_rate=kwargs["expected_doublet_rate"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"],
            plot_hist=output_name)

        dbl_clusts = None
        if clust_attr is not None:
            clusts = []
            for idx, row in unidata.uns["pred_dbl_cluster"].iterrows():
                if row["percentage"] >= 50.0:
                    logger.info(
                        f"Cluster {row['cluster']} (percentage={row['percentage']:.2f}%, q-value={row['qval']:.6g}) is identified as a doublet cluster."
                    )
                    clusts.append(row["cluster"])
            if len(clusts) > 0:
                dbl_clusts = f"{clust_attr}:{','.join(clusts)}"

        tools.mark_doublets(unidata, dbl_clusts=dbl_clusts)

    # calculate diffusion-based pseudotime from roots
    if len(kwargs["pseudotime"]) > 0:
        tools.calc_pseudotime(unidata, kwargs["pseudotime"])

    genome = unidata.uns["genome"]

    if append_data is not None:
        locs = unidata.obs_names.get_indexer(append_data.obs_names)
        idx = locs >= 0
        locs = locs[idx]
        Y = append_data.X[idx, :].tocoo(copy=False)
        Z = coo_matrix((Y.data, (locs[Y.row], Y.col)),
                       shape=(unidata.shape[0], append_data.shape[1])).tocsr()

        idy = Z.getnnz(axis=0) > 0
        n_nonzero = idy.sum()
        if n_nonzero > 0:
            if n_nonzero < append_data.shape[1]:
                Z = Z[:, idy]
                append_df = append_data.feature_metadata.loc[idy, :]
            else:
                append_df = append_data.feature_metadata

            if kwargs["citeseq"]:
                append_df = append_df.copy()
                append_df.index = append_df.index.map(lambda x: f"Ab-{x}")

            rawX = hstack([unidata.get_matrix("counts"), Z], format="csr")

            Zt = Z.astype(np.float32)
            if not kwargs["citeseq"]:
                Zt.data *= np.repeat(unidata.obs["scale"].values,
                                     np.diff(Zt.indptr))
                Zt.data = np.log1p(Zt.data)
            else:
                Zt.data = np.arcsinh(Zt.data / 5.0, dtype=np.float32)

            X = hstack([unidata.get_matrix(unidata.current_matrix()), Zt],
                       format="csr")

            new_genome = unidata.get_genome()
            if new_genome != append_data.get_genome():
                new_genome = f"{new_genome}_and_{append_data.get_genome()}"

            feature_metadata = pd.concat([unidata.feature_metadata, append_df],
                                         axis=0)
            feature_metadata.reset_index(inplace=True)
            _fillna(feature_metadata)
            unidata = UnimodalData(
                unidata.barcode_metadata, feature_metadata, {
                    unidata.current_matrix(): X,
                    "counts": rawX
                }, unidata.uns.mapping, unidata.obsm.mapping,
                unidata.varm.mapping
            )  # uns.mapping, obsm.mapping and varm.mapping are passed by reference
            unidata.uns["genome"] = new_genome

            if kwargs["citeseq"] and kwargs["citeseq_umap"]:
                umap_index = append_df.index.difference(
                    [f"Ab-{x}" for x in kwargs["citeseq_umap_exclude"]])
                unidata.obsm["X_citeseq"] = unidata.X[:,
                                                      unidata.var_names.
                                                      isin(umap_index
                                                           )].toarray()
                tools.umap(
                    unidata,
                    rep="citeseq",
                    n_neighbors=kwargs["umap_K"],
                    min_dist=kwargs["umap_min_dist"],
                    spread=kwargs["umap_spread"],
                    n_jobs=kwargs["n_jobs"],
                    full_speed=kwargs["full_speed"],
                    random_state=kwargs["random_state"],
                    out_basis="citeseq_umap",
                )

    if kwargs["output_h5ad"]:
        import time
        start_time = time.perf_counter()
        adata = unidata.to_anndata()
        if "_tmp_fmat_highly_variable_features" in adata.uns:
            adata.uns["scale.data"] = adata.uns.pop(
                "_tmp_fmat_highly_variable_features")  # assign by reference
            adata.uns["scale.data.rownames"] = unidata.var_names[
                unidata.var["highly_variable_features"] == True].values
        adata.write(f"{output_name}.h5ad", compression="gzip")
        del adata
        end_time = time.perf_counter()
        logger.info(
            f"H5AD file {output_name}.h5ad is written. Time spent = {end_time - start_time:.2f}s."
        )

    # write out results
    if kwargs["output_loom"]:
        write_output(unidata, f"{output_name}.loom")

    # Change genome name back if append_data is True
    if unidata.uns["genome"] != genome:
        unidata.uns["genome"] = genome
    # Eliminate objects starting with _tmp from uns
    unidata.uns.pop("_tmp_fmat_highly_variable_features", None)
Exemple #5
0
def infer_doublets(
    data: MultimodalData,
    channel_attr: Optional[str] = None,
    clust_attr: Optional[str] = None,
    raw_mat_key: Optional[str] = 'counts',
    min_cell: Optional[int] = 100,
    expected_doublet_rate: Optional[float] = None,
    sim_doublet_ratio: Optional[float] = 2.0,
    n_prin_comps: Optional[int] = 30,
    k: Optional[int] = None,
    n_jobs: Optional[int] = -1,
    alpha: Optional[float] = 0.05,
    random_state: Optional[int] = 0,
    plot_hist: Optional[str] = "sample",
    manual_correction: Optional[str] = None,
) -> None:
    """Infer doublets by first calculating Scrublet-like [Wolock18]_ doublet scores and then smartly determining an appropriate doublet score cutoff [Li20-2]_ .

    This function should be called after clustering if clust_attr is not None. In this case, we will test if each cluster is significantly enriched for doublets using Fisher's exact test.

    Parameters
    ----------
    data: ``pegasusio.MultimodalData``
        Annotated data matrix with rows for cells and columns for genes.

    channel_attr: ``str``, optional, default: None
        Attribute indicating sample channels. If set, calculate scrublet-like doublet scores per channel.

    clust_attr: ``str``, optional, default: None
        Attribute indicating cluster labels. If set, estimate proportion of doublets in each cluster and statistical significance.

    min_cell: ``int``, optional, default: 100
        Minimum number of cells per sample to calculate doublet scores. For samples having less than 'min_cell' cells, doublet score calculation will be skipped.

    expected_doublet_rate: ``float``, optional, default: ``None``
        The expected doublet rate for the experiment. By default, calculate the expected rate based on number of cells from the 10x multiplet rate table

    sim_doublet_ratio: ``float``, optional, default: ``2.0``
        The ratio between synthetic doublets and observed cells.

    n_prin_comps: ``int``, optional, default: ``30``
        Number of principal components.

    k: ``int``, optional, default: ``None``
        Number of observed cell neighbors. If None, k = round(0.5 * sqrt(number of observed cells)). Total neighbors k_adj = round(k * (1.0 + sim_doublet_ratio)).

    n_jobs: ``int``, optional, default: ``-1``
        Number of threads to use. If ``-1``, use all physical CPU cores.

    alpha: ``float``, optional, default: ``0.05``
        FDR significant level for cluster-level fisher exact test.

    random_state: ``int``, optional, default: ``0``
        Random seed for reproducing results.

    plot_hist: ``str``, optional, default: ``sample``
        If not None, plot diagnostic histograms using ``plot_hist`` as the prefix. If `channel_attr` is None, ``plot_hist.dbl.png`` is generated; Otherwise, ``plot_hist.channel_name.dbl.png`` files are generated. Each figure consists of 4 panels showing histograms of doublet scores for observed cells (panel 1, density in log scale), simulated doublets (panel 2, density in log scale), KDE plot (panel 3) and signed curvature plot (panel 4) of log doublet scores for simulated doublets. Each plot contains two dashed lines. The red dashed line represents the theoretical cutoff (calucalted based on number of cells and 10x doublet table) and the black dashed line represents the cutof inferred from the data.
    
    manual_correction: ``str``, optional, default: ``None``
        Use human guide to correct doublet threshold for certain channels. This is string representing a comma-separately list. Each item in the list represent one sample and the sample name and correction guide are separated using ':'. The only correction guide supported is 'peak', which means cut at the center of the peak. If only one sample available, use '' as the sample name.

    Returns
    -------
    ``None``

    Update ``data.obs``:
        * ``data.obs['pred_dbl']``: Predicted singlet/doublet types.

        * ``data.uns['pred_dbl_cluster']``: Only generated if 'clust_attr' is not None. This is a dataframe with two columns, 'Cluster' and 'Qval'. Only clusters with significantly more doublets than expected will be recorded here.

    Examples
    --------
    >>> pg.infer_doublets(data, channel_attr = 'Channel', clust_attr = 'Annotation')
    """
    assert data.get_modality() == "rna"
    try:
        rawX = data.get_matrix(raw_mat_key)
    except ValueError:
        raise ValueError(
            f"Cannot detect the raw count matrix {raw_mat_key}; stop inferring doublets!"
        )

    if_plot = plot_hist is not None

    mancor = {}
    if manual_correction is not None:
        for item in manual_correction.split(','):
            name, action = item.split(':')
            mancor[name] = action

    if channel_attr is None:
        if data.shape[0] >= min_cell:
            fig = _run_scrublet(data, raw_mat_key, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
                                n_prin_comps = n_prin_comps, k = k, n_jobs = n_jobs, random_state = random_state, plot_hist = if_plot, manual_correction = mancor.get('', None))
            if if_plot:
                fig.savefig(f"{plot_hist}.dbl.png")
        else:
            logger.warning(
                f"Data has {data.shape[0]} < {min_cell} cells and thus doublet score calculation is skipped!"
            )
            data.obs["doublet_score"] = 0.0
            data.obs["pred_dbl"] = False
    else:
        from pandas.api.types import is_categorical_dtype
        from pegasus.tools import identify_robust_genes, log_norm, highly_variable_features

        assert is_categorical_dtype(data.obs[channel_attr])
        genome = data.get_genome()
        modality = data.get_modality()
        channels = data.obs[channel_attr].cat.categories

        dbl_score = np.zeros(data.shape[0], dtype=np.float32)
        pred_dbl = np.zeros(data.shape[0], dtype=np.bool_)
        thresholds = {}
        for channel in channels:
            # Generate a new unidata object for the channel
            idx = np.where(data.obs[channel_attr] == channel)[0]
            if idx.size >= min_cell:
                unidata = UnimodalData({"barcodekey": data.obs_names[idx]},
                                       {"featurekey": data.var_names},
                                       {"counts": rawX[idx]}, {
                                           "genome": genome,
                                           "modality": modality
                                       },
                                       cur_matrix="counts")
                # Identify robust genes, count and log normalized and select top 2,000 highly variable features
                identify_robust_genes(unidata)
                log_norm(unidata)
                highly_variable_features(unidata)
                # Run _run_scrublet
                fig = _run_scrublet(unidata, raw_mat_key, name = channel, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
                                    n_prin_comps = n_prin_comps, k = k, n_jobs = n_jobs, random_state = random_state, plot_hist = if_plot, manual_correction = mancor.get(channel, None))
                if if_plot:
                    fig.savefig(f"{plot_hist}.{channel}.dbl.png")

                dbl_score[idx] = unidata.obs["doublet_score"].values
                pred_dbl[idx] = unidata.obs["pred_dbl"].values
                thresholds[channel] = unidata.uns["doublet_threshold"]
            else:
                logger.warning(
                    f"Channel {channel} has {idx.size} < {min_cell} cells and thus doublet score calculation is skipped!"
                )

        data.obs["doublet_score"] = dbl_score
        data.obs["pred_dbl"] = pred_dbl
        data.uns["doublet_thresholds"] = thresholds

    if clust_attr is not None:
        data.uns["pred_dbl_cluster"] = _identify_doublets_fisher(
            data.obs[clust_attr].values,
            data.obs["pred_dbl"].values,
            alpha=alpha)

    logger.info('Doublets are predicted!')
Exemple #6
0
def analyze_one_modality(unidata: UnimodalData, output_name: str, is_raw: bool,
                         append_data: UnimodalData, **kwargs) -> None:
    print()
    logger.info(f"Begin to analyze UnimodalData {unidata.get_uid()}.")
    if kwargs["channel_attr"] is not None:
        unidata.obs["Channel"] = unidata.obs[kwargs["channel_attr"]]

    if is_raw:
        # normailize counts and then transform to log space
        tools.log_norm(unidata, kwargs["norm_count"])
        # set group attribute
        if kwargs["batch_correction"] and kwargs["group_attribute"] is not None:
            tools.set_group_attribute(unidata, kwargs["group_attribute"])

    # select highly variable features
    standardize = False  # if no select HVF, False
    if kwargs["select_hvf"]:
        if unidata.shape[1] <= kwargs["hvf_ngenes"]:
            logger.warning(
                f"Number of genes {unidata.shape[1]} is no greater than the target number of highly variable features {kwargs['hvf_ngenes']}. HVF selection is omitted."
            )
        else:
            standardize = True
            tools.highly_variable_features(
                unidata,
                kwargs["batch_correction"],
                flavor=kwargs["hvf_flavor"],
                n_top=kwargs["hvf_ngenes"],
                n_jobs=kwargs["n_jobs"],
            )
            if kwargs["hvf_flavor"] == "pegasus":
                if kwargs["plot_hvf"] is not None:
                    from pegasus.plotting import hvfplot
                    fig = hvfplot(unidata, return_fig=True)
                    fig.savefig(f"{kwargs['plot_hvf']}.hvf.pdf")

    # batch correction: L/S
    if kwargs["batch_correction"] and kwargs["correction_method"] == "L/S":
        tools.correct_batch(unidata, features="highly_variable_features")

    if kwargs["calc_sigscore"] is not None:
        sig_files = kwargs["calc_sigscore"].split(",")
        for sig_file in sig_files:
            tools.calc_signature_score(unidata, sig_file)

    n_pc = min(kwargs["pca_n"], unidata.shape[0], unidata.shape[1])
    if n_pc < kwargs["pca_n"]:
        logger.warning(
            f"UnimodalData {unidata.get_uid()} has either dimension ({unidata.shape[0]}, {unidata.shape[1]}) less than the specified number of PCs {kwargs['pca_n']}. Reduce the number of PCs to {n_pc}."
        )

    if kwargs["batch_correction"] and kwargs[
            "correction_method"] == "scanorama":
        pca_key = tools.run_scanorama(unidata,
                                      n_components=n_pc,
                                      features="highly_variable_features",
                                      standardize=standardize,
                                      random_state=kwargs["random_state"])
    else:
        # PCA
        tools.pca(
            unidata,
            n_components=n_pc,
            features="highly_variable_features",
            standardize=standardize,
            robust=kwargs["pca_robust"],
            random_state=kwargs["random_state"],
        )
        pca_key = "pca"

    # batch correction: Harmony
    if kwargs["batch_correction"] and kwargs["correction_method"] == "harmony":
        pca_key = tools.run_harmony(unidata,
                                    rep="pca",
                                    n_jobs=kwargs["n_jobs"],
                                    n_clusters=kwargs["harmony_nclusters"],
                                    random_state=kwargs["random_state"])

    # Find K neighbors
    tools.neighbors(
        unidata,
        K=kwargs["K"],
        rep=pca_key,
        n_jobs=kwargs["n_jobs"],
        random_state=kwargs["random_state"],
        full_speed=kwargs["full_speed"],
    )

    # calculate diffmap
    if (kwargs["fle"] or kwargs["net_fle"]):
        if not kwargs["diffmap"]:
            print("Turn on --diffmap option!")
        kwargs["diffmap"] = True

    if kwargs["diffmap"]:
        tools.diffmap(
            unidata,
            n_components=kwargs["diffmap_ndc"],
            rep=pca_key,
            solver=kwargs["diffmap_solver"],
            random_state=kwargs["random_state"],
            max_t=kwargs["diffmap_maxt"],
        )
        if kwargs["diffmap_to_3d"]:
            tools.reduce_diffmap_to_3d(unidata,
                                       random_state=kwargs["random_state"])

    # calculate kBET
    if ("kBET" in kwargs) and kwargs["kBET"]:
        stat_mean, pvalue_mean, accept_rate = tools.calc_kBET(
            unidata,
            kwargs["kBET_batch"],
            rep=pca_key,
            K=kwargs["kBET_K"],
            alpha=kwargs["kBET_alpha"],
            n_jobs=kwargs["n_jobs"],
            random_state=kwargs["random_state"])
        print(
            "kBET stat_mean = {:.2f}, pvalue_mean = {:.4f}, accept_rate = {:.2%}."
            .format(stat_mean, pvalue_mean, accept_rate))

    # clustering
    if kwargs["spectral_louvain"]:
        tools.cluster(
            unidata,
            algo="spectral_louvain",
            rep=pca_key,
            resolution=kwargs["spectral_louvain_resolution"],
            rep_kmeans=kwargs["spectral_louvain_basis"],
            n_clusters=kwargs["spectral_louvain_nclusters"],
            n_clusters2=kwargs["spectral_louvain_nclusters2"],
            n_init=kwargs["spectral_louvain_ninit"],
            random_state=kwargs["random_state"],
            class_label="spectral_louvain_labels",
        )

    if kwargs["spectral_leiden"]:
        tools.cluster(
            unidata,
            algo="spectral_leiden",
            rep=pca_key,
            resolution=kwargs["spectral_leiden_resolution"],
            rep_kmeans=kwargs["spectral_leiden_basis"],
            n_clusters=kwargs["spectral_leiden_nclusters"],
            n_clusters2=kwargs["spectral_leiden_nclusters2"],
            n_init=kwargs["spectral_leiden_ninit"],
            random_state=kwargs["random_state"],
            class_label="spectral_leiden_labels",
        )

    if kwargs["louvain"]:
        tools.cluster(
            unidata,
            algo="louvain",
            rep=pca_key,
            resolution=kwargs["louvain_resolution"],
            random_state=kwargs["random_state"],
            class_label=kwargs["louvain_class_label"],
        )

    if kwargs["leiden"]:
        tools.cluster(
            unidata,
            algo="leiden",
            rep=pca_key,
            resolution=kwargs["leiden_resolution"],
            n_iter=kwargs["leiden_niter"],
            random_state=kwargs["random_state"],
            class_label=kwargs["leiden_class_label"],
        )

    # visualization
    if kwargs["net_tsne"]:
        tools.net_tsne(
            unidata,
            rep=pca_key,
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            net_alpha=kwargs["net_l2"],
            polish_learning_frac=kwargs["net_tsne_polish_learing_frac"],
            polish_n_iter=kwargs["net_tsne_polish_niter"],
            out_basis=kwargs["net_tsne_basis"],
        )

    if kwargs["net_umap"]:
        tools.net_umap(
            unidata,
            rep=pca_key,
            n_jobs=kwargs["n_jobs"],
            n_neighbors=kwargs["umap_K"],
            min_dist=kwargs["umap_min_dist"],
            spread=kwargs["umap_spread"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            full_speed=kwargs["full_speed"],
            net_alpha=kwargs["net_l2"],
            polish_learning_rate=kwargs["net_umap_polish_learing_rate"],
            polish_n_epochs=kwargs["net_umap_polish_nepochs"],
            out_basis=kwargs["net_umap_basis"],
        )

    if kwargs["net_fle"]:
        tools.net_fle(
            unidata,
            output_name,
            n_jobs=kwargs["n_jobs"],
            K=kwargs["fle_K"],
            full_speed=kwargs["full_speed"],
            target_change_per_node=kwargs["fle_target_change_per_node"],
            target_steps=kwargs["fle_target_steps"],
            is3d=False,
            memory=kwargs["fle_memory"],
            random_state=kwargs["random_state"],
            select_frac=kwargs["net_ds_frac"],
            select_K=kwargs["net_ds_K"],
            select_alpha=kwargs["net_ds_alpha"],
            net_alpha=kwargs["net_l2"],
            polish_target_steps=kwargs["net_fle_polish_target_steps"],
            out_basis=kwargs["net_fle_basis"],
        )

    if kwargs["tsne"]:
        tools.tsne(
            unidata,
            rep=pca_key,
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
        )

    if kwargs["fitsne"]:
        tools.fitsne(
            unidata,
            rep=pca_key,
            n_jobs=kwargs["n_jobs"],
            perplexity=kwargs["tsne_perplexity"],
            random_state=kwargs["random_state"],
        )

    if kwargs["umap"]:
        tools.umap(
            unidata,
            rep=pca_key,
            n_neighbors=kwargs["umap_K"],
            min_dist=kwargs["umap_min_dist"],
            spread=kwargs["umap_spread"],
            random_state=kwargs["random_state"],
        )

    if kwargs["fle"]:
        tools.fle(
            unidata,
            output_name,
            n_jobs=kwargs["n_jobs"],
            K=kwargs["fle_K"],
            full_speed=kwargs["full_speed"],
            target_change_per_node=kwargs["fle_target_change_per_node"],
            target_steps=kwargs["fle_target_steps"],
            is3d=False,
            memory=kwargs["fle_memory"],
            random_state=kwargs["random_state"],
        )

    # calculate diffusion-based pseudotime from roots
    if len(kwargs["pseudotime"]) > 0:
        tools.calc_pseudotime(unidata, kwargs["pseudotime"])

    genome = unidata.uns["genome"]

    if append_data is not None:
        locs = unidata.obs_names.get_indexer(append_data.obs_names)
        idx = locs >= 0
        locs = locs[idx]
        Y = append_data.X[idx, :].tocoo(copy=False)
        Z = coo_matrix((Y.data, (locs[Y.row], Y.col)),
                       shape=(unidata.shape[0], append_data.shape[1])).tocsr()

        idy = Z.getnnz(axis=0) > 0
        n_nonzero = idy.sum()
        if n_nonzero > 0:
            if n_nonzero < append_data.shape[1]:
                Z = Z[:, idy]
                append_df = append_data.feature_metadata.loc[idy, :]
            else:
                append_df = append_data.feature_metadata

            rawX = hstack([unidata.get_matrix("raw.X"), Z], format="csr")

            Zt = Z.astype(np.float32)
            Zt.data *= np.repeat(unidata.obs["scale"].values,
                                 np.diff(Zt.indptr))
            Zt.data = np.log1p(Zt.data)

            X = hstack([unidata.get_matrix("X"), Zt], format="csr")

            new_genome = unidata.get_genome(
            ) + "_and_" + append_data.get_genome()

            feature_metadata = pd.concat([unidata.feature_metadata, append_df],
                                         axis=0)
            feature_metadata.reset_index(inplace=True)
            feature_metadata.fillna(value=_get_fillna_dict(
                unidata.feature_metadata),
                                    inplace=True)

            unidata = UnimodalData(
                unidata.barcode_metadata, feature_metadata, {
                    "X": X,
                    "raw.X": rawX
                }, unidata.uns.mapping, unidata.obsm.mapping,
                unidata.varm.mapping
            )  # uns.mapping, obsm.mapping and varm.mapping are passed by reference
            unidata.uns["genome"] = new_genome

    if kwargs["output_h5ad"]:
        adata = unidata.to_anndata()
        adata.uns["scale.data"] = adata.uns.pop(
            "_tmp_fmat_highly_variable_features")  # assign by reference
        adata.uns["scale.data.rownames"] = unidata.var_names[
            unidata.var["highly_variable_features"]].values
        adata.write(f"{output_name}.h5ad", compression="gzip")
        del adata

    # write out results
    if kwargs["output_loom"]:
        write_output(unidata, f"{output_name}.loom")

    # Change genome name back if append_data is True
    if unidata.uns["genome"] != genome:
        unidata.uns["genome"] = genome
    # Eliminate objects starting with fmat_ from uns
    unidata.uns.pop("_tmp_fmat_highly_variable_features", None)