Beispiel #1
0
def test_from_contents_invalid(id_column):
    contents = OrderedDict([('cat1', {'aa', 'bb', 'cc'}),
                            ('cat2', {'cc', 'dd'}), ('cat3', {'ee'})])
    with pytest.raises(ValueError, match='columns overlap'):
        from_contents(contents,
                      data=pd.DataFrame({'cat1': [1, 2, 3, 4, 5]}),
                      id_column=id_column)
    with pytest.raises(ValueError, match='duplicate ids'):
        from_contents({
            'cat1': ['aa', 'bb'],
            'cat2': ['dd', 'dd']
        },
                      id_column=id_column)
    # category named id
    with pytest.raises(ValueError, match='cannot be named'):
        from_contents({
            id_column: {'aa', 'bb', 'cc'},
            'cat2': {'cc', 'dd'},
        },
                      id_column=id_column)
    # category named id
    with pytest.raises(ValueError, match='cannot contain'):
        from_contents(contents,
                      data=pd.DataFrame({id_column: [1, 2, 3, 4, 5]},
                                        index=['aa', 'bb', 'cc', 'dd', 'ee']),
                      id_column=id_column)
    with pytest.raises(ValueError, match='identifiers in contents'):
        from_contents({'cat1': ['aa']},
                      data=pd.DataFrame([[1]]),
                      id_column=id_column)
Beispiel #2
0
def test_from_contents(typ=set, id_column='id'):
    contents = OrderedDict([('cat1', {'aa', 'bb', 'cc'}),
                            ('cat2', {'cc', 'dd'}), ('cat3', {'ee'})])
    empty_data = pd.DataFrame(index=['aa', 'bb', 'cc', 'dd', 'ee'])
    baseline = from_contents(contents, data=empty_data, id_column=id_column)
    # data=None
    out = from_contents(contents, id_column=id_column)
    assert_frame_equal(out.sort_values(id_column), baseline)

    # unordered contents dict
    out = from_contents(
        {
            'cat3': contents['cat3'],
            'cat2': contents['cat2'],
            'cat1': contents['cat1']
        },
        data=empty_data,
        id_column=id_column)
    assert_frame_equal(out.reorder_levels(['cat1', 'cat2', 'cat3']), baseline)

    # empty category
    out = from_contents(
        {
            'cat1': contents['cat1'],
            'cat2': contents['cat2'],
            'cat3': contents['cat3'],
            'cat4': []
        },
        data=empty_data,
        id_column=id_column)
    assert not out.index.to_frame()['cat4'].any()  # cat4 should be all-false
    assert len(out.index.names) == 4
    out.index = out.index.to_frame().set_index(['cat1', 'cat2', 'cat3']).index
    assert_frame_equal(out, baseline)
Beispiel #3
0
def run(json_file, kallisto_spec, eVIPP_predict, output_name):

    spec_genes = pd.read_csv(kallisto_spec, sep="\t",
                             index_col="#gene_id").index.tolist()
    pathways = pd.read_csv(eVIPP_predict, sep="\t",
                           index_col="Pathway").index.tolist()

    with open(json_file) as f:
        gene_set_dict = json.load(f)

    #subset
    gene_set_dict = {k: v for (k, v) in gene_set_dict.items() if k in pathways}

    if len(gene_set_dict) > 1:
        spec_dict = {}
        for k, v in gene_set_dict.items():
            spec_dict[k] = [i for i in v if i in spec_genes]

        e = from_contents(spec_dict)

        upsetplot.plot(e,
                       sort_by='cardinality',
                       sort_categories_by='cardinality',
                       show_counts=True)
        plt.savefig(output_name)
        plt.clf()
Beispiel #4
0
 def generate(self):
     self.crowds = dict()
     conn = psycopg2.connect(dbname=self.pgsql_creds["database"],
                             user=self.pgsql_creds["username"],
                             host=self.pgsql_creds["host"],
                             password=self.pgsql_creds["password"])
     cur = conn.cursor()
     cutoff_datetime = datetime.datetime(2021,
                                         8,
                                         1,
                                         tzinfo=datetime.timezone.utc)
     for channel in self.channel_ids:
         cur.execute("select name from channel where id = %s;", (channel, ))
         chan_name = cur.fetchone()[0].replace(" Ch. hololive-EN", "")
         chan_name = chan_name.replace("【NIJISANJI EN】", "")
         print(chan_name)
         cur.execute(
             "select distinct user_id from messages m inner join video v on m.video_id = v.video_id inner join channel c on c.id = v.channel_id where v.channel_id = %s and m.time_sent >= %s;",
             (channel, cutoff_datetime))
         results = cur.fetchall()
         self.crowds[chan_name] = set(results)
     conn.close()
     data = up.from_contents(self.crowds)
     updata = up.UpSet(data)
     updata.plot()
     plt.show()
