Example #1
0
def beh_only_plot(batch=311):

    # 311 = A1 old (SVD) data -- on BF
    state_list = ['st.beh0', 'st.beh']
    basemodel = "-ref-psthfr.s_stategain.S"
    loader = "psth.fs20-ld-"
    fitter = "_jk.nf20-basic"
    df = get_model_results_per_state_model(batch=batch,
                                           state_list=state_list,
                                           basemodel=basemodel,
                                           fitter="_jk.nf20-basic",
                                           loader=loader)
    da = df[df['state_chan'] == 'active']

    dp = da.pivot(index='cellid',
                  columns='state_sig',
                  values=['r', 'r_se', 'MI', 'g', 'd'])

    dr = dp['r'].copy()
    dr['sig']=((dp['r'][state_list[1]]-dp['r'][state_list[0]]) > \
         (dp['r_se'][state_list[1]]+dp['r_se'][state_list[0]]))

    g = dp['g'].copy()
    d = dp['d'].copy()
    ggood = np.isfinite(g['st.beh'])
    stateplots.beta_comp(d.loc[ggood, 'st.beh'],
                         g.loc[ggood, 'st.beh'],
                         n1='Baseline',
                         n2='Gain',
                         title="Baseline/gain: batch {}".format(batch),
                         highlight=dr.loc[ggood, 'sig'],
                         hist_range=[-1, 1])

    MI = dp['MI'].copy()
    migood = np.isfinite(MI['st.beh'])
    stateplots.beta_comp(MI.loc[migood, 'st.beh0'],
                         MI.loc[migood, 'st.beh'],
                         n1='State independent',
                         n2='State-dep',
                         title="MI: batch {}".format(batch),
                         highlight=dr.loc[migood, 'sig'],
                         hist_range=[-0.5, 0.5])

    return df
