コード例 #1
0
 def __init__(self, fileName, cellLabel, cellLabelColor, rawCount = True):
     self.adata = st.read(file_name = fileName, file_format = 'mtx')
     st.add_cell_labels(self.adata, file_name = cellLabel)
     st.add_cell_colors(self.adata, file_name = cellLabelColor)
     self.adata.var_names_make_unique()
     self.adata.obs_names_make_unique()
     self.allCells = self.adata.obs.index.to_list()
     self.allGenes = self.adata.var.index.to_list()
     print('Raw input parsed...')
     print(self.adata)
     self.nCells = self.adata.n_obs
     self.nGenes = self.adata.n_vars
     self._keepCurrentRecords()
     st.remove_mt_genes(self.adata)
     if rawCount:
         st.normalize_per_cell(self.adata)
         st.log_transform(self.adata)
     self.backupDict = {}
     self.backupKey = 0
     self.backup(0)
     print('Initial backup saved with key: 0')
     print('Restore with self.restoreFromBackup()')
コード例 #2
0
def main():
    sns.set_style('white')
    sns.set_context('poster')
    parser = argparse.ArgumentParser(
        description='%s Parameters' % __tool_name__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-m",
                        "--matrix",
                        dest="input_filename",
                        default=None,
                        help="input file name",
                        metavar="FILE")
    parser.add_argument("-l",
                        "--cell_labels",
                        dest="cell_label_filename",
                        default=None,
                        help="filename of cell labels")
    parser.add_argument("-c",
                        "--cell_labels_colors",
                        dest="cell_label_color_filename",
                        default=None,
                        help="filename of cell label colors")
    parser.add_argument(
        "-s",
        "--select_features",
        dest="s_method",
        default='LOESS',
        help=
        "LOESS,PCA or all: Select variable genes using LOESS or principal components using PCA or all the genes are kept"
    )
    parser.add_argument("--TG",
                        "--detect_TG_genes",
                        dest="flag_gene_TG_detection",
                        action="store_true",
                        help="detect transition genes automatically")
    parser.add_argument("--DE",
                        "--detect_DE_genes",
                        dest="flag_gene_DE_detection",
                        action="store_true",
                        help="detect DE genes automatically")
    parser.add_argument("--LG",
                        "--detect_LG_genes",
                        dest="flag_gene_LG_detection",
                        action="store_true",
                        help="detect leaf genes automatically")
    parser.add_argument(
        "-g",
        "--genes",
        dest="genes",
        default=None,
        help=
        "genes to visualize, it can either be filename which contains all the genes in one column or a set of gene names separated by comma"
    )
    parser.add_argument(
        "-p",
        "--use_precomputed",
        dest="use_precomputed",
        action="store_true",
        help=
        "use precomputed data files without re-computing structure learning part"
    )
    parser.add_argument("--new",
                        dest="new_filename",
                        default=None,
                        help="file name of data to be mapped")
    parser.add_argument("--new_l",
                        dest="new_label_filename",
                        default=None,
                        help="filename of new cell labels")
    parser.add_argument("--new_c",
                        dest="new_label_color_filename",
                        default=None,
                        help="filename of new cell label colors")
    parser.add_argument("--log2",
                        dest="flag_log2",
                        action="store_true",
                        help="perform log2 transformation")
    parser.add_argument("--norm",
                        dest="flag_norm",
                        action="store_true",
                        help="normalize data based on library size")
    parser.add_argument("--atac",
                        dest="flag_atac",
                        action="store_true",
                        help="indicate scATAC-seq data")
    parser.add_argument(
        "--n_jobs",
        dest="n_jobs",
        type=int,
        default=1,
        help="Specify the number of processes to use. (default, 1")
    parser.add_argument(
        "--loess_frac",
        dest="loess_frac",
        type=float,
        default=0.1,
        help="The fraction of the data used in LOESS regression")
    parser.add_argument(
        "--loess_cutoff",
        dest="loess_cutoff",
        type=int,
        default=95,
        help=
        "the percentile used in variable gene selection based on LOESS regression"
    )
    parser.add_argument("--pca_first_PC",
                        dest="flag_first_PC",
                        action="store_true",
                        help="keep first PC")
    parser.add_argument("--pca_n_PC",
                        dest="pca_n_PC",
                        type=int,
                        default=15,
                        help="The number of selected PCs,it's 15 by default")
    parser.add_argument(
        "--dr_method",
        dest="dr_method",
        default='se',
        help=
        "Method used for dimension reduction. Choose from {{'se','mlle','umap','pca'}}"
    )
    parser.add_argument("--n_neighbors",
                        dest="n_neighbors",
                        type=float,
                        default=50,
                        help="The number of neighbor cells")
    parser.add_argument(
        "--nb_pct",
        dest="nb_pct",
        type=float,
        default=None,
        help=
        "The percentage of neighbor cells (when sepcified, it will overwrite n_neighbors)."
    )
    parser.add_argument("--n_components",
                        dest="n_components",
                        type=int,
                        default=3,
                        help="Number of components to keep.")
    parser.add_argument(
        "--clustering",
        dest="clustering",
        default='kmeans',
        help=
        "Clustering method used for seeding the intial structure, choose from 'ap','kmeans','sc'"
    )
    parser.add_argument("--damping",
                        dest="damping",
                        type=float,
                        default=0.75,
                        help="Affinity Propagation: damping factor")
    parser.add_argument(
        "--n_clusters",
        dest="n_clusters",
        type=int,
        default=10,
        help="Number of clusters for spectral clustering or kmeans")
    parser.add_argument("--EPG_n_nodes",
                        dest="EPG_n_nodes",
                        type=int,
                        default=50,
                        help=" Number of nodes for elastic principal graph")
    parser.add_argument(
        "--EPG_lambda",
        dest="EPG_lambda",
        type=float,
        default=0.02,
        help="lambda parameter used to compute the elastic energy")
    parser.add_argument("--EPG_mu",
                        dest="EPG_mu",
                        type=float,
                        default=0.1,
                        help="mu parameter used to compute the elastic energy")
    parser.add_argument(
        "--EPG_trimmingradius",
        dest="EPG_trimmingradius",
        type=float,
        default=np.inf,
        help="maximal distance of point from a node to affect its embedment")
    parser.add_argument(
        "--EPG_alpha",
        dest="EPG_alpha",
        type=float,
        default=0.02,
        help=
        "positive numeric, the value of the alpha parameter of the penalized elastic energy"
    )
    parser.add_argument("--EPG_collapse",
                        dest="flag_EPG_collapse",
                        action="store_true",
                        help="collapsing small branches")
    parser.add_argument(
        "--EPG_collapse_mode",
        dest="EPG_collapse_mode",
        default="PointNumber",
        help=
        "the mode used to collapse branches. PointNumber,PointNumber_Extrema, PointNumber_Leaves,EdgesNumber or EdgesLength"
    )
    parser.add_argument(
        "--EPG_collapse_par",
        dest="EPG_collapse_par",
        type=float,
        default=5,
        help=
        "positive numeric, the cotrol paramter used for collapsing small branches"
    )
    parser.add_argument("--disable_EPG_optimize",
                        dest="flag_disable_EPG_optimize",
                        action="store_true",
                        help="disable optimizing branching")
    parser.add_argument("--EPG_shift",
                        dest="flag_EPG_shift",
                        action="store_true",
                        help="shift branching point ")
    parser.add_argument(
        "--EPG_shift_mode",
        dest="EPG_shift_mode",
        default='NodeDensity',
        help=
        "the mode to use to shift the branching points NodePoints or NodeDensity"
    )
    parser.add_argument(
        "--EPG_shift_DR",
        dest="EPG_shift_DR",
        type=float,
        default=0.05,
        help=
        "positive numeric, the radius to be used when computing point density if EPG_shift_mode is NodeDensity"
    )
    parser.add_argument(
        "--EPG_shift_maxshift",
        dest="EPG_shift_maxshift",
        type=int,
        default=5,
        help=
        "positive integer, the maxium distance (as number of edges) to consider when exploring the branching point neighborhood"
    )
    parser.add_argument("--disable_EPG_ext",
                        dest="flag_disable_EPG_ext",
                        action="store_true",
                        help="disable extending leaves with additional nodes")
    parser.add_argument(
        "--EPG_ext_mode",
        dest="EPG_ext_mode",
        default='QuantDists',
        help=
        " the mode used to extend the graph,QuantDists, QuantCentroid or WeigthedCentroid"
    )
    parser.add_argument(
        "--EPG_ext_par",
        dest="EPG_ext_par",
        type=float,
        default=0.5,
        help=
        "the control parameter used for contribution of the different data points when extending leaves with nodes"
    )
    parser.add_argument("--DE_zscore_cutoff",
                        dest="DE_zscore_cutoff",
                        default=2,
                        help="Differentially Expressed Genes z-score cutoff")
    parser.add_argument(
        "--DE_logfc_cutoff",
        dest="DE_logfc_cutoff",
        default=0.25,
        help="Differentially Expressed Genes log fold change cutoff")
    parser.add_argument("--TG_spearman_cutoff",
                        dest="TG_spearman_cutoff",
                        default=0.4,
                        help="Transition Genes Spearman correlation cutoff")
    parser.add_argument("--TG_logfc_cutoff",
                        dest="TG_logfc_cutoff",
                        default=0.25,
                        help="Transition Genes log fold change cutoff")
    parser.add_argument("--LG_zscore_cutoff",
                        dest="LG_zscore_cutoff",
                        default=1.5,
                        help="Leaf Genes z-score cutoff")
    parser.add_argument("--LG_pvalue_cutoff",
                        dest="LG_pvalue_cutoff",
                        default=1e-2,
                        help="Leaf Genes p value cutoff")
    parser.add_argument(
        "--umap",
        dest="flag_umap",
        action="store_true",
        help="whether to use UMAP for visualization (default: No)")
    parser.add_argument("-r",
                        dest="root",
                        default=None,
                        help="root node for subwaymap_plot and stream_plot")
    parser.add_argument("--stream_log_view",
                        dest="flag_stream_log_view",
                        action="store_true",
                        help="use log2 scale for y axis of stream_plot")
    parser.add_argument("-o",
                        "--output_folder",
                        dest="output_folder",
                        default=None,
                        help="Output folder")
    parser.add_argument("--for_web",
                        dest="flag_web",
                        action="store_true",
                        help="Output files for website")
    parser.add_argument(
        "--n_genes",
        dest="n_genes",
        type=int,
        default=5,
        help=
        "Number of top genes selected from each output marker gene file for website gene visualization"
    )

    args = parser.parse_args()
    if (args.input_filename is None) and (args.new_filename is None):
        parser.error("at least one of -m, --new required")

    new_filename = args.new_filename
    new_label_filename = args.new_label_filename
    new_label_color_filename = args.new_label_color_filename
    flag_stream_log_view = args.flag_stream_log_view
    flag_gene_TG_detection = args.flag_gene_TG_detection
    flag_gene_DE_detection = args.flag_gene_DE_detection
    flag_gene_LG_detection = args.flag_gene_LG_detection
    flag_web = args.flag_web
    flag_first_PC = args.flag_first_PC
    flag_umap = args.flag_umap
    genes = args.genes
    DE_zscore_cutoff = args.DE_zscore_cutoff
    DE_logfc_cutoff = args.DE_logfc_cutoff
    TG_spearman_cutoff = args.TG_spearman_cutoff
    TG_logfc_cutoff = args.TG_logfc_cutoff
    LG_zscore_cutoff = args.LG_zscore_cutoff
    LG_pvalue_cutoff = args.LG_pvalue_cutoff
    root = args.root

    input_filename = args.input_filename
    cell_label_filename = args.cell_label_filename
    cell_label_color_filename = args.cell_label_color_filename
    s_method = args.s_method
    use_precomputed = args.use_precomputed
    n_jobs = args.n_jobs
    loess_frac = args.loess_frac
    loess_cutoff = args.loess_cutoff
    pca_n_PC = args.pca_n_PC
    flag_log2 = args.flag_log2
    flag_norm = args.flag_norm
    flag_atac = args.flag_atac
    dr_method = args.dr_method
    nb_pct = args.nb_pct  # neighbour percent
    n_neighbors = args.n_neighbors
    n_components = args.n_components  #number of components to keep
    clustering = args.clustering
    damping = args.damping
    n_clusters = args.n_clusters
    EPG_n_nodes = args.EPG_n_nodes
    EPG_lambda = args.EPG_lambda
    EPG_mu = args.EPG_mu
    EPG_trimmingradius = args.EPG_trimmingradius
    EPG_alpha = args.EPG_alpha
    flag_EPG_collapse = args.flag_EPG_collapse
    EPG_collapse_mode = args.EPG_collapse_mode
    EPG_collapse_par = args.EPG_collapse_par
    flag_EPG_shift = args.flag_EPG_shift
    EPG_shift_mode = args.EPG_shift_mode
    EPG_shift_DR = args.EPG_shift_DR
    EPG_shift_maxshift = args.EPG_shift_maxshift
    flag_disable_EPG_optimize = args.flag_disable_EPG_optimize
    flag_disable_EPG_ext = args.flag_disable_EPG_ext
    EPG_ext_mode = args.EPG_ext_mode
    EPG_ext_par = args.EPG_ext_par
    output_folder = args.output_folder  #work directory
    n_genes = args.n_genes

    if (flag_web):
        flag_savefig = False
    else:
        flag_savefig = True
    gene_list = []
    if (genes != None):
        if (os.path.exists(genes)):
            gene_list = pd.read_csv(genes,
                                    sep='\t',
                                    header=None,
                                    index_col=None,
                                    compression='gzip' if genes.split('.')[-1]
                                    == 'gz' else None).iloc[:, 0].tolist()
            gene_list = list(set(gene_list))
        else:
            gene_list = genes.split(',')
        print('Genes to visualize: ')
        print(gene_list)
    if (new_filename is None):
        if (output_folder == None):
            workdir = os.path.join(os.getcwd(), 'stream_result')
        else:
            workdir = output_folder
        if (use_precomputed):
            print('Importing the precomputed pkl file...')
            adata = st.read(file_name='stream_result.pkl',
                            file_format='pkl',
                            file_path=workdir,
                            workdir=workdir)
        else:
            if (flag_atac):
                print('Reading in atac zscore matrix...')
                adata = st.read(file_name=input_filename,
                                workdir=workdir,
                                experiment='atac-seq')
            else:
                adata = st.read(file_name=input_filename, workdir=workdir)
                print('Input: ' + str(adata.obs.shape[0]) + ' cells, ' +
                      str(adata.var.shape[0]) + ' genes')
            adata.var_names_make_unique()
            adata.obs_names_make_unique()
            if (cell_label_filename != None):
                st.add_cell_labels(adata, file_name=cell_label_filename)
            else:
                st.add_cell_labels(adata)
            if (cell_label_color_filename != None):
                st.add_cell_colors(adata, file_name=cell_label_color_filename)
            else:
                st.add_cell_colors(adata)
            if (flag_atac):
                print('Selecting top principal components...')
                st.select_top_principal_components(adata,
                                                   n_pc=pca_n_PC,
                                                   first_pc=flag_first_PC,
                                                   save_fig=True)
                st.dimension_reduction(adata,
                                       method=dr_method,
                                       n_components=n_components,
                                       n_neighbors=n_neighbors,
                                       nb_pct=nb_pct,
                                       n_jobs=n_jobs,
                                       feature='top_pcs')
            else:
                if (flag_norm):
                    st.normalize_per_cell(adata)
                if (flag_log2):
                    st.log_transform(adata)
                if (s_method != 'all'):
                    print('Filtering genes...')
                    st.filter_genes(adata, min_num_cells=5)
                    print('Removing mitochondrial genes...')
                    st.remove_mt_genes(adata)
                    if (s_method == 'LOESS'):
                        print('Selecting most variable genes...')
                        st.select_variable_genes(adata,
                                                 loess_frac=loess_frac,
                                                 percentile=loess_cutoff,
                                                 save_fig=True)
                        pd.DataFrame(adata.uns['var_genes']).to_csv(
                            os.path.join(workdir,
                                         'selected_variable_genes.tsv'),
                            sep='\t',
                            index=None,
                            header=False)
                        st.dimension_reduction(adata,
                                               method=dr_method,
                                               n_components=n_components,
                                               n_neighbors=n_neighbors,
                                               nb_pct=nb_pct,
                                               n_jobs=n_jobs,
                                               feature='var_genes')
                    if (s_method == 'PCA'):
                        print('Selecting top principal components...')
                        st.select_top_principal_components(
                            adata,
                            n_pc=pca_n_PC,
                            first_pc=flag_first_PC,
                            save_fig=True)
                        st.dimension_reduction(adata,
                                               method=dr_method,
                                               n_components=n_components,
                                               n_neighbors=n_neighbors,
                                               nb_pct=nb_pct,
                                               n_jobs=n_jobs,
                                               feature='top_pcs')
                else:
                    print('Keep all the genes...')
                    st.dimension_reduction(adata,
                                           n_components=n_components,
                                           n_neighbors=n_neighbors,
                                           nb_pct=nb_pct,
                                           n_jobs=n_jobs,
                                           feature='all')
            st.plot_dimension_reduction(adata, save_fig=flag_savefig)
            st.seed_elastic_principal_graph(adata,
                                            clustering=clustering,
                                            damping=damping,
                                            n_clusters=n_clusters)
            st.plot_branches(
                adata,
                save_fig=flag_savefig,
                fig_name='seed_elastic_principal_graph_skeleton.pdf')
            st.plot_branches_with_cells(
                adata,
                save_fig=flag_savefig,
                fig_name='seed_elastic_principal_graph.pdf')

            st.elastic_principal_graph(adata,
                                       epg_n_nodes=EPG_n_nodes,
                                       epg_lambda=EPG_lambda,
                                       epg_mu=EPG_mu,
                                       epg_trimmingradius=EPG_trimmingradius,
                                       epg_alpha=EPG_alpha)
            st.plot_branches(adata,
                             save_fig=flag_savefig,
                             fig_name='elastic_principal_graph_skeleton.pdf')
            st.plot_branches_with_cells(adata,
                                        save_fig=flag_savefig,
                                        fig_name='elastic_principal_graph.pdf')
            if (not flag_disable_EPG_optimize):
                st.optimize_branching(adata,
                                      epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='optimizing_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='optimizing_elastic_principal_graph.pdf')
            if (flag_EPG_shift):
                st.shift_branching(adata,
                                   epg_shift_mode=EPG_shift_mode,
                                   epg_shift_radius=EPG_shift_DR,
                                   epg_shift_max=EPG_shift_maxshift,
                                   epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='shifting_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='shifting_elastic_principal_graph.pdf')
            if (flag_EPG_collapse):
                st.prune_elastic_principal_graph(
                    adata,
                    epg_collapse_mode=EPG_collapse_mode,
                    epg_collapse_par=EPG_collapse_par,
                    epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='pruning_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='pruning_elastic_principal_graph.pdf')
            if (not flag_disable_EPG_ext):
                st.extend_elastic_principal_graph(
                    adata,
                    epg_ext_mode=EPG_ext_mode,
                    epg_ext_par=EPG_ext_par,
                    epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='extending_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='extending_elastic_principal_graph.pdf')
            st.plot_branches(
                adata,
                save_fig=flag_savefig,
                fig_name='finalized_elastic_principal_graph_skeleton.pdf')
            st.plot_branches_with_cells(
                adata,
                save_fig=flag_savefig,
                fig_name='finalized_elastic_principal_graph.pdf')
            st.plot_flat_tree(adata, save_fig=flag_savefig)
            if (flag_umap):
                print('UMAP visualization based on top MLLE components...')
                st.plot_visualization_2D(adata,
                                         save_fig=flag_savefig,
                                         fig_name='umap_cells')
                st.plot_visualization_2D(adata,
                                         color_by='branch',
                                         save_fig=flag_savefig,
                                         fig_name='umap_branches')
            if (root is None):
                print('Visualization of subwaymap and stream plots...')
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                for ns in list_node_start:
                    if (flag_web):
                        st.subwaymap_plot(adata,
                                          percentile_dist=100,
                                          root=ns,
                                          save_fig=flag_savefig)
                        st.stream_plot(adata,
                                       root=ns,
                                       fig_size=(8, 8),
                                       save_fig=True,
                                       flag_log_view=flag_stream_log_view,
                                       fig_legend=False,
                                       fig_name='stream_plot.png')
                    else:
                        st.subwaymap_plot(adata,
                                          percentile_dist=100,
                                          root=ns,
                                          save_fig=flag_savefig)
                        st.stream_plot(adata,
                                       root=ns,
                                       fig_size=(8, 8),
                                       save_fig=flag_savefig,
                                       flag_log_view=flag_stream_log_view)
            else:
                st.subwaymap_plot(adata,
                                  percentile_dist=100,
                                  root=root,
                                  save_fig=flag_savefig)
                st.stream_plot(adata,
                               root=root,
                               fig_size=(8, 8),
                               save_fig=flag_savefig,
                               flag_log_view=flag_stream_log_view)
            output_cell_info(adata)
            if (flag_web):
                output_for_website(adata)
            st.write(adata)

        if (flag_gene_TG_detection):
            print('Identifying transition genes...')
            st.detect_transistion_genes(adata,
                                        cutoff_spearman=TG_spearman_cutoff,
                                        cutoff_logfc=TG_logfc_cutoff,
                                        n_jobs=n_jobs)
            if (flag_web):
                ## Plot top5 genes
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                gene_list = []
                for x in adata.uns['transition_genes'].keys():
                    gene_list = gene_list + adata.uns['transition_genes'][
                        x].index[:n_genes].tolist()
                gene_list = np.unique(gene_list)
                for ns in list_node_start:
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=ns,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')
            else:
                st.plot_transition_genes(adata, save_fig=flag_savefig)

        if (flag_gene_DE_detection):
            print('Identifying differentially expressed genes...')
            st.detect_de_genes(adata,
                               cutoff_zscore=DE_logfc_cutoff,
                               cutoff_logfc=DE_logfc_cutoff,
                               n_jobs=n_jobs)
            if (flag_web):
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                gene_list = []
                for x in adata.uns['de_genes_greater'].keys():
                    gene_list = gene_list + adata.uns['de_genes_greater'][
                        x].index[:n_genes].tolist()
                for x in adata.uns['de_genes_less'].keys():
                    gene_list = gene_list + adata.uns['de_genes_less'][
                        x].index[:n_genes].tolist()
                gene_list = np.unique(gene_list)
                for ns in list_node_start:
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=ns,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')
            else:
                st.plot_de_genes(adata, save_fig=flag_savefig)

        if (flag_gene_LG_detection):
            print('Identifying leaf genes...')
            st.detect_leaf_genes(adata,
                                 cutoff_zscore=LG_zscore_cutoff,
                                 cutoff_pvalue=LG_pvalue_cutoff,
                                 n_jobs=n_jobs)
            if (flag_web):
                ## Plot top5 genes
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                gene_list = []
                for x in adata.uns['leaf_genes'].keys():
                    gene_list = gene_list + adata.uns['leaf_genes'][
                        x].index[:n_genes].tolist()
                gene_list = np.unique(gene_list)
                for ns in list_node_start:
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=ns,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')

        if ((genes != None) and (len(gene_list) > 0)):
            print('Visualizing genes...')
            flat_tree = adata.uns['flat_tree']
            list_node_start = [
                value for key, value in nx.get_node_attributes(
                    flat_tree, 'label').items()
            ]
            if (root is None):
                for ns in list_node_start:
                    if (flag_web):
                        output_for_website_subwaymap_gene(adata, gene_list)
                        st.stream_plot_gene(adata,
                                            root=ns,
                                            fig_size=(8, 8),
                                            genes=gene_list,
                                            save_fig=True,
                                            flag_log_view=flag_stream_log_view,
                                            fig_format='png')
                    else:
                        st.subwaymap_plot_gene(adata,
                                               percentile_dist=100,
                                               root=ns,
                                               genes=gene_list,
                                               save_fig=flag_savefig)
                        st.stream_plot_gene(adata,
                                            root=ns,
                                            fig_size=(8, 8),
                                            genes=gene_list,
                                            save_fig=flag_savefig,
                                            flag_log_view=flag_stream_log_view)
            else:
                if (flag_web):
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=root,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')
                else:
                    st.subwaymap_plot_gene(adata,
                                           percentile_dist=100,
                                           root=root,
                                           genes=gene_list,
                                           save_fig=flag_savefig)
                    st.stream_plot_gene(adata,
                                        root=root,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=flag_savefig,
                                        flag_log_view=flag_stream_log_view)

    else:
        print('Starting mapping procedure...')
        if (output_folder == None):
            workdir_ref = os.path.join(os.getcwd(), 'stream_result')
        else:
            workdir_ref = output_folder
        adata = st.read(file_name='stream_result.pkl',
                        file_format='pkl',
                        file_path=workdir_ref,
                        workdir=workdir_ref)
        workdir = os.path.join(workdir_ref, os.pardir, 'mapping_result')
        adata_new = st.read(file_name=new_filename, workdir=workdir)
        st.add_cell_labels(adata_new, file_name=new_label_filename)
        st.add_cell_colors(adata_new, file_name=new_label_color_filename)
        if (s_method == 'LOESS'):
            st.map_new_data(adata, adata_new, feature='var_genes')
        if (s_method == 'all'):
            st.map_new_data(adata, adata_new, feature='all')
        if (flag_umap):
            st.plot_visualization_2D(adata,
                                     adata_new=adata_new,
                                     use_precomputed=False,
                                     save_fig=flag_savefig,
                                     fig_name='umap_new_cells')
            st.plot_visualization_2D(adata,
                                     adata_new=adata_new,
                                     show_all_colors=True,
                                     save_fig=flag_savefig,
                                     fig_name='umap_all_cells')
            st.plot_visualization_2D(adata,
                                     adata_new=adata_new,
                                     color_by='branch',
                                     save_fig=flag_savefig,
                                     fig_name='umap_branches')
        if (root is None):
            flat_tree = adata.uns['flat_tree']
            list_node_start = [
                value for key, value in nx.get_node_attributes(
                    flat_tree, 'label').items()
            ]
            for ns in list_node_start:
                st.subwaymap_plot(adata,
                                  adata_new=adata_new,
                                  percentile_dist=100,
                                  show_all_cells=False,
                                  root=ns,
                                  save_fig=flag_savefig)
                st.stream_plot(adata,
                               adata_new=adata_new,
                               show_all_colors=False,
                               root=ns,
                               fig_size=(8, 8),
                               save_fig=flag_savefig,
                               flag_log_view=flag_stream_log_view)
        else:
            st.subwaymap_plot(adata,
                              adata_new=adata_new,
                              percentile_dist=100,
                              show_all_cells=False,
                              root=root,
                              save_fig=flag_savefig)
            st.stream_plot(adata,
                           adata_new=adata_new,
                           show_all_colors=False,
                           root=root,
                           fig_size=(8, 8),
                           save_fig=flag_savefig,
                           flag_log_view=flag_stream_log_view)
        if ((genes != None) and (len(gene_list) > 0)):
            if (root is None):
                for ns in list_node_start:
                    st.subwaymap_plot_gene(adata,
                                           adata_new=adata_new,
                                           percentile_dist=100,
                                           root=ns,
                                           save_fig=flag_savefig,
                                           flag_log_view=flag_stream_log_view)
            else:
                st.subwaymap_plot_gene(adata,
                                       adata_new=adata_new,
                                       percentile_dist=100,
                                       root=root,
                                       save_fig=flag_savefig,
                                       flag_log_view=flag_stream_log_view)
        st.write(adata_new, file_name='stream_mapping_result.pkl')
    print('Finished computation.')
コード例 #3
0
# Infer trajectories

# read in parameters
definition = open('./definition.yml', 'r')
task = yaml.safe_load(definition)
p = dict()
for x in task["parameters"]:
    p[x['id']] = x['default']

pd.DataFrame(counts.toarray(), index=cell_ids,
             columns=gene_ids).T.to_csv(output_folder + "counts.tsv", sep='\t')

checkpoints["method_afterpreproc"] = time.time()

adata = st.read(file_name=output_folder + "counts.tsv")
st.add_cell_labels(adata)
st.add_cell_colors(adata)

if (p["norm"]):
    st.normalize_per_cell(adata)
if (p["log2"]):
    st.log_transform(adata)

st.filter_genes(adata,
                min_num_cells=max(5, int(round(adata.shape[0] * 0.001))),
                min_pct_cells=None,
                expr_cutoff=1)
if (adata.shape[1] < 1000):
    adata.uns['var_genes'] = gene_ids
    adata.obsm['var_genes'] = adata.X
else:
コード例 #4
0
### Here we perform Pseudotime analysis with STREAM v0.36 [https://doi.org/10.1038/s41467-019-09670-4] [https://github.com/pinellolab/STREAM]
### Download counts matrix here:

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.ioff()
import stream as st
import os.path
import pickle

#### Read Counts table
adata = st.read(file_name='./counts.tsv', workdir='./')
# Read Cell labels table
st.add_cell_labels(adata, file_name='./cell_label.tsv')
# Add random colors to each sample
st.add_cell_colors(adata, file_name='./cell_color.tsv')

### CHECK FOR VARIABLE GENES
# Check if the blue (loess) curve fits the points well
st.select_variable_genes(adata)
# Open plot file
plt.savefig('loess.png')
# Close Plot
plt.close('loess.png')

# Adjust the blue curve to fits better
st.select_variable_genes(adata, loess_frac=0.01)
plt.savefig('adjust_loess.png')
plt.close('adjust_loess.png')
コード例 #5
0
def main():
    sns.set_style('white')
    sns.set_context('poster')
    parser = argparse.ArgumentParser(
        description='%s Parameters' % __tool_name__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-m",
                        "--matrix",
                        dest="input_filename",
                        default=None,
                        help="input file name",
                        metavar="FILE")
    parser.add_argument("-l",
                        "--cell_labels",
                        dest="cell_label_filename",
                        default=None,
                        help="filename of cell labels")
    parser.add_argument("-c",
                        "--cell_labels_colors",
                        dest="cell_label_color_filename",
                        default=None,
                        help="filename of cell label colors")
    parser.add_argument("--log2",
                        dest="flag_log2",
                        action="store_true",
                        help="perform log2 transformation")
    parser.add_argument("--norm",
                        dest="flag_norm",
                        action="store_true",
                        help="normalize data based on library size")
    parser.add_argument("-o",
                        "--output_folder",
                        dest="output_folder",
                        default=None,
                        help="Output folder")
    parser.add_argument("-rmt",
                        "--remove_mt_genes",
                        dest="flag_remove_mt_genes",
                        action="store_true",
                        default=False,
                        help="Remove Mitochondrial genes")
    parser.add_argument("-mcg",
                        "--min_count_genes",
                        dest="min_count_genes",
                        type=int,
                        default=None,
                        help="filter cells with less than this many genes")
    parser.add_argument("-mpg",
                        "--min_percent_genes",
                        dest="min_percent_genes",
                        type=float,
                        default=None,
                        help="The minimum percent genes")
    parser.add_argument("-mpc",
                        "--min_percent_cells",
                        dest="min_percent_cells",
                        type=float,
                        default=None,
                        help="The minimum percent cells")
    parser.add_argument("-mcc",
                        "--min_count_cells",
                        dest="min_count_cells",
                        type=int,
                        default=None,
                        help="The minimum count cells")
    parser.add_argument("-mnc",
                        "--min_num_cells",
                        dest="min_num_cells",
                        type=int,
                        default=None,
                        help="The minimum number of cells")
    parser.add_argument("-ec",
                        "--expression_cutoff",
                        dest="expression_cutoff",
                        type=float,
                        default=None,
                        help="The expression cutoff")
    parser.add_argument("-of",
                        "--of",
                        dest="output_filename_prefix",
                        default="StreamOutput",
                        help="output file name prefix")

    args = parser.parse_args()

    print(args)

    input_filename = args.input_filename
    cell_label_filename = args.cell_label_filename
    cell_label_color_filename = args.cell_label_color_filename
    flag_norm = args.flag_norm
    flag_log2 = args.flag_log2
    output_folder = args.output_folder  #work directory
    flag_remove_mt_genes = args.flag_remove_mt_genes
    min_count_genes = args.min_count_genes
    min_percent_cells = args.min_percent_cells
    min_percent_genes = args.min_percent_genes

    min_count_cells = args.min_count_cells
    min_num_cells = args.min_num_cells
    expression_cutoff = args.expression_cutoff
    output_filename_prefix = args.output_filename_prefix

    print('Starting mapping procedure...')
    if (output_folder == None):
        workdir_ref = os.path.join(os.getcwd(), 'stream_result')
    else:
        workdir_ref = output_folder
    workdir = "./"

    if (input_filename.endswith('pkl')):
        adata = st.read(file_name=input_filename,
                        file_format='pkl',
                        workdir=workdir)
    else:
        adata = st.read(file_name=input_filename, workdir=workdir)
    print('Input: ' + str(adata.obs.shape[0]) + ' cells, ' +
          str(adata.var.shape[0]) + ' genes')

    adata.var_names_make_unique()
    adata.obs_names_make_unique()
    if (cell_label_filename != None):
        st.add_cell_labels(adata, file_name=cell_label_filename)
    else:
        st.add_cell_labels(adata)
    if (cell_label_color_filename != None):
        st.add_cell_colors(adata, file_name=cell_label_color_filename)
    else:
        st.add_cell_colors(adata)

    if (flag_norm):
        st.normalize_per_cell(adata)

    if (flag_log2):
        st.log_transform(adata, base=2)

    if (flag_remove_mt_genes):
        st.remove_mt_genes(adata)

    st.filter_cells(adata,
                    min_pct_genes=min_percent_genes,
                    min_count=min_count_genes,
                    expr_cutoff=expression_cutoff)
    st.filter_genes(adata,
                    min_num_cells=min_num_cells,
                    min_pct_cells=min_percent_cells,
                    min_count=min_count_genes,
                    expr_cutoff=expression_cutoff)

    print("Writing " + output_filename_prefix + " " +
          args.output_filename_prefix + "_stream_result.pkl")
    st.write(adata,
             file_name=(output_filename_prefix + '_stream_result.pkl'),
             file_path='./',
             file_format='pkl')
    print('Output: ' + str(adata.obs.shape[0]) + ' cells, ' +
          str(adata.var.shape[0]) + ' genes')

    print('Finished computation.')
コード例 #6
0
def stream_test_Nestorowa_2016():

    workdir = os.path.join(_root, 'datasets/Nestorowa_2016/')

    temp_folder = tempfile.gettempdir()

    tar = tarfile.open(workdir + 'output/stream_result.tar.gz')
    tar.extractall(path=temp_folder)
    tar.close()
    ref_temp_folder = os.path.join(temp_folder, 'stream_result')

    print(workdir + 'data_Nestorowa.tsv.gz')
    input_file = os.path.join(workdir, 'data_Nestorowa.tsv.gz')
    label_file = os.path.join(workdir, 'cell_label.tsv.gz')
    label_color_file = os.path.join(workdir, 'cell_label_color.tsv.gz')
    comp_temp_folder = os.path.join(temp_folder, 'stream_result_comp')

    try:
        st.set_figure_params(dpi=80,
                             style='white',
                             figsize=[5.4, 4.8],
                             rc={'image.cmap': 'viridis'})
        adata = st.read(file_name=input_file, workdir=comp_temp_folder)
        adata.var_names_make_unique()
        adata.obs_names_make_unique()
        st.add_cell_labels(adata, file_name=label_file)
        st.add_cell_colors(adata, file_name=label_color_file)
        st.cal_qc(adata, assay='rna')
        st.filter_features(adata, min_n_cells=5)
        st.select_variable_genes(adata, n_genes=2000, save_fig=True)
        st.select_top_principal_components(adata,
                                           feature='var_genes',
                                           first_pc=True,
                                           n_pc=30,
                                           save_fig=True)
        st.dimension_reduction(adata,
                               method='se',
                               feature='top_pcs',
                               n_neighbors=100,
                               n_components=4,
                               n_jobs=2)
        st.plot_dimension_reduction(adata,
                                    color=['label', 'Gata1', 'n_genes'],
                                    n_components=3,
                                    show_graph=False,
                                    show_text=False,
                                    save_fig=True,
                                    fig_name='dimension_reduction.pdf')
        st.plot_visualization_2D(adata,
                                 method='umap',
                                 n_neighbors=100,
                                 color=['label', 'Gata1', 'n_genes'],
                                 use_precomputed=False,
                                 save_fig=True,
                                 fig_name='visualization_2D.pdf')
        st.seed_elastic_principal_graph(adata, n_clusters=20)
        st.plot_dimension_reduction(adata,
                                    color=['label', 'Gata1', 'n_genes'],
                                    n_components=2,
                                    show_graph=True,
                                    show_text=False,
                                    save_fig=True,
                                    fig_name='dr_seed.pdf')
        st.plot_branches(adata,
                         show_text=True,
                         save_fig=True,
                         fig_name='branches_seed.pdf')
        st.elastic_principal_graph(adata,
                                   epg_alpha=0.01,
                                   epg_mu=0.05,
                                   epg_lambda=0.01)
        st.plot_dimension_reduction(adata,
                                    color=['label', 'Gata1', 'n_genes'],
                                    n_components=2,
                                    show_graph=True,
                                    show_text=False,
                                    save_fig=True,
                                    fig_name='dr_epg.pdf')
        st.plot_branches(adata,
                         show_text=True,
                         save_fig=True,
                         fig_name='branches_epg.pdf')
        ###Extend leaf branch to reach further cells
        st.extend_elastic_principal_graph(adata,
                                          epg_ext_mode='QuantDists',
                                          epg_ext_par=0.8)
        st.plot_dimension_reduction(adata,
                                    color=['label'],
                                    n_components=2,
                                    show_graph=True,
                                    show_text=True,
                                    save_fig=True,
                                    fig_name='dr_extend.pdf')
        st.plot_branches(adata,
                         show_text=True,
                         save_fig=True,
                         fig_name='branches_extend.pdf')
        st.plot_visualization_2D(
            adata,
            method='umap',
            n_neighbors=100,
            color=['label', 'branch_id_alias', 'S4_pseudotime'],
            use_precomputed=False,
            save_fig=True,
            fig_name='visualization_2D_2.pdf')
        st.plot_flat_tree(adata,
                          color=['label', 'branch_id_alias', 'S4_pseudotime'],
                          dist_scale=0.5,
                          show_graph=True,
                          show_text=True,
                          save_fig=True)
        st.plot_stream_sc(adata,
                          root='S4',
                          color=['label', 'Gata1'],
                          dist_scale=0.5,
                          show_graph=True,
                          show_text=False,
                          save_fig=True)
        st.plot_stream(adata,
                       root='S4',
                       color=['label', 'Gata1'],
                       save_fig=True)
        st.detect_leaf_markers(adata,
                               marker_list=adata.uns['var_genes'][:300],
                               root='S4',
                               n_jobs=4)
        st.detect_transition_markers(adata,
                                     root='S4',
                                     marker_list=adata.uns['var_genes'][:300],
                                     n_jobs=4)
        st.detect_de_markers(adata,
                             marker_list=adata.uns['var_genes'][:300],
                             root='S4',
                             n_jobs=4)
        # st.write(adata,file_name='stream_result.pkl')
    except:
        print("STREAM analysis failed!")
        raise
    else:
        print("STREAM analysis finished!")

    print(ref_temp_folder)
    print(comp_temp_folder)

    pathlist = Path(ref_temp_folder)
    for path in pathlist.glob('**/*'):
        if path.is_file() and (not path.name.startswith('.')):
            file = os.path.relpath(str(path), ref_temp_folder)
            print(file)
            if (file.endswith('pdf')):
                if (os.path.getsize(os.path.join(comp_temp_folder, file)) > 0):
                    print('The file %s passed' % file)
                else:
                    raise Exception('Error! The file %s is not matched' % file)
            else:
                checklist = list()
                df_ref = pd.read_csv(os.path.join(ref_temp_folder, file),
                                     sep='\t')
                # print(df_ref.shape)
                # print(df_ref.head())
                df_comp = pd.read_csv(os.path.join(comp_temp_folder, file),
                                      sep='\t')
                # print(df_comp.shape)
                # print(df_comp.head())
                for c in df_ref.columns:
                    # print(c)
                    if (is_numeric_dtype(df_ref[c])):
                        checklist.append(all(np.isclose(df_ref[c],
                                                        df_comp[c])))
                    else:
                        checklist.append(all(df_ref[c] == df_comp[c]))
                if (all(checklist)):
                    print('The file %s passed' % file)
                else:
                    raise Exception('Error! The file %s is not matched' % file)

    print('Successful!')

    rmtree(comp_temp_folder, ignore_errors=True)
    rmtree(ref_temp_folder, ignore_errors=True)