예제 #1
0
def plot_gene_prediction(gene,
                         match,
                         data,
                         sim_der,
                         col_names,
                         ensembl_to_hgnc,
                         ax=None,
                         **kwargs):
    dep = DEPlot()
    matching_sim = match.loc[match.train_gene == gene,
                             'net'].astype(str).values
    pred_lfc = sim_der.coefficients.loc[matching_sim]
    baseline = data.loc[gene, 'wt'].groupby('time').mean().values
    random = sim_der.coefficients + baseline
    random.columns = col_names
    random = random.unstack()
    random.index = random.index.set_names(['time', 'replicate'])
    random.name = 'random'
    random = pd.concat([random], keys=['random'], names=['condition'])
    pred = pred_lfc + baseline
    pred.columns = col_names
    pred = pred.unstack()
    pred.index = pred.index.set_names(['time', 'replicate'])
    pred.name = 'predicted'
    pred = pd.concat([pred], keys=['predicted'], names=['condition'])
    true = data.loc[gene, ['ki']]
    pred = pred.reorder_levels(true.index.names)
    random = random.reorder_levels(true.index.names)
    ts_data = pd.concat([true, pred, random])
    ts_data.name = gene
    ax = dep.tsplot(ts_data, scatter=False, ax=ax, legend=False, **kwargs)
    ax.set_title(ensembl_to_hgnc.loc[gene, 'hgnc_symbol'])
    ax.set_ylabel('')
예제 #2
0
def display_sim(network,
                stim,
                perturbation,
                times,
                directory,
                exp_condition='ko',
                ctrl_condition='wt',
                node=None):
    data_dir = '{}/{}/{}/'.format(directory, network, stim)
    network_structure = "{}/{}/{}_goldstandard_signed.tsv".format(
        directory, network, network)

    ctrl_gsr = GnwSimResults(
        data_dir,
        network,
        ctrl_condition,
        sim_suffix='dream4_timeseries.tsv',
        perturb_suffix="dream4_timeseries_perturbations.tsv")
    exp_gsr = GnwSimResults(
        data_dir,
        network,
        exp_condition,
        sim_suffix='dream4_timeseries.tsv',
        perturb_suffix="dream4_timeseries_perturbations.tsv")

    data = pd.concat([ctrl_gsr.data,
                      exp_gsr.data]).T.sort_index(axis=0).sort_index(axis=1)
    if node:
        dep = DEPlot()
        idx = pd.IndexSlice
        dep.tsplot(data.loc[node, idx[:, :, perturbation, times]],
                   subgroup='Time',
                   no_fill_legend=True)
        plt.tight_layout()
        return
    dg = get_graph(network_structure)
    titles = ["x", "y", "PI3K"]
    mapping = {'G': "PI3k"}
    dg = nx.relabel_nodes(dg, mapping)
    draw_results(np.log2(data + 1), perturbation, titles, times=times, g=dg)
    plt.tight_layout()