Beispiel #5
0
    def upset_members(self, threshold=0, path=None, plot_upset=False, show_counts_bool=True, exclude_singletons_from_threshold=False, threshold_dual_cats=None, exclude_skids=None):

        celltypes = self.Celltypes

        contents = {} # empty dictionary
        for celltype in celltypes:
            name = celltype.get_name()
            contents[name] = celltype.get_skids()

        data = from_contents(contents)

        # identify indices of set intersection between all data and exclude_skids
        if(exclude_skids!=None):
            ind_dict = dict((k,i) for i,k in enumerate(data.id.values))
            inter = set(ind_dict).intersection(exclude_skids)
            indices = [ind_dict[x] for x in inter]
            data = data.iloc[np.setdiff1d(range(0, len(data)), indices)]

        unique_indices = np.unique(data.index)
        cat_types = [Celltype(' and '.join([data.index.names[i] for i, value in enumerate(index) if value==True]), 
                    list(data.loc[index].id)) for index in unique_indices]

        # apply threshold to all category types
        if(exclude_singletons_from_threshold==False):
            cat_bool = [len(x.get_skids())>=threshold for x in cat_types]
        
        # allows categories with no intersection ('singletons') to dodge the threshold
        if((exclude_singletons_from_threshold==True) & (threshold_dual_cats==None)): 
            cat_bool = [(((len(x.get_skids())>=threshold) | (' and ' not in x.get_name()))) for x in cat_types]

        # allows categories with no intersection ('singletons') to dodge the threshold and additional threshold for dual combos
        if((exclude_singletons_from_threshold==True) & (threshold_dual_cats!=None)): 
            cat_bool = [(((len(x.get_skids())>=threshold) | (' and ' not in x.get_name())) | (len(x.get_skids())>=threshold_dual_cats) & (x.get_name().count('+')<2)) for x in cat_types]

        cats_selected = list(np.array(cat_types)[cat_bool])
        skids_selected = [x for sublist in [cat.get_skids() for cat in cats_selected] for x in sublist]

        # identify indices of set intersection between all data and skids_selected
        ind_dict = dict((k,i) for i,k in enumerate(data.id.values))
        inter = set(ind_dict).intersection(skids_selected)
        indices = [ind_dict[x] for x in inter]

        data = data.iloc[indices]

        # identify skids that weren't plotting in upset plot (based on plotting threshold)
        all_skids = [x for sublist in [cat.get_skids() for cat in cat_types] for x in sublist]
        skids_excluded = list(np.setdiff1d(all_skids, skids_selected))

        if(plot_upset):
            if(show_counts_bool):
                fg = plot(data, sort_categories_by = None, show_counts='%d')
            else: 
                fg = plot(data, sort_categories_by = None)

            if(threshold_dual_cats==None):
                plt.savefig(f'{path}_excluded{len(skids_excluded)}_threshold{threshold}.pdf', bbox_inches='tight')
            if(threshold_dual_cats!=None):
                plt.savefig(f'{path}_excluded{len(skids_excluded)}_threshold{threshold}_dual-threshold{threshold_dual_cats}.pdf', bbox_inches='tight')

        return (cat_types, cats_selected, skids_excluded)
def plot_protein_upset(protein_dict):
    color = '#21918cff'
    plot_df = upsetplot.from_contents(protein_dict)
    upsetplot.plot(plot_df, sort_by='cardinality', subset_size='auto', facecolor=color)
    # plt.ylim(0, 60)
    plt.title("Distribution of Protein Overlap")

    plt.savefig("Protein_upset.svg")
    plt.savefig("Protein_upset.png")
Beispiel #7
0
def parse_ncov_watch(input_file: Path) -> upsetplot.UpSet:
    """
    Parse output from ncov_watch
    """
    variants = pd.read_csv(input_file, sep='\t')
    variant_sets = {}
    variant_sets = variants.groupby('mutation')['sample'].apply(list).to_dict()
    variant_sets = upsetplot.from_contents(variant_sets)
    return variant_sets
Beispiel #8
0
def writeIntersectionPlot(inputIterators, iter):
    contents = {}
    for circIter in inputIterators:
        contents[circIter.name] = [
            c for c in iter
            if (c.getMeta(circIter.id) != CircRow.META_INDEX_CIRC_NOT_IN_DB)
        ]

    df = from_contents(contents)
    plot(df, facecolor="red", sort_by="cardinality", show_counts='%d')
    pyplot.savefig('./output/out.png')
 def plot_peptide_upset(self, save=False):
     color = '#21918cff'
     plot_df = upsetplot.from_contents(self.peptide_dict)
     upsetplot.plot(plot_df,
                    sort_by='cardinality',
                    subset_size='auto',
                    facecolor=color)
     # plt.ylim(0, 400)
     plt.title("Distribution of Peptide Overlap")
     if save:
         plt.savefig("Peptide_upset.svg")
         plt.savefig("Peptide_upset.png")
