Esempio n. 1
0
def plot_gene_prediction(gene,
                         match,
                         data,
                         sim_der,
                         col_names,
                         ensembl_to_hgnc,
                         ax=None,
                         **kwargs):
    dep = DEPlot()
    matching_sim = match.loc[match.train_gene == gene,
                             'net'].astype(str).values
    pred_lfc = sim_der.coefficients.loc[matching_sim]
    baseline = data.loc[gene, 'wt'].groupby('time').mean().values
    random = sim_der.coefficients + baseline
    random.columns = col_names
    random = random.unstack()
    random.index = random.index.set_names(['time', 'replicate'])
    random.name = 'random'
    random = pd.concat([random], keys=['random'], names=['condition'])
    pred = pred_lfc + baseline
    pred.columns = col_names
    pred = pred.unstack()
    pred.index = pred.index.set_names(['time', 'replicate'])
    pred.name = 'predicted'
    pred = pd.concat([pred], keys=['predicted'], names=['condition'])
    true = data.loc[gene, ['ki']]
    pred = pred.reorder_levels(true.index.names)
    random = random.reorder_levels(true.index.names)
    ts_data = pd.concat([true, pred, random])
    ts_data.name = gene
    ax = dep.tsplot(ts_data, scatter=False, ax=ax, legend=False, **kwargs)
    ax.set_title(ensembl_to_hgnc.loc[gene, 'hgnc_symbol'])
    ax.set_ylabel('')
Esempio n. 2
0
def display_sim(network,
                stim,
                perturbation,
                times,
                directory,
                exp_condition='ko',
                ctrl_condition='wt',
                node=None):
    data_dir = '{}/{}/{}/'.format(directory, network, stim)
    network_structure = "{}/{}/{}_goldstandard_signed.tsv".format(
        directory, network, network)

    ctrl_gsr = GnwSimResults(
        data_dir,
        network,
        ctrl_condition,
        sim_suffix='dream4_timeseries.tsv',
        perturb_suffix="dream4_timeseries_perturbations.tsv")
    exp_gsr = GnwSimResults(
        data_dir,
        network,
        exp_condition,
        sim_suffix='dream4_timeseries.tsv',
        perturb_suffix="dream4_timeseries_perturbations.tsv")

    data = pd.concat([ctrl_gsr.data,
                      exp_gsr.data]).T.sort_index(axis=0).sort_index(axis=1)
    if node:
        dep = DEPlot()
        idx = pd.IndexSlice
        dep.tsplot(data.loc[node, idx[:, :, perturbation, times]],
                   subgroup='Time',
                   no_fill_legend=True)
        plt.tight_layout()
        return
    dg = get_graph(network_structure)
    titles = ["x", "y", "PI3K"]
    mapping = {'G': "PI3k"}
    dg = nx.relabel_nodes(dg, mapping)
    draw_results(np.log2(data + 1), perturbation, titles, times=times, g=dg)
    plt.tight_layout()