예제 #3
0
def plot_genes(sub_spec, leg_spec, net_spec, matches, dde, sim_dea, tx_to_gene,
               net_data, ts):
    # genes = [top.index[29],  ts.sort_values('percent', ascending=False).index[0],
    #          ts.sort_values('percent', ascending=False).index[30]]
    # genes = ts.index[np.random.randint(0, len(top), size=3)]
    # genes = top.sort_values('grouped_e').index[:3]
    genes = ['ENSG00000117289', 'ENSG00000213626', 'ENSG00000170044']
    w_ratios = [1] * len(genes) + [0.1, 0.1]
    cols = len(genes) + 2
    gs_top = gridspec.GridSpecFromSubplotSpec(2,
                                              cols,
                                              subplot_spec=sub_spec,
                                              width_ratios=w_ratios,
                                              wspace=0.5)

    gs_bottom = gridspec.GridSpecFromSubplotSpec(1,
                                                 4,
                                                 subplot_spec=leg_spec,
                                                 width_ratios=[1, 1, 0.1, 0.1],
                                                 wspace=0.5)
    gs_mid = gridspec.GridSpecFromSubplotSpec(1,
                                              cols,
                                              subplot_spec=net_spec,
                                              width_ratios=w_ratios,
                                              wspace=0.5)

    train_conditions = list(dde.training.values())
    dep = DEPlot()
    greys = ['0.1', '0.7', '0.4', '0.1', '0.7']
    conditions = ['wt', 'ko', 'ki', 'predicted', 'random']
    colors = {c: idx for c, idx in zip(conditions, greys)}
    # colors['random'] = '0.7'
    three_col = ['#AA0000', '#BBBBBB', '#0000AA']
    new_map = LinearSegmentedColormap.from_list('custom', three_col)
    edge_cmap = new_map
    pie_colors = ['#0F8554', '#E17C05', '0.2', '0.7']
    train_markers = ['o', '^']
    test_markers = ['d', 'X', 's']
    plot_times = dde.times.copy()
    plot_times.remove(15)
    for idx, gene in enumerate(genes):
        with sns.axes_style("whitegrid"):
            train_ax = plt.subplot(gs_top[0, idx])
            pred_ax = plt.subplot(gs_top[1, idx])
            dep.tsplot(dde.dea.data.loc[gene, train_conditions],
                       ax=train_ax,
                       legend=False,
                       no_fill_legend=True,
                       color_dict=colors,
                       scatter=False,
                       markers=train_markers)
            train_ax.set_title(tx_to_gene.loc[gene, 'hgnc_symbol'])
            train_ax.set_ylabel('log2(counts)')
            train_ax.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
            train_ax.set_xlabel('')
            train_ax.set_xlim(min(plot_times), max(plot_times))
            train_ax.set_xticks(plot_times)
            train_ax.set_xticklabels([])

            plot_gene_prediction(gene,
                                 matches,
                                 dde.dea.data,
                                 sim_dea.results['ki-wt'],
                                 sim_dea.times,
                                 tx_to_gene,
                                 ax=pred_ax,
                                 no_fill_legend=True,
                                 color_dict=colors,
                                 markers=test_markers)
            pred_ax.set_xlim(min(plot_times), max(plot_times))
            pred_ax.set_xticks(plot_times)
            pred_ax.set_xticklabels(plot_times, rotation=90)
            pred_ax.set_ylabel('log2(counts)')
            pred_ax.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
            pred_ax.set_title('')

            if idx != 0:
                train_ax.set_ylabel('')
                pred_ax.set_ylabel('')

        net_ax = plt.subplot(gs_mid[0, idx])
        labels = ['G', 'x', 'y']
        logics = ['_multiplicative', '_linear', '']
        node_info = {}
        models = net_data.loc[matches[matches.train_gene == gene]
                              ['net'].values]
        for node in labels:
            cur_dict = {}
            counts = Counter(models['{}_logic'.format(node)])
            cur_dict['fracs'] = [
                counts['{}{}'.format(node, log)] for log in logics
            ]
            no_in = sum(models['{}_in'.format(node)] == 0)
            cur_dict['fracs'][-1] -= no_in
            cur_dict['fracs'].append(no_in)
            node_info[node] = cur_dict
        plot_net(net_ax, node_info, models, labels, edge_cmap, pie_colors)

        # Add the net legend
        if idx == len(genes) - 1:
            leg = net_ax.legend(['AND', 'OR', 'Single input', 'No input'],
                                loc='center left',
                                bbox_to_anchor=(1, 0.5),
                                handletextpad=0.5,
                                frameon=False,
                                handlelength=1)
            leg.set_title("Node regulation", prop={'size': 24})

    # Add the legends
    labels = {'wt': 'Wildtype', 'ko': 'PI3K KO', 'ki': 'PI3K KI'}
    for line in train_ax.get_lines() + pred_ax.get_lines():
        label = line.get_label()
        try:
            new = labels[label]
        except KeyError:
            new = label.capitalize()
        line.set_label(new)

    train_leg = train_ax.legend(loc='center left',
                                bbox_to_anchor=(1, 0.5),
                                frameon=False,
                                handlelength=1)
    pred_leg = pred_ax.legend(loc='center left',
                              bbox_to_anchor=(1, 0.5),
                              frameon=False,
                              handlelength=1)

    train_leg.set_title('Training', prop={'size': 28, 'weight': 'bold'})
    pred_leg.set_title('Testing', prop={'size': 28, 'weight': 'bold'})

    # Add arrow
    sns.set_style(None)
    arrow_ax = plt.subplot(gs_bottom[0])
    despine(arrow_ax)
    detick(arrow_ax)

    # Scale arrow width
    fig = arrow_ax.get_figure()
    bbox = arrow_ax.get_window_extent().transformed(
        fig.dpi_scale_trans.inverted())
    width, height = bbox.width, bbox.height
    height *= fig.dpi / 2
    astyle = ArrowStyle('wedge', tail_width=height, shrink_factor=0.5)
    fa = FancyArrowPatch(posA=[0, 0.5],
                         posB=[1, 0.5],
                         arrowstyle=astyle,
                         lw=0,
                         color='k')
    arrow_ax.add_artist(fa)
    arrow_ax.set_title('fraction of models \n edge exists')
    arrow_ax.set_xticks([0, 1])
    arrow_ax.set_xticklabels(['100%', '0%'])

    # Add colorbar
    norm = mpl.colors.Normalize(vmin=-1, vmax=1)

    cbar_ax = plt.subplot(gs_bottom[1])
    cb1 = mpl.colorbar.ColorbarBase(cbar_ax,
                                    cmap=edge_cmap,
                                    norm=norm,
                                    orientation='horizontal')
    cb1.set_ticks([-1, 0, 1])
    cbar_ax.set_title('Average edge sign')