def get_intersection(data: dict, sets_outfile: str = "upsetplot.tsv"):
    """
    Take a dict of lists of unique identifiers, make quantitative venn diagram.

    Arguments:
        (REQUIRED) data: dict of lists, values within each list must be unique.
        (OPTIONAL) sets_outfile: save the set memberships here
    """
    data = from_contents(data)
    if sets_outfile:
        memberships = data.reset_index().set_index("id")
        memberships["present_in_sets"] = memberships.sum(axis=1)
        memberships.to_csv(sets_outfile, sep="\t")
    return data
Beispiel #11
0
def plot_upset(sets, path):

    if len(sets) > 1:
        df_upset = from_contents(sets)
        upset_plot = UpSet(df_upset,
                           sort_by='degree',
                           sort_categories_by='cardinality',
                           show_counts=True,
                           show_percentages=True)
        fig = plt.figure()
        upset_plot.plot(fig=fig)
        fig.savefig(path)
    elif len(sets) in {0, 1}:
        print(f'plot_upset: No sets to intersect for {path}')
Beispiel #12
0
def test_from_contents_vs_memberships(data, typ, id_column):
    contents = OrderedDict([('cat1', typ(['aa', 'bb', 'cc'])),
                            ('cat2', typ(['cc', 'dd'])),
                            ('cat3', typ(['ee']))])
    # Note that ff is not present in contents
    data_df = pd.DataFrame(data, index=['aa', 'bb', 'cc', 'dd', 'ee', 'ff'])
    baseline = from_contents(contents, data=data_df, id_column=id_column)
    # compare from_contents to from_memberships
    expected = from_memberships(memberships=[{'cat1'}, {'cat1'},
                                             {'cat1', 'cat2'}, {'cat2'},
                                             {'cat3'}, []],
                                data=data_df)
    assert_series_equal(
        baseline[id_column].reset_index(drop=True),
        pd.Series(['aa', 'bb', 'cc', 'dd', 'ee', 'ff'], name=id_column))
    assert_frame_equal(baseline.drop([id_column], axis=1), expected)
Beispiel #13
0
def parse_type_variant(input_file: Path) -> upsetplot.UpSet:
    """
    Parse the output csv from type_variant
    """
    variants = pd.read_csv(input_file, sep=',')

    variant_sets = {}
    genotype_columns = []
    for column in variants.columns:
        # handle SNP and amino acid changes
        if column.startswith('aa:') or column.startswith('snp:'):
            variant = column
            # last character of variant
            alt = variant[-1]
            strains_with_alt = set(variants.loc[variants[variant] == alt,
                                                'query'].values)
            variant_sets[variant] = strains_with_alt
        # handle deletions slightly differently
        elif column.startswith('del:'):
            variant = column
            strains_with_del = set(variants.loc[variants[variant] == 'del',
                                                'query'].values)
            variant_sets[variant] = strains_with_del

            # x is returned if something other than ref or deletion is found
            strains_with_other = set(variants.loc[variants[variant] == 'X',
                                                  'query'].values)
            if len(strains_with_other) > 0:
                variant_sets[variant + ":X"] = strains_with_other

        # skip other non-genotype columns
        else:
            pass

    # convert sets to df containing intersections
    variant_sets = upsetplot.from_contents(variant_sets)
    return variant_sets
Beispiel #14
0
def run(json_file, gmt, kallisto_spec, eVIPP_predict, output_name):

    spec_genes = pd.read_csv(kallisto_spec, sep="\t",
                             index_col="#gene_id").index.tolist()
    pathways = pd.read_csv(eVIPP_predict, sep="\t",
                           index_col="Pathway").index.tolist()

    if json_file:
        with open(json_file, "rb") as JSON:
            old_dict = json.load(JSON)
            gene_set_dict = {
                k: v
                for k, v in old_dict.iteritems() if v is not None
            }

    if gmt:
        with open(gmt) as gmt_:
            content = gmt_.read().splitlines()
            lines = [i.split("\t") for i in content]
            gene_set_dict = {item[0]: item[2:] for item in lines}
    #subset
    gene_set_dict = {k: v for (k, v) in gene_set_dict.items() if k in pathways}

    if len(gene_set_dict) > 1:
        spec_dict = {}
        for k, v in gene_set_dict.items():
            spec_dict[k] = [i for i in v if i in spec_genes]

        e = from_contents(spec_dict)

        upsetplot.plot(e,
                       sort_by='cardinality',
                       sort_categories_by='cardinality',
                       show_counts=True)
        plt.savefig(output_name)
        plt.clf()
Beispiel #15
0
            elemName = line[0]
            dir = line[8]
            adjpval = line[7]

            try:
                adjpval = float(adjpval)

                if adjpval < pvalThreshold:
                    dir2file2sigs[dir][args.prefixes[tidx]].add(elemName)
            except:
                pass

    for dir in dir2file2sigs:

        allFileData = dir2file2sigs[dir]

        outname = args.output + "." + dir

        if not outname.endswith(".png"):
            outname += ".png"

        if len(allFileData) > 1:

            upIn = from_contents(allFileData)
            plot(upIn, subset_size="auto")

            plt.title("Overlap for {} elements (pval < {})".format(
                dir, pvalThreshold))

            plt.savefig(outname, bbox_inches="tight")

