Beispiel #1
0
def update_dr(colors, alpha):
    fig1 = st.plot_dimension_reduction(adata,
                                       color=[colors],
                                       n_components=3,
                                       alpha=alpha,
                                       show_graph=True,
                                       show_text=False,
                                       plotly=True,
                                       return_fig=True)
    fig2 = st.plot_visualization_2D(adata,
                                    method='umap',
                                    n_neighbors=50,
                                    alpha=alpha,
                                    color=[colors],
                                    use_precomputed=True,
                                    plotly=True,
                                    return_fig=True)
    fig3 = st.plot_flat_tree(adata,
                             color=[colors],
                             alpha=alpha,
                             dist_scale=0.5,
                             show_graph=True,
                             show_text=True,
                             plotly=True,
                             return_fig=True)
    fig4 = st.plot_branches(adata,
                            show_text=True,
                            plotly=True,
                            return_fig=True)
    return html.Div([
        dbc.Row([dcc.Graph(figure=fig1),
                 dcc.Graph(figure=fig2)]),
        dbc.Row([dcc.Graph(figure=fig3),
                 dcc.Graph(figure=fig4)])
    ])
def main():
    sns.set_style('white')
    sns.set_context('poster')
    parser = argparse.ArgumentParser(
        description='%s Parameters' % __tool_name__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-m",
                        "--matrix",
                        dest="input_filename",
                        default=None,
                        help="input file name",
                        metavar="FILE")
    parser.add_argument("-l",
                        "--cell_labels",
                        dest="cell_label_filename",
                        default=None,
                        help="filename of cell labels")
    parser.add_argument("-c",
                        "--cell_labels_colors",
                        dest="cell_label_color_filename",
                        default=None,
                        help="filename of cell label colors")
    parser.add_argument(
        "-s",
        "--select_features",
        dest="s_method",
        default='LOESS',
        help=
        "LOESS,PCA or all: Select variable genes using LOESS or principal components using PCA or all the genes are kept"
    )
    parser.add_argument("--TG",
                        "--detect_TG_genes",
                        dest="flag_gene_TG_detection",
                        action="store_true",
                        help="detect transition genes automatically")
    parser.add_argument("--DE",
                        "--detect_DE_genes",
                        dest="flag_gene_DE_detection",
                        action="store_true",
                        help="detect DE genes automatically")
    parser.add_argument("--LG",
                        "--detect_LG_genes",
                        dest="flag_gene_LG_detection",
                        action="store_true",
                        help="detect leaf genes automatically")
    parser.add_argument(
        "-g",
        "--genes",
        dest="genes",
        default=None,
        help=
        "genes to visualize, it can either be filename which contains all the genes in one column or a set of gene names separated by comma"
    )
    parser.add_argument(
        "-p",
        "--use_precomputed",
        dest="use_precomputed",
        action="store_true",
        help=
        "use precomputed data files without re-computing structure learning part"
    )
    parser.add_argument("--new",
                        dest="new_filename",
                        default=None,
                        help="file name of data to be mapped")
    parser.add_argument("--new_l",
                        dest="new_label_filename",
                        default=None,
                        help="filename of new cell labels")
    parser.add_argument("--new_c",
                        dest="new_label_color_filename",
                        default=None,
                        help="filename of new cell label colors")
    parser.add_argument("--log2",
                        dest="flag_log2",
                        action="store_true",
                        help="perform log2 transformation")
    parser.add_argument("--norm",
                        dest="flag_norm",
                        action="store_true",
                        help="normalize data based on library size")
    parser.add_argument("--atac",
                        dest="flag_atac",
                        action="store_true",
                        help="indicate scATAC-seq data")
    parser.add_argument(
        "--n_jobs",
        dest="n_jobs",
        type=int,
        default=1,
        help="Specify the number of processes to use. (default, 1")
    parser.add_argument(
        "--loess_frac",
        dest="loess_frac",
        type=float,
        default=0.1,
        help="The fraction of the data used in LOESS regression")
    parser.add_argument(
        "--loess_cutoff",
        dest="loess_cutoff",
        type=int,
        default=95,
        help=
        "the percentile used in variable gene selection based on LOESS regression"
    )
    parser.add_argument("--pca_first_PC",
                        dest="flag_first_PC",
                        action="store_true",
                        help="keep first PC")
    parser.add_argument("--pca_n_PC",
                        dest="pca_n_PC",
                        type=int,
                        default=15,
                        help="The number of selected PCs,it's 15 by default")
    parser.add_argument(
        "--dr_method",
        dest="dr_method",
        default='se',
        help=
        "Method used for dimension reduction. Choose from {{'se','mlle','umap','pca'}}"
    )
    parser.add_argument("--n_neighbors",
                        dest="n_neighbors",
                        type=float,
                        default=50,
                        help="The number of neighbor cells")
    parser.add_argument(
        "--nb_pct",
        dest="nb_pct",
        type=float,
        default=None,
        help=
        "The percentage of neighbor cells (when sepcified, it will overwrite n_neighbors)."
    )
    parser.add_argument("--n_components",
                        dest="n_components",
                        type=int,
                        default=3,
                        help="Number of components to keep.")
    parser.add_argument(
        "--clustering",
        dest="clustering",
        default='kmeans',
        help=
        "Clustering method used for seeding the intial structure, choose from 'ap','kmeans','sc'"
    )
    parser.add_argument("--damping",
                        dest="damping",
                        type=float,
                        default=0.75,
                        help="Affinity Propagation: damping factor")
    parser.add_argument(
        "--n_clusters",
        dest="n_clusters",
        type=int,
        default=10,
        help="Number of clusters for spectral clustering or kmeans")
    parser.add_argument("--EPG_n_nodes",
                        dest="EPG_n_nodes",
                        type=int,
                        default=50,
                        help=" Number of nodes for elastic principal graph")
    parser.add_argument(
        "--EPG_lambda",
        dest="EPG_lambda",
        type=float,
        default=0.02,
        help="lambda parameter used to compute the elastic energy")
    parser.add_argument("--EPG_mu",
                        dest="EPG_mu",
                        type=float,
                        default=0.1,
                        help="mu parameter used to compute the elastic energy")
    parser.add_argument(
        "--EPG_trimmingradius",
        dest="EPG_trimmingradius",
        type=float,
        default=np.inf,
        help="maximal distance of point from a node to affect its embedment")
    parser.add_argument(
        "--EPG_alpha",
        dest="EPG_alpha",
        type=float,
        default=0.02,
        help=
        "positive numeric, the value of the alpha parameter of the penalized elastic energy"
    )
    parser.add_argument("--EPG_collapse",
                        dest="flag_EPG_collapse",
                        action="store_true",
                        help="collapsing small branches")
    parser.add_argument(
        "--EPG_collapse_mode",
        dest="EPG_collapse_mode",
        default="PointNumber",
        help=
        "the mode used to collapse branches. PointNumber,PointNumber_Extrema, PointNumber_Leaves,EdgesNumber or EdgesLength"
    )
    parser.add_argument(
        "--EPG_collapse_par",
        dest="EPG_collapse_par",
        type=float,
        default=5,
        help=
        "positive numeric, the cotrol paramter used for collapsing small branches"
    )
    parser.add_argument("--disable_EPG_optimize",
                        dest="flag_disable_EPG_optimize",
                        action="store_true",
                        help="disable optimizing branching")
    parser.add_argument("--EPG_shift",
                        dest="flag_EPG_shift",
                        action="store_true",
                        help="shift branching point ")
    parser.add_argument(
        "--EPG_shift_mode",
        dest="EPG_shift_mode",
        default='NodeDensity',
        help=
        "the mode to use to shift the branching points NodePoints or NodeDensity"
    )
    parser.add_argument(
        "--EPG_shift_DR",
        dest="EPG_shift_DR",
        type=float,
        default=0.05,
        help=
        "positive numeric, the radius to be used when computing point density if EPG_shift_mode is NodeDensity"
    )
    parser.add_argument(
        "--EPG_shift_maxshift",
        dest="EPG_shift_maxshift",
        type=int,
        default=5,
        help=
        "positive integer, the maxium distance (as number of edges) to consider when exploring the branching point neighborhood"
    )
    parser.add_argument("--disable_EPG_ext",
                        dest="flag_disable_EPG_ext",
                        action="store_true",
                        help="disable extending leaves with additional nodes")
    parser.add_argument(
        "--EPG_ext_mode",
        dest="EPG_ext_mode",
        default='QuantDists',
        help=
        " the mode used to extend the graph,QuantDists, QuantCentroid or WeigthedCentroid"
    )
    parser.add_argument(
        "--EPG_ext_par",
        dest="EPG_ext_par",
        type=float,
        default=0.5,
        help=
        "the control parameter used for contribution of the different data points when extending leaves with nodes"
    )
    parser.add_argument("--DE_zscore_cutoff",
                        dest="DE_zscore_cutoff",
                        default=2,
                        help="Differentially Expressed Genes z-score cutoff")
    parser.add_argument(
        "--DE_logfc_cutoff",
        dest="DE_logfc_cutoff",
        default=0.25,
        help="Differentially Expressed Genes log fold change cutoff")
    parser.add_argument("--TG_spearman_cutoff",
                        dest="TG_spearman_cutoff",
                        default=0.4,
                        help="Transition Genes Spearman correlation cutoff")
    parser.add_argument("--TG_logfc_cutoff",
                        dest="TG_logfc_cutoff",
                        default=0.25,
                        help="Transition Genes log fold change cutoff")
    parser.add_argument("--LG_zscore_cutoff",
                        dest="LG_zscore_cutoff",
                        default=1.5,
                        help="Leaf Genes z-score cutoff")
    parser.add_argument("--LG_pvalue_cutoff",
                        dest="LG_pvalue_cutoff",
                        default=1e-2,
                        help="Leaf Genes p value cutoff")
    parser.add_argument(
        "--umap",
        dest="flag_umap",
        action="store_true",
        help="whether to use UMAP for visualization (default: No)")
    parser.add_argument("-r",
                        dest="root",
                        default=None,
                        help="root node for subwaymap_plot and stream_plot")
    parser.add_argument("--stream_log_view",
                        dest="flag_stream_log_view",
                        action="store_true",
                        help="use log2 scale for y axis of stream_plot")
    parser.add_argument("-o",
                        "--output_folder",
                        dest="output_folder",
                        default=None,
                        help="Output folder")
    parser.add_argument("--for_web",
                        dest="flag_web",
                        action="store_true",
                        help="Output files for website")
    parser.add_argument(
        "--n_genes",
        dest="n_genes",
        type=int,
        default=5,
        help=
        "Number of top genes selected from each output marker gene file for website gene visualization"
    )

    args = parser.parse_args()
    if (args.input_filename is None) and (args.new_filename is None):
        parser.error("at least one of -m, --new required")

    new_filename = args.new_filename
    new_label_filename = args.new_label_filename
    new_label_color_filename = args.new_label_color_filename
    flag_stream_log_view = args.flag_stream_log_view
    flag_gene_TG_detection = args.flag_gene_TG_detection
    flag_gene_DE_detection = args.flag_gene_DE_detection
    flag_gene_LG_detection = args.flag_gene_LG_detection
    flag_web = args.flag_web
    flag_first_PC = args.flag_first_PC
    flag_umap = args.flag_umap
    genes = args.genes
    DE_zscore_cutoff = args.DE_zscore_cutoff
    DE_logfc_cutoff = args.DE_logfc_cutoff
    TG_spearman_cutoff = args.TG_spearman_cutoff
    TG_logfc_cutoff = args.TG_logfc_cutoff
    LG_zscore_cutoff = args.LG_zscore_cutoff
    LG_pvalue_cutoff = args.LG_pvalue_cutoff
    root = args.root

    input_filename = args.input_filename
    cell_label_filename = args.cell_label_filename
    cell_label_color_filename = args.cell_label_color_filename
    s_method = args.s_method
    use_precomputed = args.use_precomputed
    n_jobs = args.n_jobs
    loess_frac = args.loess_frac
    loess_cutoff = args.loess_cutoff
    pca_n_PC = args.pca_n_PC
    flag_log2 = args.flag_log2
    flag_norm = args.flag_norm
    flag_atac = args.flag_atac
    dr_method = args.dr_method
    nb_pct = args.nb_pct  # neighbour percent
    n_neighbors = args.n_neighbors
    n_components = args.n_components  #number of components to keep
    clustering = args.clustering
    damping = args.damping
    n_clusters = args.n_clusters
    EPG_n_nodes = args.EPG_n_nodes
    EPG_lambda = args.EPG_lambda
    EPG_mu = args.EPG_mu
    EPG_trimmingradius = args.EPG_trimmingradius
    EPG_alpha = args.EPG_alpha
    flag_EPG_collapse = args.flag_EPG_collapse
    EPG_collapse_mode = args.EPG_collapse_mode
    EPG_collapse_par = args.EPG_collapse_par
    flag_EPG_shift = args.flag_EPG_shift
    EPG_shift_mode = args.EPG_shift_mode
    EPG_shift_DR = args.EPG_shift_DR
    EPG_shift_maxshift = args.EPG_shift_maxshift
    flag_disable_EPG_optimize = args.flag_disable_EPG_optimize
    flag_disable_EPG_ext = args.flag_disable_EPG_ext
    EPG_ext_mode = args.EPG_ext_mode
    EPG_ext_par = args.EPG_ext_par
    output_folder = args.output_folder  #work directory
    n_genes = args.n_genes

    if (flag_web):
        flag_savefig = False
    else:
        flag_savefig = True
    gene_list = []
    if (genes != None):
        if (os.path.exists(genes)):
            gene_list = pd.read_csv(genes,
                                    sep='\t',
                                    header=None,
                                    index_col=None,
                                    compression='gzip' if genes.split('.')[-1]
                                    == 'gz' else None).iloc[:, 0].tolist()
            gene_list = list(set(gene_list))
        else:
            gene_list = genes.split(',')
        print('Genes to visualize: ')
        print(gene_list)
    if (new_filename is None):
        if (output_folder == None):
            workdir = os.path.join(os.getcwd(), 'stream_result')
        else:
            workdir = output_folder
        if (use_precomputed):
            print('Importing the precomputed pkl file...')
            adata = st.read(file_name='stream_result.pkl',
                            file_format='pkl',
                            file_path=workdir,
                            workdir=workdir)
        else:
            if (flag_atac):
                print('Reading in atac zscore matrix...')
                adata = st.read(file_name=input_filename,
                                workdir=workdir,
                                experiment='atac-seq')
            else:
                adata = st.read(file_name=input_filename, workdir=workdir)
                print('Input: ' + str(adata.obs.shape[0]) + ' cells, ' +
                      str(adata.var.shape[0]) + ' genes')
            adata.var_names_make_unique()
            adata.obs_names_make_unique()
            if (cell_label_filename != None):
                st.add_cell_labels(adata, file_name=cell_label_filename)
            else:
                st.add_cell_labels(adata)
            if (cell_label_color_filename != None):
                st.add_cell_colors(adata, file_name=cell_label_color_filename)
            else:
                st.add_cell_colors(adata)
            if (flag_atac):
                print('Selecting top principal components...')
                st.select_top_principal_components(adata,
                                                   n_pc=pca_n_PC,
                                                   first_pc=flag_first_PC,
                                                   save_fig=True)
                st.dimension_reduction(adata,
                                       method=dr_method,
                                       n_components=n_components,
                                       n_neighbors=n_neighbors,
                                       nb_pct=nb_pct,
                                       n_jobs=n_jobs,
                                       feature='top_pcs')
            else:
                if (flag_norm):
                    st.normalize_per_cell(adata)
                if (flag_log2):
                    st.log_transform(adata)
                if (s_method != 'all'):
                    print('Filtering genes...')
                    st.filter_genes(adata, min_num_cells=5)
                    print('Removing mitochondrial genes...')
                    st.remove_mt_genes(adata)
                    if (s_method == 'LOESS'):
                        print('Selecting most variable genes...')
                        st.select_variable_genes(adata,
                                                 loess_frac=loess_frac,
                                                 percentile=loess_cutoff,
                                                 save_fig=True)
                        pd.DataFrame(adata.uns['var_genes']).to_csv(
                            os.path.join(workdir,
                                         'selected_variable_genes.tsv'),
                            sep='\t',
                            index=None,
                            header=False)
                        st.dimension_reduction(adata,
                                               method=dr_method,
                                               n_components=n_components,
                                               n_neighbors=n_neighbors,
                                               nb_pct=nb_pct,
                                               n_jobs=n_jobs,
                                               feature='var_genes')
                    if (s_method == 'PCA'):
                        print('Selecting top principal components...')
                        st.select_top_principal_components(
                            adata,
                            n_pc=pca_n_PC,
                            first_pc=flag_first_PC,
                            save_fig=True)
                        st.dimension_reduction(adata,
                                               method=dr_method,
                                               n_components=n_components,
                                               n_neighbors=n_neighbors,
                                               nb_pct=nb_pct,
                                               n_jobs=n_jobs,
                                               feature='top_pcs')
                else:
                    print('Keep all the genes...')
                    st.dimension_reduction(adata,
                                           n_components=n_components,
                                           n_neighbors=n_neighbors,
                                           nb_pct=nb_pct,
                                           n_jobs=n_jobs,
                                           feature='all')
            st.plot_dimension_reduction(adata, save_fig=flag_savefig)
            st.seed_elastic_principal_graph(adata,
                                            clustering=clustering,
                                            damping=damping,
                                            n_clusters=n_clusters)
            st.plot_branches(
                adata,
                save_fig=flag_savefig,
                fig_name='seed_elastic_principal_graph_skeleton.pdf')
            st.plot_branches_with_cells(
                adata,
                save_fig=flag_savefig,
                fig_name='seed_elastic_principal_graph.pdf')

            st.elastic_principal_graph(adata,
                                       epg_n_nodes=EPG_n_nodes,
                                       epg_lambda=EPG_lambda,
                                       epg_mu=EPG_mu,
                                       epg_trimmingradius=EPG_trimmingradius,
                                       epg_alpha=EPG_alpha)
            st.plot_branches(adata,
                             save_fig=flag_savefig,
                             fig_name='elastic_principal_graph_skeleton.pdf')
            st.plot_branches_with_cells(adata,
                                        save_fig=flag_savefig,
                                        fig_name='elastic_principal_graph.pdf')
            if (not flag_disable_EPG_optimize):
                st.optimize_branching(adata,
                                      epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='optimizing_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='optimizing_elastic_principal_graph.pdf')
            if (flag_EPG_shift):
                st.shift_branching(adata,
                                   epg_shift_mode=EPG_shift_mode,
                                   epg_shift_radius=EPG_shift_DR,
                                   epg_shift_max=EPG_shift_maxshift,
                                   epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='shifting_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='shifting_elastic_principal_graph.pdf')
            if (flag_EPG_collapse):
                st.prune_elastic_principal_graph(
                    adata,
                    epg_collapse_mode=EPG_collapse_mode,
                    epg_collapse_par=EPG_collapse_par,
                    epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='pruning_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='pruning_elastic_principal_graph.pdf')
            if (not flag_disable_EPG_ext):
                st.extend_elastic_principal_graph(
                    adata,
                    epg_ext_mode=EPG_ext_mode,
                    epg_ext_par=EPG_ext_par,
                    epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='extending_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='extending_elastic_principal_graph.pdf')
            st.plot_branches(
                adata,
                save_fig=flag_savefig,
                fig_name='finalized_elastic_principal_graph_skeleton.pdf')
            st.plot_branches_with_cells(
                adata,
                save_fig=flag_savefig,
                fig_name='finalized_elastic_principal_graph.pdf')
            st.plot_flat_tree(adata, save_fig=flag_savefig)
            if (flag_umap):
                print('UMAP visualization based on top MLLE components...')
                st.plot_visualization_2D(adata,
                                         save_fig=flag_savefig,
                                         fig_name='umap_cells')
                st.plot_visualization_2D(adata,
                                         color_by='branch',
                                         save_fig=flag_savefig,
                                         fig_name='umap_branches')
            if (root is None):
                print('Visualization of subwaymap and stream plots...')
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                for ns in list_node_start:
                    if (flag_web):
                        st.subwaymap_plot(adata,
                                          percentile_dist=100,
                                          root=ns,
                                          save_fig=flag_savefig)
                        st.stream_plot(adata,
                                       root=ns,
                                       fig_size=(8, 8),
                                       save_fig=True,
                                       flag_log_view=flag_stream_log_view,
                                       fig_legend=False,
                                       fig_name='stream_plot.png')
                    else:
                        st.subwaymap_plot(adata,
                                          percentile_dist=100,
                                          root=ns,
                                          save_fig=flag_savefig)
                        st.stream_plot(adata,
                                       root=ns,
                                       fig_size=(8, 8),
                                       save_fig=flag_savefig,
                                       flag_log_view=flag_stream_log_view)
            else:
                st.subwaymap_plot(adata,
                                  percentile_dist=100,
                                  root=root,
                                  save_fig=flag_savefig)
                st.stream_plot(adata,
                               root=root,
                               fig_size=(8, 8),
                               save_fig=flag_savefig,
                               flag_log_view=flag_stream_log_view)
            output_cell_info(adata)
            if (flag_web):
                output_for_website(adata)
            st.write(adata)

        if (flag_gene_TG_detection):
            print('Identifying transition genes...')
            st.detect_transistion_genes(adata,
                                        cutoff_spearman=TG_spearman_cutoff,
                                        cutoff_logfc=TG_logfc_cutoff,
                                        n_jobs=n_jobs)
            if (flag_web):
                ## Plot top5 genes
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                gene_list = []
                for x in adata.uns['transition_genes'].keys():
                    gene_list = gene_list + adata.uns['transition_genes'][
                        x].index[:n_genes].tolist()
                gene_list = np.unique(gene_list)
                for ns in list_node_start:
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=ns,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')
            else:
                st.plot_transition_genes(adata, save_fig=flag_savefig)

        if (flag_gene_DE_detection):
            print('Identifying differentially expressed genes...')
            st.detect_de_genes(adata,
                               cutoff_zscore=DE_logfc_cutoff,
                               cutoff_logfc=DE_logfc_cutoff,
                               n_jobs=n_jobs)
            if (flag_web):
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                gene_list = []
                for x in adata.uns['de_genes_greater'].keys():
                    gene_list = gene_list + adata.uns['de_genes_greater'][
                        x].index[:n_genes].tolist()
                for x in adata.uns['de_genes_less'].keys():
                    gene_list = gene_list + adata.uns['de_genes_less'][
                        x].index[:n_genes].tolist()
                gene_list = np.unique(gene_list)
                for ns in list_node_start:
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=ns,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')
            else:
                st.plot_de_genes(adata, save_fig=flag_savefig)

        if (flag_gene_LG_detection):
            print('Identifying leaf genes...')
            st.detect_leaf_genes(adata,
                                 cutoff_zscore=LG_zscore_cutoff,
                                 cutoff_pvalue=LG_pvalue_cutoff,
                                 n_jobs=n_jobs)
            if (flag_web):
                ## Plot top5 genes
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                gene_list = []
                for x in adata.uns['leaf_genes'].keys():
                    gene_list = gene_list + adata.uns['leaf_genes'][
                        x].index[:n_genes].tolist()
                gene_list = np.unique(gene_list)
                for ns in list_node_start:
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=ns,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')

        if ((genes != None) and (len(gene_list) > 0)):
            print('Visualizing genes...')
            flat_tree = adata.uns['flat_tree']
            list_node_start = [
                value for key, value in nx.get_node_attributes(
                    flat_tree, 'label').items()
            ]
            if (root is None):
                for ns in list_node_start:
                    if (flag_web):
                        output_for_website_subwaymap_gene(adata, gene_list)
                        st.stream_plot_gene(adata,
                                            root=ns,
                                            fig_size=(8, 8),
                                            genes=gene_list,
                                            save_fig=True,
                                            flag_log_view=flag_stream_log_view,
                                            fig_format='png')
                    else:
                        st.subwaymap_plot_gene(adata,
                                               percentile_dist=100,
                                               root=ns,
                                               genes=gene_list,
                                               save_fig=flag_savefig)
                        st.stream_plot_gene(adata,
                                            root=ns,
                                            fig_size=(8, 8),
                                            genes=gene_list,
                                            save_fig=flag_savefig,
                                            flag_log_view=flag_stream_log_view)
            else:
                if (flag_web):
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=root,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')
                else:
                    st.subwaymap_plot_gene(adata,
                                           percentile_dist=100,
                                           root=root,
                                           genes=gene_list,
                                           save_fig=flag_savefig)
                    st.stream_plot_gene(adata,
                                        root=root,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=flag_savefig,
                                        flag_log_view=flag_stream_log_view)

    else:
        print('Starting mapping procedure...')
        if (output_folder == None):
            workdir_ref = os.path.join(os.getcwd(), 'stream_result')
        else:
            workdir_ref = output_folder
        adata = st.read(file_name='stream_result.pkl',
                        file_format='pkl',
                        file_path=workdir_ref,
                        workdir=workdir_ref)
        workdir = os.path.join(workdir_ref, os.pardir, 'mapping_result')
        adata_new = st.read(file_name=new_filename, workdir=workdir)
        st.add_cell_labels(adata_new, file_name=new_label_filename)
        st.add_cell_colors(adata_new, file_name=new_label_color_filename)
        if (s_method == 'LOESS'):
            st.map_new_data(adata, adata_new, feature='var_genes')
        if (s_method == 'all'):
            st.map_new_data(adata, adata_new, feature='all')
        if (flag_umap):
            st.plot_visualization_2D(adata,
                                     adata_new=adata_new,
                                     use_precomputed=False,
                                     save_fig=flag_savefig,
                                     fig_name='umap_new_cells')
            st.plot_visualization_2D(adata,
                                     adata_new=adata_new,
                                     show_all_colors=True,
                                     save_fig=flag_savefig,
                                     fig_name='umap_all_cells')
            st.plot_visualization_2D(adata,
                                     adata_new=adata_new,
                                     color_by='branch',
                                     save_fig=flag_savefig,
                                     fig_name='umap_branches')
        if (root is None):
            flat_tree = adata.uns['flat_tree']
            list_node_start = [
                value for key, value in nx.get_node_attributes(
                    flat_tree, 'label').items()
            ]
            for ns in list_node_start:
                st.subwaymap_plot(adata,
                                  adata_new=adata_new,
                                  percentile_dist=100,
                                  show_all_cells=False,
                                  root=ns,
                                  save_fig=flag_savefig)
                st.stream_plot(adata,
                               adata_new=adata_new,
                               show_all_colors=False,
                               root=ns,
                               fig_size=(8, 8),
                               save_fig=flag_savefig,
                               flag_log_view=flag_stream_log_view)
        else:
            st.subwaymap_plot(adata,
                              adata_new=adata_new,
                              percentile_dist=100,
                              show_all_cells=False,
                              root=root,
                              save_fig=flag_savefig)
            st.stream_plot(adata,
                           adata_new=adata_new,
                           show_all_colors=False,
                           root=root,
                           fig_size=(8, 8),
                           save_fig=flag_savefig,
                           flag_log_view=flag_stream_log_view)
        if ((genes != None) and (len(gene_list) > 0)):
            if (root is None):
                for ns in list_node_start:
                    st.subwaymap_plot_gene(adata,
                                           adata_new=adata_new,
                                           percentile_dist=100,
                                           root=ns,
                                           save_fig=flag_savefig,
                                           flag_log_view=flag_stream_log_view)
            else:
                st.subwaymap_plot_gene(adata,
                                       adata_new=adata_new,
                                       percentile_dist=100,
                                       root=root,
                                       save_fig=flag_savefig,
                                       flag_log_view=flag_stream_log_view)
        st.write(adata_new, file_name='stream_mapping_result.pkl')
    print('Finished computation.')
Beispiel #3
0
                            fig_name="branches_with_cells_elastic.pdf")

###Extend leaf branch to reach further cells
print("### Extend leaf branch to reach further cells ")
st.extend_elastic_principal_graph(adata)
st.plot_branches(adata,
                 save_fig=True,
                 fig_size=[16, 16],
                 fig_name="branches_extend_elastic.pdf")
st.plot_branches_with_cells(adata,
                            fig_legend=False,
                            save_fig=True,
                            fig_size=[16, 16],
                            fig_name="branches_with_cells_extend_elastic.pdf")

print("### plot_flat_tree")
st.plot_flat_tree(adata, fig_legend=False, save_fig=True, fig_size=[16, 16])

print("### write tsv file for stream command line")
X = adata.X
X.tofile(os.path.join(args.outdir, "adata.tsv"), sep="\t")
cell_label = adata.obs["label"]
cell_label.to_csv("stream_result/cell_label.tsv", sep="\t", index=False)

cell_label_color = adata.obs["label_color"]
cell_label_color.to_csv("stream_result/cell_label_color.tsv",
                        sep="\t",
                        index=False)

adata.write(os.path.join(args.outdir, "adata_stream.h5ad"), compression='gzip')
st.plot_branches_with_cells(adata)
plt.savefig('OPTepg_branches_and_cells.png')
plt.close('OPTepg_branches_and_cells.png')

# Extend leaf branch to reach further cells
st.extend_elastic_principal_graph(adata, epg_trimmingradius=0.1)
st.plot_branches(adata)
plt.savefig('EXTepg_branches.png')
plt.close('EXTepg_branches.png')

st.plot_branches_with_cells(adata)
plt.savefig('EXTepg_branches_and_cells.png')
plt.close('EXTepg_branches_and_cells.png')

# Plot flat tree
st.plot_flat_tree(adata, fig_legend_ncol=6, fig_size=(15, 15))
plt.savefig('flat_tree.png')
plt.close('flat_tree.png')

# Validate the learned structure by visualizing the branch assignment
st.plot_visualization_2D(adata, fig_legend_ncol=3, fig_size=(15, 15))
plt.savefig('branch_assign.png')
plt.close('branch_assign.png')

st.plot_visualization_2D(adata,
                         color_by='branch',
                         fig_legend_ncol=3,
                         fig_size=(15, 15))
plt.savefig('branch_assign_color.png')
plt.close('branch_assign_color.png')
def stream_test_Nestorowa_2016():

    workdir = os.path.join(_root, 'datasets/Nestorowa_2016/')

    temp_folder = tempfile.gettempdir()

    tar = tarfile.open(workdir + 'output/stream_result.tar.gz')
    tar.extractall(path=temp_folder)
    tar.close()
    ref_temp_folder = os.path.join(temp_folder, 'stream_result')

    print(workdir + 'data_Nestorowa.tsv.gz')
    input_file = os.path.join(workdir, 'data_Nestorowa.tsv.gz')
    label_file = os.path.join(workdir, 'cell_label.tsv.gz')
    label_color_file = os.path.join(workdir, 'cell_label_color.tsv.gz')
    comp_temp_folder = os.path.join(temp_folder, 'stream_result_comp')

    try:
        st.set_figure_params(dpi=80,
                             style='white',
                             figsize=[5.4, 4.8],
                             rc={'image.cmap': 'viridis'})
        adata = st.read(file_name=input_file, workdir=comp_temp_folder)
        adata.var_names_make_unique()
        adata.obs_names_make_unique()
        st.add_cell_labels(adata, file_name=label_file)
        st.add_cell_colors(adata, file_name=label_color_file)
        st.cal_qc(adata, assay='rna')
        st.filter_features(adata, min_n_cells=5)
        st.select_variable_genes(adata, n_genes=2000, save_fig=True)
        st.select_top_principal_components(adata,
                                           feature='var_genes',
                                           first_pc=True,
                                           n_pc=30,
                                           save_fig=True)
        st.dimension_reduction(adata,
                               method='se',
                               feature='top_pcs',
                               n_neighbors=100,
                               n_components=4,
                               n_jobs=2)
        st.plot_dimension_reduction(adata,
                                    color=['label', 'Gata1', 'n_genes'],
                                    n_components=3,
                                    show_graph=False,
                                    show_text=False,
                                    save_fig=True,
                                    fig_name='dimension_reduction.pdf')
        st.plot_visualization_2D(adata,
                                 method='umap',
                                 n_neighbors=100,
                                 color=['label', 'Gata1', 'n_genes'],
                                 use_precomputed=False,
                                 save_fig=True,
                                 fig_name='visualization_2D.pdf')
        st.seed_elastic_principal_graph(adata, n_clusters=20)
        st.plot_dimension_reduction(adata,
                                    color=['label', 'Gata1', 'n_genes'],
                                    n_components=2,
                                    show_graph=True,
                                    show_text=False,
                                    save_fig=True,
                                    fig_name='dr_seed.pdf')
        st.plot_branches(adata,
                         show_text=True,
                         save_fig=True,
                         fig_name='branches_seed.pdf')
        st.elastic_principal_graph(adata,
                                   epg_alpha=0.01,
                                   epg_mu=0.05,
                                   epg_lambda=0.01)
        st.plot_dimension_reduction(adata,
                                    color=['label', 'Gata1', 'n_genes'],
                                    n_components=2,
                                    show_graph=True,
                                    show_text=False,
                                    save_fig=True,
                                    fig_name='dr_epg.pdf')
        st.plot_branches(adata,
                         show_text=True,
                         save_fig=True,
                         fig_name='branches_epg.pdf')
        ###Extend leaf branch to reach further cells
        st.extend_elastic_principal_graph(adata,
                                          epg_ext_mode='QuantDists',
                                          epg_ext_par=0.8)
        st.plot_dimension_reduction(adata,
                                    color=['label'],
                                    n_components=2,
                                    show_graph=True,
                                    show_text=True,
                                    save_fig=True,
                                    fig_name='dr_extend.pdf')
        st.plot_branches(adata,
                         show_text=True,
                         save_fig=True,
                         fig_name='branches_extend.pdf')
        st.plot_visualization_2D(
            adata,
            method='umap',
            n_neighbors=100,
            color=['label', 'branch_id_alias', 'S4_pseudotime'],
            use_precomputed=False,
            save_fig=True,
            fig_name='visualization_2D_2.pdf')
        st.plot_flat_tree(adata,
                          color=['label', 'branch_id_alias', 'S4_pseudotime'],
                          dist_scale=0.5,
                          show_graph=True,
                          show_text=True,
                          save_fig=True)
        st.plot_stream_sc(adata,
                          root='S4',
                          color=['label', 'Gata1'],
                          dist_scale=0.5,
                          show_graph=True,
                          show_text=False,
                          save_fig=True)
        st.plot_stream(adata,
                       root='S4',
                       color=['label', 'Gata1'],
                       save_fig=True)
        st.detect_leaf_markers(adata,
                               marker_list=adata.uns['var_genes'][:300],
                               root='S4',
                               n_jobs=4)
        st.detect_transition_markers(adata,
                                     root='S4',
                                     marker_list=adata.uns['var_genes'][:300],
                                     n_jobs=4)
        st.detect_de_markers(adata,
                             marker_list=adata.uns['var_genes'][:300],
                             root='S4',
                             n_jobs=4)
        # st.write(adata,file_name='stream_result.pkl')
    except:
        print("STREAM analysis failed!")
        raise
    else:
        print("STREAM analysis finished!")

    print(ref_temp_folder)
    print(comp_temp_folder)

    pathlist = Path(ref_temp_folder)
    for path in pathlist.glob('**/*'):
        if path.is_file() and (not path.name.startswith('.')):
            file = os.path.relpath(str(path), ref_temp_folder)
            print(file)
            if (file.endswith('pdf')):
                if (os.path.getsize(os.path.join(comp_temp_folder, file)) > 0):
                    print('The file %s passed' % file)
                else:
                    raise Exception('Error! The file %s is not matched' % file)
            else:
                checklist = list()
                df_ref = pd.read_csv(os.path.join(ref_temp_folder, file),
                                     sep='\t')
                # print(df_ref.shape)
                # print(df_ref.head())
                df_comp = pd.read_csv(os.path.join(comp_temp_folder, file),
                                      sep='\t')
                # print(df_comp.shape)
                # print(df_comp.head())
                for c in df_ref.columns:
                    # print(c)
                    if (is_numeric_dtype(df_ref[c])):
                        checklist.append(all(np.isclose(df_ref[c],
                                                        df_comp[c])))
                    else:
                        checklist.append(all(df_ref[c] == df_comp[c]))
                if (all(checklist)):
                    print('The file %s passed' % file)
                else:
                    raise Exception('Error! The file %s is not matched' % file)

    print('Successful!')

    rmtree(comp_temp_folder, ignore_errors=True)
    rmtree(ref_temp_folder, ignore_errors=True)
def update_dr(color, alpha):
    fig_dr = st.plot_dimension_reduction(adata, color=[color], n_components=3, alpha=alpha, show_graph=True,
                                         show_text=False, plotly=True, return_fig=True)
    fig_dr.update_layout(
        autosize=False,
        width=450,
        height=350,
        margin=dict(
            l=30,
            r=30,
            b=5,
            t=5,
            pad=4
        ),
        plot_bgcolor='rgba(0,0,0,0)'
    )
    fig_ft = st.plot_visualization_2D(adata, method='umap', n_neighbors=50, alpha=alpha, color=[color],
                                      use_precomputed=True, plotly=True, return_fig=True)
    fig_ft.update_layout(
        autosize=False,
        width=450,
        height=350,
        margin=dict(
            l=30,
            r=30,
            b=5,
            t=5,
            pad=4
        ),
        plot_bgcolor='rgba(0,0,0,0)'
    )
    fig_2d = st.plot_flat_tree(adata, color=[color], alpha=alpha, dist_scale=0.5, show_graph=True, show_text=True,
                               plotly=True, return_fig=True)
    fig_2d.update_layout(
        autosize=False,
        width=450,
        height=350,
        margin=dict(
            l=30,
            r=30,
            b=5,
            t=5,
            pad=4
        ),
        plot_bgcolor='rgba(0,0,0,0)'
    )

    return html.Div([
        dbc.Row([dcc.Graph(figure=fig_dr, style={'margin-left': 'auto', 'margin-right': 'auto',
                                                 'lineHeight': '60px',
                                                 'borderWidth': '1px',
                                                 'borderStyle': 'dashed',
                                                 'borderRadius': '5px'}),
                 dcc.Graph(figure=fig_dr, style={'margin-left': 'auto', 'margin-right': 'auto',
                                                 'lineHeight': '60px',
                                                 'borderWidth': '1px',
                                                 'borderStyle': 'dashed',
                                                 'borderRadius': '5px'})]),
        html.Br(),
        dbc.Row([dcc.Graph(figure=fig_ft, style={'margin-left': 'auto', 'margin-right': 'auto',
                                                 'lineHeight': '60px',
                                                 'borderWidth': '1px',
                                                 'borderStyle': 'dashed',
                                                 'borderRadius': '5px'}),
                 dcc.Graph(figure=fig_2d, style={'margin-left': 'auto', 'margin-right': 'auto',
                                                 'lineHeight': '60px',
                                                 'borderWidth': '1px',
                                                 'borderStyle': 'dashed',
                                                 'borderRadius': '5px'})])
    ])
Beispiel #7
0
def main():
    sns.set_style('white')
    sns.set_context('poster')
    parser = argparse.ArgumentParser(
        description='%s Parameters' % __tool_name__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "-m",
        "--data-file",
        dest="input_filename",
        default=None,
        help="input file name, pkl format from Stream preprocessing module",
        metavar="FILE")
    parser.add_argument("-of",
                        "--of",
                        dest="output_filename_prefix",
                        default="StreamiFSOutput",
                        help="output file name prefix")
    parser.add_argument("-fig_width",
                        dest="fig_width",
                        type=int,
                        default=8,
                        help="")
    parser.add_argument("-fig_height",
                        dest="fig_height",
                        type=int,
                        default=8,
                        help="")
    parser.add_argument("-fig_legend_ncol",
                        dest="fig_legend_ncol",
                        type=int,
                        default=None,
                        help="")

    parser.add_argument("-root", dest="root", default=None, help="")
    parser.add_argument("-preference", dest="preference", help="")
    parser.add_argument("-subway_factor",
                        dest="subway_factor",
                        type=float,
                        default=2.0,
                        help="")
    parser.add_argument("-color_by", dest="color_by", default='label', help="")
    parser.add_argument("-factor_num_win",
                        dest="factor_num_win",
                        type=int,
                        default=10,
                        help="")
    parser.add_argument("-factor_min_win",
                        dest="factor_min_win",
                        type=float,
                        default=2.0,
                        help="")
    parser.add_argument("-factor_width",
                        dest="factor_width",
                        type=float,
                        default=2.5,
                        help="")
    parser.add_argument("-flag_log_view",
                        dest="flag_log_view",
                        action="store_true",
                        help="")
    parser.add_argument("-factor_zoomin",
                        dest="factor_zoomin",
                        type=float,
                        default=100.0,
                        help="")
    parser.add_argument("-flag_cells",
                        dest="flag_cells",
                        action="store_true",
                        help="")
    parser.add_argument("-flag_genes",
                        dest="flag_genes",
                        action="store_true",
                        help="")

    parser.add_argument("-genes", dest="genes", default=None, help="")
    parser.add_argument("-percentile_dist",
                        dest="percentile_dist",
                        type=float,
                        default=100,
                        help="")

    args = parser.parse_args()

    workdir = "./"

    adata = st.read(file_name=args.input_filename,
                    file_format='pkl',
                    experiment='rna-seq',
                    workdir=workdir)
    preference = args.preference.split(',')
    if (args.flag_cells != None):
        st.plot_flat_tree(adata,
                          save_fig=True,
                          fig_path="./",
                          fig_name=(args.output_filename_prefix +
                                    '_flat_tree.png'),
                          fig_size=(args.fig_width, args.fig_height),
                          fig_legend_ncol=args.fig_legend_ncol)

        st.subwaymap_plot(adata,
                          root=args.root,
                          percentile_dist=args.percentile_dist,
                          preference=preference,
                          factor=args.subway_factor,
                          color_by=args.color_by,
                          save_fig=True,
                          fig_path="./",
                          fig_name=(args.output_filename_prefix +
                                    '_cell_subway_map.png'),
                          fig_size=(args.fig_width, args.fig_height),
                          fig_legend_ncol=args.fig_legend_ncol)

        st.stream_plot(adata,
                       root=args.root,
                       preference=preference,
                       factor_num_win=args.factor_num_win,
                       factor_min_win=args.factor_min_win,
                       factor_width=args.factor_width,
                       flag_log_view=args.flag_log_view,
                       factor_zoomin=args.factor_zoomin,
                       save_fig=True,
                       fig_path="./",
                       fig_name=(args.output_filename_prefix +
                                 '_cell_stream_plot.png'),
                       fig_size=(args.fig_width, args.fig_height),
                       fig_legend=True,
                       fig_legend_ncol=args.fig_legend_ncol,
                       tick_fontsize=20,
                       label_fontsize=25)

    if (args.flag_genes != None):
        genes = args.genes.split(',')
        st.subwaymap_plot_gene(adata,
                               root=args.root,
                               genes=genes,
                               preference=preference,
                               percentile_dist=args.percentile_dist,
                               factor=args.subway_factor,
                               save_fig=True,
                               fig_path="./",
                               fig_format='png',
                               fig_size=(args.fig_width, args.fig_height))
        #              , fig_name=(args.output_filename_prefix + '_gene_subway_plot.png'))

        st.stream_plot_gene(adata,
                            root=args.root,
                            genes=genes,
                            preference=preference,
                            factor_min_win=args.factor_min_win,
                            factor_num_win=args.factor_num_win,
                            factor_width=args.factor_width,
                            save_fig=True,
                            fig_path="./",
                            fig_format='png',
                            fig_size=(args.fig_width, args.fig_height),
                            tick_fontsize=20,
                            label_fontsize=25)
        #           , fig_name=(args.output_filename_prefix + '_gene_stream_plot.png'))

    st.write(adata,
             file_name=(args.output_filename_prefix + '_stream_result.pkl'),
             file_path='./',
             file_format='pkl')

    print('Finished computation.')