Example #1
0
def filtered_strf_vs_resp_batch(batch,
                                gc,
                                stp,
                                LN,
                                combined,
                                strf,
                                save_path,
                                good_ln=0.0,
                                test_limit=None,
                                stat='r_ceiling',
                                bin_count=40):

    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           good_ln=good_ln)
    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    if stat == 'r_ceiling':
        df = df_c
    else:
        df = df_r

    tags = ['either', 'neither', 'gc', 'stp', 'combined']
    for cells, tag in zip([e, a, g, s, c], tags):
        _strf_resp_sub_batch(cells[:test_limit],
                             df,
                             tag,
                             stat,
                             batch,
                             gc,
                             stp,
                             LN,
                             combined,
                             strf,
                             save_path,
                             bin_count=bin_count)
Example #2
0
def filtered_pred_matched_batch(batch,
                                gc,
                                stp,
                                LN,
                                combined,
                                save_path,
                                good_ln=0.0,
                                test_limit=None,
                                stat='r_ceiling',
                                replace_existing=True):

    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           good_ln=good_ln)
    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    if stat == 'r_ceiling':
        df = df_c
    else:
        df = df_r

    tags = ['either', 'neither', 'gc', 'stp', 'combined']
    a = list(set(a) - set(e))
    for cells, tag in zip([e, a, g, s, c], tags):
        _sigmoid_sub_batch(cells[:test_limit],
                           df,
                           tag,
                           stat,
                           batch,
                           gc,
                           stp,
                           LN,
                           combined,
                           save_path,
                           replace_existing=replace_existing)
Example #3
0
def performance_table(batch1,
                      gc,
                      stp,
                      LN,
                      combined,
                      batch2,
                      plot_stat='r_ceiling',
                      height_scaling=3):
    # 4 tables: batch 289 and 263 all cells / improved cells
    df_r1, df_c1, df_e1 = get_dataframes(batch1, gc, stp, LN, combined)
    cellids1, under_chance1, less_LN1 = get_filtered_cellids(
        df_r1, df_e1, gc, stp, LN, combined)

    df_r2, df_c2, df_e2 = get_dataframes(batch2, gc, stp, LN, combined)
    cellids2, under_chance2, less_LN2 = get_filtered_cellids(
        df_r2, df_e2, gc, stp, LN, combined)
    e1, a1, _, _, _ = improved_cells_to_list(batch1, gc, stp, LN, combined)
    e2, a2, _, _, _ = improved_cells_to_list(batch2, gc, stp, LN, combined)

    if plot_stat == 'r_ceiling':
        df1 = df_c1
        df2 = df_c2
    else:
        df1 = df_r1
        df2 = df_c2

    models = [LN, gc, stp, combined]
    a289, a289_stats = _make_table(df1, df_e1, models, a1)
    a263, a263_stats = _make_table(df2, df_e2, models, a2)
    i289, i289_stats = _make_table(df1, df_e1, models, e1)
    i263, i263_stats = _make_table(df2, df_e2, models, e2)
    model_names = ['LN', 'GC', 'STP', 'GC+STP', 'Max(GC,STP)']

    fig1, ((a1, a2), (a3, a4)) = plt.subplots(2, 2)
    fig2, ((a5, a6), (a7, a8)) = plt.subplots(2, 2)
    fig1.patch.set_visible(False)
    fig2.patch.set_visible(False)
    iters = zip([a1, a2, a3, a4], [a5, a6, a7, a8], [a289, a263, i289, i263],
                [a289_stats, a263_stats, i289_stats, i263_stats], [
                    'Natural stimuli, all cells', 'Voc. in noise, all cells',
                    'Natural stimuli, nonlinear cells',
                    'Voc. in noise, nonlinear cells'
                ])

    for ax1, ax2, table, stats, title in iters:
        ax1.axis('off')
        ax1.axis('tight')
        table1 = ax1.table(cellText=table,
                           colLabels=model_names,
                           rowLabels=model_names,
                           loc='center',
                           cellLoc='center',
                           rowLoc='center')
        table1_cells = table1.properties()['child_artists']
        for c1 in table1_cells:
            current_height1 = c1.get_height()
            c1.set_height(current_height1 * height_scaling)
        ax1.set_title(title)

        ax2.axis('off')
        ax2.axis('tight')
        row_labels = ['mean', 'median', 'std err']
        table_text = np.empty((len(row_labels), len(model_names)), dtype='U7')
        for i, _ in enumerate(row_labels):
            for j, _ in enumerate(model_names):
                s = stats[j][i]
                text = '%.5f' % s
                table_text[i][j] = text
        table2 = ax2.table(cellText=table_text,
                           colLabels=model_names,
                           rowLabels=row_labels,
                           loc='center',
                           cellLoc='center',
                           rowLoc='center')
        table2_cells = table2.properties()['child_artists']
        for c2 in table2_cells:
            current_height2 = c2.get_height()
            c2.set_height(current_height2 * height_scaling)
        ax2.set_title(title)

    fig1.tight_layout()
    fig2.tight_layout()

    return fig2, fig1
Example #4
0
def performance_bar(batch,
                    gc,
                    stp,
                    LN,
                    combined,
                    se_filter=True,
                    LN_filter=False,
                    manual_cellids=None,
                    abbr_yaxis=False,
                    plot_stat='r_ceiling',
                    y_adjust=0.05,
                    manual_y=None,
                    only_improvements=False,
                    show_text_labels=False):
    '''
    model1: GC
    model2: STP
    model3: LN
    model4: GC+STP

    '''

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    #    cellids, under_chance, less_LN = get_filtered_cellids(batch, gc, stp,
    #                                                          LN, combined,
    #                                                          se_filter,
    #                                                          LN_filter)
    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=se_filter,
                                           LN_filter=LN_filter)
    cellids = a

    if manual_cellids is not None:
        # WARNING: Will override se and ratio filters even if they are set
        cellids = manual_cellids
    elif only_improvements:
        e, a, g, s, c = improved_cells_to_list(batch,
                                               gc,
                                               stp,
                                               LN,
                                               combined,
                                               as_lists=True)
        cellids = e

    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    n_cells = len(cellids)
    gc_test = plot_df[gc][cellids]
    stp_test = plot_df[stp][cellids]
    ln_test = plot_df[LN][cellids]
    gc_stp_test = plot_df[combined][cellids]
    #max_test = np.maximum(gc_test, stp_test)

    gc = np.median(gc_test.values)
    stp = np.median(stp_test.values)
    ln = np.median(ln_test.values)
    gc_stp = np.median(gc_stp_test.values)
    #maximum = np.median(max_test)
    largest = max(gc, stp, ln, gc_stp)  #, maximum)

    colors = [model_colors[k]
              for k in ['LN', 'gc', 'stp', 'combined']]  #, 'max']]
    #fig = plt.figure(figsize=(15, 12))
    fig = plt.figure()
    plt.bar(
        [1, 2, 3, 4],
        [ln, gc, stp, gc_stp],  # maximum],
        color=colors,
        edgecolor="black",
        linewidth=1)
    plt.xticks([1, 2, 3, 4, 5],
               ['LN', 'GC', 'STP', 'GC+STP'])  #, 'Max(GC,STP)'])
    if abbr_yaxis:
        if manual_y:
            lower, upper = manual_y
        else:
            lower = np.floor(10 * min(gc, stp, ln, gc_stp)) / 10
            upper = np.ceil(10 * max(gc, stp, ln, gc_stp)) / 10 + y_adjust
        plt.ylim(ymin=lower, ymax=upper)
    else:
        plt.ylim(ymax=largest * 1.4)
    common_kwargs = {'color': 'white', 'horizontalalignment': 'center'}
    if abbr_yaxis:
        y_text = 0.5 * (lower + min(gc, stp, ln, gc_stp))
    else:
        y_text = 0.2
    if show_text_labels:
        for i, m in enumerate([ln, gc, stp, gc_stp]):  #, maximum]):
            t = plt.text(i + 1, y_text, "%0.04f" % m, **common_kwargs)
        #t.set_path_effects([pe.withStroke(linewidth=3, foreground='black')])
    xmin, xmax = plt.xlim()
    plt.xlim(xmin, xmax - 0.35)
    ax_remove_box()
    plt.tight_layout()

    fig2 = plt.figure(figsize=text_fig)
    text = "Median Performance, batch: %d\, n:%d" % (batch, n_cells)
    plt.text(0.1, 0.5, text)

    return fig, fig2
Example #5
0
def combined_vs_max(batch,
                    gc,
                    stp,
                    LN,
                    combined,
                    se_filter=True,
                    LN_filter=False,
                    plot_stat='r_ceiling',
                    legend=False,
                    improved_only=True,
                    exclude_low_snr=False,
                    snr_path=None):

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    #    cellids, under_chance, less_LN = get_filtered_cellids(df_r, df_e, gc, stp,
    #                                                          LN, combined,
    #                                                          se_filter,
    #                                                          LN_filter)
    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=se_filter,
                                           LN_filter=LN_filter)
    improved = c
    not_improved = list(set(a) - set(c))
    if exclude_low_snr:
        snr_df = pd.read_pickle(snr_path)
        med_snr = snr_df['snr'].median()
        high_snr = snr_df.loc[snr_df['snr'] >= med_snr]
        high_snr_cells = high_snr.index.values.tolist()
        improved = list(set(improved) & set(high_snr_cells))
        not_improved = list(set(not_improved) & set(high_snr_cells))

    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    gc_not = plot_df[gc][not_improved]
    gc_imp = plot_df[gc][improved]
    #gc_test_under_chance = plot_df[gc][under_chance]
    stp_not = plot_df[stp][not_improved]
    stp_imp = plot_df[stp][improved]
    #stp_test_under_chance = plot_df[stp][under_chance]
    #ln_test = plot_df[LN][cellids]
    gc_stp_not = plot_df[combined][not_improved]
    gc_stp_imp = plot_df[combined][improved]
    max_not = np.maximum(gc_not, stp_not)
    max_imp = np.maximum(gc_imp, stp_imp)
    #gc_stp_test_rel = gc_stp_test - ln_test
    #max_test_rel = np.maximum(gc_test, stp_test) - ln_test
    imp_T, imp_p = st.wilcoxon(gc_stp_imp, max_imp)
    med_combined = np.nanmedian(gc_stp_imp)
    med_max = np.nanmedian(max_imp)

    import pdb
    pdb.set_trace()

    fig1 = plt.figure()
    c_not = model_colors['LN']
    c_imp = model_colors['max']
    if not improved_only:
        plt.scatter(max_not,
                    gc_stp_not,
                    c=c_not,
                    s=small_scatter,
                    label='no imp.')
    plt.scatter(max_imp, gc_stp_imp, c=c_imp, s=big_scatter, label='sig. imp.')
    ax = fig1.axes[0]
    plt.plot(ax.get_xlim(),
             ax.get_xlim(),
             'k--',
             linewidth=1,
             dashes=dash_spacing)
    if legend:
        plt.legend()
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.axes().set_aspect('equal')
    ax_remove_box()

    fig2 = plt.figure(figsize=text_fig)
    text = ("batch: %d\n"
            "x: Max(GC,STP)\n"
            "y: GC+STP\n"
            "wilcoxon: T: %.4E, p: %.4E\n"
            "combined median: %.4E\n"
            "max median: %.4E" % (batch, imp_T, imp_p, med_combined, med_max))
    plt.text(0.1, 0.5, text)

    return fig1, fig2