예제 #4
0
    # grouped = match.groupby('true_gene')
    # a = 0
    # unique = 0
    # for gene, data in grouped:
    #     unique += 1
    #     if gene in true_dde_genes.index:
    #         print(data)
    #         a += 1
    # print(a)
    # print(true_dde_genes.shape)
    # print(unique)
    # sys.exit()
    # print(all)

    dep = DEPlot()
    dep.tsplot(dea.voom_data.loc['ENSG00000004799',
                                 contrast.replace('ki', 'ko').split('-')],
               legend=False)
    plt.tight_layout()

    dep.tsplot(true_dea.voom_data.loc['ENSG00000004799',
                                      contrast.split('-')],
               legend=False)
    plt.tight_layout()

    # Display results
    display_sim(1995, -1, times, "../data/motif_library/gnw_networks/")
    display_sim(1995,
                -1,
                times,
예제 #5
0
raw_data = raw_data[~(raw_data < 5).all(axis=1)]
# Make the Differential Expression Analysis Object
# The reference labels specify how samples will be organized into unique values
dea = DEAnalysis(raw_data,
                 index_names=hierarchy,
                 reference_labels=['condition'],
                 time=None,
                 counts=True)
# Data can be standarized if desired
# norm_data = dea.standardize()

# Fit the contrasts and save the object
# cont = dea.possible_contrasts()
# cont[0] = 'CRE-BRaf'
dea.fit_contrasts()
dep = DEPlot(dea)
sys.exit()
# Volcano Plot
x = dea.results[0].top_table(p=0.05)

# sns.clustermap(x.iloc[:, :10])
genes = utils.grepl('SOX', x.index)
g = sns.clustermap(x.loc[genes].iloc[:, :10])
plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), rotation=30)
plt.show()
sys.exit()
gene = 'SPRY4'
print(rh.rvect_to_py(dea.data_matrix).loc[gene].reset_index())
print(dea.data.loc[gene])
# ax = sns.boxplot(data=rh.rvect_to_py(dea.data_matrix).loc[gene].reset_index(), x='index', y=gene)
예제 #6
0
import sys

import matplotlib.pyplot as plt
import pandas as pd
from pydiffexp import DEPlot
from pydiffexp.utils.io import read_dea_pickle

pd.set_option('display.width', 1000)

# Load the DEAnalysis object with fits and data
dea = read_dea_pickle("./sprouty_pickle.pkl")

# Initialize a plotting object
dep = DEPlot(dea)
#
# dep.tsplot(dea.data.loc['SPRY2'])
# plt.tight_layout()
# plt.show()

dep.heatmap()
plt.savefig(
    '/Users/jfinkle/Dropbox/Justin Finkle, shared/Media/Sprouty/images/row_max_normalized_heatmap.png',
    fmt='png',
    dpi=600)
#
#
sys.exit()