def plot_save_examples(batch,
                       compare,
                       loader,
                       basemodel,
                       fitter,
                       RELOAD=False):

    if batch in [301, 307]:
        area = "AC"
    else:
        area = "IC"

    d = nd.get_batch_cells(batch)
    cellids = list(d['cellid'])

    stats_list = []
    root_path = '/auto/users/svd/projects/pupil-behavior'

    modelset = '{}_{}_{}_{}_{}_{}'.format(compare, area, batch, loader,
                                          basemodel, fitter)
    out_path = '{}/{}/'.format(root_path, modelset)

    if os.access(root_path, os.W_OK) and not (os.path.exists(out_path)):
        os.makedirs(out_path)

    datafile = out_path + 'results.csv'
    plt.close('all')

    if (not RELOAD) and (not os.path.isfile(datafile)):
        RELOAD = True
        print('datafile not found, reloading')

    if RELOAD:
        for cellid in cellids:
            if compare == "pb":
                fh, stats = stateplots.pb_model_plot(cellid,
                                                     batch,
                                                     loader=loader,
                                                     basemodel=basemodel,
                                                     fitter=fitter)
            elif compare == "ppas":
                fh, stats = stateplots.ppas_model_plot(cellid,
                                                       batch,
                                                       loader=loader,
                                                       basemodel=basemodel,
                                                       fitter=fitter)
            else:
                fh, stats = stateplots.pp_model_plot(cellid,
                                                     batch,
                                                     loader=loader,
                                                     basemodel=basemodel,
                                                     fitter=fitter)

            # fh2 = stateplots.pp_model_plot(cellid,batch)
            stats_list.append(stats)
            if os.access(out_path, os.W_OK):
                fh.savefig(out_path + cellid + '.pdf')
                fh.savefig(out_path + cellid + '.png')
            plt.close(fh)

        col_names = [
            'cellid', 'r_p0b0', 'r_p0b', 'r_pb0', 'r_pb', 'e_p0b0', 'e_p0b',
            'e_pb0', 'e_pb', 'rf_p0b0', 'rf_p0b', 'rf_pb0', 'rf_pb', 'r_pup',
            'r_beh', 'r_beh_pup0', 'pup_mod', 'beh_mod', 'pup_mod_n',
            'beh_mod_n', 'pup_mod_beh0', 'beh_mod_pup0', 'pup_mod_beh0_n',
            'beh_mod_pup0_n', 'd_pup', 'd_beh', 'g_pup', 'g_beh',
            'ref_all_resp', 'ref_common_resp', 'tar_max_resp', 'tar_probe_resp'
        ]
        df = pd.DataFrame(columns=col_names)

        for stats in stats_list:
            df0 = pd.DataFrame([[
                stats['cellid'], stats['r_test'][0], stats['r_test'][1],
                stats['r_test'][2], stats['r_test'][3], stats['se_test'][0],
                stats['se_test'][1], stats['se_test'][2], stats['se_test'][3],
                stats['r_floor'][0], stats['r_floor'][1], stats['r_floor'][2],
                stats['r_floor'][3], stats['r_test'][3] - stats['r_test'][1],
                stats['r_test'][3] - stats['r_test'][2],
                stats['r_test'][1] - stats['r_test'][0],
                stats['pred_mod'][0, 1], stats['pred_mod'][1, 2],
                stats['pred_mod_norm'][0, 1], stats['pred_mod_norm'][1, 2],
                stats['pred_mod_full'][0, 1], stats['pred_mod_full'][1, 2],
                stats['pred_mod_full_norm'][0, 1],
                stats['pred_mod_full_norm'][1, 2], stats['b'][3, 1],
                stats['b'][3, 2], stats['g'][3, 1], stats['g'][3, 2],
                stats['ref_all_resp'], stats['ref_common_resp'],
                stats['tar_max_resp'], stats['tar_probe_resp']
            ]],
                               columns=col_names)
            df = df.append(df0)
        df.set_index(['cellid'], inplace=True)
        if os.access(out_path, os.W_OK):
            df.to_csv(datafile)
    else:
        # load cached dataframe
        df = pd.read_csv(datafile, index_col=0)

    sig_mod = list(df['r_pb'] - df['e_pb'] > df['r_p0b0'] + df['e_p0b0'])
    if compare == "pb":
        alabel = "active"
    elif compare == "ppas":
        alabel = "each passive"
    else:
        alabel = "pre/post"

    mi_bounds = [-0.4, 0.4]

    fh1 = stateplots.beta_comp(df['r_pup'],
                               df['r_beh'],
                               n1='pupil',
                               n2=alabel,
                               title=modelset + ' unique pred',
                               hist_range=[-0.02, 0.15],
                               highlight=sig_mod)
    fh2 = stateplots.beta_comp(df['pup_mod_n'],
                               df['beh_mod'],
                               n1='pupil',
                               n2=alabel,
                               title=modelset + ' mod index',
                               hist_range=mi_bounds,
                               highlight=sig_mod)
    fh3 = stateplots.beta_comp(df['beh_mod_pup0'],
                               df['beh_mod'],
                               n1=alabel + '-nopup',
                               n2=alabel,
                               title=modelset + ' unique mod',
                               hist_range=mi_bounds,
                               highlight=sig_mod)

    # unique behavior:
    #    performance of full model minus performance behavior shuffled
    # r_beh_with_pupil = df['r_pb'] - df['r_pb0']
    # naive behavior (ignorant of pupil):
    #    performance of behavior alone (pupil shuffled) minus all shuffled
    # r_beh_no_pupil = df['r_p0b'] - df['r_p0b0']

    fh4 = stateplots.beta_comp(df['r_beh_pup0'],
                               df['r_beh'],
                               n1=alabel + '-nopup',
                               n2=alabel,
                               title=modelset + ' unique r',
                               hist_range=[-0.02, .15],
                               highlight=sig_mod)

    #fh4 = stateplots.beta_comp(df['r_beh'], df['beh_mod'], n1='pred', n2='mod',
    #                           title='behavior', hist_range=[-0.4, 0.4])
    #fh5 = stateplots.beta_comp(df['r_pup'], df['pup_mod'], n1='pred', n2='mod',
    #                           title='pupil', hist_range=[-0.1, 0.1])

    if os.access(out_path, os.W_OK):
        fh1.savefig(out_path + 'summary_pred.pdf')
        fh2.savefig(out_path + 'summary_mod.pdf')
        fh3.savefig(out_path + 'summary_mod_ctl.pdf')
        fh4.savefig(out_path + 'summary_r_ctl.pdf')
