Ejemplo n.º 1
0
def scatter_soundstats(results, legend=False, manual_lims=False):
    # HACK to get separate markers for batch 263 noisy vs clean dsounds
    voc_in_noise = False
    for k in results:
        if '0dB' in k:
            voc_in_noise = True
            break

    if voc_in_noise:
        clean_means = []
        clean_sds = []
        noisy_means = []
        noisy_sds = []
        for k, (mean, sd) in results.items():
            if '0dB' in k:
                noisy_means.append(mean)
                noisy_sds.append(sd)
            else:
                clean_means.append(mean)
                clean_sds.append(sd)
        fig = plt.figure(figsize=small_fig)
        plt.scatter(clean_sds, clean_means, color='black',
                    s=big_scatter, label='clean')
        plt.scatter(noisy_sds, noisy_means, color=model_colors['combined'],
                    s=small_scatter, label='noisy')
        if legend:
            plt.legend()
        if manual_lims:
            plt.xlim(0,32)
            plt.ylim(0,65)

    else:
        means = []
        sds = []
        for k, (mean, sd) in results.items():
            means.append(mean)
            sds.append(sd)

        fig = plt.figure(figsize=small_fig)
        plt.scatter(sds, means, color='black', s=big_scatter)
        if manual_lims:
            plt.xlim(0,32)
            plt.ylim(0,65)

    plt.tight_layout()
    ax_remove_box()

    return fig
Ejemplo n.º 2
0
def scatter_bar(batches, modelnames, stest=SIG_TEST_MODELS, axes=None):
    if axes is None:
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(3.5, 6))
    else:
        ax1, ax2 = axes

    cellids = [
        get_significant_cells(batch, stest, as_list=True) for batch in batches
    ]
    r_values = [
        nd.batch_comp(batch, modelnames, cellids=cells, stat=PLOT_STAT)
        for batch, cells in zip(batches, cellids)
    ]
    all_r_values = pd.concat(r_values)

    # NOTE: if ALL_FAMILY_MODELS changes, the 3, 2 indices will need to be updated
    # Scatter Plot -- LN vs c1dx2+d
    ax1.scatter(all_r_values[modelnames[3]],
                all_r_values[modelnames[2]],
                s=2,
                c='black')
    ax1.plot([[0, 0], [1, 1]], c='black', linestyle='dashed', linewidth=1)
    ax_remove_box(ax1)
    #set_equal_axes(ax1)
    ax1.set_ylim(0, 1)
    ax1.set_xlim(0, 1)
    ax1.set_xlabel('LN_pop prediction accuracy')
    ax1.set_ylabel('conv1Dx2 prediction accuracy')

    # Bar Plot -- Median for each model
    # NOTE: ordering of names is assuming ALL_FAMILY_MODELS is being used and has not changed.
    short_names = ['conv2d', 'conv1d', 'conv1dx2+d', 'LN_pop', 'dnn1_single']
    bar_colors = [DOT_COLORS[k] for k in short_names]
    ax2.bar(np.arange(0, len(modelnames)),
            all_r_values.median(axis=0).values,
            color=bar_colors,
            tick_label=short_names)
    ax_remove_box(ax2)
    ax2.set_ylabel('Median Prediction Accuracy')
    ax2.set_xticklabels(ax2.get_xticklabels(), rotation='45', ha='right')

    fig = plt.gcf()
    fig.tight_layout()

    return ax1, ax2
