Пример #1
0
def prepare_intersection_data(data: CardLiveData,
                              type_value: str) -> upsetplot.UpSet:
    """
    Prepare the CardLiveData to generate intersection plots, specifcally
    convert into an UpSet object containing all intersections and cardinalities
    :param data: a CardLiveData object from which the rgi_parser is called
    :param type_value: The category in RGI to plot set membersips for
    :return: An upsetplot.UpSet class containing the intersections and category
             memberships for creating a plotly based UpSet plot
    """

    totals_df = data.rgi_parser.get_column_values(data_type=type_value,
                                                  values_name='categories',
                                                  drop_duplicates=True)
    totals_df = totals_df.dropna()
    category_sets = totals_df.reset_index().groupby('filename')\
                                   .agg(lambda x: tuple(x)).applymap(list)
    category_sets = category_sets['categories']\
                    .apply(lambda x: sorted(x)).sort_values().apply(tuple)
    category_sets = category_sets.value_counts()

    # convert to upset data
    upset_data = upsetplot.from_memberships(category_sets.index,
                                            category_sets.values)
    upset_data = upsetplot.UpSet(upset_data, sort_by='cardinality')
    return upset_data
Пример #2
0
 def generate(self):
     self.crowds = dict()
     conn = psycopg2.connect(dbname=self.pgsql_creds["database"],
                             user=self.pgsql_creds["username"],
                             host=self.pgsql_creds["host"],
                             password=self.pgsql_creds["password"])
     cur = conn.cursor()
     cutoff_datetime = datetime.datetime(2021,
                                         8,
                                         1,
                                         tzinfo=datetime.timezone.utc)
     for channel in self.channel_ids:
         cur.execute("select name from channel where id = %s;", (channel, ))
         chan_name = cur.fetchone()[0].replace(" Ch. hololive-EN", "")
         chan_name = chan_name.replace("【NIJISANJI EN】", "")
         print(chan_name)
         cur.execute(
             "select distinct user_id from messages m inner join video v on m.video_id = v.video_id inner join channel c on c.id = v.channel_id where v.channel_id = %s and m.time_sent >= %s;",
             (channel, cutoff_datetime))
         results = cur.fetchall()
         self.crowds[chan_name] = set(results)
     conn.close()
     data = up.from_contents(self.crowds)
     updata = up.UpSet(data)
     updata.plot()
     plt.show()
Пример #3
0
    def run(self, output):
        dcount = 0
        dbstr = " ".join(self.dbs)

        if os.path.exists(output + ".raw.tab"):
            print("Starting from previous task")
            with open(output + ".raw.tab", 'r') as input:
                for l in input:
                    s = l.rstrip().split()
                    self.counter[s[0]] = int(s[1])
        else:
            with sp.Popen(f'{self.meryl} print venn {dbstr}',
                          shell=True,
                          stdout=sp.PIPE,
                          bufsize=1,
                          universal_newlines=True) as sf:
                for h in sf.stdout:
                    s = h.split()
                    self.counter[s[1]] += 1
                    dcount += 1
                    if dcount % 10000000 == 0:
                        print(f'Progress: {dcount}')

            # print out raw data
            with open(output + ".raw.tab", 'w') as out:
                for w in sorted(self.counter,
                                key=self.counter.get,
                                reverse=True):
                    out.write(f'{w}\t{self.counter[w]}\n')

            print("Created raw output file")

        # Prepare membership df
        array = list()
        data = list()
        for k, v in self.counter.items():
            tlist = list()
            for i, e in enumerate(self.dbs):
                if (int(k) & (1 << int(i))):
                    # The bit is set, add the file name
                    tlist.append(basename(e).split('.')[0])
            array.append(tlist)
            data.append(v)

        # Plot things out
        dataset = upsetplot.from_memberships(array, data=data)
        print(dataset)

        upset = upsetplot.UpSet(dataset,
                                sort_by='cardinality',
                                show_percentages=True)
        upset.plot()

        plt.savefig(output + ".pdf")
