예제 #1
0
from pydiffexp import DEPlot
from pydiffexp.utils.io import read_dea_pickle

pd.set_option('display.width', 1000)

# Load the DEAnalysis object with fits and data
dea = read_dea_pickle("./sprouty_pickle.pkl")

# Initialize a plotting object
dep = DEPlot(dea)
#
# dep.tsplot(dea.data.loc['SPRY2'])
# plt.tight_layout()
# plt.show()

dep.heatmap()
plt.savefig(
    '/Users/jfinkle/Dropbox/Justin Finkle, shared/Media/Sprouty/images/row_max_normalized_heatmap.png',
    fmt='png',
    dpi=600)
#
#
sys.exit()

# Volcano Plot
x = dea.results['KO-WT'].top_table(coef=1, use_fstat=False)
# dep.volcano_plot(x, top_n=5, show_labels=True)

# Time Series Plot
x = dea.data.loc['IVL']
예제 #2
0
def plot_collections(hm_data, hash_data, term_data, output='show'):
    # Organize plot
    # Overall gridspec with 1 row, two columns
    f = plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(1, 2)

    # Create a gridspec within the gridspec. 1 row and 2 columns, specifying width ratio
    gs_left = gridspec.GridSpecFromSubplotSpec(
        2,
        2,
        subplot_spec=gs[0],
        width_ratios=[hm_data.shape[1], 2],
        height_ratios=[1, 50],
        wspace=0.05,
        hspace=0.05)
    gs_right = gridspec.GridSpecFromSubplotSpec(2,
                                                1,
                                                subplot_spec=gs[1],
                                                height_ratios=[1, 1.5],
                                                hspace=0.25)

    cbar_ax = plt.subplot(gs_left[0, 0])
    hidden_ax = plt.subplot(gs_left[0, 1])
    hm_ax = plt.subplot(gs_left[1, 0])
    hash_ax = plt.subplot(gs_left[1, 1])
    gene_ax = plt.subplot(gs_right[0])
    go_ax = plt.subplot(gs_right[1])

    # Hide the top right axes where the venn diagram goes
    gene_ax.axis('off')

    # Initialize plotter
    dep = DEPlot()

    hm_ax, hash_ax = dep.heatmap(hm_data,
                                 hash_data,
                                 hm_ax=hm_ax,
                                 hash_ax=hash_ax,
                                 cbar_ax=cbar_ax,
                                 yticklabels=False,
                                 cbar_kws=dict(orientation='horizontal',
                                               ticks=[-1, 0, 1]))
    cbar_ax.xaxis.tick_top()
    cbar_ax.invert_xaxis()
    hidden_ax.set_xlabel('')
    hidden_ax.set_ylabel('')
    hidden_ax.axis('off')

    index_order = [
        'DEG', 'DDE', 'DRG', 'DEG∩DDE', 'DEG∩DRG', 'DDE∩DRG', 'DEG∩DDE∩DRG',
        'All'
    ]

    c_index = [1, 7, 5, 9, 3, 6]
    colors = [Prism_10.mpl_colors[idx] for idx in c_index] + ['k', '0.5']
    cmap = {gc: colors[ii] for ii, gc in enumerate(index_order)}

    x = term_data.reset_index(level=0)
    x.columns = ['gene_class'] + x.columns[1:].values.tolist()
    full_dist = x.copy()
    full_dist['gene_class'] = "All"
    full_dist.reset_index(drop=True, inplace=True)
    all_x = pd.concat([x, full_dist])
    g = all_x.groupby('gene_class')

    go_ax = sns.boxplot(data=g.filter(lambda xx: True),
                        x='depth',
                        y='gene_class',
                        order=index_order,
                        ax=go_ax,
                        showfliers=False,
                        boxprops=dict(linewidth=0),
                        medianprops=dict(solid_capstyle='butt', color='w'),
                        palette=cmap)

    small_groups = g.filter(lambda x: len(x) < 50)
    go_ax = sns.swarmplot(data=small_groups,
                          x='depth',
                          y='gene_class',
                          order=index_order,
                          ax=go_ax,
                          color='k')

    go_ax.plot([x['depth'].median(), x['depth'].median()],
               go_ax.get_ylim(),
               'k-',
               lw=2,
               zorder=0,
               c='0.25')

    term_sizes = g.apply(len).reindex(index_order).fillna(0).astype(int)
    y_ticks = ["n={}".format(term_sizes.loc[idx]) for idx in index_order]
    go_ax.set_yticklabels(y_ticks)
    go_ax.set_ylabel('')
    plt.tight_layout()

    if output.lower() == 'show':
        plt.show()

    else:
        plt.savefig(output, fmt='pdf')