Ejemplo n.º 3
0
def _stacked_hists(var1, var2, m1, m2, c1, c2, bin_count=30, hist_kwargs={}):
    #c1: LN, c2: max
    fig, (a1, a2) = plt.subplots(2, 1, sharex=True, sharey=True)
    w1 = [np.ones(len(var1)) / len(var1)]
    w2 = [np.ones(len(var2)) / len(var2)]
    upper = max(var1.max(), var2.max())
    lower = min(var1.min(), var2.min())
    bins = np.linspace(lower, upper, bin_count + 1)
    a1.hist(var1,
            weights=w1,
            fc=faded_LN,
            bins=bins,
            edgecolor=dark_LN,
            **hist_kwargs)
    a2.hist(var2,
            weights=w2,
            fc=faded_max,
            bins=bins,
            edgecolor=dark_max,
            **hist_kwargs)
    a1.axes.axvline(m1,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a1.axes.axvline(m2,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a2.axes.axvline(m1,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a2.axes.axvline(m2,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    ax_remove_box(a1)
    ax_remove_box(a2)

    return fig
Ejemplo n.º 4
0
def plot_self_equivalence(batch,
                          stp1,
                          stp2,
                          gc1,
                          gc2,
                          LN1,
                          LN2,
                          stp_load,
                          gc_load,
                          axes=None):
    eqs1 = self_equivalence_data(batch,
                                 stp1,
                                 stp2,
                                 LN1,
                                 LN2,
                                 load_path=stp_load)
    eqs2 = self_equivalence_data(batch, gc1, gc2, LN1, LN2, load_path=gc_load)
    eqs = np.hstack([eqs1, eqs2])
    weights1 = [np.ones(len(eqs)) / len(eqs)]

    if axes is not None:
        a1 = axes
    else:
        _, a1 = plt.subplots(1, 1)

    a1.hist(eqs,
            bins=30,
            range=[-0.5, 1],
            weights=weights1,
            fc='gray',
            edgecolor='black',
            linewidth=1,
            alpha=0.6)
    a1.axes.axvline(np.median(eqs),
                    color='black',
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    if axes is None:
        ax_remove_box(a1)
        plt.tight_layout()

    return a1
Ejemplo n.º 5
0
def compare_sims(start=0, end=None):
    # TODO: set up to compare on synthetic stimuli
    xfspec, ctx = xhelp.load_model_xform(_DEFAULT_CELL, _DEFAULT_BATCH,
                                         _DEFAULT_MODEL)
    val = ctx['val']
    gc_sim = build_toy_gc_cell(0, 0, 0, -0.5) #base, amp, shift, kappa
    gc_sim[-2]['fn_kwargs']['compute_contrast'] = True
    stp_sim = build_toy_stp_cell([0, 0.1], [0.08, 0.08]) #u, tau
    LN_sim = build_toy_LN_cell()

    stim = val['stim'].as_continuous()
    gc_val = gc_sim.evaluate(val)
    gc_sim.recording = gc_val
    gc_psth = gc_val['pred'].as_continuous().flatten()
    stp_val = stp_sim.evaluate(val)
    stp_sim.recording = stp_val
    stp_psth = stp_val['pred'].as_continuous().flatten()
    LN_val = LN_sim.evaluate(val)
    LN_sim.recording = LN_val
    LN_psth = LN_val['pred'].as_continuous().flatten()

    fig = plt.figure(figsize=wide_fig)
    if end is None:
        end = stim.shape[-1]
    plt.imshow(stim, aspect='auto', cmap=spectrogram_cmap,
               origin='lower', extent=(0, stim.shape[-1], 2.1, 3.4))
    lw = 0.75
    plt.plot(LN_psth, color=model_colors['LN'], linewidth=lw)
    plt.plot(gc_psth, color=model_colors['gc'], linewidth=lw*1.25)
    plt.plot(stp_psth, color=model_colors['stp'], alpha=0.75,
             linewidth=lw*1.25)
    plt.ylim(-0.1, 3.4)
    plt.xlim(start, end)
    ax = plt.gca()
    ax_remove_box(ax)

    return fig
Ejemplo n.º 6
0
def equivalence_effect_size(batch,
                            gc,
                            stp,
                            LN,
                            combined,
                            se_filter=True,
                            LN_filter=False,
                            save_path=None,
                            load_path=None,
                            test_limit=None,
                            only_improvements=False,
                            legend=False,
                            effect_key='performance_effect',
                            equiv_key='partial_corr',
                            enable_hover=False,
                            plot_stat='r_ceiling',
                            highlight_cells=None):

    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=se_filter,
                                           LN_filter=LN_filter)
    _, cellids, _, _, _ = improved_cells_to_list(batch,
                                                 gc,
                                                 stp,
                                                 LN,
                                                 combined,
                                                 se_filter=se_filter,
                                                 LN_filter=LN_filter,
                                                 as_lists=False)

    if load_path is None:
        equivs = []
        partials = []
        gcs = []
        stps = []
        effects = []
        for cell in a[:test_limit]:
            xf1, ctx1 = xhelp.load_model_xform(cell, batch, gc)
            xf2, ctx2 = xhelp.load_model_xform(cell, batch, stp)
            xf3, ctx3 = xhelp.load_model_xform(cell, batch, LN)

            gc_pred = ctx1['val'].apply_mask()['pred'].as_continuous()
            stp_pred = ctx2['val'].apply_mask()['pred'].as_continuous()
            ln_pred = ctx3['val'].apply_mask()['pred'].as_continuous()

            ff = np.isfinite(gc_pred) & np.isfinite(stp_pred) & np.isfinite(
                ln_pred)
            gcff = gc_pred[ff]
            stpff = stp_pred[ff]
            lnff = ln_pred[ff]

            C = np.hstack((np.expand_dims(gcff, 0).transpose(),
                           np.expand_dims(stpff, 0).transpose(),
                           np.expand_dims(lnff, 0).transpose()))
            partials.append(partial_corr(C)[0, 1])

            equivs.append(np.corrcoef(gcff - lnff, stpff - lnff)[0, 1])
            this_gc = np.corrcoef(gcff, lnff)[0, 1]
            this_stp = np.corrcoef(stpff, lnff)[0, 1]
            gcs.append(this_gc)
            stps.append(this_stp)
            effects.append(1 - 0.5 * (this_gc + this_stp))

        df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)

        if plot_stat == 'r_ceiling':
            plot_df = df_c
        else:
            plot_df = df_r

        models = [gc, stp, LN]
        gc_rel_all, stp_rel_all = _relative_score(plot_df, models, a)

        results = {
            'cellid':
            a[:test_limit],
            'equivalence':
            equivs,
            'effect_size':
            effects,
            'corr_gc_LN':
            gcs,
            'corr_stp_LN':
            stps,
            'partial_corr':
            partials,
            'performance_effect':
            0.5 * (gc_rel_all[:test_limit] + stp_rel_all[:test_limit])
        }
        df = pd.DataFrame.from_dict(results)
        df.set_index('cellid', inplace=True)
        if save_path is not None:
            df.to_pickle(save_path)
    else:
        df = pd.read_pickle(load_path)

    df = df[cellids]
    equivalence = df[equiv_key].values
    effect_size = df[effect_key].values
    r, p = st.pearsonr(effect_size, equivalence)
    improved = c
    not_improved = list(set(a) - set(c))

    fig1, ax = plt.subplots(1, 1)
    ax.axes.axhline(0,
                    color='black',
                    linewidth=1,
                    linestyle='dashed',
                    dashes=dash_spacing)
    ax_remove_box(ax)

    extra_title_lines = []
    if only_improvements:
        equivalence_imp = df[equiv_key][improved].values
        equivalence_not = df[equiv_key][not_improved].values
        effectsize_imp = df[effect_key][improved].values
        effectsize_not = df[effect_key][not_improved].values

        r_imp, p_imp = st.pearsonr(effectsize_imp, equivalence_imp)
        r_not, p_not = st.pearsonr(effectsize_not, equivalence_not)
        n_imp = len(improved)
        n_not = len(not_improved)
        lines = [
            "improved cells,  r:  %.4f,    p:  %.4E" % (r_imp, p_imp),
            "not improved,  r:  %.4f,    p:  %.4E" % (r_not, p_not)
        ]
        extra_title_lines.extend(lines)

        #        plt.scatter(effectsize_not, equivalence_not, s=small_scatter,
        #                    color=model_colors['LN'], label='no imp')
        plt.scatter(effectsize_imp,
                    equivalence_imp,
                    s=big_scatter,
                    color=model_colors['max'],
                    label='sig. imp.')
        if enable_hover:
            mplcursors.cursor(ax, hover=True).connect(
                "add", lambda sel: sel.annotation.set_text(improved[sel.target.
                                                                    index]))
        if legend:
            plt.legend()
    else:
        plt.scatter(effect_size, equivalence, s=big_scatter, color='black')
        if enable_hover:
            mplcursors.cursor(ax, hover=True).connect(
                "add",
                lambda sel: sel.annotation.set_text(a[sel.target.index]))

    if highlight_cells is not None:
        highlights = [df[df.index == h] for h in highlight_cells]
        equiv_highlights = [df[equiv_key].values for df in highlights]
        effect_highlights = [df[effect_key].values for df in highlights]

        for eq, eff in zip(equiv_highlights, effect_highlights):
            plt.scatter(eff,
                        eq,
                        s=big_scatter * 10,
                        facecolors='none',
                        edgecolors='black',
                        linewidths=1)

    #plt.ylabel('Equivalence:  CC(GC-LN, STP-LN)')
    #plt.xlabel('Effect size:  1 - 0.5*(CC(GC,LN) + CC(STP,LN))')
    if equiv_key == 'equivalence':
        y_text = 'equivalence, CC(GC-LN, STP-LN)\n'
    elif equiv_key == 'partial_corr':
        y_text = 'equivalence, partial correlation\n'
    else:
        y_text = 'unknown equivalence key'

    if effect_key == 'effect_size':
        x_text = 'effect size: 1 - 0.5*(CC(GC,LN) + CC(STP,LN))'
    elif effect_key == 'performance_effect':
        x_text = 'effect size: 0.5*(rGC-rLN + rSTP-rLN)'
    else:
        x_text = 'unknown effect key'

    text = ("scatter: Equivalence of Change to Predicted PSTH\n"
            "batch: %d\n"
            "vs Effect Size\n"
            "all cells,  r:  %.4f,    p:  %.4E\n"
            "y: %s"
            "x: %s" % (batch, r, p, y_text, x_text))
    for ln in extra_title_lines:
        text += "\n%s" % ln

    plt.tight_layout()

    #    fig2 = plt.figure()
    #    md = np.nanmedian(equivalence)
    #    n_cells = equivalence.shape[0]
    #    plt.hist(equivalence, bins=30, range=[-0.5, 1], histtype='bar',
    #             color=model_colors['combined'], edgecolor='black', linewidth=1)
    #    plt.plot(np.array([0,0]), np.array(fig2.axes[0].get_ylim()), 'k--',
    #             linewidth=2, dashes=dash_spacing)
    #    plt.text(0.05, 0.95, 'n = %d\nmd = %.2f' % (n_cells, md),
    #             ha='left', va='top', transform=fig2.axes[0].transAxes)
    #plt.xlabel('CC, GC-LN vs STP-LN')
    #plt.title('Equivalence of Change in Prediction Relative to LN Model')

    fig3 = plt.figure()
    plt.text(0.1, 0.75, text, va='top')
    #plt.text(0.1, 0.25, text2, va='top')

    return fig1, fig3
Ejemplo n.º 7
0
def equivalence_histogram(batch,
                          gc,
                          stp,
                          LN,
                          combined,
                          se_filter=True,
                          LN_filter=False,
                          test_limit=None,
                          alpha=0.05,
                          save_path=None,
                          load_path=None,
                          equiv_key='partial_corr',
                          effect_key='performance_effect',
                          self_equiv=False,
                          self_kwargs={},
                          eq_models=[],
                          cross_kwargs={},
                          cross_models=[],
                          use_median=True,
                          exclude_low_snr=False,
                          snr_path=None,
                          adjust_scores=True,
                          use_log_ratios=False):
    '''
    model1: GC
    model2: STP
    model3: LN

    '''
    e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined)
    _, cellids, _, _, _ = improved_cells_to_list(batch,
                                                 gc,
                                                 stp,
                                                 LN,
                                                 combined,
                                                 as_lists=False)
    improved = c
    not_improved = list(set(a) - set(c))

    if load_path is None:
        df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)

        rs = []
        for c in a[:test_limit]:
            xf1, ctx1 = xhelp.load_model_xform(c, batch, gc)
            xf2, ctx2 = xhelp.load_model_xform(c, batch, stp)
            xf3, ctx3 = xhelp.load_model_xform(c, batch, LN)

            gc_pred = ctx1['val'].apply_mask()['pred'].as_continuous()
            stp_pred = ctx2['val'].apply_mask()['pred'].as_continuous()
            ln_pred = ctx3['val'].apply_mask()['pred'].as_continuous()

            ff = np.isfinite(gc_pred) & np.isfinite(stp_pred) & np.isfinite(
                ln_pred)
            gcff = gc_pred[ff]
            stpff = stp_pred[ff]
            lnff = ln_pred[ff]
            rs.append(np.corrcoef(gcff - lnff, stpff - lnff)[0, 1])

        blank = np.full_like(rs, np.nan)
        results = {
            'cellid': a[:test_limit],
            'equivalence': rs,
            'effect_size': blank,
            'corr_gc_LN': blank,
            'corr_stp_LN': blank
        }
        df = pd.DataFrame.from_dict(results)
        df.set_index('cellid', inplace=True)
        if save_path is not None:
            df.to_pickle(save_path)
    else:
        df = pd.read_pickle(load_path)

    df = df[cellids]
    rs = df[equiv_key].values
    if exclude_low_snr:
        snr_df = pd.read_pickle(snr_path)
        med_snr = snr_df['snr'].median()
        high_snr = snr_df.loc[snr_df['snr'] >= med_snr]
        high_snr_cells = high_snr.index.values.tolist()
        improved = list(set(improved) & set(high_snr_cells))
        not_improved = list(set(not_improved) & set(high_snr_cells))

    imp = np.array(improved)
    not_imp = np.array(not_improved)
    imp_mask = np.isin(a, imp)
    not_mask = np.isin(a, not_imp)
    rs_not = rs[not_mask]
    rs_imp = rs[imp_mask]
    md_not = np.nanmedian(rs_not)
    md_imp = np.nanmedian(rs_imp)
    u, p = st.mannwhitneyu(rs_not, rs_imp, alternative='two-sided')
    n_not = len(not_improved)
    n_imp = len(improved)

    if self_equiv:
        stp1, stp2, gc1, gc2, LN1, LN2 = eq_models
        g1, s2, s1, g2, L1, L2 = cross_models
        _, ga, _, _, _ = improved_cells_to_list(batch,
                                                gc1,
                                                gc2,
                                                LN1,
                                                LN2,
                                                as_lists=True)
        _, sa, _, _, _ = improved_cells_to_list(batch,
                                                stp1,
                                                stp2,
                                                LN1,
                                                LN2,
                                                as_lists=True)
        aa = list(set(ga) & set(sa))
        if exclude_low_snr:
            snr_df = pd.read_pickle(snr_path)
            med_snr = snr_df['snr'].median()
            high_snr = snr_df.loc[snr_df['snr'] >= med_snr]
            high_snr_cells = high_snr.index.values.tolist()
            aa = list(set(aa) & set(high_snr_cells))

        eqs_stp, eqs_gc = _get_self_equivs(**self_kwargs, cellids=aa)
        md_stpeq = np.nanmedian(eqs_stp)
        md_gceq = np.nanmedian(eqs_gc)
        n_eq = eqs_stp.size
        u_gceq, p_gceq = st.mannwhitneyu(eqs_gc,
                                         rs_imp,
                                         alternative='two-sided')
        u_stpeq, p_stpeq = st.mannwhitneyu(eqs_stp,
                                           rs_imp,
                                           alternative='two-sided')

        eqs_x1, eqs_x2 = _get_self_equivs(**cross_kwargs, cellids=aa)
        md_x1 = np.nanmedian(eqs_x1)
        md_x2 = np.nanmedian(eqs_x2)

        sub1 = df.index.isin(ga) & df.index.isin(aa)
        sub2 = df.index.isin(sa) & df.index.isin(aa)
        eqs_sub1 = df[equiv_key][sub1].values
        eqs_sub2 = df[equiv_key][sub2].values
        md_sub1 = np.nanmedian(eqs_sub1)
        md_sub2 = np.nanmedian(eqs_sub2)

        if adjust_scores:
            if use_median:
                md_avg1 = 0.5 * (md_x1 + md_x2)
                md_avg2 = 0.5 * (md_sub1 + md_sub2)
                ratio = md_avg2 / md_avg1
                md_stpeq *= ratio
                md_gceq *= ratio

            else:
                eqs_avg1 = 0.5 * (eqs_x1 + eqs_x2)  # between-model, halved est
                eqs_avg2 = 0.5 * (eqs_sub1 + eqs_sub2
                                  )  # between-model, full data
                ratios = np.abs(eqs_avg2 / eqs_avg1)
                if use_log_ratios:
                    log_ratios = np.abs(np.log(ratios))
                    log_ratios /= log_ratios.max()
                    stp_scaled = eqs_stp + (1 - eqs_stp) * log_ratios
                    gc_scaled = eqs_gc + (1 - eqs_gc) * log_ratios
                else:
                    stp_scaled = eqs_stp * ratios
                    gc_scaled = eqs_gc * ratios
                md_stpeq = np.nanmedian(stp_scaled)
                md_gceq = np.nanmedian(gc_scaled)

    not_color = model_colors['LN']
    imp_color = model_colors['max']
    weights1 = [np.ones(len(rs_not)) / len(rs_not)]
    weights2 = [np.ones(len(rs_imp)) / len(rs_imp)]

    #n_cells = rs.shape[0]
    fig1, (a1, a2) = plt.subplots(2, 1)

    a1.hist(rs_not,
            bins=30,
            range=[-0.5, 1],
            weights=weights1,
            fc=faded_LN,
            edgecolor=dark_LN,
            linewidth=1)
    a2.hist(rs_imp,
            bins=30,
            range=[-0.5, 1],
            weights=weights2,
            fc=faded_max,
            edgecolor=dark_max,
            linewidth=1)

    a1.axes.axvline(md_not,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a1.axes.axvline(md_imp,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a2.axes.axvline(md_not,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a2.axes.axvline(md_imp,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)

    if self_equiv:
        a1.axes.annotate('',
                         xy=(md_gceq, 0),
                         xycoords='data',
                         xytext=(md_gceq, 0.07),
                         textcoords='data',
                         arrowprops=dict(arrowstyle="->",
                                         connectionstyle="arc3"))
        a1.axes.annotate('',
                         xy=(md_stpeq, 0),
                         xycoords='data',
                         xytext=(md_stpeq, 0.07),
                         textcoords='data',
                         arrowprops=dict(arrowstyle="->",
                                         connectionstyle="arc3"))
        a2.axes.annotate('',
                         xy=(md_gceq, 0),
                         xycoords='data',
                         xytext=(md_gceq, 0.07),
                         textcoords='data',
                         arrowprops=dict(arrowstyle="->",
                                         connectionstyle="arc3"))
        a2.axes.annotate('',
                         xy=(md_stpeq, 0),
                         xycoords='data',
                         xytext=(md_stpeq, 0.07),
                         textcoords='data',
                         arrowprops=dict(arrowstyle="->",
                                         connectionstyle="arc3"))

    ymin1, ymax1 = a1.get_ylim()
    ymin2, ymax2 = a2.get_ylim()
    ymax = max(ymax1, ymax2)
    a1.set_ylim(0, ymax)
    a2.set_ylim(0, ymax)

    ax_remove_box(a1)
    ax_remove_box(a2)
    plt.tight_layout()

    if equiv_key == 'equivalence':
        x_text = 'equivalence, CC(GC-LN, STP-LN)\n'
    elif equiv_key == 'partial_corr':
        x_text = 'equivalence, partial correlation\n'
    else:
        x_text = 'unknown equivalence key'
    fig3 = plt.figure(figsize=text_fig)
    text2 = ("hist: equivalence of changein prediction relative to LN model\n"
             "batch: %d\n"
             "x: %s"
             "y: cell fraction\n"
             "n not imp:  %d,  md:  %.2f\n"
             "n sig. imp:  %d,  md:  %.2f\n"
             "st.mannwhitneyu:  u:  %.4E p:  %.4E" %
             (batch, x_text, n_not, md_not, n_imp, md_imp, u, p))
    if self_equiv:
        text2 += ("\n\nSelf equivalence, n: %d\n"
                  "stp:  md:  %.2f,  u:  %.4E   p  %.4E\n"
                  "gc:   md:  %.2f,  u:  %.4E   p  %.4E\n"
                  "md sub1: %.2E\n"
                  "md sub2: %.2E\n"
                  "md x1: %.2E\n"
                  "md x2: %.2E" %
                  (n_eq, md_stpeq, u_stpeq, p_stpeq, md_gceq, u_gceq, p_gceq,
                   md_sub1, md_sub2, md_x1, md_x2))
    plt.text(0.1, 0.5, text2)

    return fig1, fig3
Ejemplo n.º 8
0
def equivalence_scatter(batch,
                        gc,
                        stp,
                        LN,
                        combined,
                        se_filter=True,
                        LN_filter=False,
                        plot_stat='r_ceiling',
                        enable_hover=False,
                        manual_lims=None,
                        drop_outliers=False,
                        color_improvements=True,
                        xmodel='GC',
                        ymodel='STP',
                        legend=False,
                        self_equiv=False,
                        self_eq_models=[],
                        show_highlights=False,
                        exclude_low_snr=False,
                        snr_path=None):
    '''
    model1: GC
    model2: STP
    model3: LN

    '''

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=se_filter,
                                           LN_filter=LN_filter)
    improved = c
    not_improved = list(set(a) - set(c))
    if exclude_low_snr:
        snr_df = pd.read_pickle(snr_path)
        med_snr = snr_df['snr'].median()
        high_snr = snr_df.loc[snr_df['snr'] >= med_snr]
        high_snr_cells = high_snr.index.values.tolist()
        improved = list(set(improved) & set(high_snr_cells))
        not_improved = list(set(not_improved) & set(high_snr_cells))
        a = list(set(a) & set(high_snr_cells))
        # check number of animals included
        prefixes = [s[:3] for s in a]
        n_animals = len(list(set(prefixes)))
        print('n_animals: %d' % n_animals)

    models = [gc, stp, LN]
    gc_rel_imp, stp_rel_imp = _relative_score(plot_df, models, improved)
    gc_rel_not, stp_rel_not = _relative_score(plot_df, models, not_improved)
    gc_rel_all, stp_rel_all = _relative_score(plot_df, models, a)

    # LN, STP, GC
    cells_to_highlight = ['TAR010c-40-1', 'AMT005c-20-1', 'TAR009d-22-1']
    cells_to_plot_gc_rel = []
    cells_to_plot_stp_rel = []
    for c in cells_to_highlight:
        stp_rel = plot_df[stp][c] - plot_df[LN][c]
        gc_rel = plot_df[gc][c] - plot_df[LN][c]
        cells_to_plot_gc_rel.append(gc_rel)
        cells_to_plot_stp_rel.append(stp_rel)

    # compute corr. before dropping outliers (only dropping for visualization)
    r_imp, p_imp = st.pearsonr(gc_rel_imp, stp_rel_imp)
    r_not, p_not = st.pearsonr(gc_rel_not, stp_rel_not)
    r_all, p_all = st.pearsonr(gc_rel_all, stp_rel_all)

    if self_equiv:
        stp1, stp2, gc1, gc2, LN1, LN2 = self_eq_models
        _, ga, _, _, _ = improved_cells_to_list(batch,
                                                gc1,
                                                gc2,
                                                LN1,
                                                LN2,
                                                as_lists=True)
        _, sa, _, _, _ = improved_cells_to_list(batch,
                                                stp1,
                                                stp2,
                                                LN1,
                                                LN2,
                                                as_lists=True)
        aa = list(set(ga) & set(sa))

        if exclude_low_snr:
            snr_df = pd.read_pickle(snr_path)
            med_snr = snr_df['snr'].median()
            high_snr = snr_df.loc[snr_df['snr'] >= med_snr]
            high_snr_cells = high_snr.index.values.tolist()
            aa = list(set(aa) & set(high_snr_cells))

        df_r_eq = nd.batch_comp(batch, [gc1, gc2, stp1, stp2, LN1, LN2],
                                stat=plot_stat)
        df_r_eq.dropna(axis=0, how='any', inplace=True)
        df_r_eq.sort_index(inplace=True)
        df_r_eq = df_r_eq[df_r_eq.index.isin(aa)]

        gc1_rel_imp = df_r_eq[gc1].values - df_r_eq[LN1].values
        gc2_rel_imp = df_r_eq[gc2].values - df_r_eq[LN2].values
        stp1_rel_imp = df_r_eq[stp1].values - df_r_eq[LN1].values
        stp2_rel_imp = df_r_eq[stp2].values - df_r_eq[LN2].values
        r_gceq, p_gceq = st.pearsonr(gc1_rel_imp, gc2_rel_imp)
        r_stpeq, p_stpeq = st.pearsonr(stp1_rel_imp, stp2_rel_imp)
        n_eq = gc1_rel_imp.size

        # compute on same subset for full estimation data
        # to compare to cross-set
        gc_subset1 = df_r[gc][ga].values
        gc_subset2 = df_r[gc][sa].values
        stp_subset1 = df_r[stp][ga].values
        stp_subset2 = df_r[stp][sa].values
        LN_subset1 = df_r[LN][ga].values
        LN_subset2 = df_r[LN][sa].values
        gc_rel1 = gc_subset1 - LN_subset1
        gc_rel2 = gc_subset2 - LN_subset2
        stp_rel1 = stp_subset1 - LN_subset1
        stp_rel2 = stp_subset2 - LN_subset2
        r_sub1, p_sub1 = st.pearsonr(gc_rel1, stp_rel1)
        r_sub2, p_sub2 = st.pearsonr(gc_rel2, stp_rel2)

    gc_rel_imp, stp_rel_imp = drop_common_outliers(gc_rel_imp, stp_rel_imp)
    gc_rel_not, stp_rel_not = drop_common_outliers(gc_rel_not, stp_rel_not)
    gc_rel_all, stp_rel_all = drop_common_outliers(gc_rel_all, stp_rel_all)

    n_imp = len(improved)
    n_not = len(not_improved)
    n_all = len(a)

    y_max = np.max(stp_rel_all)
    y_min = np.min(stp_rel_all)
    x_max = np.max(gc_rel_all)
    x_min = np.min(gc_rel_all)

    fig = plt.figure()
    ax = plt.gca()
    ax.axes.axhline(0,
                    color='black',
                    linewidth=1,
                    linestyle='dashed',
                    dashes=dash_spacing)
    ax.axes.axvline(0,
                    color='black',
                    linewidth=1,
                    linestyle='dashed',
                    dashes=dash_spacing)
    if color_improvements:
        ax.scatter(gc_rel_not,
                   stp_rel_not,
                   c=model_colors['LN'],
                   s=small_scatter)
        ax.scatter(gc_rel_imp,
                   stp_rel_imp,
                   c=model_colors['max'],
                   s=big_scatter)
        if show_highlights:
            for i, (g, s) in enumerate(
                    zip(cells_to_plot_gc_rel, cells_to_plot_stp_rel)):
                plt.text(g, s, str(i + 1), fontsize=12, color='black')
    else:
        ax.scatter(gc_rel_all, stp_rel_all, c='black', s=big_scatter)

    if legend:
        plt.legend()

    if manual_lims is not None:
        ax.set_ylim(*manual_lims)
        ax.set_xlim(*manual_lims)
    else:
        upper = max(y_max, x_max)
        lower = min(y_min, x_min)
        upper_lim = np.ceil(10 * upper) / 10
        lower_lim = np.floor(10 * lower) / 10
        ax.set_ylim(lower_lim, upper_lim)
        ax.set_xlim(lower_lim, upper_lim)
    plt.axes().set_aspect('equal')

    if enable_hover:
        mplcursors.cursor(ax, hover=True).connect(
            "add", lambda sel: sel.annotation.set_text(a[sel.target.index]))

    plt.tight_layout()

    fig2 = plt.figure(figsize=text_fig)
    text = ("Performance Improvements over LN\n"
            "batch: %d\n"
            "dropped outliers?:  %s\n"
            "all cells:  r: %.2E, p: %.2E, n: %d\n"
            "improved:  r: %.2E, p: %.2E, n: %d\n"
            "not imp:  r: %.2E, p: %.2E, n: %d\n"
            "x: %s - LN\n"
            "y: %s - LN" % (batch, drop_outliers, r_all, p_all, n_all, r_imp,
                            p_imp, n_imp, r_not, p_not, n_not, xmodel, ymodel))
    if self_equiv:
        text += ("\n\nSelf Equivalence, n:  %d\n"
                 "gc,  r: %.2Ef, p: %.2E\n"
                 "stp, r: %.2Ef, p: %.2E\n"
                 "sub1, r: %.2E, p: %.2E\n"
                 "sub2, r: %.2E, p: %.2E\n" %
                 (n_eq, r_gceq, p_gceq, r_stpeq, p_stpeq, r_sub1, p_sub1,
                  r_sub2, p_sub2))
    plt.text(0.1, 0.5, text)
    ax_remove_box(ax)

    return fig, fig2
Ejemplo n.º 9
0
def compare_sim_fits(batch, gc, stp, LN, combined, simulation_spec=None,
                     start=0, end=None, load_path=None, skip_combined=True,
                     save_path=None, tag='', ext_start=1.1):
    if load_path is None:
        if simulation_spec is None:
            raise ValueError("simulation_spec required unless loading previous"
                              " result")
        stp_ctx = fit_to_simulation(stp, simulation_spec)
        gc_ctx = fit_to_simulation(gc, simulation_spec)
        LN_ctx = fit_to_simulation(LN, simulation_spec)
        combined_ctx = fit_to_simulation(combined, simulation_spec)

        if save_path is not None:
            results = {'simulation': simulation_spec,
                       'contexts': [stp_ctx, gc_ctx, LN_ctx, combined_ctx]}
            pickle.dump(results, open(save_path, 'wb'))
    else:
        results = pickle.load(open(load_path, 'rb'))
        simulation_spec = results['simulation']
        stp_ctx, gc_ctx, LN_ctx, combined_ctx = results['contexts']

    simulation = stp_ctx['val']['resp'].as_continuous().flatten()
    stp_pred = stp_ctx['val']['pred'].as_continuous().flatten()
    gc_pred = gc_ctx['val']['pred'].as_continuous().flatten()
    LN_pred = LN_ctx['val']['pred'].as_continuous().flatten()
    combined_pred = combined_ctx['val']['pred'].as_continuous().flatten()

    stim = stp_ctx['val']['stim'].as_continuous()
    if end is None:
        end = stim.shape[-1]

    fig1 = plt.figure(figsize=wide_fig)
    if end is None:
        end = stim.shape[-1]
    ext_stop = 1.25*(ext_start+0.1)
    plt.imshow(stim, aspect='auto', cmap=spectrogram_cmap,
               origin='lower', extent=(0, stim.shape[-1], ext_start, ext_stop))
    lw = 0.75
    plt.plot(simulation, color='gray', alpha=0.65, linewidth=lw*2)
    t = np.linspace(0, simulation.shape[-1]-1, simulation.shape[-1])
    plt.fill_between(t, simulation, color='gray', alpha=0.15)
    plt.plot(LN_pred, color='black', alpha=0.55, linewidth=lw)
    plt.plot(gc_pred, color=model_colors['gc'], linewidth=lw*1.25)
    plt.plot(stp_pred, color=model_colors['stp'], linewidth=lw*1.25)
    if not skip_combined:
        plt.plot(combined_pred, color=model_colors['combined'], \
                 linewidth=lw*1.25)

    plt.ylim(-0.1, ext_stop)
    plt.xlim(start, end)
    ax = plt.gca()
    ax_remove_box(ax)

    fig2 = plt.figure(figsize=text_fig)
    text = ("simulation_spec: %s\n"
            "cellid: %s\n"
            "tag: %s\n"
            "stp_r_test: %.4f\n"
            "gc_r_test: %.4f\n"
            "LN_r_test: %.4f\n"
            "combined_r_test: %.4f"
            % (simulation_spec.meta['modelname'],
               simulation_spec.meta['cellid'],
               tag,
               stp_ctx['modelspec'].meta['r_test'],
               gc_ctx['modelspec'].meta['r_test'],
               LN_ctx['modelspec'].meta['r_test'],
               combined_ctx['modelspec'].meta['r_test']
               ))
    plt.text(0.1, 0.5, text)

    return fig1, fig2
Ejemplo n.º 10
0
dotcolor_ns = 'lightgray'
thinlinecolor = 'gray'
barcolors = [(211/255, 211/255, 211/255), (102/255, 1/255, 104/255)]
barwidth = 0.5

ax = plt.subplot(2, 3, 1)
plt.plot(amp_bounds, amp_bounds, 'k--')
plt.plot(amp_mtx[~show_units, 0], amp_mtx[~show_units, 1], '.',
         color=dotcolor_ns)
plt.plot(amp_mtx[show_units, 0], amp_mtx[show_units, 1], '.', color=dotcolor)
plt.title('bat {} n={}/{} good units'.format(
        batch, np.sum(show_units), u_mtx.shape[0]))
plt.xlabel(xstr+' gain')
plt.ylabel(ystr+' gain')
plt.axis('equal')
ax_remove_box(ax)

ax = plt.subplot(2, 3, 2)
plt.plot(np.array([-0.5, 1.5]), np.array([0, 0]), 'k--')
plt.bar(np.arange(2), ampmean, color=barcolors, width=barwidth)
plt.errorbar(np.arange(2), ampmean, yerr=amperr, color='black', linewidth=2)
plt.plot(amp_mtx[show_units].T, linewidth=0.5, color=thinlinecolor)

w, p = ss.wilcoxon(amp_mtx_norm[show_units, 0], amp_mtx_norm[show_units, 1])
plt.ylim(amp_bounds)
plt.ylabel('STRF gain')
plt.xlabel('{} {:.3f} - {} {:.3f} - rat {:.3f} - p<{:.5f}'.format(
        xstr, ampmean[0], ystr, ampmean[1], ampmean[1]/ampmean[0], p))
ax_remove_box(ax)

ax = plt.subplot(2, 3, 3)
Ejemplo n.º 11
0
def combined_vs_max(batch,
                    gc,
                    stp,
                    LN,
                    combined,
                    se_filter=True,
                    LN_filter=False,
                    plot_stat='r_ceiling',
                    legend=False,
                    improved_only=True,
                    exclude_low_snr=False,
                    snr_path=None):

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    #    cellids, under_chance, less_LN = get_filtered_cellids(df_r, df_e, gc, stp,
    #                                                          LN, combined,
    #                                                          se_filter,
    #                                                          LN_filter)
    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=se_filter,
                                           LN_filter=LN_filter)
    improved = c
    not_improved = list(set(a) - set(c))
    if exclude_low_snr:
        snr_df = pd.read_pickle(snr_path)
        med_snr = snr_df['snr'].median()
        high_snr = snr_df.loc[snr_df['snr'] >= med_snr]
        high_snr_cells = high_snr.index.values.tolist()
        improved = list(set(improved) & set(high_snr_cells))
        not_improved = list(set(not_improved) & set(high_snr_cells))

    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    gc_not = plot_df[gc][not_improved]
    gc_imp = plot_df[gc][improved]
    #gc_test_under_chance = plot_df[gc][under_chance]
    stp_not = plot_df[stp][not_improved]
    stp_imp = plot_df[stp][improved]
    #stp_test_under_chance = plot_df[stp][under_chance]
    #ln_test = plot_df[LN][cellids]
    gc_stp_not = plot_df[combined][not_improved]
    gc_stp_imp = plot_df[combined][improved]
    max_not = np.maximum(gc_not, stp_not)
    max_imp = np.maximum(gc_imp, stp_imp)
    #gc_stp_test_rel = gc_stp_test - ln_test
    #max_test_rel = np.maximum(gc_test, stp_test) - ln_test
    imp_T, imp_p = st.wilcoxon(gc_stp_imp, max_imp)
    med_combined = np.nanmedian(gc_stp_imp)
    med_max = np.nanmedian(max_imp)

    import pdb
    pdb.set_trace()

    fig1 = plt.figure()
    c_not = model_colors['LN']
    c_imp = model_colors['max']
    if not improved_only:
        plt.scatter(max_not,
                    gc_stp_not,
                    c=c_not,
                    s=small_scatter,
                    label='no imp.')
    plt.scatter(max_imp, gc_stp_imp, c=c_imp, s=big_scatter, label='sig. imp.')
    ax = fig1.axes[0]
    plt.plot(ax.get_xlim(),
             ax.get_xlim(),
             'k--',
             linewidth=1,
             dashes=dash_spacing)
    if legend:
        plt.legend()
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.axes().set_aspect('equal')
    ax_remove_box()

    fig2 = plt.figure(figsize=text_fig)
    text = ("batch: %d\n"
            "x: Max(GC,STP)\n"
            "y: GC+STP\n"
            "wilcoxon: T: %.4E, p: %.4E\n"
            "combined median: %.4E\n"
            "max median: %.4E" % (batch, imp_T, imp_p, med_combined, med_max))
    plt.text(0.1, 0.5, text)

    return fig1, fig2
Ejemplo n.º 12
0
Archivo: snr.py Proyecto: LBHB/nems_db
def snr_vs_equivalence(snr_path, stp_path, gc_path):
    stp_equiv_df = pd.read_pickle(stp_path)
    gc_equiv_df = pd.read_pickle(gc_path)
    snr_df = pd.read_pickle(snr_path)
    cellids = list(
        set(stp_equiv_df.index.values.tolist())
        & set(gc_equiv_df.index.values.tolist()))

    snr_df = snr_df.loc[cellids].reindex(cellids)
    stp_equiv_df = stp_equiv_df.loc[cellids].reindex(cellids)
    gc_equiv_df = gc_equiv_df.loc[cellids].reindex(cellids)

    stp_equivs = stp_equiv_df['equivalence'].values
    gc_equivs = gc_equiv_df['equivalence'].values
    snrs = snr_df['snr'].values

    md_snr = np.nanmedian(snrs)
    low_snr_mask = snrs < md_snr
    high_snr_mask = snrs >= md_snr
    stp_low_equivs = stp_equivs[low_snr_mask]
    stp_high_equivs = stp_equivs[high_snr_mask]
    gc_low_equivs = gc_equivs[low_snr_mask]
    gc_high_equivs = gc_equivs[high_snr_mask]

    md_stp_low = np.median(stp_low_equivs)
    md_stp_high = np.median(stp_high_equivs)
    md_gc_low = np.median(gc_low_equivs)
    md_gc_high = np.median(gc_high_equivs)

    u_stp, p_stp = st.mannwhitneyu(stp_low_equivs,
                                   stp_high_equivs,
                                   alternative='two-sided')
    u_gc, p_gc = st.mannwhitneyu(gc_low_equivs,
                                 gc_high_equivs,
                                 alternative='two-sided')

    #r_stp, p_stp = st.pearsonr(stp_equivs, snrs)
    #r_gc, p_gc = st.pearsonr(gc_equivs, snrs)

    fig1 = plt.figure(figsize=small_fig)
    ax1 = plt.gca()
    plt.scatter(snrs, stp_equivs, c=model_colors['stp'], s=big_scatter)
    ax1.axes.axvline(md_snr,
                     color='black',
                     linewidth=1,
                     linestyle='dashed',
                     dashes=dash_spacing)
    plt.tight_layout()
    plt.subplots_adjust(left=0.25)
    ax_remove_box(ax1)

    fig2 = plt.figure(figsize=small_fig)
    ax2 = plt.gca()
    plt.scatter(snrs, gc_equivs, c=model_colors['gc'], s=big_scatter)
    ax2.axes.axvline(md_snr,
                     color='black',
                     linewidth=1,
                     linestyle='dashed',
                     dashes=dash_spacing)
    plt.tight_layout()
    plt.subplots_adjust(left=0.25)
    ax_remove_box(ax2)

    ymin1, ymax1 = ax1.get_ylim()
    ymin2, ymax2 = ax2.get_ylim()
    ax1.set_ylim(min(ymin1, ymin2), max(ymax1, ymax2))
    ax2.set_ylim(min(ymin1, ymin2), max(ymax1, ymax2))

    fig3 = plt.figure(figsize=text_fig)
    text = ("SNR vs equivalence\n"
            "x axis: signal power / total power\n"
            "y axis: equivalence (partial corr)\n"
            "mannwhitneyu two sided low vs high snr\n"
            "n_high: %d\n"
            "n_low: %d\n"
            "u_stp: %.4E\n"
            "p_stp: %.4E\n"
            "md_stp_low: %.4E\n"
            "md_stp_high: %.4E\n"
            "u_gc: %.4E\n"
            "p_gc: %.4E\n"
            "md_gc_low: %.4E\n"
            "md_gc_high: %.4E" %
            (stp_high_equivs.size, stp_low_equivs.size, u_stp, p_stp,
             md_stp_low, md_stp_high, u_gc, p_gc, md_gc_low, md_gc_high))

    plt.text(0.1, 0.5, text)

    return fig1, fig2, fig3
Ejemplo n.º 13
0
def rate_histogram(batch,
                   gc,
                   stp,
                   LN,
                   combined,
                   load_path,
                   rate_type='mean',
                   plot_stat='r_ceiling',
                   fs=100,
                   allow_overlap=True):

    e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined)
    if allow_overlap:
        gc_imp = g
        stp_imp = s
        not_imp = list(set(a) - set(c))
        both_imp = list((set(g) & set(s)) | set(c))
    else:
        gc_imp = list(set(g) - set(s))
        stp_imp = list(set(s) - set(g))
        not_imp = list(set(a) - set(c) - set(g) - set(s))
        # either both improve or only combined improve
        both_imp = list((set(g) & set(s)) | ((set(c) - set(s) - set(g))))

    df = pd.read_pickle(load_path)
    # * fs to get spikes per second
    #rates_LN = df.loc[not_imp]['rate'].values * fs
    rates_LN = df[df.index.isin(not_imp)]['rate'].values * fs
    #rates_gc = df.loc[gc_imp]['rate'].values * fs
    rates_gc = df[df.index.isin(gc_imp)]['rate'].values * fs
    #rates_stp = df.loc[stp_imp]['rate'].values * fs
    rates_stp = df[df.index.isin(stp_imp)]['rate'].values * fs
    rates_both = df[df.index.isin(both_imp)]['rate'].values * fs

    xmax = max(rates_gc.max(), rates_stp.max(), rates_both.max()) * 1.10

    md_LN = np.nanmedian(rates_LN)
    n_LN = rates_LN.size
    md_gc = np.nanmedian(rates_gc)
    n_gc = rates_gc.size
    md_stp = np.nanmedian(rates_stp)
    n_stp = rates_stp.size
    md_both = np.nanmedian(rates_both)
    n_both = rates_both.size
    # vs each other
    u, p = st.mannwhitneyu(rates_gc, rates_stp, alternative='two-sided')
    u_both_gc, p_both_gc = st.mannwhitneyu(rates_both,
                                           rates_gc,
                                           alternative='two-sided')
    u_both_stp, p_both_stp = st.mannwhitneyu(rates_both,
                                             rates_stp,
                                             alternative='two-sided')
    # vs LN
    u_gc, p_gc = st.mannwhitneyu(rates_gc, rates_LN, alternative='two-sided')
    u_stp, p_stp = st.mannwhitneyu(rates_stp,
                                   rates_LN,
                                   alternative='two-sided')
    u_both, p_both = st.mannwhitneyu(rates_both,
                                     rates_LN,
                                     alternative='two-sided')

    # TODO: stat comparison for both group

    weights1 = [np.ones(len(rates_stp)) / len(rates_stp)]
    weights2 = [np.ones(len(rates_gc)) / len(rates_gc)]
    weights3 = [np.ones(len(rates_LN)) / len(rates_LN)]
    weights4 = [np.ones(len(rates_both)) / len(rates_both)]

    fig, (a3, a1, a2, a4) = plt.subplots(4, 1, figsize=tall_fig)
    axes = [a3, a1, a2, a4]
    a1.hist(rates_stp,
            bins=30,
            range=[0, xmax],
            weights=weights1,
            fc=faded_stp,
            edgecolor=dark_stp,
            linewidth=1)
    a2.hist(rates_gc,
            bins=30,
            range=[0, xmax],
            weights=weights2,
            fc=faded_gc,
            edgecolor=dark_gc,
            linewidth=1)
    a3.hist(rates_LN,
            bins=30,
            range=[0, xmax],
            weights=weights3,
            fc=model_colors['LN'],
            alpha=0.5,
            edgecolor=dark_LN,
            linewidth=1)
    a4.hist(rates_both,
            bins=30,
            range=[0, xmax],
            weights=weights4,
            fc=model_colors['combined'],
            alpha=0.5,
            edgecolor=dark_combined,
            linewidth=1)

    for ax in axes:
        ax.axes.axvline(md_stp,
                        color=dark_stp,
                        linewidth=2,
                        linestyle='dashed',
                        dashes=dash_spacing)
        ax.axes.axvline(md_gc,
                        color=dark_gc,
                        linewidth=2,
                        linestyle='dashed',
                        dashes=dash_spacing)
        ax.axes.axvline(md_LN,
                        color=dark_LN,
                        linewidth=2,
                        linestyle='dashed',
                        dashes=dash_spacing)
        ax.axes.axvline(md_both,
                        color=dark_combined,
                        linewidth=2,
                        linestyle='dashed',
                        dashes=dash_spacing)

    ymaxes = [ax.get_ylim()[1] for ax in axes]
    ymins = [ax.get_ylim()[0] for ax in axes]
    ymax = max(ymaxes)
    ymin = min(ymins)
    for ax in axes:
        ax.set_ylim(ymin, ymax)
    fig.tight_layout()
    ax_remove_box(a1)
    ax_remove_box(a2)
    ax_remove_box(a3)
    ax_remove_box(a4)

    fig2 = plt.figure(figsize=text_fig)
    text = ("%s rate blocked by model improvement\n"
            "gc, md:  %.2E, n:  %d\n"
            "stp, md: %.2E, n:  %d\n"
            "LN, md:  %.2E, n:  %d\n"
            "both, md:%.2E, n:  %d\n"
            "m.w. stp v gc: u:  %.4E, p:  %.4E\n"
            "gc vs LN:      u:  %.4E, p:  %.4E\n"
            "stp vs LN:     u:  %.4E, p:  %.4E\n"
            "both vs gc:    u:  %.4E, p:  %.4E\n"
            "both vs stp:   u:  %.4E, p:  %.4E\n"
            "both vs LN:    u:  %.4E, p:  %.4E\n" %
            (rate_type, md_gc, n_gc, md_stp, n_stp, md_LN, n_LN, md_both,
             n_both, u, p, u_gc, p_gc, u_stp, p_stp, u_both_gc, p_both_gc,
             u_both_stp, p_both_stp, u_both, p_both))
    plt.text(0.1, 0.5, text)

    return fig, fig2
Ejemplo n.º 14
0
def stp_parameter_comp(batch, modelname, modelname0=None):

    d = nems_db.params.fitted_params_per_batch(batch,
                                               modelname,
                                               stats_keys=[],
                                               multi='first')

    u_bounds = np.array([-0.6, 2.1])
    tau_bounds = np.array([-0.1, 1.5])
    str_bounds = np.array([-0.25, 0.55])
    amp_bounds = np.array([-1, 1.5])

    indices = list(d.index)

    fir_index = None
    do_index = None
    for ind in indices:
        if '--u' in ind:
            u_index = ind
        elif '--tau' in ind:
            tau_index = ind
        elif '--fir' in ind:
            fir_index = ind
        elif ('--do' in ind) and ('gains' in ind):
            do_index = ind

    u = d.loc[u_index]
    tau = d.loc[tau_index]

    if fir_index:
        fir = d.loc[fir_index]
    elif do_index:
        fir = d.loc[do_index]
        delay_index = do_index.replace('gains', 'delays')
        f1s_index = do_index.replace('gains', 'f1s')
        taus_index = do_index.replace('gains', 'taus')
        for cellid in fir.index:
            print(cellid)
            c = da_coefficients(f1s=d.loc[f1s_index, cellid],
                                taus=d.loc[taus_index, cellid],
                                delays=d.loc[delay_index, cellid],
                                gains=d.loc[do_index, cellid],
                                n_coefs=10)
            fir[cellid] = c
    else:
        raise ValueError('FIR/DO index not found')
    r_test = d.loc['meta--r_test']
    se_test = d.loc['meta--se_test']
    print(u)
    if modelname0 is not None:
        d0 = nems_db.params.fitted_params_per_batch(batch,
                                                    modelname0,
                                                    stats_keys=[],
                                                    multi='first')
        r0_test = d0.loc['meta--r_test']
        se0_test = d0.loc['meta--se_test']

    u_mtx = np.zeros((len(u), 2))
    tau_mtx = np.zeros_like(u_mtx)
    m_fir = np.zeros_like(u_mtx)
    r_test_mtx = np.zeros(len(u))
    r0_test_mtx = np.zeros(len(u))
    se_test_mtx = np.zeros(len(u))
    se0_test_mtx = np.zeros(len(u))
    str_mtx = np.zeros_like(u_mtx)

    i = 0
    for cellid in u.index:
        r_test_mtx[i] = r_test[cellid]
        se_test_mtx[i] = se_test[cellid]
        if modelname0 is not None:
            r0_test_mtx[i] = r0_test[cellid]
            se0_test_mtx[i] = se0_test[cellid]

        t_fir = fir[cellid]
        x = np.mean(t_fir, axis=1) / np.std(t_fir)
        mn, = np.where(x == np.min(x))
        mx, = np.where(x == np.max(x))
        xidx = np.array([mx[0], mn[0]])
        m_fir[i, :] = x[xidx]
        u_mtx[i, :] = u[cellid][xidx]
        tau_mtx[i, :] = np.abs(tau[cellid][xidx])
        str_mtx[i, :] = stp_magnitude(tau_mtx[i, :], u_mtx[i, :], fs=100,
                                      A=1)[0]
        i += 1

    # EI_units = (m_fir[:,0]>0) & (m_fir[:,1]<0)
    EI_units = (m_fir[:, 1] < 0)
    #good_pred = (r_test_mtx > se_test_mtx*2)
    good_pred = ((r_test_mtx > se_test_mtx * 3) |
                 (r0_test_mtx > se0_test_mtx * 3))

    mod_units = (r_test_mtx - se_test_mtx) > (r0_test_mtx + se0_test_mtx)

    show_units = mod_units & good_pred

    u_mtx[u_mtx < u_bounds[0]] = u_bounds[0]
    u_mtx[u_mtx > u_bounds[1]] = u_bounds[1]
    tau_mtx[tau_mtx > tau_bounds[1]] = tau_bounds[1]
    str_mtx[str_mtx < str_bounds[0]] = str_bounds[0]
    str_mtx[str_mtx > str_bounds[1]] = str_bounds[1]
    m_fir[m_fir < amp_bounds[0]] = amp_bounds[0]
    m_fir[m_fir > amp_bounds[1]] = amp_bounds[1]

    umean = np.median(u_mtx[show_units], axis=0)
    uerr = np.std(u_mtx[show_units], axis=0) / np.sqrt(np.sum(show_units))
    taumean = np.median(tau_mtx[show_units], axis=0)
    tauerr = np.std(tau_mtx[show_units], axis=0) / np.sqrt(str_mtx.shape[0])
    strmean = np.median(str_mtx[show_units], axis=0)
    strerr = np.std(str_mtx[show_units], axis=0) / np.sqrt(str_mtx.shape[0])

    xstr = 'E'
    ystr = 'I'

    fh = plt.figure(figsize=(8, 5))

    dotcolor = 'black'
    dotcolor_ns = 'lightgray'
    thinlinecolor = 'gray'
    barcolors = [(235 / 255, 47 / 255, 40 / 255),
                 (115 / 255, 200 / 255, 239 / 255)]
    barwidth = 0.5

    ax = plt.subplot(2, 3, 1)
    plt.plot(np.array(amp_bounds), np.array(amp_bounds), 'k--')
    plt.plot(m_fir[~show_units, 0],
             m_fir[~show_units, 1],
             '.',
             color=dotcolor_ns)
    plt.plot(m_fir[show_units, 0], m_fir[show_units, 1], '.', color=dotcolor)
    plt.title('n={}/{} good units'.format(np.sum(show_units),
                                          np.sum(good_pred)))
    plt.xlabel('exc channel gain')
    plt.ylabel('inh channel gain')
    ax_remove_box(ax)

    ax = plt.subplot(2, 3, 2)
    plt.plot(u_bounds, u_bounds, 'k--')
    plt.plot(u_mtx[~show_units, 0],
             u_mtx[~show_units, 1],
             '.',
             color=dotcolor_ns)
    plt.plot(u_mtx[show_units, 0], u_mtx[show_units, 1], '.', color=dotcolor)
    plt.axis('equal')
    plt.xlabel('exc channel u')
    plt.ylabel('inh channel u')
    plt.ylim(u_bounds)
    ax_remove_box(ax)

    ax = plt.subplot(2, 3, 3)
    plt.plot(str_bounds, str_bounds, 'k--')
    plt.plot(str_mtx[~show_units, 0],
             str_mtx[~show_units, 1],
             '.',
             color=dotcolor_ns)
    plt.plot(str_mtx[show_units, 0],
             str_mtx[show_units, 1],
             '.',
             color=dotcolor)
    plt.axis('equal')
    plt.xlabel('exc channel str')
    plt.ylabel('inh channel str')
    plt.ylim(str_bounds)
    ax_remove_box(ax)

    ax = plt.subplot(2, 3, 4)
    plt.plot(np.array([-0.5, 1.5]), np.array([0, 0]), 'k--')
    plt.bar(np.arange(2), umean, color=barcolors, width=barwidth)
    plt.errorbar(np.arange(2), umean, yerr=uerr, color='black', linewidth=2)
    plt.plot(u_mtx[show_units].T, linewidth=0.5, color=thinlinecolor)
    #    plt.plot(np.random.normal(0, 0.05, size=u_mtx[show_units, 0].shape),
    #             u_mtx[show_units, 0], '.', color=dotcolor)
    #    plt.plot(np.random.normal(1, 0.05, size=u_mtx[show_units, 0].shape),
    #             u_mtx[show_units, 1], '.', color=dotcolor)

    w, p = ss.wilcoxon(u_mtx[show_units, 0], u_mtx[show_units, 1])
    plt.ylim(u_bounds)
    plt.ylabel('u')
    plt.xlabel('{} {:.3f} - {} {:.3f} - rat {:.3f} - p={:.1e}'.format(
        xstr, umean[0], ystr, umean[1], umean[1] / umean[0], p))
    ax_remove_box(ax)

    ax = plt.subplot(2, 3, 5)
    plt.plot(np.array([-0.5, 1.5]), np.array([0, 0]), 'k--')
    plt.bar(np.arange(2), np.sqrt(taumean), color=barcolors, width=barwidth)
    plt.errorbar(np.arange(2),
                 np.sqrt(taumean),
                 yerr=np.sqrt(tauerr),
                 color='black',
                 linewidth=2)
    plt.plot(np.sqrt(tau_mtx[show_units].T),
             linewidth=0.5,
             color=thinlinecolor)

    w, p = ss.wilcoxon(tau_mtx[show_units, 0], tau_mtx[show_units, 1])
    plt.ylim((-np.sqrt(np.abs(tau_bounds[0])), np.sqrt(tau_bounds[1])))
    plt.ylabel('sqrt(tau)')
    plt.xlabel('E {:.3f} - I {:.3f} - rat {:.3f} - p={:.1e}'.format(
        taumean[0], taumean[1], taumean[1] / taumean[0], p))
    ax_remove_box(ax)

    ax = plt.subplot(2, 3, 6)
    plt.plot(np.array([-0.5, 1.5]), np.array([0, 0]), 'k--')
    plt.bar(np.arange(2), strmean, color=barcolors, width=barwidth)
    plt.errorbar(np.arange(2),
                 strmean,
                 yerr=strerr,
                 color='black',
                 linewidth=2)
    plt.plot(str_mtx[show_units].T, linewidth=0.5, color=thinlinecolor)

    w, p = ss.wilcoxon(str_mtx[show_units, 0], str_mtx[show_units, 1])
    plt.ylim(str_bounds)
    plt.ylabel('STP str')
    plt.xlabel('E {:.3f} - I {:.3f} - rat {:.3f} - p={:.1e}'.format(
        strmean[0], strmean[1], strmean[1] / strmean[0], p))
    ax_remove_box(ax)

    plt.tight_layout()

    return fh
Ejemplo n.º 15
0
def model_comp_pareto(batch,
                      modelgroups,
                      ax,
                      cellids,
                      nparms_modelgroups=None,
                      dot_colors=None,
                      dot_markers=None,
                      fill_styles=None,
                      plot_stat='r_test',
                      plot_medians=False,
                      labeled_models=None,
                      show_legend=True):

    if labeled_models is None:
        labeled_models = []
    if nparms_modelgroups is None:
        nparms_modelgroups = copy.copy(modelgroups)
    if fill_styles is None:
        fill_styles = {k: 'full' for (k, v) in dot_colors.items()}
    mean_cells_per_site = len(cellids)  # NAT4 dataset, so all cellids are used
    overall_min = 100
    overall_max = -100

    all_model_means = []
    labeled_data = []
    for k, modelnames in modelgroups.items():
        np_modelnames = nparms_modelgroups[k]
        b_ceiling = nd.batch_comp(batch,
                                  modelnames,
                                  cellids=cellids,
                                  stat=plot_stat)
        b_n = nd.batch_comp(batch,
                            np_modelnames,
                            cellids=cellids,
                            stat='n_parms')
        if not plot_medians:
            model_mean = b_ceiling.mean()
        else:
            model_mean = b_ceiling.median()
        b_m = np.array(model_mean)
        print(
            f"{k} modelcount {len(modelnames)} fits per model: {b_n.count().values}"
        )
        n_parms = np.array([np.mean(b_n[m]) for m in np_modelnames])

        # don't divide by cells per site if only one cell was fit
        if ('single' not in k) and (k != 'LN') and (k != 'stp'):
            n_parms = n_parms / mean_cells_per_site

        y_max = b_m.max() * 1.05
        y_min = b_m.min() * 0.9
        overall_max = max(overall_max, y_max)
        overall_min = min(overall_min, y_min)

        ax.plot(n_parms,
                b_m,
                color=dot_colors[k],
                marker=dot_markers[k],
                label=k,
                markersize=4.5,
                fillstyle=fill_styles[k])
        for m in labeled_models:
            if m in modelnames:
                i = modelnames.index(m)
                labeled_data.append([n_parms[i], b_m[i]])

        all_model_means.append(model_mean)

    ax.plot(*list(zip(*labeled_data)),
            's',
            color='black',
            marker='o',
            fillstyle='none',
            markersize=10)

    handles, labels = ax.get_legend_handles_labels()
    # reverse the order
    if show_legend:
        ax.legend(handles,
                  labels,
                  loc='lower right',
                  fontsize=7,
                  frameon=False)
    ax.set_xlabel('Free parameters per neuron')
    ax.set_ylabel('Median prediction correlation')
    ax.set_ylim((overall_min, overall_max))
    ax_remove_box(ax)
    plt.tight_layout()

    return ax, b_ceiling, all_model_means, labels
Ejemplo n.º 16
0
def single_scatter(batch,
                   gc,
                   stp,
                   LN,
                   combined,
                   compare,
                   plot_stat='r_ceiling',
                   legend=False):
    all_batch_cells = nd.get_batch_cells(batch, as_list=True)
    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=True,
                                           LN_filter=False,
                                           as_lists=True)

    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    improved = c
    not_improved = list(set(a) - set(c))
    models = [gc, stp, LN, combined]
    names = ['gc', 'stp', 'LN', 'combined']
    m1 = models[compare[0]]
    m2 = models[compare[1]]
    name1 = names[compare[0]]
    name2 = names[compare[1]]
    n_batch = len(all_batch_cells)
    n_all = len(a)
    n_imp = len(improved)
    n_not_imp = len(not_improved)

    m1_scores = plot_df[m1][not_improved]
    m1_scores_improved = plot_df[m1][improved]
    m2_scores = plot_df[m2][not_improved]
    m2_scores_improved = plot_df[m2][improved]

    fig = plt.figure()
    plt.plot([0, 1], [0, 1],
             color='black',
             linewidth=1,
             linestyle='dashed',
             dashes=dash_spacing)
    plt.scatter(m1_scores,
                m2_scores,
                s=small_scatter,
                label='no imp.',
                color=model_colors['LN'])
    #color='none',
    #edgecolors='black', linewidth=0.35)
    plt.scatter(m1_scores_improved,
                m2_scores_improved,
                s=big_scatter,
                label='sig. imp.',
                color=model_colors['max'])
    #color='none',
    #edgecolors='black', linewidth=0.35)
    ax_remove_box()

    if legend:
        plt.legend()
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.axes().set_aspect('equal')

    fig2 = plt.figure(figsize=text_fig)
    plt.text(
        0.1, 0.5, "batch %d\n"
        "%d/%d  auditory/total cells\n"
        "%d no improvements\n"
        "%d at least one improvement\n"
        "stat: %s, x: %s, y: %s" %
        (batch, n_all, n_batch, n_not_imp, n_imp, plot_stat, name1, name2))

    return fig, fig2