# Volcano Plot
x = dea.results['KO-WT'].top_table(coef=1, use_fstat=False)
# dep.volcano_plot(x, top_n=5, show_labels=True)
예제 #7
0
def plot_sankey(gs, gc, ar_der, dea, e, ts_der, c_condition):
    # Initialize plotter
    dep = DEPlot()

    cur_ax = plt.subplot(gs[0, 0])
    seg_color = Prism_10.mpl_colors[5]
    gene_class = 'DRG'
    genes = gc[gene_class]
    path_df = ar_der.discrete.loc[genes]
    # path_df = path_df[(path_df!=0).any(axis=1)]
    path_df.insert(0, 0, 0)
    path_df = path_df[(path_df != 0).any(axis=1)]
    path_df.columns = dea.times
    # recolor selected node
    node_color_dict = {(5, -1): Bold_8.mpl_colors[0]}
    print(
        path_df.apply(
            pd.Series.value_counts,
            axis=0).fillna(0).sort_index(ascending=False).astype(int))
    cur_ax = dep.plot_flows(cur_ax, ['diff'], ['0.5'], [1], ['all'],
                            x_coords=path_df.columns,
                            min_sw=0.01,
                            max_sw=1,
                            uniform=True,
                            path_df=path_df,
                            node_width=None,
                            legend=False,
                            node_color=seg_color,
                            node_color_dict=node_color_dict)
    cur_ax.set_xticklabels('')
    annotate_n(cur_ax, len(path_df))
    cur_ax.set_ylabel('Discrete \nFC')
    cur_ax.set_yticks(range(-1, 2))

    ts_diff_signs = sign_diff(dea, ts_der, genes, e, c_condition)
    # ts_diff_signs = ts_diff_signs[(ts_diff_signs!=0).any(axis=1)]
    ts_path_df = np.cumsum(np.sign(ts_diff_signs), axis=1)
    path_df = ts_path_df
    path_df.insert(0, 0, 0)
    path_df = path_df[(path_df != 0).any(axis=1)]
    path_df.columns = dea.times
    print(
        path_df.apply(
            pd.Series.value_counts,
            axis=0).fillna(0).sort_index(ascending=False).astype(int))

    cur_ax = plt.subplot(gs[1, 0])
    highlight = Bold_8.mpl_colors[0]
    segs = [(3, -1), (3, 0), (3, 1), (3, 2), (4, -2), (4, -1), (4, 0), (4, 1)]
    seg_color_dict = {
        'up': {s: highlight
               for s in segs},
        'down': None,
        'over': None
    }
    cur_ax = dep.plot_flows(cur_ax, ['diff'], ['0.5'], [1], ['all'],
                            x_coords=path_df.columns,
                            min_sw=0.01,
                            max_sw=1,
                            uniform=True,
                            path_df=path_df,
                            node_width=None,
                            legend=False,
                            node_color=seg_color,
                            seg_color_dict=seg_color_dict)

    cur_ax.set_xticklabels(path_df.columns)
    cur_ax.set_ylim([-4, 4])
    annotate_n(cur_ax, len(path_df))
    plt.xlabel('Time (min)')
    cur_ax.set_ylabel('Cumulative Trajectory\nDifferences')