def compute_overlaps(d,
                     unmatched,
                     ensembl_genes_map,
                     species,
                     use_gene_id=False,
                     gene_id_from_tool=False):
    final_table = defaultdict(list)
    sets = []

    if use_gene_id and not gene_id_from_tool:
        outbasename = "tools_overlap_gene_id_fetched_from_symbol"
    elif gene_id_from_tool:
        outbasename = "tools_overlap_gene_id_from_tool"
    else:
        outbasename = "tools_overlap_gene_name"

    for tool, counter in d.items():
        sets.append((tool, set(counter.keys())))
        for gene in counter:
            final_table[gene].append([tool, counter[gene]])

    with open("{}.tsv".format(outbasename), "w") as outw:

        if gene_id_from_tool:
            import mygene
            mg = mygene.MyGeneInfo()
            gene_ids = [k for k in final_table.keys()]
            # Get symbols from geneIDs
            ens_map = mg.querymany(
                qterms=gene_ids,
                scopes="ensembl.gene",
                fields=["symbol"],
                returnall=True,
                as_dataframe=True,
                size=1,
                species=species)['out'][["symbol"]].to_dict()['symbol']

            ensembl_map = {
                k
                for k, v in ens_map.items() if isinstance(v, float)
            }
            with open("{}_unmatched_genes.tsv".format(outbasename),
                      "w") as nogene:
                nogene.write("\n".join(list(ensembl_map)) + "\n")
            nogene.close()

        else:
            gene_names = [k for k in final_table.keys()]
            # Get gene IDs from symbols
            # ens_map = mg.querymany(qterms=gene_names, scopes="symbol", fields=["ensembl.gene"], returnall=True,
            #             as_dataframe=True, size=1, species="human")['out'][["ensembl.gene"]].to_dict()['ensembl.gene']
            # ensembl_map = {k for k, v in ens_map.items() if isinstance(v, float)}
            # many unknown as well as genes mapping to multiple IDs (due to the use of haplotype regions)

            ens_map = ensembl_genes_map[ensembl_genes_map['gene_name'].isin(gene_names)][["gene_name", "gene_id"]]. \
                drop_duplicates(keep="first").set_index("gene_name").to_dict()['gene_id']

            with open("{}_unmatched_genes.tsv".format(outbasename),
                      "w") as nogene:
                if not use_gene_id:  # if using gene names, unfetched gene names are only checked here
                    unmatched = [
                        v for v in gene_names if v not in list(ens_map.keys())
                    ]

                nogene.write("\n".join(list(unmatched)) + "\n")
            nogene.close()

        if gene_id_from_tool:
            outw.write(
                "#gene_id\tgene_name\ttools_with_event\tnumber_of_events" +
                "\n")
        else:
            outw.write(
                "#gene_name\tgene_id\ttools_with_event\tnumber_of_events" +
                "\n")

        for k, v in final_table.items():
            flat_list = [item for sublist in v for item in sublist]
            tools = [t for t in flat_list if isinstance(t, str)]
            number_of_events = [
                str(t) for t in flat_list if isinstance(t, int)
            ]
            try:
                outw.write(k + "\t" + ens_map[k] + "\t" + ",".join(tools) +
                           "\t" + ",".join(number_of_events) + "\n")
            except TypeError:  # e.g. cases where a vast-tools eventID is reported
                outw.write(k + "\t" + "" + "\t" + ",".join(tools) + "\t" +
                           ",".join(number_of_events) + "\n")
            except KeyError:  # e.g. cases where geneID doesn't exist for the given symbol
                outw.write(k + "\t" + "" + "\t" + ",".join(tools) + "\t" +
                           ",".join(number_of_events) + "\n")

    outw.close()
    plt.figure()
    plt.title("Genes with splicing events")
    plt.tight_layout()
    out = "{}.pdf".format(outbasename)
    labels = [v[0] for v in sets]
    just_sets = [v[1] for v in sets]

    if use_gene_id and not gene_id_from_tool:
        converted_sets = []
        for s in just_sets:
            new_s = set()
            for gene in s:
                try:
                    new_s.add(ens_map[gene])
                except KeyError:
                    continue
            converted_sets.append(new_s)
        just_sets = converted_sets

    if len(just_sets) == 2:
        venn2(just_sets, tuple(labels))
    elif len(just_sets) > 3:
        if use_gene_id:
            d = {tool: just_sets[i] for i, tool in enumerate(labels)}
        data = upsetplot.from_contents(d)

        upset = upsetplot.UpSet(data,
                                subset_size="count",
                                intersection_plot_elements=4,
                                show_counts='%d')
        upset.plot()
    else:
        venn3(just_sets, tuple(labels))

    plt.savefig(out)
    plt.close()