Ejemplo n.º 17
0
def gc_distributions(batch,
                     gc,
                     stp,
                     LN,
                     combined,
                     se_filter=True,
                     good_ln=0,
                     use_combined=False):
    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    cellids, under_chance, less_LN = get_filtered_cellids(batch,
                                                          gc,
                                                          stp,
                                                          LN,
                                                          combined,
                                                          as_lists=False)
    _, _, _, _, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           good_ln=good_ln)

    if use_combined:
        params_model = combined
    else:
        params_model = gc
    gc_params = fitted_params_per_batch(batch,
                                        params_model,
                                        stats_keys=[],
                                        meta=[])
    gc_params_cells = gc_params.transpose().index.values.tolist()
    for cell in gc_params_cells:
        if cell not in cellids:
            cellids[cell] = False
    not_c = list(set(gc_params.transpose()[cellids].index.values) - set(c))

    # index keys are formatted like "4--dsig.d--kappa"
    mod_keys = params_model.split('_')[1]
    for i, k in enumerate(mod_keys.split('-')):
        if 'dsig' in k:
            break
    b_key = f'{i}--{k}--base'
    a_key = f'{i}--{k}--amplitude'
    s_key = f'{i}--{k}--shift'
    k_key = f'{i}--{k}--kappa'
    ka_key = k_key + '_mod'
    ba_key = b_key + '_mod'
    aa_key = a_key + '_mod'
    sa_key = s_key + '_mod'
    all_keys = [b_key, a_key, s_key, k_key, ba_key, aa_key, sa_key, ka_key]

    phi_dfs = [
        gc_params[gc_params.index == k].transpose()[cellids].transpose()
        for k in all_keys
    ]
    sep_dfs = [df[not_c].values.flatten().astype(np.float64) for df in phi_dfs]
    gc_sep_dfs = [df[c].values.flatten().astype(np.float64) for df in phi_dfs]

    # removing extreme outliers b/c kept getting one or two cells with
    # values that were multiple orders of magnitude different than all others
    #    diffs = [sep_dfs[i+1] - sep_dfs[i]
    #             for i, _ in enumerate(sep_dfs[:-1])
    #             if i % 2 == 0]
    #diffs = sep_dfs[1::2] - sep_dfs[::2]

    #    gc_diffs = [gc_sep_dfs[i+1] - gc_sep_dfs[i]
    #                for i, _ in enumerate(gc_sep_dfs[:-1])
    #                if i % 2 == 0]
    #gc_diffs = gc_sep_dfs[1::2] - gc_sep_dfs[::2]

    raw_low, raw_high = sep_dfs[:4], sep_dfs[4:]
    diffs = [high - low for low, high in zip(raw_low, raw_high)]
    medians = [np.median(d) for d in diffs]
    medians_low = [np.median(d) for d in raw_low]
    medians_high = [np.median(d) for d in raw_high]

    gc_raw_low, gc_raw_high = gc_sep_dfs[:4], gc_sep_dfs[4:]
    gc_diffs = [high - low for low, high in zip(gc_raw_low, gc_raw_high)]

    gc_medians = [np.median(d) for d in gc_diffs]
    gc_medians_low = [np.median(d) for d in gc_raw_low]
    gc_medians_high = [np.median(d) for d in gc_raw_high]

    ts, ps = zip(*[
        st.mannwhitneyu(diff, gc_diff, alternative='two-sided')
        for diff, gc_diff in zip(diffs, gc_diffs)
    ])

    diffs = drop_common_outliers(*diffs)
    gc_diffs = drop_common_outliers(*gc_diffs)
    not_imp_outliers = len(diffs[0])
    imp_outliers = len(gc_diffs[0])

    color = model_colors['LN']
    c_color = model_colors['max']
    gc_label = 'GC ++ (%d)' % len(c)
    total_cells = len(c) + len(not_c)
    hist_kwargs = {'label': ['no imp', 'sig imp'], 'linewidth': 1}

    figs = []
    for i, name in zip([0, 1, 2, 3], ['base', 'amplitude', 'shift', 'kappa']):
        f1 = _stacked_hists(diffs[i],
                            gc_diffs[i],
                            medians[i],
                            gc_medians[i],
                            color,
                            c_color,
                            hist_kwargs=hist_kwargs)
        f2 = plt.figure(figsize=text_fig)
        text = ("%s distributions, n: %d\n"
                "n gc imp (bot): %d, med: %.4f\n"
                "n not imp (top): %d, med: %.4f\n"
                "yaxes: fraction of cells\n"
                "xaxis: 'fractional change in parameter per unit contrast'\n"
                "st.mannwhitneyu u: %.4E,\np: %.4E\n"
                "not imp w/o outliers: %d\n"
                "imp w/o outliers: %d" %
                (name, total_cells, len(c), gc_medians[i], len(not_c),
                 medians[i], ts[i], ps[i], not_imp_outliers, imp_outliers))
        plt.text(0.1, 0.5, text)
        figs.append(f1)
        figs.append(f2)

    f3 = plt.figure(figsize=small_fig)
    # median gc effect plots
    yin1, out1 = gc_dummy_sigmoid(*medians_low, low=0.0, high=0.3)
    yin2, out2 = gc_dummy_sigmoid(*medians_high, low=0.0, high=0.3)
    plt.scatter(yin1, out1, color=color, s=big_scatter, alpha=0.3)
    plt.scatter(yin2, out2, color=color, s=big_scatter * 2)
    figs.append(f3)
    plt.tight_layout()
    ax_remove_box()

    f3a = plt.figure(figsize=text_fig)
    text = ("non improved cells\n"
            "median low contrast:\n"
            "base:  %.4f,   amplitude:  %.4f\n"
            "shift:  %.4f,   kappa:  %.4f\n"
            "median high contrast:\n"
            "base:  %.4f,   amplitude:  %.4f\n"
            "shift:  %.4f,   kappa:  %.4f\n" % (*medians_low, *medians_high))
    plt.text(0.1, 0.5, text)
    figs.append(f3a)

    f4 = plt.figure(figsize=small_fig)
    gc_yin1, gc_out1 = gc_dummy_sigmoid(*gc_medians_low, low=0.0, high=0.3)
    gc_yin2, gc_out2 = gc_dummy_sigmoid(*gc_medians_high, low=0.0, high=0.3)
    plt.scatter(gc_yin1, gc_out1, color=c_color, s=big_scatter, alpha=0.3)
    plt.scatter(gc_yin2, gc_out2, color=c_color, s=big_scatter * 2)
    figs.append(f4)
    plt.tight_layout()
    ax_remove_box()

    f4a = plt.figure(figsize=text_fig)
    text = ("improved cells\n"
            "median low contrast:\n"
            "base:  %.4f,   amplitude:  %.4f\n"
            "shift:  %.4f,   kappa:  %.4f\n"
            "median high contrast:\n"
            "base:  %.4f,   amplitude:  %.4f\n"
            "shift:  %.4f,   kappa:  %.4f\n" %
            (*gc_medians_low, *gc_medians_high))
    plt.text(0.1, 0.5, text)
    figs.append(f4a)

    return figs