def extracting_rsid_from_file(filename):
    df = pd.read_csv(filename)
    file_rsid_dict[filename] = set(df['rsid'].values)
    #return file_rsid_dict


for file in files:
    extracting_rsid_from_file(file)

print(len(file_rsid_dict))

# k=from_memberships([
# file_rsid_dict['A2_ALL_eur.csv'],
# file_rsid_dict['A2_ALL_eur_leave_ukbb.csv']
# ])
# print(k)

from upsetplot import generate_counts, plot, from_contents
from matplotlib import pyplot
contents = {
    'A2_ALL_eur.csv': file_rsid_dict['A2_ALL_eur.csv'],
    "A2_ALL_eur_leave_ukbb.csv": file_rsid_dict['A2_ALL_eur_leave_ukbb.csv'],
    "A2_ALL_leave_23andme.csv": file_rsid_dict['A2_ALL_leave_23andme.csv'],
    "A2_ALL_leave_UKBB.csv": file_rsid_dict['A2_ALL_leave_UKBB.csv']
}
k = from_contents(contents)
#plot(k, orientation='vertical', show_counts='%d')
plot(k, show_counts='%d')
pyplot.show()
Beispiel #17
0
 def gen_upset_plot(self, className=None):
     # total_peps = len([pep for s in self.results.samples for pep in s.peptides])
     total_peps = np.sum([len(s.peptides) for s in self.results.samples])
     data = from_contents({s.sample_name: set(s.peptides)
                           for s in self.results.samples})
     for intersection in data.index.unique():
         if len(data.loc[intersection, :])/total_peps < 0.005:
             data.drop(index=intersection, inplace=True)
     data['peptide_length'] = np.vectorize(len)(data['id'])
     n_sets = len(data.index.unique())
     if n_sets <= 100:  # Plot horizontal
         upset = UpSet(data,
                       sort_by='cardinality',
                       #sort_categories_by=None,
                       show_counts=True,)
                       #totals_plot_elements=4,
                       #intersection_plot_elements=10)
         upset.add_catplot(value='peptide_length', kind='boxen', color='gray')
         plot = upset.plot()
         plot['totals'].grid(False)
         ylim = plot['intersections'].get_ylim()[1]
         plot['intersections'].set_ylim((0, ylim * 1.1))
         for c in plot['intersections'].get_children():
             if isinstance(c, plotText):
                 text = c.get_text()
                 text = text.replace('\n', ' ')
                 c.set_text(text)
                 c.set_rotation('vertical')
                 pos = c.get_position()
                 pos = (pos[0], pos[1] + 0.02 * ylim)
                 c.set_position(pos)
     else:  # plot vertical
         upset = UpSet(data, subset_size='count',
                       orientation='vertical',
                       sort_by='cardinality',
                       sort_categories_by=None,
                       show_counts=True)
         upset.add_catplot(value='peptide_length', kind='boxen', color='gray')
         plot = upset.plot()
         lim = plot['intersections'].get_xlim()
         plot['intersections'].set_xlim([0, lim[1] * 1.6])
         plot['totals'].grid(False)
         ylim = plot['totals'].get_ylim()[1]
         for c in plot['totals'].get_children():
             if isinstance(c, plotText):
                 text = c.get_text()
                 text = text.replace('\n', ' ')
                 c.set_text(text)
                 c.set_rotation('vertical')
                 pos = c.get_position()
                 pos = (pos[0], pos[1] + 0.1 * ylim)
                 c.set_position(pos)
         plt.draw()
     upset_fig = f'{self.fig_dir / "upsetplot.svg"}'
     plt.savefig(upset_fig, bbox_inches="tight")
     encoded_upset_fig = base64.b64encode(open(upset_fig, 'rb').read()).decode()
     card = div(className='card', style="height: 100%")
     card.add(div([b('UpSet Plot'), p('Only intersections > 0.5% are displayed')], className='card-header'))
     plot_body = div(img(src=f'data:image/svg+xml;base64,{encoded_upset_fig}',
                         className='img-fluid',
                         style=f'width: 100%; height: auto'),
                     className='card-body')
     card.add(plot_body)
     return div(card, className=className)