예제 #8
0
def plot_collections(hm_data, hash_data, term_data, output='show'):
    # Organize plot
    # Overall gridspec with 1 row, two columns
    f = plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(1, 2)

    # Create a gridspec within the gridspec. 1 row and 2 columns, specifying width ratio
    gs_left = gridspec.GridSpecFromSubplotSpec(
        2,
        2,
        subplot_spec=gs[0],
        width_ratios=[hm_data.shape[1], 2],
        height_ratios=[1, 50],
        wspace=0.05,
        hspace=0.05)
    gs_right = gridspec.GridSpecFromSubplotSpec(2,
                                                1,
                                                subplot_spec=gs[1],
                                                height_ratios=[1, 1.5],
                                                hspace=0.25)

    cbar_ax = plt.subplot(gs_left[0, 0])
    hidden_ax = plt.subplot(gs_left[0, 1])
    hm_ax = plt.subplot(gs_left[1, 0])
    hash_ax = plt.subplot(gs_left[1, 1])
    gene_ax = plt.subplot(gs_right[0])
    go_ax = plt.subplot(gs_right[1])

    # Hide the top right axes where the venn diagram goes
    gene_ax.axis('off')

    # Initialize plotter
    dep = DEPlot()

    hm_ax, hash_ax = dep.heatmap(hm_data,
                                 hash_data,
                                 hm_ax=hm_ax,
                                 hash_ax=hash_ax,
                                 cbar_ax=cbar_ax,
                                 yticklabels=False,
                                 cbar_kws=dict(orientation='horizontal',
                                               ticks=[-1, 0, 1]))
    cbar_ax.xaxis.tick_top()
    cbar_ax.invert_xaxis()
    hidden_ax.set_xlabel('')
    hidden_ax.set_ylabel('')
    hidden_ax.axis('off')

    index_order = [
        'DEG', 'DDE', 'DRG', 'DEG∩DDE', 'DEG∩DRG', 'DDE∩DRG', 'DEG∩DDE∩DRG',
        'All'
    ]

    c_index = [1, 7, 5, 9, 3, 6]
    colors = [Prism_10.mpl_colors[idx] for idx in c_index] + ['k', '0.5']
    cmap = {gc: colors[ii] for ii, gc in enumerate(index_order)}

    x = term_data.reset_index(level=0)
    x.columns = ['gene_class'] + x.columns[1:].values.tolist()
    full_dist = x.copy()
    full_dist['gene_class'] = "All"
    full_dist.reset_index(drop=True, inplace=True)
    all_x = pd.concat([x, full_dist])
    g = all_x.groupby('gene_class')

    go_ax = sns.boxplot(data=g.filter(lambda xx: True),
                        x='depth',
                        y='gene_class',
                        order=index_order,
                        ax=go_ax,
                        showfliers=False,
                        boxprops=dict(linewidth=0),
                        medianprops=dict(solid_capstyle='butt', color='w'),
                        palette=cmap)

    small_groups = g.filter(lambda x: len(x) < 50)
    go_ax = sns.swarmplot(data=small_groups,
                          x='depth',
                          y='gene_class',
                          order=index_order,
                          ax=go_ax,
                          color='k')

    go_ax.plot([x['depth'].median(), x['depth'].median()],
               go_ax.get_ylim(),
               'k-',
               lw=2,
               zorder=0,
               c='0.25')

    term_sizes = g.apply(len).reindex(index_order).fillna(0).astype(int)
    y_ticks = ["n={}".format(term_sizes.loc[idx]) for idx in index_order]
    go_ax.set_yticklabels(y_ticks)
    go_ax.set_ylabel('')
    plt.tight_layout()

    if output.lower() == 'show':
        plt.show()

    else:
        plt.savefig(output, fmt='pdf')
예제 #9
0
gene_names = pd.read_csv('../data/GSE69822/GSE69822_RNA-Seq_RPKMs_cleaned.txt',
                         sep=' ')
gene_names = gene_names[['id',
                         'hgnc_symbol']].set_index('id').drop_duplicates()
p_thresh = 0.001
scores = p_results.loc[(der.top_table()['adj_pval'] < p_thresh)
                       & (p_results['p_value'] < p_thresh)]

# Remove clusters that have no dynamic DE (i.e. all 1, -1, 0)
interesting = scores.loc[
    scores.Cluster.apply(ast.literal_eval).apply(set).apply(len) > 1]
# print(interesting.sort_values(['score'], ascending=False).head(100))
c = (interesting[interesting.Cluster == '(0, 0, 0, 0, -1, -1)'].sort_values(
    'score', ascending=False))
print(c)
dep = DEPlot()
dep.tsplot(dea.data.loc["ENSG00000186187", ['ko', 'wt']], legend=False)
plt.title('ZNRF1')
plt.tight_layout()
plt.show()
sys.exit()
for gene in c.index:
    dep.tsplot(dea.data.loc[gene, ['pten', 'wt']], legend=False)
    plt.tight_layout()
    plt.show()
sys.exit()

# Heatmap of expression
de_data = (der.top_table().iloc[:, :6])  #.multiply(der.p_value < p_thresh)
sort_idx = interesting.sort_values(['Cluster', 'score'],
                                   ascending=False).index.values