Ejemplo n.º 18
0
def performance_bar(batch,
                    gc,
                    stp,
                    LN,
                    combined,
                    se_filter=True,
                    LN_filter=False,
                    manual_cellids=None,
                    abbr_yaxis=False,
                    plot_stat='r_ceiling',
                    y_adjust=0.05,
                    manual_y=None,
                    only_improvements=False,
                    show_text_labels=False):
    '''
    model1: GC
    model2: STP
    model3: LN
    model4: GC+STP

    '''

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    #    cellids, under_chance, less_LN = get_filtered_cellids(batch, gc, stp,
    #                                                          LN, combined,
    #                                                          se_filter,
    #                                                          LN_filter)
    e, a, g, s, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           se_filter=se_filter,
                                           LN_filter=LN_filter)
    cellids = a

    if manual_cellids is not None:
        # WARNING: Will override se and ratio filters even if they are set
        cellids = manual_cellids
    elif only_improvements:
        e, a, g, s, c = improved_cells_to_list(batch,
                                               gc,
                                               stp,
                                               LN,
                                               combined,
                                               as_lists=True)
        cellids = e

    if plot_stat == 'r_ceiling':
        plot_df = df_c
    else:
        plot_df = df_r

    n_cells = len(cellids)
    gc_test = plot_df[gc][cellids]
    stp_test = plot_df[stp][cellids]
    ln_test = plot_df[LN][cellids]
    gc_stp_test = plot_df[combined][cellids]
    #max_test = np.maximum(gc_test, stp_test)

    gc = np.median(gc_test.values)
    stp = np.median(stp_test.values)
    ln = np.median(ln_test.values)
    gc_stp = np.median(gc_stp_test.values)
    #maximum = np.median(max_test)
    largest = max(gc, stp, ln, gc_stp)  #, maximum)

    colors = [model_colors[k]
              for k in ['LN', 'gc', 'stp', 'combined']]  #, 'max']]
    #fig = plt.figure(figsize=(15, 12))
    fig = plt.figure()
    plt.bar(
        [1, 2, 3, 4],
        [ln, gc, stp, gc_stp],  # maximum],
        color=colors,
        edgecolor="black",
        linewidth=1)
    plt.xticks([1, 2, 3, 4, 5],
               ['LN', 'GC', 'STP', 'GC+STP'])  #, 'Max(GC,STP)'])
    if abbr_yaxis:
        if manual_y:
            lower, upper = manual_y
        else:
            lower = np.floor(10 * min(gc, stp, ln, gc_stp)) / 10
            upper = np.ceil(10 * max(gc, stp, ln, gc_stp)) / 10 + y_adjust
        plt.ylim(ymin=lower, ymax=upper)
    else:
        plt.ylim(ymax=largest * 1.4)
    common_kwargs = {'color': 'white', 'horizontalalignment': 'center'}
    if abbr_yaxis:
        y_text = 0.5 * (lower + min(gc, stp, ln, gc_stp))
    else:
        y_text = 0.2
    if show_text_labels:
        for i, m in enumerate([ln, gc, stp, gc_stp]):  #, maximum]):
            t = plt.text(i + 1, y_text, "%0.04f" % m, **common_kwargs)
        #t.set_path_effects([pe.withStroke(linewidth=3, foreground='black')])
    xmin, xmax = plt.xlim()
    plt.xlim(xmin, xmax - 0.35)
    ax_remove_box()
    plt.tight_layout()

    fig2 = plt.figure(figsize=text_fig)
    text = "Median Performance, batch: %d\, n:%d" % (batch, n_cells)
    plt.text(0.1, 0.5, text)

    return fig, fig2