Esempio n. 3
0
def plot_genes(sub_spec, leg_spec, net_spec, matches, dde, sim_dea, tx_to_gene,
               net_data, ts):
    # genes = [top.index[29],  ts.sort_values('percent', ascending=False).index[0],
    #          ts.sort_values('percent', ascending=False).index[30]]
    # genes = ts.index[np.random.randint(0, len(top), size=3)]
    # genes = top.sort_values('grouped_e').index[:3]
    genes = ['ENSG00000117289', 'ENSG00000213626', 'ENSG00000170044']
    w_ratios = [1] * len(genes) + [0.1, 0.1]
    cols = len(genes) + 2
    gs_top = gridspec.GridSpecFromSubplotSpec(2,
                                              cols,
                                              subplot_spec=sub_spec,
                                              width_ratios=w_ratios,
                                              wspace=0.5)

    gs_bottom = gridspec.GridSpecFromSubplotSpec(1,
                                                 4,
                                                 subplot_spec=leg_spec,
                                                 width_ratios=[1, 1, 0.1, 0.1],
                                                 wspace=0.5)
    gs_mid = gridspec.GridSpecFromSubplotSpec(1,
                                              cols,
                                              subplot_spec=net_spec,
                                              width_ratios=w_ratios,
                                              wspace=0.5)

    train_conditions = list(dde.training.values())
    dep = DEPlot()
    greys = ['0.1', '0.7', '0.4', '0.1', '0.7']
    conditions = ['wt', 'ko', 'ki', 'predicted', 'random']
    colors = {c: idx for c, idx in zip(conditions, greys)}
    # colors['random'] = '0.7'
    three_col = ['#AA0000', '#BBBBBB', '#0000AA']
    new_map = LinearSegmentedColormap.from_list('custom', three_col)
    edge_cmap = new_map
    pie_colors = ['#0F8554', '#E17C05', '0.2', '0.7']
    train_markers = ['o', '^']
    test_markers = ['d', 'X', 's']
    plot_times = dde.times.copy()
    plot_times.remove(15)
    for idx, gene in enumerate(genes):
        with sns.axes_style("whitegrid"):
            train_ax = plt.subplot(gs_top[0, idx])
            pred_ax = plt.subplot(gs_top[1, idx])
            dep.tsplot(dde.dea.data.loc[gene, train_conditions],
                       ax=train_ax,
                       legend=False,
                       no_fill_legend=True,
                       color_dict=colors,
                       scatter=False,
                       markers=train_markers)
            train_ax.set_title(tx_to_gene.loc[gene, 'hgnc_symbol'])
            train_ax.set_ylabel('log2(counts)')
            train_ax.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
            train_ax.set_xlabel('')
            train_ax.set_xlim(min(plot_times), max(plot_times))
            train_ax.set_xticks(plot_times)
            train_ax.set_xticklabels([])

            plot_gene_prediction(gene,
                                 matches,
                                 dde.dea.data,
                                 sim_dea.results['ki-wt'],
                                 sim_dea.times,
                                 tx_to_gene,
                                 ax=pred_ax,
                                 no_fill_legend=True,
                                 color_dict=colors,
                                 markers=test_markers)
            pred_ax.set_xlim(min(plot_times), max(plot_times))
            pred_ax.set_xticks(plot_times)
            pred_ax.set_xticklabels(plot_times, rotation=90)
            pred_ax.set_ylabel('log2(counts)')
            pred_ax.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
            pred_ax.set_title('')

            if idx != 0:
                train_ax.set_ylabel('')
                pred_ax.set_ylabel('')

        net_ax = plt.subplot(gs_mid[0, idx])
        labels = ['G', 'x', 'y']
        logics = ['_multiplicative', '_linear', '']
        node_info = {}
        models = net_data.loc[matches[matches.train_gene == gene]
                              ['net'].values]
        for node in labels:
            cur_dict = {}
            counts = Counter(models['{}_logic'.format(node)])
            cur_dict['fracs'] = [
                counts['{}{}'.format(node, log)] for log in logics
            ]
            no_in = sum(models['{}_in'.format(node)] == 0)
            cur_dict['fracs'][-1] -= no_in
            cur_dict['fracs'].append(no_in)
            node_info[node] = cur_dict
        plot_net(net_ax, node_info, models, labels, edge_cmap, pie_colors)

        # Add the net legend
        if idx == len(genes) - 1:
            leg = net_ax.legend(['AND', 'OR', 'Single input', 'No input'],
                                loc='center left',
                                bbox_to_anchor=(1, 0.5),
                                handletextpad=0.5,
                                frameon=False,
                                handlelength=1)
            leg.set_title("Node regulation", prop={'size': 24})

    # Add the legends
    labels = {'wt': 'Wildtype', 'ko': 'PI3K KO', 'ki': 'PI3K KI'}
    for line in train_ax.get_lines() + pred_ax.get_lines():
        label = line.get_label()
        try:
            new = labels[label]
        except KeyError:
            new = label.capitalize()
        line.set_label(new)

    train_leg = train_ax.legend(loc='center left',
                                bbox_to_anchor=(1, 0.5),
                                frameon=False,
                                handlelength=1)
    pred_leg = pred_ax.legend(loc='center left',
                              bbox_to_anchor=(1, 0.5),
                              frameon=False,
                              handlelength=1)

    train_leg.set_title('Training', prop={'size': 28, 'weight': 'bold'})
    pred_leg.set_title('Testing', prop={'size': 28, 'weight': 'bold'})

    # Add arrow
    sns.set_style(None)
    arrow_ax = plt.subplot(gs_bottom[0])
    despine(arrow_ax)
    detick(arrow_ax)

    # Scale arrow width
    fig = arrow_ax.get_figure()
    bbox = arrow_ax.get_window_extent().transformed(
        fig.dpi_scale_trans.inverted())
    width, height = bbox.width, bbox.height
    height *= fig.dpi / 2
    astyle = ArrowStyle('wedge', tail_width=height, shrink_factor=0.5)
    fa = FancyArrowPatch(posA=[0, 0.5],
                         posB=[1, 0.5],
                         arrowstyle=astyle,
                         lw=0,
                         color='k')
    arrow_ax.add_artist(fa)
    arrow_ax.set_title('fraction of models \n edge exists')
    arrow_ax.set_xticks([0, 1])
    arrow_ax.set_xticklabels(['100%', '0%'])

    # Add colorbar
    norm = mpl.colors.Normalize(vmin=-1, vmax=1)

    cbar_ax = plt.subplot(gs_bottom[1])
    cb1 = mpl.colorbar.ColorbarBase(cbar_ax,
                                    cmap=edge_cmap,
                                    norm=norm,
                                    orientation='horizontal')
    cb1.set_ticks([-1, 0, 1])
    cbar_ax.set_title('Average edge sign')