Пример #5
0
def plot_upset_plot(bulk, sc, opref, gtf=None, \
                    kind='gene', novelty='Known'):

    sns.set_context('paper', font_scale=1)

    sc_datasets = get_dataset_names(sc)
    bulk_datasets = get_dataset_names(bulk)
    sample_df = get_sample_df(sc_datasets + bulk_datasets)

    # colors
    known_green = '#009E73'
    nnc_gold = '#E69F00'
    nic_orange = '#D55E00'

    if novelty == 'Known':
        color = known_green
    elif novelty == 'NNC':
        color = nnc_gold
    elif novelty == 'NIC':
        color = nic_orange

    # df is table with counts, df_copy is df right before groupby and counting
    if novelty != 'Known':
        gtf = None
    df, df_copy = make_counts_table(bulk,
                                    sc,
                                    sample_df,
                                    gtf,
                                    kind=kind,
                                    novelty=novelty)

    # what column?
    if kind == 'gene':
        id_col = 'annot_gene_id'
        len_col = 'Gene length'
        df_copy.rename({'len': len_col}, axis=1, inplace=True)
    elif kind == 'transcript':
        id_col = 'annot_transcript_id'
        len_col = 'Transcript length'
        df_copy.rename({'len': len_col}, axis=1, inplace=True)
    if novelty == 'Known':
        nov = 'known'
    elif novelty == 'NNC':
        nov = 'NNC'
    elif novelty == 'NIC':
        nov = 'NIC'

    ylab = 'Number of {} {}s'.format(nov, kind)

    # plot de plot
    blot = upsetplot.UpSet(df_copy,
                           subset_size='auto',
                           show_counts='%d',
                           sort_by='cardinality')

    if novelty == 'Known':
        blot.add_catplot(value=len_col,
                         kind='box',
                         color=color,
                         fliersize=1,
                         linewidth=1)

    ax_dict = blot.plot()
    ax_dict['intersections'].set_ylabel(ylab)

    if novelty == 'Known':
        #         ax_dict['extra1'].set_ylim((-10000,200000))
        ax_dict['extra1'].set_yscale('log')


#         blot.plot()

    f = plt.gcf()
    fname = '{}_{}_{}_detection_upset.pdf'.format(opref, novelty, kind)
    f.savefig(fname, dpi=300, bbox_inches='tight')

    return (df, df_copy)
    gene_lists[file.split('/')[-1].split('.')[0]] = these_genes
    combined_list += these_genes
    de_intersection = list(set(these_genes) & set(all_de_genes))
    intersection_dict[file.split('/')[-1].split('.')[0]] = len(de_intersection)/len(these_genes)
    # Write the file with genes that intersect with DE genes
    open(output + '/' + file.split('/')[-1].split('.')[0] + '_DEGenes_Intersection.txt', 'w').write\
        ('\n'.join(de_intersection))

# Write another file that intersects the merged gene IDs from all approaches with the DE genes
open(output + '/' + 'ApproachesMerged_DEGenes_intersection.txt', 'w').write\
    ('\n'.join(list(set(combined_list) & set(all_de_genes))))


# create the barplot
f, ax = plt.subplots(1, figsize=(8, 6))
plt.ylabel('Percentage of intersection with DE genes', fontsize=14)
plt.title('Fraction of identified target genes that\nare also differentially expressed', fontsize=16, fontweight='bold')
ax.set_facecolor('#f2f2f2')
ax.grid(True, axis='y', color='white', linewidth=1, which='major')
ax.bar([x for x in range(len(file_list))], [intersection_dict[x] for x in list(intersection_dict.keys())], width=0.8, color='#045f8c', zorder=12)
plt.xticks([x for x in range(len(file_list))], [x for x in list(intersection_dict.keys())], fontsize=14)
plt.savefig(output + '/DEGenes_Intersection_Fraction.pdf', bbox_inches='tight')


# Upsetplot creates the intersection between each pairing of elements
intersection = upsetplot.from_contents(gene_lists)

upset = upsetplot.UpSet(intersection, show_counts=True, element_size=40, intersection_plot_elements=8)
upset.plot()
plt.savefig(output + '/DE_Genes_Intersection_UpSet.pdf', bbox_inches='tight')