def compute_overlaps(d,
                     unmatched,
                     ensembl_genes_map,
                     species,
                     use_gene_id=False,
                     gene_id_from_tool=False):
    final_table = defaultdict(list)
    sets = []

    if use_gene_id and not gene_id_from_tool:
        outbasename = "tools_overlap_gene_id_fetched_from_symbol"
    elif gene_id_from_tool:
        outbasename = "tools_overlap_gene_id_from_tool"
    else:
        outbasename = "tools_overlap_gene_name"

    for tool, counter in d.items():
        sets.append((tool, set(counter.keys())))
        for gene in counter:
            final_table[gene].append([tool, counter[gene]])

    with open("{}.tsv".format(outbasename), "w") as outw:

        if gene_id_from_tool:
            import mygene
            mg = mygene.MyGeneInfo()
            gene_ids = [k for k in final_table.keys()]
            # Get symbols from geneIDs
            ens_map = mg.querymany(
                qterms=gene_ids,
                scopes="ensembl.gene",
                fields=["symbol"],
                returnall=True,
                as_dataframe=True,
                size=1,
                species=species)['out'][["symbol"]].to_dict()['symbol']

            ensembl_map = {
                k
                for k, v in ens_map.items() if isinstance(v, float)
            }
            with open("{}_unmatched_genes.tsv".format(outbasename),
                      "w") as nogene:
                nogene.write("\n".join(list(ensembl_map)) + "\n")
            nogene.close()

        else:
            gene_names = [k for k in final_table.keys()]
            # Get gene IDs from symbols
            # ens_map = mg.querymany(qterms=gene_names, scopes="symbol", fields=["ensembl.gene"], returnall=True,
            #             as_dataframe=True, size=1, species="human")['out'][["ensembl.gene"]].to_dict()['ensembl.gene']
            # ensembl_map = {k for k, v in ens_map.items() if isinstance(v, float)}
            # many unknown as well as genes mapping to multiple IDs (due to the use of haplotype regions)

            ens_map = ensembl_genes_map[ensembl_genes_map['gene_name'].isin(gene_names)][["gene_name", "gene_id"]]. \
                drop_duplicates(keep="first").set_index("gene_name").to_dict()['gene_id']

            with open("{}_unmatched_genes.tsv".format(outbasename),
                      "w") as nogene:
                if not use_gene_id:  # if using gene names, unfetched gene names are only checked here
                    unmatched = [
                        v for v in gene_names if v not in list(ens_map.keys())
                    ]

                nogene.write("\n".join(list(unmatched)) + "\n")
            nogene.close()

        if gene_id_from_tool:
            outw.write(
                "#gene_id\tgene_name\ttools_with_event\tnumber_of_events" +
                "\n")
        else:
            outw.write(
                "#gene_name\tgene_id\ttools_with_event\tnumber_of_events" +
                "\n")

        for k, v in final_table.items():
            flat_list = [item for sublist in v for item in sublist]
            tools = [t for t in flat_list if isinstance(t, str)]
            number_of_events = [
                str(t) for t in flat_list if isinstance(t, int)
            ]
            try:
                outw.write(k + "\t" + ens_map[k] + "\t" + ",".join(tools) +
                           "\t" + ",".join(number_of_events) + "\n")
            except TypeError:  # e.g. cases where a vast-tools eventID is reported
                outw.write(k + "\t" + "" + "\t" + ",".join(tools) + "\t" +
                           ",".join(number_of_events) + "\n")
            except KeyError:  # e.g. cases where geneID doesn't exist for the given symbol
                outw.write(k + "\t" + "" + "\t" + ",".join(tools) + "\t" +
                           ",".join(number_of_events) + "\n")

    outw.close()
    plt.figure()
    plt.title("Genes with splicing events")
    plt.tight_layout()
    out = "{}.pdf".format(outbasename)
    labels = [v[0] for v in sets]
    just_sets = [v[1] for v in sets]

    if use_gene_id and not gene_id_from_tool:
        converted_sets = []
        for s in just_sets:
            new_s = set()
            for gene in s:
                try:
                    new_s.add(ens_map[gene])
                except KeyError:
                    continue
            converted_sets.append(new_s)
        just_sets = converted_sets

    if len(just_sets) == 2:
        venn2(just_sets, tuple(labels))
    elif len(just_sets) > 3:
        if use_gene_id:
            d = {tool: just_sets[i] for i, tool in enumerate(labels)}
        data = upsetplot.from_contents(d)

        upset = upsetplot.UpSet(data,
                                subset_size="count",
                                intersection_plot_elements=4,
                                show_counts='%d')
        upset.plot()
    else:
        venn3(just_sets, tuple(labels))

    plt.savefig(out)
    plt.close()