Esempio n. 4
0
    # a = 0
    # unique = 0
    # for gene, data in grouped:
    #     unique += 1
    #     if gene in true_dde_genes.index:
    #         print(data)
    #         a += 1
    # print(a)
    # print(true_dde_genes.shape)
    # print(unique)
    # sys.exit()
    # print(all)

    dep = DEPlot()
    dep.tsplot(dea.voom_data.loc['ENSG00000004799',
                                 contrast.replace('ki', 'ko').split('-')],
               legend=False)
    plt.tight_layout()

    dep.tsplot(true_dea.voom_data.loc['ENSG00000004799',
                                      contrast.split('-')],
               legend=False)
    plt.tight_layout()

    # Display results
    display_sim(1995, -1, times, "../data/motif_library/gnw_networks/")
    display_sim(1995,
                -1,
                times,
                "../data/motif_library/gnw_networks/",
                exp_condition='ki')
Esempio n. 5
0
plt.savefig(
    '/Users/jfinkle/Dropbox/Justin Finkle, shared/Media/Sprouty/images/row_max_normalized_heatmap.png',
    fmt='png',
    dpi=600)
#
#
sys.exit()

# Volcano Plot
x = dea.results['KO-WT'].top_table(coef=1, use_fstat=False)
# dep.volcano_plot(x, top_n=5, show_labels=True)

# Time Series Plot
x = dea.data.loc['IVL']

dep.tsplot(x)
plt.show()
sys.exit()
x = dea.results['KO_ts-WT_ts']
y = dea.results['KO-WT']

gene = 'CXCL1'
# print(x.discrete_clusters[(x.discrete_clusters['Cluster'] != '(0, 0, 0, 0)') & (y.discrete['KO_0-WT_0'] == 0)])
print(x.continuous.head())
print(dea.results['(KO-WT)_ts'].continuous.head())

p_data = dea.data.loc[gene]

dep.tsplot(p_data)
plt.show()
Esempio n. 6
0
                         sep=' ')
gene_names = gene_names[['id',
                         'hgnc_symbol']].set_index('id').drop_duplicates()
p_thresh = 0.001
scores = p_results.loc[(der.top_table()['adj_pval'] < p_thresh)
                       & (p_results['p_value'] < p_thresh)]

# Remove clusters that have no dynamic DE (i.e. all 1, -1, 0)
interesting = scores.loc[
    scores.Cluster.apply(ast.literal_eval).apply(set).apply(len) > 1]
# print(interesting.sort_values(['score'], ascending=False).head(100))
c = (interesting[interesting.Cluster == '(0, 0, 0, 0, -1, -1)'].sort_values(
    'score', ascending=False))
print(c)
dep = DEPlot()
dep.tsplot(dea.data.loc["ENSG00000186187", ['ko', 'wt']], legend=False)
plt.title('ZNRF1')
plt.tight_layout()
plt.show()
sys.exit()
for gene in c.index:
    dep.tsplot(dea.data.loc[gene, ['pten', 'wt']], legend=False)
    plt.tight_layout()
    plt.show()
sys.exit()

# Heatmap of expression
de_data = (der.top_table().iloc[:, :6])  #.multiply(der.p_value < p_thresh)
sort_idx = interesting.sort_values(['Cluster', 'score'],
                                   ascending=False).index.values
hm_data = de_data.loc[sort_idx]