Example #3
0
def stp_v_beh():

    batch1 = 274
    batch2 = 275
    modelnames=["env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-fir.1x15-lvl.1-dexp.1_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-fir.1x15-lvl.1-rep.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-rep.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-fir.1x15-lvl.1-dexp.1_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-fir.1x15-lvl.1-rep.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-rep.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-rep.2-stp.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic"]
    fileprefix="fig8.stp_v_beh"
    n1=modelnames[0]
    n2=modelnames[-3]

    xc_range = [-0.05, 0.6]

    df1 = nd.batch_comp(batch1,modelnames,stat='r_test').reset_index()
    df1_e = nd.batch_comp(batch1,modelnames,stat='se_test').reset_index()

    df2 = nd.batch_comp(batch2,modelnames,stat='r_test').reset_index()
    df2_e = nd.batch_comp(batch2,modelnames,stat='se_test').reset_index()

    df = df1.append(df2)
    df_e = df1_e.append(df2_e)

    cellcount = len(df)

    beta1 = df[n1]
    beta2 = df[n2]
    beta1_test = df[n1]
    beta2_test = df[n2]
    se1 = df_e[n1]
    se2 = df_e[n2]

    beta1[beta1>1]=1
    beta2[beta2>1]=1

    # test for significant improvement
    improvedcells = (beta2_test-se2 > beta1_test+se1)

    # test for signficant prediction at all
    goodcells = ((beta2_test > se2*3) | (beta1_test > se1*3))

    fh = plt.figure()
    ax = plt.subplot(2,2,1)
    stateplots.beta_comp(beta1[goodcells], beta2[goodcells],
                         n1='LN STRF', n2='STP+BEH LN STRF',
                         hist_range=xc_range, ax=ax,
                         highlight=improvedcells[goodcells])


    # LN vs. STP:
    beta1b = df[modelnames[3]]
    beta1a = df[modelnames[0]]
    beta1 = beta1b - beta1a
    se1a = df_e[modelnames[3]]
    se1b= df_e[modelnames[0]]

    b1=4
    b0=3
    beta2b = df[modelnames[b1]]
    beta2a = df[modelnames[b0]]
    beta2 = beta2b - beta2a
    se2a = df_e[modelnames[b1]]
    se2b= df_e[modelnames[b0]]

    stpgood = (beta1 > se1a+se1b)
    behgood = (beta2 > se2a+se2b)
    neither_good = np.logical_not(stpgood) & np.logical_not(behgood)
    both_good = stpgood & behgood
    stp_only_good = stpgood & np.logical_not(behgood)
    beh_only_good = np.logical_not(stpgood) & behgood

    xc_range = np.array([-0.05, 0.15])
    beta1[beta1<xc_range[0]]=xc_range[0]
    beta2[beta2<xc_range[0]]=xc_range[0]

    zz = np.zeros(2)
    ax=plt.subplot(2,2,2)
    ax.plot(xc_range,zz,'k--',linewidth=0.5)
    ax.plot(zz,xc_range,'k--',linewidth=0.5)
    ax.plot(xc_range, xc_range, 'k--', linewidth=0.5)
    l = ax.plot(beta1[neither_good], beta2[neither_good], '.', color='lightgray') +\
        ax.plot(beta1[beh_only_good], beta2[beh_only_good], '.', color='purple') +\
        ax.plot(beta1[stp_only_good], beta2[stp_only_good], '.', color='orange') +\
        ax.plot(beta1[both_good], beta2[both_good], '.', color='black')
    ax_remove_box(ax)
    ax.set_aspect('equal', 'box')
    #plt.axis('equal')
    ax.set_xlim(xc_range)
    ax.set_ylim(xc_range)
    ax.set_xlabel('delta(stp)')
    ax.set_ylabel('delta(beh)')

    olap=np.zeros(100)
    a = stpgood.values.copy()
    b = behgood.values.copy()
    for i in range(100):
        np.random.shuffle(a)
        olap[i] = np.sum(a & b)

    ll=[np.sum(neither_good), np.sum(beh_only_good),
        np.sum(stp_only_good), np.sum(both_good)]
    ax.legend(l, ll)

    ax=plt.subplot(2,2,3)
    m = np.array(df.loc[goodcells].mean()[modelnames])
    xc_range = [-0.02, 0.2]
    plt.bar(np.arange(len(modelnames)), m, color='black')
    plt.plot(np.array([-1, len(modelnames)]), np.array([0, 0]), 'k--',
             linewidth=0.5)
    plt.ylim(xc_range)
    plt.title("batch {}, n={}/{} good cells".format(
            batch, np.sum(goodcells), len(goodcells)))
    plt.ylabel('median pred corr')
    plt.xlabel('model architecture')
    ax_remove_box(ax)

    for i in range(len(modelnames)-1):

        d1 = np.array(df[modelnames[i]])
        d2 = np.array(df[modelnames[i+1]])
        s, p = ss.wilcoxon(d1, d2)
        plt.text(i+0.5, m[i+1], "{:.1e}".format(p), ha='center', fontsize=6)

    plt.xticks(np.arange(len(m)),np.round(m,3))

    return fh, df[stpgood]['cellid'].tolist()
Example #4
0
         'k--',
         lw=0.5)
plt.plot(np.zeros(2), np.array(amp_bounds), 'k--', lw=0.5)
plt.plot(m_fir[~show_units, 0], m_fir[~show_units, 1], '.', color=dotcolor_ns)
plt.plot(m_fir[show_units, 0], m_fir[show_units, 1], '.', color=dotcolor)
plt.title('STP STRF n={}/{} good units'.format(np.sum(show_units),
                                               np.sum(good_pred)))
plt.xlabel('bigger channel gain')
plt.ylabel('smaller channel gain')
ax.set_aspect('equal', 'box')
nplt.ax_remove_box(ax)

ax = fh0.add_subplot(3, 2, 3)
stateplots.beta_comp(r0_ceiling_mtx[good_pred],
                     r_ceiling_mtx[good_pred],
                     n1='LN STRF',
                     n2='RW3 STP STRF',
                     hist_range=[0.0, 1.0],
                     ax=ax)
ax.set_title('good_pred: {}/{}'.format(np.sum(good_pred), len(r0_ceiling_mtx)))
ax = fh0.add_subplot(3, 2, 4)
stateplots.beta_comp(r0_test_mtx[good_pred],
                     r_test_mtx[good_pred],
                     n1='LN STRF',
                     n2='RW3 STP STRF',
                     hist_range=[0.0, 1.0],
                     ax=ax)