Beispiel #19
0
def plotTopExpGenes(exp_matrix,
                    id2symbol=None,
                    top=50,
                    controls="/nfs2/database/gencode_v29/chrM.gene.list",
                    ncols=2,
                    control_name='MT',
                    out_name="TopExpGenes.html",
                    no_strip_version=False,
                    venn_list: list = None,
                    venn_names: list = None):
    controls = [x.strip().split()[0]
                for x in open(controls)] if controls else []
    id2symbol = dict(x.strip().split()[:2]
                     for x in open(id2symbol)) if id2symbol else dict()
    df = pd.read_csv(exp_matrix, sep='\t', header=0, index_col=0)
    df.index.name = 'id'
    if not no_strip_version:
        df.index = [x.split('.')[0] for x in df.index]
    # plot for each sample
    plots = list()
    top_dict = dict()
    detected_gene_dict = dict()
    for sample in df.columns:
        data = df.loc[:, [sample]].sort_values(by=sample, ascending=False)
        detected_gene_dict[sample] = list(data[data[sample] > 0].index)
        data['percent'] = data[sample] / data[sample].sum()
        plot_data = data.iloc[:top].copy()
        top_dict[sample] = set(plot_data.index)
        total_percent = plot_data['percent'].sum()
        plot_data['order'] = range(plot_data.shape[0])
        plot_data['symbols'] = [
            id2symbol[x] if x in id2symbol else x for x in plot_data.index
        ]
        plot_data['marker'] = [
            '*' if x in controls else "circle" for x in plot_data.index
        ]
        p = figure(title="Top {} account for {:.2%} of total in {}".format(
            top, total_percent, sample),
                   tools="wheel_zoom,reset,hover",
                   tooltips=[
                       ('x', '@percent{0.00%}'),
                       ('y', '@symbols'),
                   ])
        # min_val = plot_data['percent'].min()
        # max_val = plot_data['percent'].max()
        # mapper = linear_cmap(field_name='percent', palette=bp.Oranges[8], low=min_val, high=max_val)
        p.xaxis.axis_label = '% of Total(~{})'.format(int(data[sample].sum()))
        for marker in ['*', 'circle']:
            source_data = plot_data[plot_data['marker'] == marker]
            if source_data.shape[0] == 0:
                continue
            source = ColumnDataSource(source_data)
            p.scatter(
                x='percent',
                y='order',
                # line_color=mapper,
                # color=mapper,
                fill_alpha=0.2,
                size=10,
                marker=marker,
                legend='non-{}'.format(control_name)
                if marker == "circle" else control_name,
                source=source)
        p.yaxis.ticker = list(range(plot_data.shape[0]))
        p.yaxis.major_label_overrides = dict(
            zip(range(plot_data.shape[0]),
                (id2symbol[x] if x in id2symbol else x
                 for x in plot_data.index)))
        p.xaxis[0].formatter = NumeralTickFormatter(format="0.00%")
        plots.append(p)
    else:
        fig = gridplot(plots, ncols=ncols)
        output_file(out_name)
        save(fig)
    # plot venn
    if venn_list is None:
        if len(top_dict) <= 6:
            venn.venn(top_dict, cmap="tab10")
            plt.savefig('venn.pdf')
    else:
        if len(venn_list) == 1 and ',' not in venn_list[0]:
            with open(venn_list[0]) as f:
                group_dict = dict(x.strip().split()[:2] for x in f)
                tmp_dict = dict()
                for k, v in group_dict.items():
                    tmp_dict.setdefault(v, set())
                    tmp_dict[v].add(k)
            venn_list = []
            venn_names = []
            for k, v in tmp_dict.items():
                venn_list.append(','.join(v))
                venn_names.append(k)
        if venn_names is None:
            venn_names = []
            for group in venn_list:
                venn_names.append(group.replace(',', '-'))
        for group, name in zip(venn_list, venn_names):
            groups = group.split(',')
            tmp_dict = {x: y for x, y in top_dict.items() if x in groups}
            if 2 <= len(tmp_dict) <= 6:
                venn.venn(tmp_dict,
                          cmap="tab10",
                          fmt="{size}\n{percentage:.2f}%",
                          fontsize=9)
                plt.savefig('top{}.{}.venn.png'.format(top, name), dpi=300)
            else:
                print('venn for {}?'.format(groups))
                print('venn only support 2-6 sets')
    plt.close()
    # intersection plot
    if venn_list is None:
        if len(detected_gene_dict) <= 9:
            plot(from_contents(detected_gene_dict),
                 sum_over=False,
                 sort_categories_by=None,
                 show_counts=True)
            plt.savefig('all.cmbVenn.png', dpi=300)
            plt.close()
    else:
        for group, name in zip(venn_list, venn_names):
            groups = group.split(',')
            if len(groups) < 2:
                continue
            tmp_dict = {
                x: y
                for x, y in detected_gene_dict.items() if x in groups
            }
            plot(from_contents(tmp_dict),
                 sum_over=False,
                 sort_categories_by=None,
                 show_counts=True)
            plt.savefig('all.{}.cmbVenn.png'.format(name), dpi=300)
            plt.close()