Ejemplo n.º 19
0
def stp_distributions(batch,
                      gc,
                      stp,
                      LN,
                      combined,
                      se_filter=True,
                      good_ln=0,
                      log_scale=False,
                      legend=False,
                      use_combined=False):

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    cellids, under_chance, less_LN = get_filtered_cellids(batch,
                                                          gc,
                                                          stp,
                                                          LN,
                                                          combined,
                                                          as_lists=False)
    _, _, _, _, c = improved_cells_to_list(batch,
                                           gc,
                                           stp,
                                           LN,
                                           combined,
                                           good_ln=good_ln)

    if use_combined:
        params_model = combined
    else:
        params_model = stp
    stp_params = fitted_params_per_batch(batch,
                                         params_model,
                                         stats_keys=[],
                                         meta=[])
    stp_params_cells = stp_params.transpose().index.values.tolist()
    for cell in stp_params_cells:
        if cell not in cellids:
            cellids[cell] = False
    not_c = list(set(stp_params.transpose()[cellids].index.values) - set(c))

    # index keys are formatted like "2--stp.2--tau"
    mod_keys = stp.split('_')[1]
    for i, k in enumerate(mod_keys.split('-')):
        if 'stp' in k:
            break
    tau_key = '%d--%s--tau' % (i, k)
    u_key = '%d--%s--u' % (i, k)

    all_taus = stp_params[stp_params.index ==
                          tau_key].transpose()[cellids].transpose()
    all_us = stp_params[stp_params.index ==
                        u_key].transpose()[cellids].transpose()
    dims = all_taus.values.flatten()[0].shape[0]

    # convert to dims x cells array instead of cells, array w/ multidim values
    #sep_taus = _df_to_array(all_taus, dims).mean(axis=0)
    #sep_us = _df_to_array(all_us, dims).mean(axis=0)
    #med_tau = np.median(sep_taus)
    #med_u = np.median(sep_u)
    sep_taus = _df_to_array(all_taus[not_c], dims).mean(axis=0)
    sep_us = _df_to_array(all_us[not_c], dims).mean(axis=0)
    med_tau = np.median(sep_taus)
    med_u = np.median(sep_us)

    stp_taus = all_taus[c]
    stp_us = all_us[c]
    stp_sep_taus = _df_to_array(stp_taus, dims).mean(axis=0)
    stp_sep_us = _df_to_array(stp_us, dims).mean(axis=0)

    stp_med_tau = np.median(stp_sep_taus)
    stp_med_u = np.median(stp_sep_us)
    #tau_t, tau_p = st.ttest_ind(sep_taus, stp_sep_taus)
    #u_t, u_p = st.ttest_ind(sep_us, stp_sep_us)

    # NOTE: not actually a t statistic now, it's mann-whitney U statistic,
    #       just didn't want to change all of the var names incase i revert
    tau_t, tau_p = st.mannwhitneyu(sep_taus,
                                   stp_sep_taus,
                                   alternative='two-sided')
    u_t, u_p = st.mannwhitneyu(sep_us, stp_sep_us, alternative='two-sided')

    sep_taus, sep_us = drop_common_outliers(sep_taus, sep_us)
    stp_sep_taus, stp_sep_us = drop_common_outliers(stp_sep_taus, stp_sep_us)
    not_imp_outliers = len(sep_taus)
    imp_outliers = len(stp_sep_taus)

    fig1, (a1, a2) = plt.subplots(2, 1, sharex=True, sharey=True)
    color = model_colors['LN']
    imp_color = model_colors['max']
    stp_label = 'STP ++ (%d)' % len(c)
    total_cells = len(c) + len(not_c)
    bin_count = 30
    hist_kwargs = {'linewidth': 1, 'label': ['not imp', 'stp imp']}

    plt.sca(a1)
    weights1 = [np.ones(len(sep_taus)) / len(sep_taus)]
    weights2 = [np.ones(len(stp_sep_taus)) / len(stp_sep_taus)]
    upper = max(sep_taus.max(), stp_sep_taus.max())
    lower = min(sep_taus.min(), stp_sep_taus.min())
    bins = np.linspace(lower, upper, bin_count + 1)
    #    if log_scale:
    #        lower_bound = min(sep_taus.min(), stp_sep_taus.min())
    #        upper_bound = max(sep_taus.max(), stp_sep_taus.max())
    #        bins = np.logspace(lower_bound, upper_bound, bin_count+1)
    #        hist_kwargs['bins'] = bins
    #    plt.hist([sep_taus, stp_sep_taus], weights=weights, **hist_kwargs)
    a1.hist(sep_taus,
            weights=weights1,
            fc=faded_LN,
            edgecolor=dark_LN,
            bins=bins,
            **hist_kwargs)
    a2.hist(stp_sep_taus,
            weights=weights2,
            fc=faded_max,
            edgecolor=dark_max,
            bins=bins,
            **hist_kwargs)
    a1.axes.axvline(med_tau,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a1.axes.axvline(stp_med_tau,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a2.axes.axvline(med_tau,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a2.axes.axvline(stp_med_tau,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    ax_remove_box(a1)
    ax_remove_box(a2)

    #plt.title('tau,  sig diff?:  p=%.4E' % tau_p)
    #plt.xlabel('tau (ms)')

    fig2 = plt.figure(figsize=text_fig)
    text = ("tau distributions, n: %d\n"
            "n stp imp (bot): %d, med: %.4f\n"
            "n not imp (top): %d, med: %.4f\n"
            "yaxes: fraction of cells\n"
            "xaxis: tau(ms)\n"
            "st.mannwhitneyu u: %.4E,\np: %.4E\n"
            "not imp after outliers: %d\n"
            "imp after outliers: %d\n" %
            (total_cells, len(c), stp_med_tau, len(not_c), med_tau, tau_t,
             tau_p, not_imp_outliers, imp_outliers))
    plt.text(0.1, 0.5, text)

    fig3, (a3, a4) = plt.subplots(2, 1, sharex=True, sharey=True)
    weights3 = [np.ones(len(sep_us)) / len(sep_us)]
    weights4 = [np.ones(len(stp_sep_us)) / len(stp_sep_us)]
    upper = max(sep_us.max(), stp_sep_us.max())
    lower = min(sep_us.min(), stp_sep_us.min())
    bins = np.linspace(lower, upper, bin_count + 1)
    #    if log_scale:
    #        lower_bound = min(sep_us.min(), stp_sep_us.min())
    #        upper_bound = max(sep_us.max(), stp_sep_us.max())
    #        bins = np.logspace(lower_bound, upper_bound, bin_count+1)
    #        hist_kwargs['bins'] = bins
    #    plt.hist([sep_us, stp_sep_us], weights=weights, **hist_kwargs)
    a3.hist(sep_us,
            weights=weights3,
            fc=faded_LN,
            edgecolor=dark_LN,
            bins=bins,
            **hist_kwargs)
    a4.hist(stp_sep_us,
            weights=weights4,
            fc=faded_max,
            edgecolor=dark_max,
            bins=bins,
            **hist_kwargs)
    a3.axes.axvline(med_u,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a3.axes.axvline(stp_med_u,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a4.axes.axvline(med_u,
                    color=dark_LN,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    a4.axes.axvline(stp_med_u,
                    color=dark_max,
                    linewidth=2,
                    linestyle='dashed',
                    dashes=dash_spacing)
    ax_remove_box(a3)
    ax_remove_box(a4)
    #plt.title('u,  sig diff?:  p=%.4E' % u_p)
    #plt.xlabel('u (fractional change in gain \nper unit of stimulus amplitude)')
    #plt.ylabel('proportion within group')

    fig4 = plt.figure(figsize=text_fig)
    text = ("u distributions, n: %d\n"
            "n stp imp (bot): %d, med: %.4f\n"
            "n not imp (top): %d, med: %.4f\n"
            "yaxes: fraction of cells\n"
            "xaxis: u(fractional change in gain per unit stimulus amplitude)\n"
            "st.mannwhitneyu u: %.4E,\np: %.4E" %
            (total_cells, len(c), stp_med_u, len(not_c), med_u, u_t, u_p))
    plt.text(0.1, 0.5, text)

    stp_mag, stp_yin, stp_out = stp_magnitude(np.array([[stp_med_tau]]),
                                              np.array([[stp_med_u]]))
    mag, yin, out = stp_magnitude(np.array([[med_tau]]), np.array([[med_u]]))
    fig5 = plt.figure(figsize=short_fig)
    plt.plot(stp_out.as_continuous().flatten(),
             color=imp_color,
             label='STP ++')
    plt.plot(out.as_continuous().flatten(), color=color)
    if legend:
        plt.legend()
    ax_remove_box()

    return fig1, fig2, fig3, fig4, fig5
Ejemplo n.º 20
0
def stp_v_beh():

    batch1 = 274
    batch2 = 275
    modelnames=["env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-fir.1x15-lvl.1-dexp.1_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-fir.1x15-lvl.1-rep.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-rep.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-fir.1x15-lvl.1-dexp.1_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-fir.1x15-lvl.1-rep.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-rep.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-rep.2-stp.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic"]
    fileprefix="fig8.stp_v_beh"
    n1=modelnames[0]
    n2=modelnames[-3]

    xc_range = [-0.05, 0.6]

    df1 = nd.batch_comp(batch1,modelnames,stat='r_test').reset_index()
    df1_e = nd.batch_comp(batch1,modelnames,stat='se_test').reset_index()

    df2 = nd.batch_comp(batch2,modelnames,stat='r_test').reset_index()
    df2_e = nd.batch_comp(batch2,modelnames,stat='se_test').reset_index()

    df = df1.append(df2)
    df_e = df1_e.append(df2_e)

    cellcount = len(df)

    beta1 = df[n1]
    beta2 = df[n2]
    beta1_test = df[n1]
    beta2_test = df[n2]
    se1 = df_e[n1]
    se2 = df_e[n2]

    beta1[beta1>1]=1
    beta2[beta2>1]=1

    # test for significant improvement
    improvedcells = (beta2_test-se2 > beta1_test+se1)

    # test for signficant prediction at all
    goodcells = ((beta2_test > se2*3) | (beta1_test > se1*3))

    fh = plt.figure()
    ax = plt.subplot(2,2,1)
    stateplots.beta_comp(beta1[goodcells], beta2[goodcells],
                         n1='LN STRF', n2='STP+BEH LN STRF',
                         hist_range=xc_range, ax=ax,
                         highlight=improvedcells[goodcells])


    # LN vs. STP:
    beta1b = df[modelnames[3]]
    beta1a = df[modelnames[0]]
    beta1 = beta1b - beta1a
    se1a = df_e[modelnames[3]]
    se1b= df_e[modelnames[0]]

    b1=4
    b0=3
    beta2b = df[modelnames[b1]]
    beta2a = df[modelnames[b0]]
    beta2 = beta2b - beta2a
    se2a = df_e[modelnames[b1]]
    se2b= df_e[modelnames[b0]]

    stpgood = (beta1 > se1a+se1b)
    behgood = (beta2 > se2a+se2b)
    neither_good = np.logical_not(stpgood) & np.logical_not(behgood)
    both_good = stpgood & behgood
    stp_only_good = stpgood & np.logical_not(behgood)
    beh_only_good = np.logical_not(stpgood) & behgood

    xc_range = np.array([-0.05, 0.15])
    beta1[beta1<xc_range[0]]=xc_range[0]
    beta2[beta2<xc_range[0]]=xc_range[0]

    zz = np.zeros(2)
    ax=plt.subplot(2,2,2)
    ax.plot(xc_range,zz,'k--',linewidth=0.5)
    ax.plot(zz,xc_range,'k--',linewidth=0.5)
    ax.plot(xc_range, xc_range, 'k--', linewidth=0.5)
    l = ax.plot(beta1[neither_good], beta2[neither_good], '.', color='lightgray') +\
        ax.plot(beta1[beh_only_good], beta2[beh_only_good], '.', color='purple') +\
        ax.plot(beta1[stp_only_good], beta2[stp_only_good], '.', color='orange') +\
        ax.plot(beta1[both_good], beta2[both_good], '.', color='black')
    ax_remove_box(ax)
    ax.set_aspect('equal', 'box')
    #plt.axis('equal')
    ax.set_xlim(xc_range)
    ax.set_ylim(xc_range)
    ax.set_xlabel('delta(stp)')
    ax.set_ylabel('delta(beh)')

    olap=np.zeros(100)
    a = stpgood.values.copy()
    b = behgood.values.copy()
    for i in range(100):
        np.random.shuffle(a)
        olap[i] = np.sum(a & b)

    ll=[np.sum(neither_good), np.sum(beh_only_good),
        np.sum(stp_only_good), np.sum(both_good)]
    ax.legend(l, ll)

    ax=plt.subplot(2,2,3)
    m = np.array(df.loc[goodcells].mean()[modelnames])
    xc_range = [-0.02, 0.2]
    plt.bar(np.arange(len(modelnames)), m, color='black')
    plt.plot(np.array([-1, len(modelnames)]), np.array([0, 0]), 'k--',
             linewidth=0.5)
    plt.ylim(xc_range)
    plt.title("batch {}, n={}/{} good cells".format(
            batch, np.sum(goodcells), len(goodcells)))
    plt.ylabel('median pred corr')
    plt.xlabel('model architecture')
    ax_remove_box(ax)

    for i in range(len(modelnames)-1):

        d1 = np.array(df[modelnames[i]])
        d2 = np.array(df[modelnames[i+1]])
        s, p = ss.wilcoxon(d1, d2)
        plt.text(i+0.5, m[i+1], "{:.1e}".format(p), ha='center', fontsize=6)

    plt.xticks(np.arange(len(m)),np.round(m,3))

    return fh, df[stpgood]['cellid'].tolist()
Ejemplo n.º 21
0
def example_clip(cellid,
                 batch,
                 gc,
                 stp,
                 LN,
                 combined,
                 skip_combined=False,
                 normalize=True,
                 stim_idx=0,
                 trim_start=None,
                 trim_end=None,
                 smooth_response=False,
                 kernel_length=3,
                 strf_spec='LN',
                 load_path=None,
                 save_path=None):
    if load_path is None:
        gc_ctx, stp_ctx, LN_ctx, combined_ctx = \
                _get_plot_contexts(cellid, batch, gc, stp, LN, combined)
        if save_path is not None:
            results = {'contexts': [gc_ctx, stp_ctx, LN_ctx, combined_ctx]}
            pickle.dump(results, open(save_path, 'wb'))
    else:
        results = pickle.load(open(load_path, 'rb'))
        gc_ctx, stp_ctx, LN_ctx, combined_ctx = results['contexts']

    gc_pred, stp_pred, LN_pred, combined_pred, resp, stim = \
        _get_plot_signals(gc_ctx, stp_ctx, LN_ctx, combined_ctx)

    gc_v, stp_v, LN_v, combined_v = _get_plot_vals(gc_ctx, stp_ctx, LN_ctx,
                                                   combined_ctx)

    # break up into separate stims
    epochs = gc_v.epochs
    stims = ep.epoch_names_matching(epochs, 'STIM_')
    s = stims[stim_idx]
    row = epochs[epochs.name == s]
    fs = gc_v['resp'].fs
    start = int(row['start'].values[0] * fs)
    end = int(row['end'].values[0] * fs)

    if trim_start is not None:
        start += trim_start
    if trim_end is not None:
        end = start + (trim_end - trim_start)

    resp_plot = resp[start:end]
    if smooth_response:
        # box filter, "simple average"
        kernel = np.ones((kernel_length, )) * (1 / kernel_length)
        resp_plot = convolve(resp_plot, kernel, mode='same')
    LN_plot = LN_pred[start:end]
    gc_plot = gc_pred[start:end]
    stp_plot = stp_pred[start:end]
    combined_plot = combined_pred[start:end]
    stim_plot = stim[:, start:end]
    if normalize:
        max_all = np.nanmax(
            np.concatenate(
                [resp_plot, LN_plot, gc_plot, stp_plot, combined_plot]))
        gc_plot = gc_plot / max_all
        stp_plot = stp_plot / max_all
        LN_plot = LN_plot / max_all
        combined_plot = combined_plot / max_all
        resp_plot = resp_plot / max_all

    fig = plt.figure(figsize=wide_fig)
    xmin = 0
    xmax = end - start
    plt.imshow(stim_plot,
               aspect='auto',
               cmap=spectrogram_cmap,
               origin='lower',
               extent=(xmin, xmax, 1.1, 1.5))
    lw = 0.75
    plt.plot(resp_plot, color=model_colors['LN'], linewidth=lw)
    t = np.linspace(0, resp_plot.shape[-1] - 1, resp_plot.shape[-1])
    plt.fill_between(t, resp_plot, color='gray', alpha=0.15)
    plt.plot(gc_plot, color=model_colors['gc'], linewidth=lw)
    plt.plot(stp_plot,
             color=model_colors['stp'],
             alpha=0.65,
             linewidth=lw * 1.25)
    plt.plot(LN_plot, color='black', alpha=0.55, linewidth=lw)
    signals = ['Response', 'LN', 'GC', 'STP']
    if not skip_combined:
        plt.plot(combined_plot,
                 color=model_colors['combined'],
                 linewidth=lw,
                 linestyle='--')
        signals.append('GC+STP')
    plt.ylim(-0.1, 1.5)
    ax = plt.gca()
    ax_remove_box(ax)

    fig2 = plt.figure(figsize=text_fig)
    text = ("cellid: %s\n"
            "stp_r_test: %.4f\n"
            "gc_r_test: %.4f\n"
            "LN_r_test: %.4f\n"
            "comb_r_test: %.4f" % (cellid, stp_ctx['modelspec'].meta['r_test'],
                                   gc_ctx['modelspec'].meta['r_test'],
                                   LN_ctx['modelspec'].meta['r_test'],
                                   combined_ctx['modelspec'].meta['r_test']))
    plt.text(0.1, 0.5, text)

    # TODO: probably need to just rip code out of strf_heatmap instead,
    #       setting the extent is not working. or alternatively just
    #       resize it manually to mach the spectrogram
    fig3 = plt.figure(figsize=wide_fig)
    ax2 = plt.gca()

    if strf_spec == 'LN':
        modelspec = LN_ctx['modelspec']
    elif strf_spec == 'stp':
        modelspec = stp_ctx['modelspec']
    elif strf_spec == 'gc':
        modelspec = gc_ctx['modelspec']
    else:
        modelspec = combined_ctx['modelspec']

    nplt.strf_heatmap(modelspec,
                      ax=ax2,
                      show_factorized=False,
                      show_cbar=False,
                      manual_extent=(0, 1, 1.1, 1.5))
    ax2.set_ylim(-0.1, 1.5)
    ax_remove_box(ax2)

    return fig, fig2, fig3