Example #6
0
def single_scatter(batch,
                   gc,
                   stp,
                   LN,
                   combined,
                   compare,
                   plot_stat='r_ceiling',
                   legend=False):
    all_batch_cells = nd.get_batch_cells(batch, as_list=True)
    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=True,
                                           LN_filter=False,
                                           as_lists=True)

    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    improved = c
    not_improved = list(set(a) - set(c))
    models = [gc, stp, LN, combined]
    names = ['gc', 'stp', 'LN', 'combined']
    m1 = models[compare[0]]
    m2 = models[compare[1]]
    name1 = names[compare[0]]
    name2 = names[compare[1]]
    n_batch = len(all_batch_cells)
    n_all = len(a)
    n_imp = len(improved)
    n_not_imp = len(not_improved)

    m1_scores = plot_df[m1][not_improved]
    m1_scores_improved = plot_df[m1][improved]
    m2_scores = plot_df[m2][not_improved]
    m2_scores_improved = plot_df[m2][improved]

    fig = plt.figure()
    plt.plot([0, 1], [0, 1],
             color='black',
             linewidth=1,
             linestyle='dashed',
             dashes=dash_spacing)
    plt.scatter(m1_scores,
                m2_scores,
                s=small_scatter,
                label='no imp.',
                color=model_colors['LN'])
    #color='none',
    #edgecolors='black', linewidth=0.35)
    plt.scatter(m1_scores_improved,
                m2_scores_improved,
                s=big_scatter,
                label='sig. imp.',
                color=model_colors['max'])
    #color='none',
    #edgecolors='black', linewidth=0.35)
    ax_remove_box()

    if legend:
        plt.legend()
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.axes().set_aspect('equal')

    fig2 = plt.figure(figsize=text_fig)
    plt.text(
        0.1, 0.5, "batch %d\n"
        "%d/%d  auditory/total cells\n"
        "%d no improvements\n"
        "%d at least one improvement\n"
        "stat: %s, x: %s, y: %s" %
        (batch, n_all, n_batch, n_not_imp, n_imp, plot_stat, name1, name2))

    return fig, fig2
Example #7
0
def gc_stp_scatter(batch,
                   gc,
                   stp,
                   LN,
                   combined,
                   se_filter=True,
                   LN_filter=False,
                   manual_cellids=None,
                   plot_stat='r_ceiling'):

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    cellids, under_chance, less_LN = get_filtered_cellids(
        df_r, df_e, gc, stp, LN, combined, se_filter, LN_filter)

    if manual_cellids is not None:
        # WARNING: Will override se and ratio filters even if they are set
        cellids = manual_cellids

    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    n_cells = len(cellids)
    n_under_chance = len(under_chance) if under_chance != cellids else 0
    n_less_LN = len(less_LN) if less_LN != cellids else 0

    gc_test = plot_df[gc][cellids]
    gc_test_under_chance = plot_df[gc][under_chance]

    stp_test = plot_df[stp][cellids]
    stp_test_under_chance = plot_df[stp][under_chance]

    fig = plt.figure()
    plt.scatter(gc_test, stp_test, c=wsu_gray, s=20)
    ax = fig.axes[0]
    plt.plot(ax.get_xlim(),
             ax.get_ylim(),
             'k--',
             linewidth=1,
             dashes=dash_spacing)
    plt.scatter(gc_test_under_chance,
                stp_test_under_chance,
                c=dropped_cell_color,
                s=20)
    plt.title('GC vs STP')
    plt.xlabel('GC')
    plt.ylabel('STP')
    if se_filter:
        plt.text(0.90,
                 -0.05,
                 'all = %d' % (n_cells + n_under_chance + n_less_LN),
                 ha='right',
                 va='bottom')
        plt.text(0.90, 0.00, 'n = %d' % n_cells, ha='right', va='bottom')
        plt.text(0.90,
                 0.05,
                 'uc = %d' % n_under_chance,
                 ha='right',
                 va='bottom',
                 color=dropped_cell_color)
Example #8
0
def performance_scatters(batch,
                         gc,
                         stp,
                         LN,
                         combined,
                         se_filter=True,
                         LN_filter=False,
                         manual_cellids=None,
                         plot_stat='r_ceiling',
                         show_dropped=True,
                         color_improvements=False):
    '''
    model1: GC
    model2: STP
    model3: LN
    model4: GC+STP

    '''
    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    #    cellids, under_chance, less_LN = get_filtered_cellids(batch, gc, stp,
    #                                                          LN, combined,
    #                                                          se_filter,
    #                                                          LN_filter)
    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=se_filter,
                                           LN_filter=LN_filter)
    cellids = a

    if manual_cellids is not None:
        # WARNING: Will override se and ratio filters even if they are set
        cellids = manual_cellids

    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    gc_test = plot_df[gc][cellids]
    gc_test_under_chance = plot_df[gc][under_chance]
    stp_test = plot_df[stp][cellids]
    stp_test_under_chance = plot_df[stp][under_chance]
    ln_test = plot_df[LN][cellids]
    ln_test_under_chance = plot_df[LN][under_chance]
    gc_stp_test = plot_df[combined][cellids]
    gc_stp_test_under_chance = plot_df[combined][under_chance]

    # Row 1 (vs LN)
    fig1, ax = plt.subplots(1, 1)
    ax.scatter(ln_test, gc_test, c=scatter_color, s=20)
    ax.plot(ax.get_xlim(),
            ax.get_ylim(),
            'k--',
            linewidth=2,
            dashes=dash_spacing)
    ax.scatter(gc_test_under_chance,
               ln_test_under_chance,
               c=dropped_cell_color,
               s=20)
    #ax.scatter(gc_test_less_LN, ln_test_less_LN, c='blue', s=1)
    ax.set_title('LN vs GC')
    ax.set_ylabel('GC')
    ax.set_xlabel('LN')

    fig2, ax = plt.subplots(1, 1)
    ax.scatter(ln_test, stp_test, c=scatter_color, s=20)
    ax.plot(ax.get_xlim(),
            ax.get_ylim(),
            'k--',
            linewidth=2,
            dashes=dash_spacing)
    ax.scatter(stp_test_under_chance,
               ln_test_under_chance,
               c=dropped_cell_color,
               s=20)
    #ax.scatter(stp_test_less_LN, ln_test_less_LN, c='blue', s=1)
    ax.set_title('LN vs STP')
    ax.set_ylabel('STP')
    ax.set_xlabel('LN')

    fig3, ax = plt.subplots(1, 1)
    ax.scatter(ln_test, gc_stp_test, c=scatter_color, s=20)
    ax.plot(ax.get_xlim(),
            ax.get_ylim(),
            'k--',
            linewidth=2,
            dashes=dash_spacing)
    ax.scatter(ln_test_under_chance,
               gc_stp_test_under_chance,
               c=dropped_cell_color,
               s=20)
    #ax.scatter(gc_stp_test_less_LN, ln_test_less_LN, c='blue', s=1)
    ax.set_title('LN vs GC + STP')
    ax.set_ylabel('GC + STP')
    ax.set_xlabel('LN')

    # Row 2 (head-to-head)
    fig4, ax = plt.subplots(1, 1)
    ax.scatter(gc_test, stp_test, c=scatter_color, s=20)
    ax.plot(ax.get_xlim(),
            ax.get_ylim(),
            'k--',
            linewidth=2,
            dashes=dash_spacing)
    ax.scatter(gc_test_under_chance,
               stp_test_under_chance,
               c=dropped_cell_color,
               s=20)
    #ax.scatter(gc_test_less_LN, stp_test_less_LN, c='blue', s=20)
    ax.set_title('GC vs STP')
    ax.set_xlabel('GC')
    ax.set_ylabel('STP')

    fig5, ax = plt.subplots(1, 1)
    ax.scatter(gc_test, gc_stp_test, c=scatter_color, s=20)
    ax.plot(ax.get_xlim(),
            ax.get_ylim(),
            'k--',
            linewidth=2,
            dashes=dash_spacing)
    ax.scatter(gc_test_under_chance,
               gc_stp_test_under_chance,
               c=dropped_cell_color,
               s=20)
    #ax.scatter(gc_test_less_LN, gc_stp_test_less_LN, c='blue', s=20)
    ax.set_title('GC vs GC + STP')
    ax.set_xlabel('GC')
    ax.set_ylabel('GC + STP')

    fig6, ax = plt.subplots(1, 1)
    ax.scatter(stp_test, gc_stp_test, c=scatter_color, s=20)
    ax.plot(ax.get_xlim(),
            ax.get_ylim(),
            'k--',
            linewidth=2,
            dashes=dash_spacing)
    ax.scatter(stp_test_under_chance,
               gc_stp_test_under_chance,
               c=dropped_cell_color,
               s=20)
    #ax.scatter(stp_test_less_LN, gc_stp_test_less_LN, c='blue', s=20)
    ax.set_title('STP vs GC + STP')
    ax.set_xlabel('STP')
    ax.set_ylabel('GC + STP')

    #plt.tight_layout()

    return fig1, fig2, fig3, fig4, fig5, fig6
Example #9
0
def cf_vs_model_performance(batch,
                            gc,
                            stp,
                            LN,
                            combined,
                            cf_load_path=None,
                            cf_kwargs={},
                            se_filter=True,
                            LN_filter=False,
                            plot_stat='r_ceiling',
                            include_LN=False,
                            include_combined=False,
                            only_improvements=False):
    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    cellids, under_chance, less_LN = get_filtered_cellids(
        df_r, df_e, gc, stp, LN, combined, se_filter, LN_filter)

    cellids = cellids
    if only_improvements:
        e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined)
        gc_cells = list((set(cellids) & (set(e) | set(g))) - set(c) - set(s))
        stp_cells = list((set(cellids) & (set(e) | set(s))) - set(c) - set(g))
        n_gc = len(gc_cells)
        n_stp = len(stp_cells)
    else:
        gc_cells = stp_cells = cellids
        n_gc = n_stp = len(cellids)

    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    if cf_load_path is None:
        df = cf_batch(batch, LN, load_path=cf_load_path, **cf_kwargs)
    else:
        df = pd.read_pickle(cf_load_path)

    gc_cfs = df['cf'][gc_cells].values.astype('float32')
    stp_cfs = df['cf'][stp_cells].values.astype('float32')
    #ln_test = plot_df[LN][cellids].values.astype('float32')
    gc_test = plot_df[gc][gc_cells].values.astype('float32')
    stp_test = plot_df[stp][stp_cells].values.astype('float32')
    #combined_test = plot_df[combined][cellids].values.astype('float32')

    r_gc, p_gc = st.spearmanr(gc_cfs, gc_test)
    r_stp, p_stp = st.spearmanr(stp_cfs, stp_test)

    plt.figure()
    #    if include_LN:
    #        plt.scatter(cfs, ln_test, color='gray', alpha=0.5)
    plt.scatter(gc_cfs, gc_test, color='goldenrod', alpha=0.5)
    plt.scatter(stp_cfs, stp_test, color=wsu_crimson, alpha=0.5)
    #    if include_combined:
    #        plt.scatter(cfs, combined_test, color='purple', alpha=0.5)
    plt.xscale('log', basex=np.e)

    title = ("CF vs model performance\n"
             "gc -- rho:  %.4f, p:  %.4E, n:  %d\n"
             "stp -- rho:  %.4f, p:  %.4E, n:  %d" %
             (r_gc, p_gc, n_gc, r_stp, p_stp, n_stp))
    plt.title(title)