ax = fh0.add_subplot(3, 2, 5)
F0 = np.concatenate(fir0, axis=0)
plt.hist(F0.flatten())
Example #5
0
beta2_test = df_r[n2]
se1 = df_e[n1]
se2 = df_e[n2]

beta1[beta1 > 1] = 1
beta2[beta2 > 1] = 1

# test for significant improvement
improvedcells = (beta2_test - se2 > beta1_test + se1)

# test for signficant prediction at all
goodcells = ((beta2_test > se2 * 3) | (beta1_test > se1 * 3))

fh1 = stateplots.beta_comp(beta1[goodcells],
                           beta2[goodcells],
                           n1='LN STRF',
                           n2='RW3 STP STRF',
                           hist_range=xc_range,
                           highlight=improvedcells[goodcells])
#fh1 = stateplots.beta_comp(beta1, beta2,
#                           n1='LN STRF', n2='RW3 STP STRF',
#                           hist_range=xc_range,
#                           highlight=improvedcells)

fh2 = plt.figure(figsize=(3.5, 3))
m = np.array(df.loc[goodcells].mean()[modelnames])
plt.bar(np.arange(len(modelnames)), m, color='black')
plt.plot(np.array([-1, len(modelnames)]), np.array([0, 0]), 'k--')
plt.ylim((-.05, 0.8))
plt.title("batch {}, n={}/{} good cells".format(batch, np.sum(goodcells),
                                                len(goodcells)))