Beispiel #20
0
def run(files:list, exp=None, out_prefix='result', has_header=False,
        intersect_only=True, intersect_xoy=1, union_only=False, show_venn_percent=False,
        set_names:list=None, venn_list:list=None, venn_names:list=None, graph_format='png'):
    """
    根据文件内容构建集合, 并按指定规则进行运算, 默认计算所有集合的交集
    :param files: 当仅提供一个文件时, 文件的各列被当作是集合, 集合的元素是单元格的内容;
    提供多个文件时, 每个文件内容被当作一个集合, 集合的元素为一整行。
    :param exp: 表达式, 字符串的形式, 如's1-s2'表示第一个集合减去第二个集合, 集合顺序与文件提供的顺序一一对应
    :param out_prefix: 指定集合运算结果的文件名前缀
    :param has_header: 指定文件是否包含header, 默认无, 如有header, header不参与计算
    :param intersect_only: 默认提供, 不考虑exp指定的运算, 而是计算所有集合的交集, 即交集结果的所有元素在集合中出现的频数等于集合数
    :param intersect_xoy: 如提供, 不考虑exp指定的运算, 而是计算所有集合的交集, 而且输出交集结果的元素
    在所有集合中出现的频数大于或等于该参数指定的阈值.
    :param union_only: 计算各个集合的并集
    :param show_venn_percent: 如果提供,在venn图中显示百分比
    :param set_names: 用于画venn图, 对各个集合进行命名, 与文件名顺序应一致, 默认对文件名进行'.'分割获取第一个字符串作为集合名
    :param venn_list: 用于画venn图, 如 'A,B,C' 'B,C,D'表示画两个韦恩图, 第一个韦恩图用ABC集合, 而第二个韦恩图用BCD集合,
    默认None, 用所有集合画一个韦恩图; 另外, 可以给该参数输入一个文件, 第一列为集合名, 第二列为分组信息, 后续画图将按照此分组信息分别进行
    :param venn_names: 与venn_list一一对应, 用于分别命名venn图文件
    :param graph_format: output figure format, default png
    :return: None
    """
    venn_set_dict = dict()
    set_number = len(files)
    if len(files) >= 2:
        for ind, each in enumerate(files, start=1):
            exec('s{}=set(open("{}").readlines())'.format(ind, each))
            if set_names is None:
                name = os.path.basename(each).rsplit('.', 1)[0]
                exec('venn_set_dict["{}"] = s{}'.format(name, ind))
            else:
                exec('venn_set_dict["{}"] = s{}'.format(set_names[ind - 1], ind))
    else:
        import pandas as pd
        table = pd.read_table(files[0], header=0 if has_header else None)
        set_number = table.shape[1]
        set_names = table.columns if set_names is None else set_names
        for i in range(table.shape[1]):
            exec('s{}=set(table.iloc[:, {}].dropna())'.format(i+1, i))
            exec('venn_set_dict["{}"] = s{}'.format(set_names[i], i + 1))

    result = list()
    count_dict = dict()
    if exp:
        print("do as you say in exp")
        result = eval(exp)
    elif intersect_xoy > 1:
        print('do intersect_xoy')
        union = eval('|'.join(['s'+str(x) for x in range(1, set_number+1)]))
        result = set()
        for each in union:
            varspace = dict(locals())
            in_times = sum(eval("each in s{}".format(x), varspace) for x in range(1, set_number+1))
            if in_times >= intersect_xoy:
                result.add(each)
                count_dict[each] = in_times
    elif union_only:
        print('do union only')
        result = eval('|'.join(['s'+str(x) for x in range(1, set_number+1)]))
    elif intersect_only:
        print('do intersect only')
        result = eval('&'.join(['s'+str(x) for x in range(1, set_number+1)]))
    if not result:
        print('result is empty!')
    else:
        print('result size: {}'.format(len(result)))
    with open(out_prefix + '.list', 'w') as f:
        if not count_dict:
            _ = [f.write(x) for x in result]
        else:
            data = ([x, count_dict[x]] for x in result)
            _ = [f.write(x.strip() + '\t' + str(count_dict[x]) + '\n') for x in result]
    if exp:
        return
    # plot venn
    if venn_list is None:
        if 2 <= len(venn_set_dict) <= 6:
            if show_venn_percent:
                venn.venn(venn_set_dict, cmap="tab10", fmt="{size}\n{percentage:.2f}%", fontsize=9)
            else:
                venn.venn(venn_set_dict, cmap="tab10")
            plt.savefig(out_prefix+f'.venn.{graph_format}')
    else:
        if len(venn_list) == 1 and ',' not in venn_list[0]:
            with open(venn_list[0]) as f:
                group_dict = dict(x.strip().split()[:2] for x in f)
                tmp_dict = dict()
                for k, v in group_dict.items():
                    tmp_dict.setdefault(v, set())
                    tmp_dict[v].add(k)
            venn_list = []
            venn_names = []
            for k, v in tmp_dict.items():
                venn_list.append(','.join(v))
                venn_names.append(k)

        if venn_names is None:
            venn_names = []
            for group in venn_list:
                venn_names.append(group.replace(',', '-'))

        for group, name in zip(venn_list, venn_names):
            groups = group.split(',')
            tmp_dict = {x: y for x, y in venn_set_dict.items() if x in groups}
            if 2 <= len(tmp_dict) <= 6:
                if show_venn_percent:
                    venn.venn(tmp_dict, cmap="tab10", fmt="{size}\n{percentage:.2f}%", fontsize=9)
                else:
                    venn.venn(tmp_dict, cmap="tab10")
                out_name = out_prefix + '.{}.venn.{}'.format(name, graph_format)
                plt.savefig(out_name, dpi=300)
                plt.close()

            else:
                print('venn for {}?'.format(groups))
                print('venn only support 2-6 sets')

    # intersection plot
    if venn_list is None:
        if len(venn_set_dict) <= 8:
            plot(from_contents(venn_set_dict), sum_over=False, sort_categories_by=None, show_counts=True)
            plt.savefig('{}.upSet.{}'.format(out_prefix, graph_format), dpi=300)
            plt.close()
    else:
        for group, name in zip(venn_list, venn_names):
            groups = group.split(',')
            tmp_dict = {x: y for x, y in venn_set_dict.items() if x in groups}
            if len(tmp_dict) > 1:
                plot(from_contents(tmp_dict), sum_over=False, sort_categories_by=None, show_counts=True)
                plt.savefig('{}.{}.upSet.{}'.format(out_prefix, name, graph_format), dpi=300)
                plt.close()