Example #10
0
def equivalence_histogram(batch,
                          gc,
                          stp,
                          LN,
                          combined,
                          se_filter=True,
                          LN_filter=False,
                          test_limit=None,
                          alpha=0.05,
                          save_path=None,
                          load_path=None,
                          equiv_key='partial_corr',
                          effect_key='performance_effect',
                          self_equiv=False,
                          self_kwargs={},
                          eq_models=[],
                          cross_kwargs={},
                          cross_models=[],
                          use_median=True,
                          exclude_low_snr=False,
                          snr_path=None,
                          adjust_scores=True,
                          use_log_ratios=False):
    '''
    model1: GC
    model2: STP
    model3: LN

    '''
    e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined)
    _, cellids, _, _, _ = improved_cells_to_list(batch,
                                                 gc,
                                                 stp,
                                                 LN,
                                                 combined,
                                                 as_lists=False)
    improved = c
    not_improved = list(set(a) - set(c))

    if load_path is None:
        df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)

        rs = []
        for c in a[:test_limit]:
            xf1, ctx1 = xhelp.load_model_xform(c, batch, gc)
            xf2, ctx2 = xhelp.load_model_xform(c, batch, stp)
            xf3, ctx3 = xhelp.load_model_xform(c, batch, LN)

            gc_pred = ctx1['val'].apply_mask()['pred'].as_continuous()
            stp_pred = ctx2['val'].apply_mask()['pred'].as_continuous()
            ln_pred = ctx3['val'].apply_mask()['pred'].as_continuous()

            ff = np.isfinite(gc_pred) & np.isfinite(stp_pred) & np.isfinite(
                ln_pred)
            gcff = gc_pred[ff]
            stpff = stp_pred[ff]
            lnff = ln_pred[ff]
            rs.append(np.corrcoef(gcff - lnff, stpff - lnff)[0, 1])

        blank = np.full_like(rs, np.nan)
        results = {
            'cellid': a[:test_limit],
            'equivalence': rs,
            'effect_size': blank,
            'corr_gc_LN': blank,
            'corr_stp_LN': blank
        }
        df = pd.DataFrame.from_dict(results)
        df.set_index('cellid', inplace=True)
        if save_path is not None:
            df.to_pickle(save_path)
    else:
        df = pd.read_pickle(load_path)

    df = df[cellids]
    rs = df[equiv_key].values
    if exclude_low_snr:
        snr_df = pd.read_pickle(snr_path)
        med_snr = snr_df['snr'].median()
        high_snr = snr_df.loc[snr_df['snr'] >= med_snr]
        high_snr_cells = high_snr.index.values.tolist()
        improved = list(set(improved) & set(high_snr_cells))
        not_improved = list(set(not_improved) & set(high_snr_cells))

    imp = np.array(improved)
    not_imp = np.array(not_improved)
    imp_mask = np.isin(a, imp)
    not_mask = np.isin(a, not_imp)
    rs_not = rs[not_mask]
    rs_imp = rs[imp_mask]
    md_not = np.nanmedian(rs_not)
    md_imp = np.nanmedian(rs_imp)
    u, p = st.mannwhitneyu(rs_not, rs_imp, alternative='two-sided')
    n_not = len(not_improved)
    n_imp = len(improved)

    if self_equiv:
        stp1, stp2, gc1, gc2, LN1, LN2 = eq_models
        g1, s2, s1, g2, L1, L2 = cross_models
        _, ga, _, _, _ = improved_cells_to_list(batch,
                                                gc1,
                                                gc2,
                                                LN1,
                                                LN2,
                                                as_lists=True)
        _, sa, _, _, _ = improved_cells_to_list(batch,
                                                stp1,
                                                stp2,
                                                LN1,
                                                LN2,
                                                as_lists=True)
        aa = list(set(ga) & set(sa))
        if exclude_low_snr:
            snr_df = pd.read_pickle(snr_path)
            med_snr = snr_df['snr'].median()
            high_snr = snr_df.loc[snr_df['snr'] >= med_snr]
            high_snr_cells = high_snr.index.values.tolist()
            aa = list(set(aa) & set(high_snr_cells))

        eqs_stp, eqs_gc = _get_self_equivs(**self_kwargs, cellids=aa)
        md_stpeq = np.nanmedian(eqs_stp)
        md_gceq = np.nanmedian(eqs_gc)
        n_eq = eqs_stp.size
        u_gceq, p_gceq = st.mannwhitneyu(eqs_gc,
                                         rs_imp,
                                         alternative='two-sided')
        u_stpeq, p_stpeq = st.mannwhitneyu(eqs_stp,
                                           rs_imp,
                                           alternative='two-sided')

        eqs_x1, eqs_x2 = _get_self_equivs(**cross_kwargs, cellids=aa)
        md_x1 = np.nanmedian(eqs_x1)
        md_x2 = np.nanmedian(eqs_x2)

        sub1 = df.index.isin(ga) & df.index.isin(aa)
        sub2 = df.index.isin(sa) & df.index.isin(aa)
        eqs_sub1 = df[equiv_key][sub1].values
        eqs_sub2 = df[equiv_key][sub2].values
        md_sub1 = np.nanmedian(eqs_sub1)
        md_sub2 = np.nanmedian(eqs_sub2)

        if adjust_scores:
            if use_median:
                md_avg1 = 0.5 * (md_x1 + md_x2)
                md_avg2 = 0.5 * (md_sub1 + md_sub2)
                ratio = md_avg2 / md_avg1
                md_stpeq *= ratio
                md_gceq *= ratio

            else:
                eqs_avg1 = 0.5 * (eqs_x1 + eqs_x2)  # between-model, halved est
                eqs_avg2 = 0.5 * (eqs_sub1 + eqs_sub2
                                  )  # between-model, full data
                ratios = np.abs(eqs_avg2 / eqs_avg1)
                if use_log_ratios:
                    log_ratios = np.abs(np.log(ratios))
                    log_ratios /= log_ratios.max()
                    stp_scaled = eqs_stp + (1 - eqs_stp) * log_ratios
                    gc_scaled = eqs_gc + (1 - eqs_gc) * log_ratios
                else:
                    stp_scaled = eqs_stp * ratios
                    gc_scaled = eqs_gc * ratios
                md_stpeq = np.nanmedian(stp_scaled)
                md_gceq = np.nanmedian(gc_scaled)

    not_color = model_colors['LN']
    imp_color = model_colors['max']
    weights1 = [np.ones(len(rs_not)) / len(rs_not)]
    weights2 = [np.ones(len(rs_imp)) / len(rs_imp)]

    #n_cells = rs.shape[0]
    fig1, (a1, a2) = plt.subplots(2, 1)

    a1.hist(rs_not,
            bins=30,
            range=[-0.5, 1],
            weights=weights1,
            fc=faded_LN,
            edgecolor=dark_LN,
            linewidth=1)
    a2.hist(rs_imp,
            bins=30,
            range=[-0.5, 1],
            weights=weights2,
            fc=faded_max,
            edgecolor=dark_max,
            linewidth=1)

    a1.axes.axvline(md_not,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a1.axes.axvline(md_imp,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a2.axes.axvline(md_not,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a2.axes.axvline(md_imp,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)

    if self_equiv:
        a1.axes.annotate('',
                         xy=(md_gceq, 0),
                         xycoords='data',
                         xytext=(md_gceq, 0.07),
                         textcoords='data',
                         arrowprops=dict(arrowstyle="->",
                                         connectionstyle="arc3"))
        a1.axes.annotate('',
                         xy=(md_stpeq, 0),
                         xycoords='data',
                         xytext=(md_stpeq, 0.07),
                         textcoords='data',
                         arrowprops=dict(arrowstyle="->",
                                         connectionstyle="arc3"))
        a2.axes.annotate('',
                         xy=(md_gceq, 0),
                         xycoords='data',
                         xytext=(md_gceq, 0.07),
                         textcoords='data',
                         arrowprops=dict(arrowstyle="->",
                                         connectionstyle="arc3"))
        a2.axes.annotate('',
                         xy=(md_stpeq, 0),
                         xycoords='data',
                         xytext=(md_stpeq, 0.07),
                         textcoords='data',
                         arrowprops=dict(arrowstyle="->",
                                         connectionstyle="arc3"))

    ymin1, ymax1 = a1.get_ylim()
    ymin2, ymax2 = a2.get_ylim()
    ymax = max(ymax1, ymax2)
    a1.set_ylim(0, ymax)
    a2.set_ylim(0, ymax)

    ax_remove_box(a1)
    ax_remove_box(a2)
    plt.tight_layout()

    if equiv_key == 'equivalence':
        x_text = 'equivalence, CC(GC-LN, STP-LN)\n'
    elif equiv_key == 'partial_corr':
        x_text = 'equivalence, partial correlation\n'
    else:
        x_text = 'unknown equivalence key'
    fig3 = plt.figure(figsize=text_fig)
    text2 = ("hist: equivalence of changein prediction relative to LN model\n"
             "batch: %d\n"
             "x: %s"
             "y: cell fraction\n"
             "n not imp:  %d,  md:  %.2f\n"
             "n sig. imp:  %d,  md:  %.2f\n"
             "st.mannwhitneyu:  u:  %.4E p:  %.4E" %
             (batch, x_text, n_not, md_not, n_imp, md_imp, u, p))
    if self_equiv:
        text2 += ("\n\nSelf equivalence, n: %d\n"
                  "stp:  md:  %.2f,  u:  %.4E   p  %.4E\n"
                  "gc:   md:  %.2f,  u:  %.4E   p  %.4E\n"
                  "md sub1: %.2E\n"
                  "md sub2: %.2E\n"
                  "md x1: %.2E\n"
                  "md x2: %.2E" %
                  (n_eq, md_stpeq, u_stpeq, p_stpeq, md_gceq, u_gceq, p_gceq,
                   md_sub1, md_sub2, md_x1, md_x2))
    plt.text(0.1, 0.5, text2)

    return fig1, fig3
Example #11
0
def equivalence_scatter(batch,
                        gc,
                        stp,
                        LN,
                        combined,
                        se_filter=True,
                        LN_filter=False,
                        plot_stat='r_ceiling',
                        enable_hover=False,
                        manual_lims=None,
                        drop_outliers=False,
                        color_improvements=True,
                        xmodel='GC',
                        ymodel='STP',
                        legend=False,
                        self_equiv=False,
                        self_eq_models=[],
                        show_highlights=False,
                        exclude_low_snr=False,
                        snr_path=None):
    '''
    model1: GC
    model2: STP
    model3: LN

    '''

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=se_filter,
                                           LN_filter=LN_filter)
    improved = c
    not_improved = list(set(a) - set(c))
    if exclude_low_snr:
        snr_df = pd.read_pickle(snr_path)
        med_snr = snr_df['snr'].median()
        high_snr = snr_df.loc[snr_df['snr'] >= med_snr]
        high_snr_cells = high_snr.index.values.tolist()
        improved = list(set(improved) & set(high_snr_cells))
        not_improved = list(set(not_improved) & set(high_snr_cells))
        a = list(set(a) & set(high_snr_cells))
        # check number of animals included
        prefixes = [s[:3] for s in a]
        n_animals = len(list(set(prefixes)))
        print('n_animals: %d' % n_animals)

    models = [gc, stp, LN]
    gc_rel_imp, stp_rel_imp = _relative_score(plot_df, models, improved)
    gc_rel_not, stp_rel_not = _relative_score(plot_df, models, not_improved)
    gc_rel_all, stp_rel_all = _relative_score(plot_df, models, a)

    # LN, STP, GC
    cells_to_highlight = ['TAR010c-40-1', 'AMT005c-20-1', 'TAR009d-22-1']
    cells_to_plot_gc_rel = []
    cells_to_plot_stp_rel = []
    for c in cells_to_highlight:
        stp_rel = plot_df[stp][c] - plot_df[LN][c]
        gc_rel = plot_df[gc][c] - plot_df[LN][c]
        cells_to_plot_gc_rel.append(gc_rel)
        cells_to_plot_stp_rel.append(stp_rel)

    # compute corr. before dropping outliers (only dropping for visualization)
    r_imp, p_imp = st.pearsonr(gc_rel_imp, stp_rel_imp)
    r_not, p_not = st.pearsonr(gc_rel_not, stp_rel_not)
    r_all, p_all = st.pearsonr(gc_rel_all, stp_rel_all)

    if self_equiv:
        stp1, stp2, gc1, gc2, LN1, LN2 = self_eq_models
        _, ga, _, _, _ = improved_cells_to_list(batch,
                                                gc1,
                                                gc2,
                                                LN1,
                                                LN2,
                                                as_lists=True)
        _, sa, _, _, _ = improved_cells_to_list(batch,
                                                stp1,
                                                stp2,
                                                LN1,
                                                LN2,
                                                as_lists=True)
        aa = list(set(ga) & set(sa))

        if exclude_low_snr:
            snr_df = pd.read_pickle(snr_path)
            med_snr = snr_df['snr'].median()
            high_snr = snr_df.loc[snr_df['snr'] >= med_snr]
            high_snr_cells = high_snr.index.values.tolist()
            aa = list(set(aa) & set(high_snr_cells))

        df_r_eq = nd.batch_comp(batch, [gc1, gc2, stp1, stp2, LN1, LN2],
                                stat=plot_stat)
        df_r_eq.dropna(axis=0, how='any', inplace=True)
        df_r_eq.sort_index(inplace=True)
        df_r_eq = df_r_eq[df_r_eq.index.isin(aa)]

        gc1_rel_imp = df_r_eq[gc1].values - df_r_eq[LN1].values
        gc2_rel_imp = df_r_eq[gc2].values - df_r_eq[LN2].values
        stp1_rel_imp = df_r_eq[stp1].values - df_r_eq[LN1].values
        stp2_rel_imp = df_r_eq[stp2].values - df_r_eq[LN2].values
        r_gceq, p_gceq = st.pearsonr(gc1_rel_imp, gc2_rel_imp)
        r_stpeq, p_stpeq = st.pearsonr(stp1_rel_imp, stp2_rel_imp)
        n_eq = gc1_rel_imp.size

        # compute on same subset for full estimation data
        # to compare to cross-set
        gc_subset1 = df_r[gc][ga].values
        gc_subset2 = df_r[gc][sa].values
        stp_subset1 = df_r[stp][ga].values
        stp_subset2 = df_r[stp][sa].values
        LN_subset1 = df_r[LN][ga].values
        LN_subset2 = df_r[LN][sa].values
        gc_rel1 = gc_subset1 - LN_subset1
        gc_rel2 = gc_subset2 - LN_subset2
        stp_rel1 = stp_subset1 - LN_subset1
        stp_rel2 = stp_subset2 - LN_subset2
        r_sub1, p_sub1 = st.pearsonr(gc_rel1, stp_rel1)
        r_sub2, p_sub2 = st.pearsonr(gc_rel2, stp_rel2)

    gc_rel_imp, stp_rel_imp = drop_common_outliers(gc_rel_imp, stp_rel_imp)
    gc_rel_not, stp_rel_not = drop_common_outliers(gc_rel_not, stp_rel_not)
    gc_rel_all, stp_rel_all = drop_common_outliers(gc_rel_all, stp_rel_all)

    n_imp = len(improved)
    n_not = len(not_improved)
    n_all = len(a)

    y_max = np.max(stp_rel_all)
    y_min = np.min(stp_rel_all)
    x_max = np.max(gc_rel_all)
    x_min = np.min(gc_rel_all)

    fig = plt.figure()
    ax = plt.gca()
    ax.axes.axhline(0,
                    color='black',
                    linewidth=1,
                    linestyle='dashed',
                    dashes=dash_spacing)
    ax.axes.axvline(0,
                    color='black',
                    linewidth=1,
                    linestyle='dashed',
                    dashes=dash_spacing)
    if color_improvements:
        ax.scatter(gc_rel_not,
                   stp_rel_not,
                   c=model_colors['LN'],
                   s=small_scatter)
        ax.scatter(gc_rel_imp,
                   stp_rel_imp,
                   c=model_colors['max'],
                   s=big_scatter)
        if show_highlights:
            for i, (g, s) in enumerate(
                    zip(cells_to_plot_gc_rel, cells_to_plot_stp_rel)):
                plt.text(g, s, str(i + 1), fontsize=12, color='black')
    else:
        ax.scatter(gc_rel_all, stp_rel_all, c='black', s=big_scatter)

    if legend:
        plt.legend()

    if manual_lims is not None:
        ax.set_ylim(*manual_lims)
        ax.set_xlim(*manual_lims)
    else:
        upper = max(y_max, x_max)
        lower = min(y_min, x_min)
        upper_lim = np.ceil(10 * upper) / 10
        lower_lim = np.floor(10 * lower) / 10
        ax.set_ylim(lower_lim, upper_lim)
        ax.set_xlim(lower_lim, upper_lim)
    plt.axes().set_aspect('equal')

    if enable_hover:
        mplcursors.cursor(ax, hover=True).connect(
            "add", lambda sel: sel.annotation.set_text(a[sel.target.index]))

    plt.tight_layout()

    fig2 = plt.figure(figsize=text_fig)
    text = ("Performance Improvements over LN\n"
            "batch: %d\n"
            "dropped outliers?:  %s\n"
            "all cells:  r: %.2E, p: %.2E, n: %d\n"
            "improved:  r: %.2E, p: %.2E, n: %d\n"
            "not imp:  r: %.2E, p: %.2E, n: %d\n"
            "x: %s - LN\n"
            "y: %s - LN" % (batch, drop_outliers, r_all, p_all, n_all, r_imp,
                            p_imp, n_imp, r_not, p_not, n_not, xmodel, ymodel))
    if self_equiv:
        text += ("\n\nSelf Equivalence, n:  %d\n"
                 "gc,  r: %.2Ef, p: %.2E\n"
                 "stp, r: %.2Ef, p: %.2E\n"
                 "sub1, r: %.2E, p: %.2E\n"
                 "sub2, r: %.2E, p: %.2E\n" %
                 (n_eq, r_gceq, p_gceq, r_stpeq, p_stpeq, r_sub1, p_sub1,
                  r_sub2, p_sub2))
    plt.text(0.1, 0.5, text)
    ax_remove_box(ax)

    return fig, fig2
Example #12
0
def gain_by_contrast_slopes(batch,
                            gc,
                            stp,
                            LN,
                            combined,
                            se_filter=True,
                            good_LN=0,
                            bins=30,
                            use_exp=True):

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    #cellids = df_r[LN] > good_LN
    cellids = df_r[LN] > df_e[LN] * 2
    gc_LN_SE = (df_e[gc] + df_e[LN])
    #    stp_LN_SE = (df_e[stp] + df_e[LN])
    gc_cells = (cellids) & ((df_r[gc] - df_r[LN]) > gc_LN_SE)
    #    stp_cells = (df_r[LN] > good_LN) & ((df_r[stp] - df_r[LN]) > stp_LN_SE)
    #    both_cells = gc_cells & stp_cells
    #    gc_cells = gc_cells & np.logical_not(both_cells)
    #    stp_cells = stp_cells & np.logical_not(both_cells)
    LN_cells = cellids & np.logical_not(gc_cells)  # | stp_cells | both_cells)
    meta = ['r_test', 'ctmax_val', 'ctmax_est', 'ctmin_val', 'ctmin_est']
    gc_params = fitted_params_per_batch(289, gc, stats_keys=[], meta=meta)
    # drop cellids that haven't been fit for all models
    gc_params_cells = gc_params.transpose().index.values.tolist()
    for c in gc_params_cells:
        if c not in LN_cells:
            LN_cells[c] = False
        if c not in gc_cells:
            gc_cells[c] = False
#        if c not in stp_cells:
#            stp_cells[c] = False
#        if c not in both_cells:
#            both_cells[c] = False

# index keys are formatted like "4--dsig.d--kappa"
    mod_keys = gc.split('_')[1]
    for i, k in enumerate(mod_keys.split('-')):
        if 'dsig' in k:
            break
    k_key = f'{i}--{k}--kappa'
    ka_key = k_key + '_mod'
    meta_keys = ['meta--' + k for k in meta]
    all_keys = [k_key, ka_key] + meta_keys
    phi_dfs = [
        gc_params[gc_params.index == k].transpose()[LN_cells].transpose()
        for k in all_keys
    ]
    sep_dfs = [df.values.flatten().astype(np.float64) for df in phi_dfs]
    gc_dfs = [
        gc_params[gc_params.index == k].transpose()[gc_cells].transpose()
        for k in all_keys
    ]
    gc_sep_dfs = [df.values.flatten().astype(np.float64) for df in gc_dfs]
    #    stp_dfs = [gc_params[gc_params.index==k].transpose()[stp_cells].transpose()
    #               for k in all_keys]
    #    stp_sep_dfs = [df.values.flatten().astype(np.float64) for df in stp_dfs]
    #    both_dfs = [gc_params[gc_params.index==k].transpose()[both_cells].transpose()
    #               for k in all_keys]
    #    both_sep_dfs = [df.values.flatten().astype(np.float64) for df in both_dfs]
    low, high, r_test, ctmax_val, ctmax_est, ctmin_val, ctmin_est = sep_dfs
    gc_low, gc_high, gc_r, gc_ctmax_val, \
        gc_ctmax_est, gc_ctmin_val, gc_ctmin_est = gc_sep_dfs
    #    stp_low, stp_high, stp_r, stp_ctmax_val, \
    #        stp_ctmax_est, stp_ctmin_val, stp_ctmin_est = stp_sep_dfs
    #    both_low, both_high, both_r, both_ctmax_val, \
    #        both_ctmax_est, both_ctmin_val, both_ctmin_est = both_sep_dfs

    ctmax = np.maximum(ctmax_val, ctmax_est)
    gc_ctmax = np.maximum(gc_ctmax_val, gc_ctmax_est)
    ctmin = np.minimum(ctmin_val, ctmin_est)
    gc_ctmin = np.minimum(gc_ctmin_val, gc_ctmin_est)
    #    stp_ctmax = np.maximum(stp_ctmax_val, stp_ctmax_est)
    #    stp_ctmin = np.minimum(stp_ctmin_val, stp_ctmin_est)
    #    both_ctmax = np.maximum(both_ctmax_val, both_ctmax_est)
    #    both_ctmin = np.minimum(both_ctmin_val, both_ctmin_est)
    ct_range = ctmax - ctmin
    gc_ct_range = gc_ctmax - gc_ctmin
    #    stp_ct_range = stp_ctmax - stp_ctmin
    #    both_ct_range = both_ctmax - both_ctmin
    gain = (high - low) * ct_range
    gc_gain = (gc_high - gc_low) * gc_ct_range
    # test hyp. that gc_gains are more negative than LN
    gc_LN_p = st.mannwhitneyu(gc_gain, gain, alternative='two-sided')[1]
    med_gain = np.median(gain)
    gc_med_gain = np.median(gc_gain)
    #    stp_gain = (stp_high - stp_low)*stp_ct_range
    #    both_gain = (both_high - both_low)*both_ct_range

    k_low = low + (high - low) * ctmin
    k_high = low + (high - low) * ctmax
    gc_k_low = gc_low + (gc_high - gc_low) * gc_ctmin
    gc_k_high = gc_low + (gc_high - gc_low) * gc_ctmax
    #    stp_k_low = stp_low + (stp_high - stp_low)*stp_ctmin
    #    stp_k_high = stp_low + (stp_high - stp_low)*stp_ctmax
    #    both_k_low = both_low + (both_high - both_low)*both_ctmin
    #    both_k_high = both_low + (both_high - both_low)*both_ctmax

    if use_exp:
        k_low = np.exp(k_low)
        k_high = np.exp(k_high)
        gc_k_low = np.exp(gc_k_low)
        gc_k_high = np.exp(gc_k_high)
#        stp_k_low = np.exp(stp_k_low)
#        stp_k_high = np.exp(stp_k_high)
#        both_k_low = np.exp(both_k_low)
#        both_k_high = np.exp(both_k_high)

#    fig = plt.figure()#, axes = plt.subplots(1, 2, )
#    #axes[0].plot([ctmin, ctmax], [k_low, k_high], color='black', alpha=0.5)
#    plt.hist(high-low, bins=bins, color='black', alpha=0.5)
#
#    #axes[0].plot([gc_ctmin, gc_ctmax], [gc_k_low, gc_k_high], color='red',
#    #              alpha=0.3)
#    plt.hist(gc_high-gc_low, bins=bins, color='red', alpha=0.3)
#
#    #axes[0].plot([stp_ctmin, stp_ctmax], [stp_k_low, stp_k_high], color='blue',
#    #              alpha=0.3)
#    plt.hist(stp_high-stp_low, bins=bins, color='blue', alpha=0.3)
#    plt.xlabel('gain slope')
#    plt.ylabel('count')
#    plt.title(f'raw counts, LN > {good_LN}')
#    plt.legend([f'LN, {len(low)}', f'gc, {len(gc_low)}', f'stp, {len(stp_low)}',
#                f'Both, {len(both_low)}'])

    smallest_slope = min(np.min(gain), np.min(gc_gain))  #, np.min(stp_gain),
    #np.min(both_gain))
    largest_slope = max(np.max(gain), np.max(gc_gain))  #, np.max(stp_gain),
    #np.max(both_gain))
    slope_range = (smallest_slope, largest_slope)
    bins = np.linspace(smallest_slope, largest_slope, bins)
    bar_width = bins[1] - bins[0]
    axis_locs = bins[:-1]
    hist = np.histogram(gain, bins=bins, range=slope_range)
    gc_hist = np.histogram(gc_gain, bins=bins, range=slope_range)
    #    stp_hist = np.histogram(stp_gain, bins=bins, range=slope_range)
    #    both_hist = np.histogram(both_gain, bins=bins, range=slope_range)
    raw = hist[0]
    gc_raw = gc_hist[0]
    #    stp_raw = stp_hist[0]
    #    both_raw = both_hist[0]
    #prop_hist = hist[0] / np.sum(hist[0])
    #prop_gc_hist = gc_hist[0] / np.sum(gc_hist[0])
    #    prop_stp_hist = stp_hist[0] / np.sum(stp_hist[0])
    #    prop_both_hist = both_hist[0] / np.sum(both_hist[0])

    fig1 = plt.figure()
    plt.bar(axis_locs, raw, width=bar_width, color='gray', alpha=0.8)
    plt.bar(axis_locs,
            gc_raw,
            width=bar_width,
            color='maroon',
            alpha=0.8,
            bottom=raw)
    #    plt.bar(axis_locs, stp_raw, width=bar_width, color='teal', alpha=0.8,
    #            bottom=raw+gc_raw)
    #    plt.bar(axis_locs, both_raw, width=bar_width, color='goldenrod', alpha=0.8,
    #            bottom=raw+gc_raw+stp_raw)
    plt.xlabel('gain slope')
    plt.ylabel('count')
    plt.title(f'raw counts, LN > {good_LN}')
    plt.legend([
        f'LN, {len(low)}, md={med_gain:.4f}',
        f'gc, {len(gc_low)}, md={gc_med_gain:.4f}, p={gc_LN_p:.4f}'
    ])
Example #13
0
def gd_ratio(batch,
             gc,
             stp,
             LN,
             combined,
             se_filter=True,
             good_LN=0,
             bins=30,
             use_exp=True):
    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    #cellids = df_r[LN] > good_LN
    cellids = df_r[LN] > df_e[LN] * 2
    gc_LN_SE = (df_e[gc] + df_e[LN])
    #stp_LN_SE = (df_e[stp] + df_e[LN])
    gc_cells = cellids & ((df_r[gc] - df_r[LN]) > gc_LN_SE)
    #stp_cells = (df_r[LN] > good_LN) & ((df_r[stp] - df_r[LN]) > stp_LN_SE)
    #both_cells = gc_cells & stp_cells
    LN_cells = cellids & np.logical_not(gc_cells)
    #stp_cells = stp_cells & np.logical_not(both_cells)
    meta = ['r_test', 'ctmax_val', 'ctmax_est', 'ctmin_val', 'ctmin_est']
    gc_params = fitted_params_per_batch(289, gc, stats_keys=[], meta=meta)
    # drop cellids that haven't been fit for all models
    gc_params_cells = gc_params.transpose().index.values.tolist()
    for c in gc_params_cells:
        if c not in LN_cells:
            LN_cells[c] = False
        if c not in gc_cells:
            gc_cells[c] = False
#        if c not in stp_cells:
#            stp_cells[c] = False
#        if c not in both_cells:
#            both_cells[c] = False

# index keys are formatted like "4--dsig.d--kappa"
    mod_keys = gc.split('_')[1]
    for i, k in enumerate(mod_keys.split('-')):
        if 'dsig' in k:
            break
    k_key = f'{i}--{k}--kappa'
    ka_key = k_key + '_mod'
    meta_keys = ['meta--' + k for k in meta]
    all_keys = [k_key, ka_key] + meta_keys
    phi_dfs = [
        gc_params[gc_params.index == k].transpose()[LN_cells].transpose()
        for k in all_keys
    ]
    sep_dfs = [df.values.flatten().astype(np.float64) for df in phi_dfs]
    gc_dfs = [
        gc_params[gc_params.index == k].transpose()[gc_cells].transpose()
        for k in all_keys
    ]
    gc_sep_dfs = [df.values.flatten().astype(np.float64) for df in gc_dfs]
    #    stp_dfs = [gc_params[gc_params.index==k].transpose()[stp_cells].transpose()
    #               for k in all_keys]
    #    stp_sep_dfs = [df.values.flatten().astype(np.float64) for df in stp_dfs]
    #    both_dfs = [gc_params[gc_params.index==k].transpose()[both_cells].transpose()
    #               for k in all_keys]
    #    both_sep_dfs = [df.values.flatten().astype(np.float64) for df in both_dfs]
    low, high, r_test, ctmax_val, ctmax_est, ctmin_val, ctmin_est = sep_dfs
    gc_low, gc_high, gc_r, gc_ctmax_val, \
        gc_ctmax_est, gc_ctmin_val, gc_ctmin_est = gc_sep_dfs
    #    stp_low, stp_high, stp_r, stp_ctmax_val, \
    #        stp_ctmax_est, stp_ctmin_val, stp_ctmin_est = stp_sep_dfs
    #    both_low, both_high, both_r, both_ctmax_val, \
    #        both_ctmax_est, both_ctmin_val, both_ctmin_est = both_sep_dfs

    ctmax = np.maximum(ctmax_val, ctmax_est)
    gc_ctmax = np.maximum(gc_ctmax_val, gc_ctmax_est)
    ctmin = np.minimum(ctmin_val, ctmin_est)
    gc_ctmin = np.minimum(gc_ctmin_val, gc_ctmin_est)
    #    stp_ctmax = np.maximum(stp_ctmax_val, stp_ctmax_est)
    #    stp_ctmin = np.minimum(stp_ctmin_val, stp_ctmin_est)
    #    both_ctmax = np.maximum(both_ctmax_val, both_ctmax_est)
    #    both_ctmin = np.minimum(both_ctmin_val, both_ctmin_est)

    k_low = low + (high - low) * ctmin
    k_high = low + (high - low) * ctmax
    gc_k_low = gc_low + (gc_high - gc_low) * gc_ctmin
    gc_k_high = gc_low + (gc_high - gc_low) * gc_ctmax
    #    stp_k_low = stp_low + (stp_high - stp_low)*stp_ctmin
    #    stp_k_high = stp_low + (stp_high - stp_low)*stp_ctmax
    #    both_k_low = both_low + (both_high - both_low)*both_ctmin
    #    both_k_high = both_low + (both_high - both_low)*both_ctmax

    if use_exp:
        k_low = np.exp(k_low)
        k_high = np.exp(k_high)
        gc_k_low = np.exp(gc_k_low)
        gc_k_high = np.exp(gc_k_high)


#        stp_k_low = np.exp(stp_k_low)
#        stp_k_high = np.exp(stp_k_high)
#        both_k_low = np.exp(both_k_low)
#        both_k_high = np.exp(both_k_high)

    ratio = k_low / k_high
    gc_ratio = gc_k_low / gc_k_high
    #    stp_ratio = stp_k_low / stp_k_high
    #    both_ratio = both_k_low / both_k_high

    fig1, ((ax1), (ax2)) = plt.subplots(
        1,
        2,
    )
    ax1.hist(ratio, bins=bins)
    ax1.set_title('all cells')
    ax2.hist(gc_ratio, bins=bins)
    ax2.set_title('gc')
    #    ax3.hist(stp_ratio, bins=bins)
    #    ax3.set_title('stp')
    if not use_exp:
        title = 'k_low / k_high'
    else:
        title = 'e^(k_low - k_high)'
    fig1.suptitle(title)

    fig3 = plt.figure()
    plt.scatter(ratio, r_test)
    plt.title('low/high vs r_test')

    fig4 = plt.figure()
    plt.scatter(gc_ratio, gc_r)
    plt.title('low/high vs r_test, gc improvements only')
Example #14
0
def stp_distributions(batch,
                      gc,
                      stp,
                      LN,
                      combined,
                      se_filter=True,
                      good_ln=0,
                      log_scale=False,
                      legend=False,
                      use_combined=False):

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    cellids, under_chance, less_LN = get_filtered_cellids(batch,
                                                          gc,
                                                          stp,
                                                          LN,
                                                          combined,
                                                          as_lists=False)
    _, _, _, _, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           good_ln=good_ln)

    if use_combined:
        params_model = combined
    else:
        params_model = stp
    stp_params = fitted_params_per_batch(batch,
                                         params_model,
                                         stats_keys=[],
                                         meta=[])
    stp_params_cells = stp_params.transpose().index.values.tolist()
    for cell in stp_params_cells:
        if cell not in cellids:
            cellids[cell] = False
    not_c = list(set(stp_params.transpose()[cellids].index.values) - set(c))

    # index keys are formatted like "2--stp.2--tau"
    mod_keys = stp.split('_')[1]
    for i, k in enumerate(mod_keys.split('-')):
        if 'stp' in k:
            break
    tau_key = '%d--%s--tau' % (i, k)
    u_key = '%d--%s--u' % (i, k)

    all_taus = stp_params[stp_params.index ==
                          tau_key].transpose()[cellids].transpose()
    all_us = stp_params[stp_params.index ==
                        u_key].transpose()[cellids].transpose()
    dims = all_taus.values.flatten()[0].shape[0]

    # convert to dims x cells array instead of cells, array w/ multidim values
    #sep_taus = _df_to_array(all_taus, dims).mean(axis=0)
    #sep_us = _df_to_array(all_us, dims).mean(axis=0)
    #med_tau = np.median(sep_taus)
    #med_u = np.median(sep_u)
    sep_taus = _df_to_array(all_taus[not_c], dims).mean(axis=0)
    sep_us = _df_to_array(all_us[not_c], dims).mean(axis=0)
    med_tau = np.median(sep_taus)
    med_u = np.median(sep_us)

    stp_taus = all_taus[c]
    stp_us = all_us[c]
    stp_sep_taus = _df_to_array(stp_taus, dims).mean(axis=0)
    stp_sep_us = _df_to_array(stp_us, dims).mean(axis=0)

    stp_med_tau = np.median(stp_sep_taus)
    stp_med_u = np.median(stp_sep_us)
    #tau_t, tau_p = st.ttest_ind(sep_taus, stp_sep_taus)
    #u_t, u_p = st.ttest_ind(sep_us, stp_sep_us)

    # NOTE: not actually a t statistic now, it's mann-whitney U statistic,
    #       just didn't want to change all of the var names incase i revert
    tau_t, tau_p = st.mannwhitneyu(sep_taus,
                                   stp_sep_taus,
                                   alternative='two-sided')
    u_t, u_p = st.mannwhitneyu(sep_us, stp_sep_us, alternative='two-sided')

    sep_taus, sep_us = drop_common_outliers(sep_taus, sep_us)
    stp_sep_taus, stp_sep_us = drop_common_outliers(stp_sep_taus, stp_sep_us)
    not_imp_outliers = len(sep_taus)
    imp_outliers = len(stp_sep_taus)

    fig1, (a1, a2) = plt.subplots(2, 1, sharex=True, sharey=True)
    color = model_colors['LN']
    imp_color = model_colors['max']
    stp_label = 'STP ++ (%d)' % len(c)
    total_cells = len(c) + len(not_c)
    bin_count = 30
    hist_kwargs = {'linewidth': 1, 'label': ['not imp', 'stp imp']}

    plt.sca(a1)
    weights1 = [np.ones(len(sep_taus)) / len(sep_taus)]
    weights2 = [np.ones(len(stp_sep_taus)) / len(stp_sep_taus)]
    upper = max(sep_taus.max(), stp_sep_taus.max())
    lower = min(sep_taus.min(), stp_sep_taus.min())
    bins = np.linspace(lower, upper, bin_count + 1)
    #    if log_scale:
    #        lower_bound = min(sep_taus.min(), stp_sep_taus.min())
    #        upper_bound = max(sep_taus.max(), stp_sep_taus.max())
    #        bins = np.logspace(lower_bound, upper_bound, bin_count+1)
    #        hist_kwargs['bins'] = bins
    #    plt.hist([sep_taus, stp_sep_taus], weights=weights, **hist_kwargs)
    a1.hist(sep_taus,
            weights=weights1,
            fc=faded_LN,
            edgecolor=dark_LN,
            bins=bins,
            **hist_kwargs)
    a2.hist(stp_sep_taus,
            weights=weights2,
            fc=faded_max,
            edgecolor=dark_max,
            bins=bins,
            **hist_kwargs)
    a1.axes.axvline(med_tau,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a1.axes.axvline(stp_med_tau,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a2.axes.axvline(med_tau,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a2.axes.axvline(stp_med_tau,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    ax_remove_box(a1)
    ax_remove_box(a2)

    #plt.title('tau,  sig diff?:  p=%.4E' % tau_p)
    #plt.xlabel('tau (ms)')

    fig2 = plt.figure(figsize=text_fig)
    text = ("tau distributions, n: %d\n"
            "n stp imp (bot): %d, med: %.4f\n"
            "n not imp (top): %d, med: %.4f\n"
            "yaxes: fraction of cells\n"
            "xaxis: tau(ms)\n"
            "st.mannwhitneyu u: %.4E,\np: %.4E\n"
            "not imp after outliers: %d\n"
            "imp after outliers: %d\n" %
            (total_cells, len(c), stp_med_tau, len(not_c), med_tau, tau_t,
             tau_p, not_imp_outliers, imp_outliers))
    plt.text(0.1, 0.5, text)

    fig3, (a3, a4) = plt.subplots(2, 1, sharex=True, sharey=True)
    weights3 = [np.ones(len(sep_us)) / len(sep_us)]
    weights4 = [np.ones(len(stp_sep_us)) / len(stp_sep_us)]
    upper = max(sep_us.max(), stp_sep_us.max())
    lower = min(sep_us.min(), stp_sep_us.min())
    bins = np.linspace(lower, upper, bin_count + 1)
    #    if log_scale:
    #        lower_bound = min(sep_us.min(), stp_sep_us.min())
    #        upper_bound = max(sep_us.max(), stp_sep_us.max())
    #        bins = np.logspace(lower_bound, upper_bound, bin_count+1)
    #        hist_kwargs['bins'] = bins
    #    plt.hist([sep_us, stp_sep_us], weights=weights, **hist_kwargs)
    a3.hist(sep_us,
            weights=weights3,
            fc=faded_LN,
            edgecolor=dark_LN,
            bins=bins,
            **hist_kwargs)
    a4.hist(stp_sep_us,
            weights=weights4,
            fc=faded_max,
            edgecolor=dark_max,
            bins=bins,
            **hist_kwargs)
    a3.axes.axvline(med_u,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a3.axes.axvline(stp_med_u,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a4.axes.axvline(med_u,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a4.axes.axvline(stp_med_u,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    ax_remove_box(a3)
    ax_remove_box(a4)
    #plt.title('u,  sig diff?:  p=%.4E' % u_p)
    #plt.xlabel('u (fractional change in gain \nper unit of stimulus amplitude)')
    #plt.ylabel('proportion within group')

    fig4 = plt.figure(figsize=text_fig)
    text = ("u distributions, n: %d\n"
            "n stp imp (bot): %d, med: %.4f\n"
            "n not imp (top): %d, med: %.4f\n"
            "yaxes: fraction of cells\n"
            "xaxis: u(fractional change in gain per unit stimulus amplitude)\n"
            "st.mannwhitneyu u: %.4E,\np: %.4E" %
            (total_cells, len(c), stp_med_u, len(not_c), med_u, u_t, u_p))
    plt.text(0.1, 0.5, text)

    stp_mag, stp_yin, stp_out = stp_magnitude(np.array([[stp_med_tau]]),
                                              np.array([[stp_med_u]]))
    mag, yin, out = stp_magnitude(np.array([[med_tau]]), np.array([[med_u]]))
    fig5 = plt.figure(figsize=short_fig)
    plt.plot(stp_out.as_continuous().flatten(),
             color=imp_color,
             label='STP ++')
    plt.plot(out.as_continuous().flatten(), color=color)
    if legend:
        plt.legend()
    ax_remove_box()

    return fig1, fig2, fig3, fig4, fig5
Example #15
0
def gc_distributions(batch,
                     gc,
                     stp,
                     LN,
                     combined,
                     se_filter=True,
                     good_ln=0,
                     use_combined=False):
    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    cellids, under_chance, less_LN = get_filtered_cellids(batch,
                                                          gc,
                                                          stp,
                                                          LN,
                                                          combined,
                                                          as_lists=False)
    _, _, _, _, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           good_ln=good_ln)

    if use_combined:
        params_model = combined
    else:
        params_model = gc
    gc_params = fitted_params_per_batch(batch,
                                        params_model,
                                        stats_keys=[],
                                        meta=[])
    gc_params_cells = gc_params.transpose().index.values.tolist()
    for cell in gc_params_cells:
        if cell not in cellids:
            cellids[cell] = False
    not_c = list(set(gc_params.transpose()[cellids].index.values) - set(c))

    # index keys are formatted like "4--dsig.d--kappa"
    mod_keys = params_model.split('_')[1]
    for i, k in enumerate(mod_keys.split('-')):
        if 'dsig' in k:
            break
    b_key = f'{i}--{k}--base'
    a_key = f'{i}--{k}--amplitude'
    s_key = f'{i}--{k}--shift'
    k_key = f'{i}--{k}--kappa'
    ka_key = k_key + '_mod'
    ba_key = b_key + '_mod'
    aa_key = a_key + '_mod'
    sa_key = s_key + '_mod'
    all_keys = [b_key, a_key, s_key, k_key, ba_key, aa_key, sa_key, ka_key]

    phi_dfs = [
        gc_params[gc_params.index == k].transpose()[cellids].transpose()
        for k in all_keys
    ]
    sep_dfs = [df[not_c].values.flatten().astype(np.float64) for df in phi_dfs]
    gc_sep_dfs = [df[c].values.flatten().astype(np.float64) for df in phi_dfs]

    # removing extreme outliers b/c kept getting one or two cells with
    # values that were multiple orders of magnitude different than all others
    #    diffs = [sep_dfs[i+1] - sep_dfs[i]
    #             for i, _ in enumerate(sep_dfs[:-1])
    #             if i % 2 == 0]
    #diffs = sep_dfs[1::2] - sep_dfs[::2]

    #    gc_diffs = [gc_sep_dfs[i+1] - gc_sep_dfs[i]
    #                for i, _ in enumerate(gc_sep_dfs[:-1])
    #                if i % 2 == 0]
    #gc_diffs = gc_sep_dfs[1::2] - gc_sep_dfs[::2]

    raw_low, raw_high = sep_dfs[:4], sep_dfs[4:]
    diffs = [high - low for low, high in zip(raw_low, raw_high)]
    medians = [np.median(d) for d in diffs]
    medians_low = [np.median(d) for d in raw_low]
    medians_high = [np.median(d) for d in raw_high]

    gc_raw_low, gc_raw_high = gc_sep_dfs[:4], gc_sep_dfs[4:]
    gc_diffs = [high - low for low, high in zip(gc_raw_low, gc_raw_high)]

    gc_medians = [np.median(d) for d in gc_diffs]
    gc_medians_low = [np.median(d) for d in gc_raw_low]
    gc_medians_high = [np.median(d) for d in gc_raw_high]

    ts, ps = zip(*[
        st.mannwhitneyu(diff, gc_diff, alternative='two-sided')
        for diff, gc_diff in zip(diffs, gc_diffs)
    ])

    diffs = drop_common_outliers(*diffs)
    gc_diffs = drop_common_outliers(*gc_diffs)
    not_imp_outliers = len(diffs[0])
    imp_outliers = len(gc_diffs[0])

    color = model_colors['LN']
    c_color = model_colors['max']
    gc_label = 'GC ++ (%d)' % len(c)
    total_cells = len(c) + len(not_c)
    hist_kwargs = {'label': ['no imp', 'sig imp'], 'linewidth': 1}

    figs = []
    for i, name in zip([0, 1, 2, 3], ['base', 'amplitude', 'shift', 'kappa']):
        f1 = _stacked_hists(diffs[i],
                            gc_diffs[i],
                            medians[i],
                            gc_medians[i],
                            color,
                            c_color,
                            hist_kwargs=hist_kwargs)
        f2 = plt.figure(figsize=text_fig)
        text = ("%s distributions, n: %d\n"
                "n gc imp (bot): %d, med: %.4f\n"
                "n not imp (top): %d, med: %.4f\n"
                "yaxes: fraction of cells\n"
                "xaxis: 'fractional change in parameter per unit contrast'\n"
                "st.mannwhitneyu u: %.4E,\np: %.4E\n"
                "not imp w/o outliers: %d\n"
                "imp w/o outliers: %d" %
                (name, total_cells, len(c), gc_medians[i], len(not_c),
                 medians[i], ts[i], ps[i], not_imp_outliers, imp_outliers))
        plt.text(0.1, 0.5, text)
        figs.append(f1)
        figs.append(f2)

    f3 = plt.figure(figsize=small_fig)
    # median gc effect plots
    yin1, out1 = gc_dummy_sigmoid(*medians_low, low=0.0, high=0.3)
    yin2, out2 = gc_dummy_sigmoid(*medians_high, low=0.0, high=0.3)
    plt.scatter(yin1, out1, color=color, s=big_scatter, alpha=0.3)
    plt.scatter(yin2, out2, color=color, s=big_scatter * 2)
    figs.append(f3)
    plt.tight_layout()
    ax_remove_box()

    f3a = plt.figure(figsize=text_fig)
    text = ("non improved cells\n"
            "median low contrast:\n"
            "base:  %.4f,   amplitude:  %.4f\n"
            "shift:  %.4f,   kappa:  %.4f\n"
            "median high contrast:\n"
            "base:  %.4f,   amplitude:  %.4f\n"
            "shift:  %.4f,   kappa:  %.4f\n" % (*medians_low, *medians_high))
    plt.text(0.1, 0.5, text)
    figs.append(f3a)

    f4 = plt.figure(figsize=small_fig)
    gc_yin1, gc_out1 = gc_dummy_sigmoid(*gc_medians_low, low=0.0, high=0.3)
    gc_yin2, gc_out2 = gc_dummy_sigmoid(*gc_medians_high, low=0.0, high=0.3)
    plt.scatter(gc_yin1, gc_out1, color=c_color, s=big_scatter, alpha=0.3)
    plt.scatter(gc_yin2, gc_out2, color=c_color, s=big_scatter * 2)
    figs.append(f4)
    plt.tight_layout()
    ax_remove_box()

    f4a = plt.figure(figsize=text_fig)
    text = ("improved cells\n"
            "median low contrast:\n"
            "base:  %.4f,   amplitude:  %.4f\n"
            "shift:  %.4f,   kappa:  %.4f\n"
            "median high contrast:\n"
            "base:  %.4f,   amplitude:  %.4f\n"
            "shift:  %.4f,   kappa:  %.4f\n" %
            (*gc_medians_low, *gc_medians_high))
    plt.text(0.1, 0.5, text)
    figs.append(f4a)

    return figs
Example #16
0
def relative_bar_comparison(batch1,
                            batch2,
                            gc,
                            stp,
                            LN,
                            combined,
                            se_filter=True,
                            ln_filter=False,
                            plot_stat='r_ceiling',
                            only_improvements=False,
                            good_ln=0.0):
    raise ValueError('fix cellid filters before using me')

    df_r1, df_c1, df_e1 = get_dataframes(batch1, gc, stp, LN, combined)
    cellids1, under_chance1, less_LN1 = get_filtered_cellids(
        df_r1, df_e1, gc, stp, LN, combined, se_filter, ln_filter)

    df_r2, df_c2, df_e2 = get_dataframes(batch2, gc, stp, LN, combined)
    cellids2, under_chance2, less_LN2 = get_filtered_cellids(
        df_r2, df_e2, gc, stp, LN, combined, se_filter, ln_filter)

    if only_improvements:
        # only use cells for which there was a significant improvement
        # to one or more models
        e1, n1, g1, s1, c1 = improved_cells_to_list(batch1,
                                                    gc,
                                                    stp,
                                                    LN,
                                                    combined,
                                                    good_ln=good_ln)
        filter1 = list(set(e1) | set(g1) | set(s1) | set(c1))

        e2, n2, g2, s2, c2 = improved_cells_to_list(batch2,
                                                    gc,
                                                    stp,
                                                    LN,
                                                    combined,
                                                    good_ln=good_ln)
        filter2 = list(set(e2) | set(g2) | set(s2) | set(c2))

        cellids1 = [c for c in cellids1 if c in filter1]
        cellids2 = [c for c in cellids2 if c in filter2]

    if plot_stat == 'r_ceiling':
        plot_df1 = df_c1
        plot_df2 = df_c2
    else:
        plot_df1 = df_r1
        plot_df2 = df_r2

    n_cells1 = len(cellids1)
    gc_test1 = plot_df1[gc][cellids1]
    stp_test1 = plot_df1[stp][cellids1]
    ln_test1 = plot_df1[LN][cellids1]
    gc_stp_test1 = plot_df1[combined][cellids1]

    n_cells2 = len(cellids2)
    gc_test2 = plot_df2[gc][cellids2]
    stp_test2 = plot_df2[stp][cellids2]
    ln_test2 = plot_df2[LN][cellids2]
    gc_stp_test2 = plot_df2[combined][cellids2]

    gc_rel1 = gc_test1 - ln_test1
    stp_rel1 = stp_test1 - ln_test1
    gc_stp_rel1 = gc_stp_test1 - ln_test1
    gc1 = np.mean(gc_rel1.values)
    stp1 = np.mean(stp_rel1.values)
    gc_stp1 = np.mean(gc_stp_rel1.values)
    #largest1 = max(gc1, stp1, gc_stp1)

    gc_rel2 = gc_test2 - ln_test2
    stp_rel2 = stp_test2 - ln_test2
    gc_stp_rel2 = gc_stp_test2 - ln_test2
    gc2 = np.mean(gc_rel2.values)
    stp2 = np.mean(stp_rel2.values)
    gc_stp2 = np.mean(gc_stp_rel2.values)
    #largest2 = max(gc2, stp2, gc_stp2)

    fig = plt.figure(figsize=(15, 12))
    plt.bar(
        [1, 2, 3, 4, 5, 6],
        [gc1, stp1, gc_stp1, gc2, stp2, gc_stp2],
        #color=['purple', 'green', 'gray', 'blue'])
        color=[
            gc_color, stp_color, gc_stp_color, gc_color, stp_color,
            gc_stp_color
        ],
        edgecolor="black",
        linewidth=2)
    plt.xticks(
        [1, 2, 3, 4, 5],
        ['GC %d' % batch1, 'STP', 'GC+STP',
         'GC %d' % batch2, 'STP', 'GC+STP'])
    #    if abbr_yaxis:
    #        lower = np.floor(10*min(gc, stp, ln, gc_stp))/10
    #        upper = np.ceil(10*max(gc, stp, ln, gc_stp))/10 + y_adjust
    #        plt.ylim(ymin=lower, ymax=upper)
    #    else:
    #        plt.ylim(ymax=largest*1.4)

    common_kwargs = {'color': 'white', 'horizontalalignment': 'center'}
    #    if abbr_yaxis:
    #        y_text = 0.5*(lower + min(gc, stp, ln, gc_stp))
    #    else:
    y_text = 0.2
    plt.text(1, y_text, "%0.04f" % gc1, **common_kwargs)
    plt.text(2, y_text, "%0.04f" % stp1, **common_kwargs)
    plt.text(3, y_text, "%0.04f" % gc_stp1, **common_kwargs)
    plt.text(4, y_text, "%0.04f" % gc2, **common_kwargs)
    plt.text(5, y_text, "%0.04f" % stp2, **common_kwargs)
    plt.text(6, y_text, "%0.04f" % gc_stp2, **common_kwargs)
    plt.title("Mean Relative (to LN) Performance for GC, STP, and GC+STP\n"
              "batch %d, n: %d   vs   batch %d, n: %d" %
              (batch1, n_cells1, batch2, n_cells2))

    return fig
Example #17
0
def significance(batch,
                 gc,
                 stp,
                 LN,
                 combined,
                 manual_cellids=None,
                 plot_stat='r_ceiling',
                 include_legend=True,
                 only_improvements=False):
    '''
    model1: GC
    model2: STP
    model3: LN
    model4: GC+STP

    '''
    # NOTE: The comparison of max(gc, stp) to gc/stp should be
    #       ignored. They're known to be different by definition
    #       and there's a bug in the scipy code that causes the
    #       W-statistic to be reported as 0.
    #       This happens because all of the differences are one-sided
    #       and the scipy code takes the minimum of either positive
    #       or negative differences, one of which will always be 0

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined)
    cellids = a

    gc_test = df_r[gc][cellids]
    stp_test = df_r[stp][cellids]
    ln_test = df_r[LN][cellids]
    gc_stp_test = df_r[combined][cellids]
    max_test = np.maximum(gc_test, stp_test)

    modelnames = ['GC', 'STP', 'LN', 'GC + STP', 'Max(GC,STP)']
    models = {
        'GC': gc_test,
        'STP': stp_test,
        'LN': ln_test,
        'GC + STP': gc_stp_test,
        'Max(GC,STP)': max_test
    }
    array = np.ndarray(shape=(len(modelnames), len(modelnames)), dtype=float)

    for i, m_one in enumerate(modelnames):
        for j, m_two in enumerate(modelnames):
            # get series of values corresponding to selected measure
            # for each model
            series_one = models[m_one]
            series_two = models[m_two]
            # TODO: no reason to convert these to lists anymore?
            first = series_one.tolist()
            second = series_two.tolist()

            if j != i:
                w, p = st.wilcoxon(first, second)
            if j == i:
                # if indices equal, on diagonal so no comparison
                array[i][j] = 0.00
            elif j > i:
                # if j is larger, below diagonal so get mean difference
                array[i][j] = w
            else:
                # if j is smaller, above diagonal so run t-test and
                # get p-value
                array[i][j] = p

    xticks = range(len(modelnames))
    yticks = xticks
    minor_xticks = np.arange(-0.5, len(modelnames), 1)
    minor_yticks = np.arange(-0.5, len(modelnames), 1)

    fig = plt.figure(figsize=(12, 12))
    ax = plt.gca()

    # ripped from stackoverflow. adds text labels to the grid
    # at positions i,j (model x model)  with text z (value of array at i, j)
    for (i, j), z in np.ndenumerate(array):
        if j == i:
            color = "#EBEBEB"
        elif j > i:
            color = "#368DFF"
        else:
            if array[i][j] < 0.001:
                color = "#74E572"
            elif array[i][j] < 0.01:
                color = "#59AF57"
            elif array[i][j] < 0.05:
                color = "#397038"
            else:
                color = "#ABABAB"

        ax.add_patch(
            mpatch.Rectangle(
                xy=(j - 0.5, i - 0.5),
                width=1.0,
                height=1.0,
                angle=0.0,
                facecolor=color,
                edgecolor='black',
            ))
        if j == i:
            # don't draw text for diagonal
            continue
#        formatting = '{:.04f}'
#        if z <= 0.0001:
#            formatting = '{:.2E}'
        formatting = '{:.2E}'
        ax.text(
            j,
            i,
            formatting.format(z),
            ha='center',
            va='center',
        )

    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.set_yticks(yticks)
    ax.set_yticklabels(modelnames, fontsize=10)
    ax.set_xticks(xticks)
    ax.set_xticklabels(modelnames, fontsize=10, rotation="vertical")
    ax.set_yticks(minor_yticks, minor=True)
    ax.set_xticks(minor_xticks, minor=True)
    ax.grid(b=False)
    ax.grid(which='minor', color='b', linestyle='-', linewidth=0.75)
    title = "Wilcoxon Signed Test\nOnly improvements?:  %s" % only_improvements
    ax.set_title(title, ha='center', fontsize=14)

    if include_legend:
        blue_patch = mpatch.Patch(color='#368DFF',
                                  label='W statistic',
                                  edgecolor='black')
        p001_patch = mpatch.Patch(color='#74E572',
                                  label='P < 0.001',
                                  edgecolor='black')
        p01_patch = mpatch.Patch(color='#59AF57',
                                 label='P < 0.01',
                                 edgecolor='black')
        p05_patch = mpatch.Patch(color='#397038',
                                 label='P < 0.05',
                                 edgecolor='black')
        nonsig_patch = mpatch.Patch(
            color='#ABABAB',
            label='Not Significant',
            edgecolor='black',
        )

        plt.legend(
            #bbox_to_anchor=(0., 1.02, 1., .102), ncol=2,
            bbox_to_anchor=(1.05, 1),
            ncol=1,
            loc=2,
            handles=[
                p05_patch,
                p01_patch,
                p001_patch,
                nonsig_patch,
                blue_patch,
            ])
    plt.tight_layout()

    return fig
Example #18
0
def equivalence_effect_size(batch,
                            gc,
                            stp,
                            LN,
                            combined,
                            se_filter=True,
                            LN_filter=False,
                            save_path=None,
                            load_path=None,
                            test_limit=None,
                            only_improvements=False,
                            legend=False,
                            effect_key='performance_effect',
                            equiv_key='partial_corr',
                            enable_hover=False,
                            plot_stat='r_ceiling',
                            highlight_cells=None):

    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=se_filter,
                                           LN_filter=LN_filter)
    _, cellids, _, _, _ = improved_cells_to_list(batch,
                                                 gc,
                                                 stp,
                                                 LN,
                                                 combined,
                                                 se_filter=se_filter,
                                                 LN_filter=LN_filter,
                                                 as_lists=False)

    if load_path is None:
        equivs = []
        partials = []
        gcs = []
        stps = []
        effects = []
        for cell in a[:test_limit]:
            xf1, ctx1 = xhelp.load_model_xform(cell, batch, gc)
            xf2, ctx2 = xhelp.load_model_xform(cell, batch, stp)
            xf3, ctx3 = xhelp.load_model_xform(cell, batch, LN)

            gc_pred = ctx1['val'].apply_mask()['pred'].as_continuous()
            stp_pred = ctx2['val'].apply_mask()['pred'].as_continuous()
            ln_pred = ctx3['val'].apply_mask()['pred'].as_continuous()

            ff = np.isfinite(gc_pred) & np.isfinite(stp_pred) & np.isfinite(
                ln_pred)
            gcff = gc_pred[ff]
            stpff = stp_pred[ff]
            lnff = ln_pred[ff]

            C = np.hstack((np.expand_dims(gcff, 0).transpose(),
                           np.expand_dims(stpff, 0).transpose(),
                           np.expand_dims(lnff, 0).transpose()))
            partials.append(partial_corr(C)[0, 1])

            equivs.append(np.corrcoef(gcff - lnff, stpff - lnff)[0, 1])
            this_gc = np.corrcoef(gcff, lnff)[0, 1]
            this_stp = np.corrcoef(stpff, lnff)[0, 1]
            gcs.append(this_gc)
            stps.append(this_stp)
            effects.append(1 - 0.5 * (this_gc + this_stp))

        df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)

        if plot_stat == 'r_ceiling':
            plot_df = df_c
        else:
            plot_df = df_r

        models = [gc, stp, LN]
        gc_rel_all, stp_rel_all = _relative_score(plot_df, models, a)

        results = {
            'cellid':
            a[:test_limit],
            'equivalence':
            equivs,
            'effect_size':
            effects,
            'corr_gc_LN':
            gcs,
            'corr_stp_LN':
            stps,
            'partial_corr':
            partials,
            'performance_effect':
            0.5 * (gc_rel_all[:test_limit] + stp_rel_all[:test_limit])
        }
        df = pd.DataFrame.from_dict(results)
        df.set_index('cellid', inplace=True)
        if save_path is not None:
            df.to_pickle(save_path)
    else:
        df = pd.read_pickle(load_path)

    df = df[cellids]
    equivalence = df[equiv_key].values
    effect_size = df[effect_key].values
    r, p = st.pearsonr(effect_size, equivalence)
    improved = c
    not_improved = list(set(a) - set(c))

    fig1, ax = plt.subplots(1, 1)
    ax.axes.axhline(0,
                    color='black',
                    linewidth=1,
                    linestyle='dashed',
                    dashes=dash_spacing)
    ax_remove_box(ax)

    extra_title_lines = []
    if only_improvements:
        equivalence_imp = df[equiv_key][improved].values
        equivalence_not = df[equiv_key][not_improved].values
        effectsize_imp = df[effect_key][improved].values
        effectsize_not = df[effect_key][not_improved].values

        r_imp, p_imp = st.pearsonr(effectsize_imp, equivalence_imp)
        r_not, p_not = st.pearsonr(effectsize_not, equivalence_not)
        n_imp = len(improved)
        n_not = len(not_improved)
        lines = [
            "improved cells,  r:  %.4f,    p:  %.4E" % (r_imp, p_imp),
            "not improved,  r:  %.4f,    p:  %.4E" % (r_not, p_not)
        ]
        extra_title_lines.extend(lines)

        #        plt.scatter(effectsize_not, equivalence_not, s=small_scatter,
        #                    color=model_colors['LN'], label='no imp')
        plt.scatter(effectsize_imp,
                    equivalence_imp,
                    s=big_scatter,
                    color=model_colors['max'],
                    label='sig. imp.')
        if enable_hover:
            mplcursors.cursor(ax, hover=True).connect(
                "add", lambda sel: sel.annotation.set_text(improved[sel.target.
                                                                    index]))
        if legend:
            plt.legend()
    else:
        plt.scatter(effect_size, equivalence, s=big_scatter, color='black')
        if enable_hover:
            mplcursors.cursor(ax, hover=True).connect(
                "add",
                lambda sel: sel.annotation.set_text(a[sel.target.index]))

    if highlight_cells is not None:
        highlights = [df[df.index == h] for h in highlight_cells]
        equiv_highlights = [df[equiv_key].values for df in highlights]
        effect_highlights = [df[effect_key].values for df in highlights]

        for eq, eff in zip(equiv_highlights, effect_highlights):
            plt.scatter(eff,
                        eq,
                        s=big_scatter * 10,
                        facecolors='none',
                        edgecolors='black',
                        linewidths=1)

    #plt.ylabel('Equivalence:  CC(GC-LN, STP-LN)')
    #plt.xlabel('Effect size:  1 - 0.5*(CC(GC,LN) + CC(STP,LN))')
    if equiv_key == 'equivalence':
        y_text = 'equivalence, CC(GC-LN, STP-LN)\n'
    elif equiv_key == 'partial_corr':
        y_text = 'equivalence, partial correlation\n'
    else:
        y_text = 'unknown equivalence key'

    if effect_key == 'effect_size':
        x_text = 'effect size: 1 - 0.5*(CC(GC,LN) + CC(STP,LN))'
    elif effect_key == 'performance_effect':
        x_text = 'effect size: 0.5*(rGC-rLN + rSTP-rLN)'
    else:
        x_text = 'unknown effect key'

    text = ("scatter: Equivalence of Change to Predicted PSTH\n"
            "batch: %d\n"
            "vs Effect Size\n"
            "all cells,  r:  %.4f,    p:  %.4E\n"
            "y: %s"
            "x: %s" % (batch, r, p, y_text, x_text))
    for ln in extra_title_lines:
        text += "\n%s" % ln

    plt.tight_layout()

    #    fig2 = plt.figure()
    #    md = np.nanmedian(equivalence)
    #    n_cells = equivalence.shape[0]
    #    plt.hist(equivalence, bins=30, range=[-0.5, 1], histtype='bar',
    #             color=model_colors['combined'], edgecolor='black', linewidth=1)
    #    plt.plot(np.array([0,0]), np.array(fig2.axes[0].get_ylim()), 'k--',
    #             linewidth=2, dashes=dash_spacing)
    #    plt.text(0.05, 0.95, 'n = %d\nmd = %.2f' % (n_cells, md),
    #             ha='left', va='top', transform=fig2.axes[0].transAxes)
    #plt.xlabel('CC, GC-LN vs STP-LN')
    #plt.title('Equivalence of Change in Prediction Relative to LN Model')

    fig3 = plt.figure()
    plt.text(0.1, 0.75, text, va='top')
    #plt.text(0.1, 0.25, text2, va='top')

    return fig1, fig3