plt.ylabel('median pred corr')
Example #6
0
                                                batch=batch,
                                                modelname=modelname)
        ctx['modelspec'].quickplot(rec=ctx['val'])

    figure, axes = plt.subplots(2, 3, figsize=(8, 5))

    ax = axes[0, 0]

    bound = 1.2
    histbins = np.linspace(-0.5, 0.5, 21)

    beta_comp(b,
              f,
              n1='bg',
              n2='fg',
              hist_range=[-bound, bound],
              ax=ax,
              click_fun=_rdt_info,
              highlight=si,
              title=bs)

    ax = axes[0, 1]

    stat, p = wilcoxon(f, b)
    md = np.mean(f - b)
    rg[batch] = gdiff

    list_of_tuples = list(
        zip(cellids, tar_id, f, b, gdiff, r_test, r_test_S, r_test_SR))
    rgdf[batch] = pd.DataFrame(
        list_of_tuples,
cellids = t.index
diff = t['r_pb'] - t['r_p0b0']

# screen for signficant cellids
cellids_sig = cellids[diff > np.std(diff)]
sig_mod = (diff > np.std(diff))

# filter mod results by significant cellids
t_filtered = t.T.filter(cellids_sig).T

# =================   Plot master summary figure of single cells  ==============
# for all cells (compare sig vs. not sig)
f = stateplots.beta_comp(t['pup_mod'],
                         t['beh_mod'],
                         n1='pupil',
                         n2='active',
                         title='mod index',
                         hist_range=[-0.4, 0.4],
                         highlight=sig_mod)
# just for significant cells (compare fs vs. rs)
fs = []
for cid in t_filtered.index:
    try:
        if nd.get_wft(cid) == 1:
            fs.append(1)
        else:
            fs.append(0)
    except:
        fs.append(0)

f2 = stateplots.beta_comp(t_filtered['pup_mod'],
Example #8
0
beta1[beta1 > 1] = 1
beta2[beta2 > 1] = 1

# test for significant improvement
improvedcells = (beta2_test - se2 > beta1_test + se1)

# test for signficant prediction at all
goodcells = ((beta2_test > se2 * 3) | (beta1_test > se1 * 3))

fh = plt.figure()
ax = plt.subplot(2, 2, 1)
stateplots.beta_comp(beta1[goodcells],
                     beta2[goodcells],
                     n1='LN STRF',
                     n2='STP+BEH LN STRF',
                     hist_range=xc_range,
                     ax=ax,
                     highlight=improvedcells[goodcells])

# LN vs. STP:
beta1b = df[modelnames[3]]
beta1a = df[modelnames[0]]
beta1 = beta1b - beta1a
se1a = df_e[modelnames[3]]
se1b = df_e[modelnames[0]]

b1 = 1
b0 = 0
beta2b = df[modelnames[b1]]
beta2a = df[modelnames[b0]]
Example #9
0
    df = get_model_results_per_state_model(batch=batch,
                                           state_list=state_list,
                                           basemodel=basemodel)

    # figure out what cells show significant state ef
    da = df[df['state_chan'] == 'pupil']
    dp = pd.pivot_table(da,
                        index='cellid',
                        columns='state_sig',
                        values=['r', 'r_se'])
    dr = dp['r'].copy()
    dr['b_unique'] = dr[state_list[3]]**2 - dr[state_list[2]]**2
    dr['p_unique'] = dr[state_list[3]]**2 - dr[state_list[1]]**2
    dr['bp_common'] = dr[state_list[3]]**2 - dr[
        state_list[0]]**2 - dr['b_unique'] - dr['p_unique']
    dr['bp_full'] = dr['b_unique'] + dr['p_unique'] + dr['bp_common']
    dr['null'] = dr[state_list[0]]**2 * np.sign(dr[state_list[0]])
    dr['full'] = dr[state_list[3]]**2 * np.sign(dr[state_list[3]])

    dr['sig']=((dp['r'][state_list[3]]-dp['r'][state_list[0]]) > \
         (dp['r_se'][state_list[3]]+dp['r_se'][state_list[0]]))
    plt.close('all')
    fig = plt.figure()
    ax = plt.subplot(1, 1, 1)
    beta_comp(dr['p_unique'],
              dr['b_unique'],
              n1='Pupil',
              n2='Behavior',
              highlight=dr['sig'],
              ax=ax,
              hist_range=[-0.05, 0.15])
Example #10
0
            cid.extend([m['meta']['cellid']] * len(s))
            r_test = np.append(r_test, s * r)
            se_test = np.append(se_test, s * se)
            r_test_S = np.append(r_test_S,
                                 s * modelspecs_shf[i].meta['r_test'][0])
            se_test_S = np.append(se_test_S,
                                  s * modelspecs_shf[i].meta['se_test'][0])

    si = (r_test - r_test_S) > (se_test + se_test_S)

    def _rdt_info(i):
        print("{}: f={:.3} b={:.3}".format(cid[i], f[i], b[i]))
        cellid = cid[i]
        xfspec, ctx = nw.load_model_baphy_xform(cellid,
                                                batch=batch,
                                                modelname=modelname)
        ctx['modelspec'].quickplot(rec=ctx['val'])

    #plt.figure()
    #ax=plt.subplot(1,1,1)
    fig = beta_comp(b,
                    f,
                    n1='bg',
                    n2='fg',
                    hist_range=[-0.75, 0.75],
                    click_fun=_rdt_info,
                    highlight=si,
                    title=bs)

    fig.savefig(outpath + 'gain_comp_' + keywordstring + '_' + bs + '.png')
Example #11
0
def aud_vs_state(df, nb=5, title=None, state_list=None):
    """
    d = dataframe output by get_model_results_per_state_model()
    nb = number of bins
    """

    plt.figure(figsize=(4, 6))

    da = df[df['state_chan'] == 'active']

    dp = da.pivot(index='cellid', columns='state_sig', values=['r', 'r_se'])

    dr = dp['r'].copy()
    dr['b_unique'] = dr[state_list[3]]**2 - dr[state_list[2]]**2
    dr['p_unique'] = dr[state_list[3]]**2 - dr[state_list[1]]**2
    dr['bp_common'] = dr[state_list[3]]**2 - dr[
        state_list[0]]**2 - dr['b_unique'] - dr['p_unique']
    dr['bp_full'] = dr['b_unique'] + dr['p_unique'] + dr['bp_common']
    dr['null'] = dr[state_list[0]]**2 * np.sign(dr[state_list[0]])
    dr['full'] = dr[state_list[3]]**2 * np.sign(dr[state_list[3]])

    dr['sig']=((dp['r'][state_list[3]]-dp['r'][state_list[0]]) > \
         (dp['r_se'][state_list[3]]+dp['r_se'][state_list[0]]))

    #dm = dr.loc[dr['sig'].values,['null','full','bp_common','p_unique','b_unique']]
    dm = dr.loc[:,
                ['null', 'full', 'bp_common', 'p_unique', 'b_unique', 'sig']]
    dm = dm.sort_values(['null'])
    mfull = dm[['null', 'full', 'bp_common', 'p_unique', 'b_unique',
                'sig']].values

    if nb > 0:
        stepsize = mfull.shape[0] / nb
        mm = np.zeros((nb, mfull.shape[1]))
        for i in range(nb):
            #x0=int(np.floor(i*stepsize))
            #x1=int(np.floor((i+1)*stepsize))
            #mm[i,:]=np.mean(m[x0:x1,:],axis=0)
            x01 = (mfull[:, 0] > i / nb) & (mfull[:, 0] <= (i + 1) / nb)
            mm[i, :] = np.nanmean(mfull[x01, :], axis=0)
        print(np.round(mm, 3))

        m = mm.copy()
    else:
        # alt to look at each cell individually:
        m = mfull.copy()

    mb = m[:, 2:]

    ax1 = plt.subplot(3, 1, 1)
    stateplots.beta_comp(mfull[:, 0],
                         mfull[:, 1],
                         n1='State independent',
                         n2='Full state-dep',
                         ax=ax1,
                         highlight=dm['sig'],
                         hist_range=[-0.1, 1])

    ax2 = plt.subplot(3, 1, 2)
    ind = np.arange(mb.shape[0])
    width = 0.8
    #ind = m[:,0]
    p1 = plt.bar(ind, mb[:, 0], width=width)
    p2 = plt.bar(ind, mb[:, 1], width=width, bottom=mb[:, 0])
    p3 = plt.bar(ind, mb[:, 2], width=width, bottom=mb[:, 0] + mb[:, 1])
    plt.legend(('common', 'p_unique', 'b-unique'))
    if title is not None:
        plt.title(title)
    plt.xlabel('behavior-independent quintile')
    plt.ylabel('mean r2')

    ax3 = plt.subplot(3, 1, 3)
    ind = np.arange(mb.shape[0])
    #ind = m[:,0]
    p1 = plt.plot(ind, mb[:, 0])
    p2 = plt.plot(ind, mb[:, 1] + mb[:, 0])
    p3 = plt.plot(ind, mb[:, 2] + mb[:, 0] + mb[:, 1])
    plt.legend(('common', 'p_unique', 'b-unique'))
    plt.xlabel('behavior-independent quintile')
    plt.ylabel('mean r2')

    plt.tight_layout()
    return ax1, ax2, ax3
Example #12
0
def aud_vs_state(df, nb=5, title=None, state_list=None,
                 colors=['r','g','b','k'], norm_by_null=False):
    """
    d = dataframe output by get_model_results_per_state_model()
    nb = number of bins
    """
    if state_list is None:
        state_list = ['st.pup0.beh0','st.pup0.beh','st.pup.beh0','st.pup.beh']
    
    f = plt.figure(figsize=(5.0,5.0))

    dr = df.copy()

    if len(state_list)==4:
        dr['bp_common'] = dr['r_full'] - df['r_task_unique'] - df['r_pupil_unique'] - dr['r_shuff']
        dr = dr.sort_values('r_shuff')
        mfull = dr[['r_shuff', 'r_full', 'bp_common', 'r_task_unique', 'r_pupil_unique', 'sig_state']].values

    elif len(state_list)==2:
        dr['bp_common'] = dr['r_full'] - dr['r_shuff']
        dr = dr.sort_values('r_shuff')
        dr['b_unique'] = dr['bp_common']*0
        dr['p_unique'] = dr['bp_common']*0
        mfull=dr[['r_shuff', 'r_full', 'bp_common', 'b_unique', 'p_unique', 'sig_state']].values

    mfull=mfull.astype(float)
    if nb > 0:
        mm=np.zeros((nb,mfull.shape[1]))
        for i in range(nb):
            x01=(mfull[:,0]>i/nb) & (mfull[:,0]<=(i+1)/nb)
            if np.sum(x01):
                mm[i,:]=np.nanmean(mfull[x01,:],axis=0)

        print(np.round(mm,3))

        m = mm.copy()
    else:
        # alt to look at each cell individually:
        m = mfull.copy()

    mall = np.nanmean(mfull, axis=0, keepdims=True)

    # remove sensory component, which swamps everything else
    mall = mall[:, 2:]
    mb=m[:,2:]

    ax1 = plt.subplot(2,2,1)
    stateplots.beta_comp(mfull[:,0],mfull[:,1],n1='State independent',n2='Full state-dep',
                         ax=ax1, highlight=mfull[:, -1], hist_range=[-0.1, 1])

    plt.subplot(2,2,3)
    width=0.8
    mplots=np.concatenate((mall, mb), axis=0)
    ind = np.arange(mplots.shape[0])

    plt.bar(ind, mplots[:,0], width=width, color=colors[1])
    plt.bar(ind, mplots[:,1], width=width, bottom=mplots[:,0], color=colors[2])
    plt.bar(ind, mplots[:,2], width=width, bottom=mplots[:,0]+mplots[:,1], color=colors[3])
    plt.legend(('common','b-unique','p_unique'))
    if title is not None:
        plt.title(title)
    plt.xlabel('behavior-independent quintile')
    plt.ylabel('mean r2')

    ax3 = plt.subplot(2,2,2)
    if norm_by_null:
        d=(mfull[:,1]-mfull[:,0]) / (1-np.abs(mfull[:,0]))
        ylabel = "dep-indep normed"
    else:
        d=(mfull[:,1]-mfull[:,0])
        ylabel = "dep-indep"
    stateplots.beta_comp(mfull[:,0], d, n1='State independent',n2=ylabel,
                     ax=ax3, highlight=mfull[:,-1], hist_range=[-0.1, 1], markersize=4)
    if not norm_by_null:
        ax3.plot([1,0], [0,1], 'k--', linewidth=0.5)
        
    slope, intercept, r, p, std_err = st.linregress(mfull[:,0],d)

    dr['site'] = [c[:7] for c in dr.index.get_level_values(0)]
    
    x = get_bootstrapped_sample({s: mfull[(dr.site==s).values, 0] for s in dr.site.unique()}, 
                                        {s: d[(dr.site==s).values] for s in dr.site.unique()}, metric='corrcoef', nboot=10000)
    pboot, _ = get_direct_prob(x, np.zeros(x.shape[0]))
    
    mm = np.array([np.min(mfull[:,0]), np.max(mfull[:,0])])
    ax3.plot(mm,intercept+slope*mm,'k--', linewidth=0.5)
    plt.title('n={} cc={:.3} p={:.4}, pboot={:.5f}'.format(len(d),r,p,1-pboot),fontsize=7)

    ax4 = plt.subplot(2,2,4)
    if norm_by_null:
        d=(mfull[:,1]-mfull[:,0]) / (1-np.abs(mfull[:,0]))
        ylabel = "dep-indep normed"
    else:
        d=(mfull[:,1]-mfull[:,0])
        ylabel = "dep-indep"
    snr = np.log(dr['SNR'].values)
    _ok = np.isfinite(d) & np.isfinite(snr)
    ax4.plot(snr[_ok], d[_ok], 'k.', markersize=4)
    #stateplots.beta_comp(snr[_ok], d[_ok], n1='SNR',n2='dep - indep',
    #                 ax=ax4, highlight=mfull[_ok,-1], hist_range=[-0.1, 1], markersize=4)
    slope, intercept, r, p, std_err = st.linregress(snr[_ok], d[_ok])
        
    x = get_bootstrapped_sample({s: snr[(dr.site==s).values & _ok] for s in dr.site.unique()}, 
                                        {s: d[(dr.site==s).values & _ok] for s in dr.site.unique()}, metric='corrcoef', nboot=10000)
    pboot, _ = get_direct_prob(x, np.zeros(x.shape[0]))

    mm = np.array([np.min(snr[_ok]), np.max(snr[_ok])])
    ax4.plot(mm,intercept+slope*mm,'k--', linewidth=0.5)
    ax4.set_xlabel('log(SNR)')
    ax4.set_ylabel(ylabel)
    ax4.set_title('n={} cc={:.3} p={:.4}, pboot={:.5f}'.format(len(d),r,p, 1-pboot),fontsize=7)
    nplt.ax_remove_box(ax4)

    f.tight_layout()

    return f
Example #13
0
def aud_vs_state(df,
                 nb=5,
                 title=None,
                 state_list=None,
                 colors=['r', 'g', 'b', 'k']):
    """
    d = dataframe output by get_model_results_per_state_model()
    nb = number of bins
    """
    if state_list is None:
        state_list = [
            'st.pup0.beh0', 'st.pup0.beh', 'st.pup.beh0', 'st.pup.beh'
        ]

    f = plt.figure(figsize=(4, 6))

    da = df[df['state_chan'] == 'active']

    dp = da.pivot(index='cellid', columns='state_sig', values=['r', 'r_se'])

    dr = dp['r'].copy()

    if len(state_list) == 4:
        dr['b_unique'] = dr[state_list[3]]**2 - dr[state_list[2]]**2
        dr['p_unique'] = dr[state_list[3]]**2 - dr[state_list[1]]**2
        dr['bp_common'] = dr[state_list[3]]**2 - dr[
            state_list[0]]**2 - dr['b_unique'] - dr['p_unique']
        dr['bp_full'] = dr['b_unique'] + dr['p_unique'] + dr['bp_common']
        dr['null'] = dr[state_list[0]]**2 * np.sign(dr[state_list[0]])
        dr['full'] = dr[state_list[3]]**2 * np.sign(dr[state_list[3]])

        dr['sig']=((dp['r'][state_list[3]]-dp['r'][state_list[0]]) > \
             (dp['r_se'][state_list[3]]+
              dp['r_se'][state_list[0]]))

        #dm = dr.loc[dr['sig'].values,['null','full','bp_common','p_unique','b_unique']]
        dm = dr.loc[:, [
            'null', 'full', 'bp_common', 'b_unique', 'p_unique', 'sig'
        ]]
        dm = dm.sort_values(['null'])
        mfull = dm[[
            'null', 'full', 'bp_common', 'b_unique', 'p_unique', 'sig'
        ]].values

    elif len(state_list) == 2:
        dr['bp_common'] = dr[state_list[1]]**2 - dr[state_list[0]]**2
        dr['b_unique'] = dr['bp_common'] * 0
        dr['p_unique'] = dr['bp_common'] * 0

        dr['bp_full'] = dr['b_unique'] + dr['p_unique'] + dr['bp_common']
        dr['null'] = dr[state_list[0]]**2 * np.sign(dr[state_list[0]])
        dr['full'] = dr[state_list[1]]**2 * np.sign(dr[state_list[1]])

        dr['sig']=((dp['r'][state_list[1]]-dp['r'][state_list[0]]) > \
             (dp['r_se'][state_list[1]]+
              dp['r_se'][state_list[0]]))
        dr['cellid'] = dp['r'][state_list[1]].index
        #dm = dr.loc[dr['sig'].values,['null','full','bp_common','p_unique','b_unique']]
        dm = dr.loc[:, [
            'cellid', 'null', 'full', 'bp_common', 'b_unique', 'p_unique',
            'sig'
        ]]
        dm = dm.sort_values(['null'])
        mfull = dm[[
            'null', 'full', 'bp_common', 'b_unique', 'p_unique', 'sig'
        ]].values
        cellids = dm['cellid'].to_list()

        big_idx = mfull[:, 1] - mfull[:, 0] > 0.2
        for i, b in enumerate(big_idx):
            if b:
                print('{} : {:.3f} - {:.3f}'.format(cellids[i], mfull[i, 0],
                                                    mfull[i, 1]))

    if nb > 0:
        stepsize = mfull.shape[0] / nb
        mm = np.zeros((nb, mfull.shape[1]))
        for i in range(nb):
            #x0=int(np.floor(i*stepsize))
            #x1=int(np.floor((i+1)*stepsize))
            #mm[i,:]=np.mean(m[x0:x1,:],axis=0)
            x01 = (mfull[:, 0] > i / nb) & (mfull[:, 0] <= (i + 1) / nb)
            if np.sum(x01):
                mm[i, :] = np.nanmean(mfull[x01, :], axis=0)

        print(np.round(mm, 3))

        m = mm.copy()
    else:
        # alt to look at each cell individually:
        m = mfull.copy()

    mall = np.nanmean(mfull, axis=0, keepdims=True)

    # remove sensory component, which swamps everything else
    mall = mall[:, 2:]
    mb = m[:, 2:]

    ax1 = plt.subplot(3, 1, 1)
    stateplots.beta_comp(mfull[:, 0],
                         mfull[:, 1],
                         n1='State independent',
                         n2='Full state-dep',
                         ax=ax1,
                         highlight=dm['sig'],
                         hist_range=[-0.1, 1])

    ax2 = plt.subplot(3, 1, 2)
    width = 0.8
    #ind = m[:,0]
    mplots = np.concatenate((mall, mb), axis=0)
    ind = np.arange(mplots.shape[0])

    p1 = plt.bar(ind, mplots[:, 0], width=width, color=colors[1])
    p2 = plt.bar(ind,
                 mplots[:, 1],
                 width=width,
                 bottom=mplots[:, 0],
                 color=colors[2])
    p3 = plt.bar(ind,
                 mplots[:, 2],
                 width=width,
                 bottom=mplots[:, 0] + mplots[:, 1],
                 color=colors[3])
    plt.legend(('common', 'b-unique', 'p_unique'))
    if title is not None:
        plt.title(title)
    plt.xlabel('behavior-independent quintile')
    plt.ylabel('mean r2')

    ax3 = plt.subplot(3, 1, 3)
    d = (mfull[:, 1] - mfull[:, 0])  #/(1-np.abs(mfull[:,0]))
    stateplots.beta_comp(mfull[:, 0],
                         d,
                         n1='State independent',
                         n2='dep - indep',
                         ax=ax3,
                         highlight=dm['sig'],
                         hist_range=[-0.1, 1],
                         markersize=4)
    ax3.plot([1, 0], [0, 1], 'k--', linewidth=0.5)
    r, p = st.pearsonr(mfull[:, 0], d)
    plt.title('cc={:.3} p={:.4}'.format(r, p))

    #ind = np.arange(mb.shape[0])
    ##ind = m[:,0]
    #p1 = plt.plot(ind, mb[:,0])
    #p2 = plt.plot(ind, mb[:,1]+mb[:,0])
    #p3 = plt.plot(ind, mb[:,2]+mb[:,0]+mb[:,1])
    #plt.legend(('common','p_unique','b-unique'))
    #plt.xlabel('behavior-independent quintile')
    #plt.ylabel('mean r2')

    plt.tight_layout()
    return f
Example #14
0
state_mod = np.stack(state_mod)
r_test = np.stack(r_test)
se_test = np.stack(se_test)

u_state_mod = state_mod[[sv_len],:] - state_mod[:sv_len, :]
u_r_test = r_test[[sv_len],:] - r_test[:sv_len,:]
u_r_good = u_r_test > se_test[:sv_len,:]

plt.close('all')
plt.figure()


for i,p in enumerate(sv_pairs):
    ax = plt.subplot(len(sv_pairs),3,i*3+1)
    stateplots.beta_comp(u_r_test[p[0],:], u_r_test[p[1],:],
                         n1=statevars[p[0]], n2=statevars[p[1]],
                         title='u r_test', hist_range=[-0.05, 0.15],
                         ax=ax, highlight=(u_r_good[p[0],:] | u_r_good[p[1],:]))

    ax = plt.subplot(len(sv_pairs),3,i*3+2)
    stateplots.beta_comp(u_state_mod[p[0],:,p[0]+1], u_state_mod[p[1],:,p[1]+1],
                         n1=statevars[p[0]], n2=statevars[p[1]],
                         title='u state_mod', hist_range=[-0.6, 0.6],
                         ax=ax, highlight=(u_r_good[p[0],:] | u_r_good[p[1],:]))
#    plt.plot(tvr[np.logical_not(u_r_good[p[0],:])],
#             state_mod[-1,np.logical_not(u_r_good[p[0],:]),p[0]+1],
#             '.', color='LightGray')
#    plt.plot(tvr[u_r_good[p[0],:]], state_mod[-1,u_r_good[p[0],:],p[0]+1], 'k.')
#    plt.xlabel('tvr')
#    plt.ylabel("raw state_mod " + statevars[p[0]])

    ax = plt.subplot(len(sv_pairs),3,i*3+3)