Beispiel #21
0
# read the table
df = pd.read_table(snakemake.input[0], comment="#", index_col=0)
df = df.drop(":-")
df = df > 0

# put it in a dictionary
data = dict()
for col in df.columns:
    row = list()
    for index, val in df[col].iteritems():
        if val:
            row.append(index)
    data[col] = row

f, ax = plt.subplots()
ax.axis("off")

# generate the upset plot
# but make sure to filter to a max of 31 combinations
# 31 is the maximum of combinations of 5 different items
content = from_contents(data)
uniques, counts = np.unique(content.index, return_counts=True)

sorted_uniques = [x for _, x in sorted(zip(counts, uniques), reverse=True)]

plot(content.loc[sorted_uniques[:31]], sort_by=None)

upsetplot = plot(from_contents(data), fig=f)

plt.savefig(snakemake.output[0], dpi=250)
    gene_lists[file.split('/')[-1].split('.')[0]] = these_genes
    combined_list += these_genes
    de_intersection = list(set(these_genes) & set(all_de_genes))
    intersection_dict[file.split('/')[-1].split('.')[0]] = len(de_intersection)/len(these_genes)
    # Write the file with genes that intersect with DE genes
    open(output + '/' + file.split('/')[-1].split('.')[0] + '_DEGenes_Intersection.txt', 'w').write\
        ('\n'.join(de_intersection))

# Write another file that intersects the merged gene IDs from all approaches with the DE genes
open(output + '/' + 'ApproachesMerged_DEGenes_intersection.txt', 'w').write\
    ('\n'.join(list(set(combined_list) & set(all_de_genes))))


# create the barplot
f, ax = plt.subplots(1, figsize=(8, 6))
plt.ylabel('Percentage of intersection with DE genes', fontsize=14)
plt.title('Fraction of identified target genes that\nare also differentially expressed', fontsize=16, fontweight='bold')
ax.set_facecolor('#f2f2f2')
ax.grid(True, axis='y', color='white', linewidth=1, which='major')
ax.bar([x for x in range(len(file_list))], [intersection_dict[x] for x in list(intersection_dict.keys())], width=0.8, color='#045f8c', zorder=12)
plt.xticks([x for x in range(len(file_list))], [x for x in list(intersection_dict.keys())], fontsize=14)
plt.savefig(output + '/DEGenes_Intersection_Fraction.pdf', bbox_inches='tight')


# Upsetplot creates the intersection between each pairing of elements
intersection = upsetplot.from_contents(gene_lists)

upset = upsetplot.UpSet(intersection, show_counts=True, element_size=40, intersection_plot_elements=8)
upset.plot()
plt.savefig(output + '/DE_Genes_Intersection_UpSet.pdf', bbox_inches='tight')
        for compMethod in foundRes[topN]:

            print(topN, compMethod, "before sets")

            inputSets = [set([y[0] for y in x[1]]) for x in foundRes[topN][compMethod]]

            print(topN, compMethod, "after sets")

            method2genes = {}
            for x in foundRes[topN][compMethod]:
                method2genes[x[0]] = set([y[0] for y in x[1]])
                print(x[0], len(method2genes[x[0]]))

            #print(set(method2genes["RobustDE+Robust"]).difference(method2genes["combined+msEmpiRe_DESeq2"]))

            upIn = from_contents(method2genes)

            print(topN, compMethod, "after content", [x for x in method2genes], upIn.index.values)

            lvls = set([x for x in upIn.index.values])
            print("lvls", lvls)

            if len(lvls) <= 2:

                plt.figure()
                plt.title("no data to plot - maybe only 1 or 2 groups?")
                outname = args.output + "." + str(topN) + "." + compMethod

                if not outname.endswith(".png"):
                    outname += ".png"