def beh_only_plot(batch=311): # 311 = A1 old (SVD) data -- on BF state_list = ['st.beh0', 'st.beh'] basemodel = "-ref-psthfr.s_stategain.S" loader = "psth.fs20-ld-" fitter = "_jk.nf20-basic" df = get_model_results_per_state_model(batch=batch, state_list=state_list, basemodel=basemodel, fitter="_jk.nf20-basic", loader=loader) da = df[df['state_chan'] == 'active'] dp = da.pivot(index='cellid', columns='state_sig', values=['r', 'r_se', 'MI', 'g', 'd']) dr = dp['r'].copy() dr['sig']=((dp['r'][state_list[1]]-dp['r'][state_list[0]]) > \ (dp['r_se'][state_list[1]]+dp['r_se'][state_list[0]])) g = dp['g'].copy() d = dp['d'].copy() ggood = np.isfinite(g['st.beh']) stateplots.beta_comp(d.loc[ggood, 'st.beh'], g.loc[ggood, 'st.beh'], n1='Baseline', n2='Gain', title="Baseline/gain: batch {}".format(batch), highlight=dr.loc[ggood, 'sig'], hist_range=[-1, 1]) MI = dp['MI'].copy() migood = np.isfinite(MI['st.beh']) stateplots.beta_comp(MI.loc[migood, 'st.beh0'], MI.loc[migood, 'st.beh'], n1='State independent', n2='State-dep', title="MI: batch {}".format(batch), highlight=dr.loc[migood, 'sig'], hist_range=[-0.5, 0.5]) return df
def plot_save_examples(batch, compare, loader, basemodel, fitter, RELOAD=False): if batch in [301, 307]: area = "AC" else: area = "IC" d = nd.get_batch_cells(batch) cellids = list(d['cellid']) stats_list = [] root_path = '/auto/users/svd/projects/pupil-behavior' modelset = '{}_{}_{}_{}_{}_{}'.format(compare, area, batch, loader, basemodel, fitter) out_path = '{}/{}/'.format(root_path, modelset) if os.access(root_path, os.W_OK) and not (os.path.exists(out_path)): os.makedirs(out_path) datafile = out_path + 'results.csv' plt.close('all') if (not RELOAD) and (not os.path.isfile(datafile)): RELOAD = True print('datafile not found, reloading') if RELOAD: for cellid in cellids: if compare == "pb": fh, stats = stateplots.pb_model_plot(cellid, batch, loader=loader, basemodel=basemodel, fitter=fitter) elif compare == "ppas": fh, stats = stateplots.ppas_model_plot(cellid, batch, loader=loader, basemodel=basemodel, fitter=fitter) else: fh, stats = stateplots.pp_model_plot(cellid, batch, loader=loader, basemodel=basemodel, fitter=fitter) # fh2 = stateplots.pp_model_plot(cellid,batch) stats_list.append(stats) if os.access(out_path, os.W_OK): fh.savefig(out_path + cellid + '.pdf') fh.savefig(out_path + cellid + '.png') plt.close(fh) col_names = [ 'cellid', 'r_p0b0', 'r_p0b', 'r_pb0', 'r_pb', 'e_p0b0', 'e_p0b', 'e_pb0', 'e_pb', 'rf_p0b0', 'rf_p0b', 'rf_pb0', 'rf_pb', 'r_pup', 'r_beh', 'r_beh_pup0', 'pup_mod', 'beh_mod', 'pup_mod_n', 'beh_mod_n', 'pup_mod_beh0', 'beh_mod_pup0', 'pup_mod_beh0_n', 'beh_mod_pup0_n', 'd_pup', 'd_beh', 'g_pup', 'g_beh', 'ref_all_resp', 'ref_common_resp', 'tar_max_resp', 'tar_probe_resp' ] df = pd.DataFrame(columns=col_names) for stats in stats_list: df0 = pd.DataFrame([[ stats['cellid'], stats['r_test'][0], stats['r_test'][1], stats['r_test'][2], stats['r_test'][3], stats['se_test'][0], stats['se_test'][1], stats['se_test'][2], stats['se_test'][3], stats['r_floor'][0], stats['r_floor'][1], stats['r_floor'][2], stats['r_floor'][3], stats['r_test'][3] - stats['r_test'][1], stats['r_test'][3] - stats['r_test'][2], stats['r_test'][1] - stats['r_test'][0], stats['pred_mod'][0, 1], stats['pred_mod'][1, 2], stats['pred_mod_norm'][0, 1], stats['pred_mod_norm'][1, 2], stats['pred_mod_full'][0, 1], stats['pred_mod_full'][1, 2], stats['pred_mod_full_norm'][0, 1], stats['pred_mod_full_norm'][1, 2], stats['b'][3, 1], stats['b'][3, 2], stats['g'][3, 1], stats['g'][3, 2], stats['ref_all_resp'], stats['ref_common_resp'], stats['tar_max_resp'], stats['tar_probe_resp'] ]], columns=col_names) df = df.append(df0) df.set_index(['cellid'], inplace=True) if os.access(out_path, os.W_OK): df.to_csv(datafile) else: # load cached dataframe df = pd.read_csv(datafile, index_col=0) sig_mod = list(df['r_pb'] - df['e_pb'] > df['r_p0b0'] + df['e_p0b0']) if compare == "pb": alabel = "active" elif compare == "ppas": alabel = "each passive" else: alabel = "pre/post" mi_bounds = [-0.4, 0.4] fh1 = stateplots.beta_comp(df['r_pup'], df['r_beh'], n1='pupil', n2=alabel, title=modelset + ' unique pred', hist_range=[-0.02, 0.15], highlight=sig_mod) fh2 = stateplots.beta_comp(df['pup_mod_n'], df['beh_mod'], n1='pupil', n2=alabel, title=modelset + ' mod index', hist_range=mi_bounds, highlight=sig_mod) fh3 = stateplots.beta_comp(df['beh_mod_pup0'], df['beh_mod'], n1=alabel + '-nopup', n2=alabel, title=modelset + ' unique mod', hist_range=mi_bounds, highlight=sig_mod) # unique behavior: # performance of full model minus performance behavior shuffled # r_beh_with_pupil = df['r_pb'] - df['r_pb0'] # naive behavior (ignorant of pupil): # performance of behavior alone (pupil shuffled) minus all shuffled # r_beh_no_pupil = df['r_p0b'] - df['r_p0b0'] fh4 = stateplots.beta_comp(df['r_beh_pup0'], df['r_beh'], n1=alabel + '-nopup', n2=alabel, title=modelset + ' unique r', hist_range=[-0.02, .15], highlight=sig_mod) #fh4 = stateplots.beta_comp(df['r_beh'], df['beh_mod'], n1='pred', n2='mod', # title='behavior', hist_range=[-0.4, 0.4]) #fh5 = stateplots.beta_comp(df['r_pup'], df['pup_mod'], n1='pred', n2='mod', # title='pupil', hist_range=[-0.1, 0.1]) if os.access(out_path, os.W_OK): fh1.savefig(out_path + 'summary_pred.pdf') fh2.savefig(out_path + 'summary_mod.pdf') fh3.savefig(out_path + 'summary_mod_ctl.pdf') fh4.savefig(out_path + 'summary_r_ctl.pdf')
def stp_v_beh(): batch1 = 274 batch2 = 275 modelnames=["env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-fir.1x15-lvl.1-dexp.1_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-fir.1x15-lvl.1-rep.2-dexp.2-mrg_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-rep.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-fir.1x15-lvl.1-dexp.1_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-fir.1x15-lvl.1-rep.2-dexp.2-mrg_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-rep.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-rep.2-stp.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic"] fileprefix="fig8.stp_v_beh" n1=modelnames[0] n2=modelnames[-3] xc_range = [-0.05, 0.6] df1 = nd.batch_comp(batch1,modelnames,stat='r_test').reset_index() df1_e = nd.batch_comp(batch1,modelnames,stat='se_test').reset_index() df2 = nd.batch_comp(batch2,modelnames,stat='r_test').reset_index() df2_e = nd.batch_comp(batch2,modelnames,stat='se_test').reset_index() df = df1.append(df2) df_e = df1_e.append(df2_e) cellcount = len(df) beta1 = df[n1] beta2 = df[n2] beta1_test = df[n1] beta2_test = df[n2] se1 = df_e[n1] se2 = df_e[n2] beta1[beta1>1]=1 beta2[beta2>1]=1 # test for significant improvement improvedcells = (beta2_test-se2 > beta1_test+se1) # test for signficant prediction at all goodcells = ((beta2_test > se2*3) | (beta1_test > se1*3)) fh = plt.figure() ax = plt.subplot(2,2,1) stateplots.beta_comp(beta1[goodcells], beta2[goodcells], n1='LN STRF', n2='STP+BEH LN STRF', hist_range=xc_range, ax=ax, highlight=improvedcells[goodcells]) # LN vs. STP: beta1b = df[modelnames[3]] beta1a = df[modelnames[0]] beta1 = beta1b - beta1a se1a = df_e[modelnames[3]] se1b= df_e[modelnames[0]] b1=4 b0=3 beta2b = df[modelnames[b1]] beta2a = df[modelnames[b0]] beta2 = beta2b - beta2a se2a = df_e[modelnames[b1]] se2b= df_e[modelnames[b0]] stpgood = (beta1 > se1a+se1b) behgood = (beta2 > se2a+se2b) neither_good = np.logical_not(stpgood) & np.logical_not(behgood) both_good = stpgood & behgood stp_only_good = stpgood & np.logical_not(behgood) beh_only_good = np.logical_not(stpgood) & behgood xc_range = np.array([-0.05, 0.15]) beta1[beta1<xc_range[0]]=xc_range[0] beta2[beta2<xc_range[0]]=xc_range[0] zz = np.zeros(2) ax=plt.subplot(2,2,2) ax.plot(xc_range,zz,'k--',linewidth=0.5) ax.plot(zz,xc_range,'k--',linewidth=0.5) ax.plot(xc_range, xc_range, 'k--', linewidth=0.5) l = ax.plot(beta1[neither_good], beta2[neither_good], '.', color='lightgray') +\ ax.plot(beta1[beh_only_good], beta2[beh_only_good], '.', color='purple') +\ ax.plot(beta1[stp_only_good], beta2[stp_only_good], '.', color='orange') +\ ax.plot(beta1[both_good], beta2[both_good], '.', color='black') ax_remove_box(ax) ax.set_aspect('equal', 'box') #plt.axis('equal') ax.set_xlim(xc_range) ax.set_ylim(xc_range) ax.set_xlabel('delta(stp)') ax.set_ylabel('delta(beh)') olap=np.zeros(100) a = stpgood.values.copy() b = behgood.values.copy() for i in range(100): np.random.shuffle(a) olap[i] = np.sum(a & b) ll=[np.sum(neither_good), np.sum(beh_only_good), np.sum(stp_only_good), np.sum(both_good)] ax.legend(l, ll) ax=plt.subplot(2,2,3) m = np.array(df.loc[goodcells].mean()[modelnames]) xc_range = [-0.02, 0.2] plt.bar(np.arange(len(modelnames)), m, color='black') plt.plot(np.array([-1, len(modelnames)]), np.array([0, 0]), 'k--', linewidth=0.5) plt.ylim(xc_range) plt.title("batch {}, n={}/{} good cells".format( batch, np.sum(goodcells), len(goodcells))) plt.ylabel('median pred corr') plt.xlabel('model architecture') ax_remove_box(ax) for i in range(len(modelnames)-1): d1 = np.array(df[modelnames[i]]) d2 = np.array(df[modelnames[i+1]]) s, p = ss.wilcoxon(d1, d2) plt.text(i+0.5, m[i+1], "{:.1e}".format(p), ha='center', fontsize=6) plt.xticks(np.arange(len(m)),np.round(m,3)) return fh, df[stpgood]['cellid'].tolist()
'k--', lw=0.5) plt.plot(np.zeros(2), np.array(amp_bounds), 'k--', lw=0.5) plt.plot(m_fir[~show_units, 0], m_fir[~show_units, 1], '.', color=dotcolor_ns) plt.plot(m_fir[show_units, 0], m_fir[show_units, 1], '.', color=dotcolor) plt.title('STP STRF n={}/{} good units'.format(np.sum(show_units), np.sum(good_pred))) plt.xlabel('bigger channel gain') plt.ylabel('smaller channel gain') ax.set_aspect('equal', 'box') nplt.ax_remove_box(ax) ax = fh0.add_subplot(3, 2, 3) stateplots.beta_comp(r0_ceiling_mtx[good_pred], r_ceiling_mtx[good_pred], n1='LN STRF', n2='RW3 STP STRF', hist_range=[0.0, 1.0], ax=ax) ax.set_title('good_pred: {}/{}'.format(np.sum(good_pred), len(r0_ceiling_mtx))) ax = fh0.add_subplot(3, 2, 4) stateplots.beta_comp(r0_test_mtx[good_pred], r_test_mtx[good_pred], n1='LN STRF', n2='RW3 STP STRF', hist_range=[0.0, 1.0], ax=ax) ax = fh0.add_subplot(3, 2, 5) F0 = np.concatenate(fir0, axis=0) plt.hist(F0.flatten())
beta2_test = df_r[n2] se1 = df_e[n1] se2 = df_e[n2] beta1[beta1 > 1] = 1 beta2[beta2 > 1] = 1 # test for significant improvement improvedcells = (beta2_test - se2 > beta1_test + se1) # test for signficant prediction at all goodcells = ((beta2_test > se2 * 3) | (beta1_test > se1 * 3)) fh1 = stateplots.beta_comp(beta1[goodcells], beta2[goodcells], n1='LN STRF', n2='RW3 STP STRF', hist_range=xc_range, highlight=improvedcells[goodcells]) #fh1 = stateplots.beta_comp(beta1, beta2, # n1='LN STRF', n2='RW3 STP STRF', # hist_range=xc_range, # highlight=improvedcells) fh2 = plt.figure(figsize=(3.5, 3)) m = np.array(df.loc[goodcells].mean()[modelnames]) plt.bar(np.arange(len(modelnames)), m, color='black') plt.plot(np.array([-1, len(modelnames)]), np.array([0, 0]), 'k--') plt.ylim((-.05, 0.8)) plt.title("batch {}, n={}/{} good cells".format(batch, np.sum(goodcells), len(goodcells))) plt.ylabel('median pred corr')
batch=batch, modelname=modelname) ctx['modelspec'].quickplot(rec=ctx['val']) figure, axes = plt.subplots(2, 3, figsize=(8, 5)) ax = axes[0, 0] bound = 1.2 histbins = np.linspace(-0.5, 0.5, 21) beta_comp(b, f, n1='bg', n2='fg', hist_range=[-bound, bound], ax=ax, click_fun=_rdt_info, highlight=si, title=bs) ax = axes[0, 1] stat, p = wilcoxon(f, b) md = np.mean(f - b) rg[batch] = gdiff list_of_tuples = list( zip(cellids, tar_id, f, b, gdiff, r_test, r_test_S, r_test_SR)) rgdf[batch] = pd.DataFrame( list_of_tuples,
cellids = t.index diff = t['r_pb'] - t['r_p0b0'] # screen for signficant cellids cellids_sig = cellids[diff > np.std(diff)] sig_mod = (diff > np.std(diff)) # filter mod results by significant cellids t_filtered = t.T.filter(cellids_sig).T # ================= Plot master summary figure of single cells ============== # for all cells (compare sig vs. not sig) f = stateplots.beta_comp(t['pup_mod'], t['beh_mod'], n1='pupil', n2='active', title='mod index', hist_range=[-0.4, 0.4], highlight=sig_mod) # just for significant cells (compare fs vs. rs) fs = [] for cid in t_filtered.index: try: if nd.get_wft(cid) == 1: fs.append(1) else: fs.append(0) except: fs.append(0) f2 = stateplots.beta_comp(t_filtered['pup_mod'],
beta1[beta1 > 1] = 1 beta2[beta2 > 1] = 1 # test for significant improvement improvedcells = (beta2_test - se2 > beta1_test + se1) # test for signficant prediction at all goodcells = ((beta2_test > se2 * 3) | (beta1_test > se1 * 3)) fh = plt.figure() ax = plt.subplot(2, 2, 1) stateplots.beta_comp(beta1[goodcells], beta2[goodcells], n1='LN STRF', n2='STP+BEH LN STRF', hist_range=xc_range, ax=ax, highlight=improvedcells[goodcells]) # LN vs. STP: beta1b = df[modelnames[3]] beta1a = df[modelnames[0]] beta1 = beta1b - beta1a se1a = df_e[modelnames[3]] se1b = df_e[modelnames[0]] b1 = 1 b0 = 0 beta2b = df[modelnames[b1]] beta2a = df[modelnames[b0]]
df = get_model_results_per_state_model(batch=batch, state_list=state_list, basemodel=basemodel) # figure out what cells show significant state ef da = df[df['state_chan'] == 'pupil'] dp = pd.pivot_table(da, index='cellid', columns='state_sig', values=['r', 'r_se']) dr = dp['r'].copy() dr['b_unique'] = dr[state_list[3]]**2 - dr[state_list[2]]**2 dr['p_unique'] = dr[state_list[3]]**2 - dr[state_list[1]]**2 dr['bp_common'] = dr[state_list[3]]**2 - dr[ state_list[0]]**2 - dr['b_unique'] - dr['p_unique'] dr['bp_full'] = dr['b_unique'] + dr['p_unique'] + dr['bp_common'] dr['null'] = dr[state_list[0]]**2 * np.sign(dr[state_list[0]]) dr['full'] = dr[state_list[3]]**2 * np.sign(dr[state_list[3]]) dr['sig']=((dp['r'][state_list[3]]-dp['r'][state_list[0]]) > \ (dp['r_se'][state_list[3]]+dp['r_se'][state_list[0]])) plt.close('all') fig = plt.figure() ax = plt.subplot(1, 1, 1) beta_comp(dr['p_unique'], dr['b_unique'], n1='Pupil', n2='Behavior', highlight=dr['sig'], ax=ax, hist_range=[-0.05, 0.15])
cid.extend([m['meta']['cellid']] * len(s)) r_test = np.append(r_test, s * r) se_test = np.append(se_test, s * se) r_test_S = np.append(r_test_S, s * modelspecs_shf[i].meta['r_test'][0]) se_test_S = np.append(se_test_S, s * modelspecs_shf[i].meta['se_test'][0]) si = (r_test - r_test_S) > (se_test + se_test_S) def _rdt_info(i): print("{}: f={:.3} b={:.3}".format(cid[i], f[i], b[i])) cellid = cid[i] xfspec, ctx = nw.load_model_baphy_xform(cellid, batch=batch, modelname=modelname) ctx['modelspec'].quickplot(rec=ctx['val']) #plt.figure() #ax=plt.subplot(1,1,1) fig = beta_comp(b, f, n1='bg', n2='fg', hist_range=[-0.75, 0.75], click_fun=_rdt_info, highlight=si, title=bs) fig.savefig(outpath + 'gain_comp_' + keywordstring + '_' + bs + '.png')
def aud_vs_state(df, nb=5, title=None, state_list=None): """ d = dataframe output by get_model_results_per_state_model() nb = number of bins """ plt.figure(figsize=(4, 6)) da = df[df['state_chan'] == 'active'] dp = da.pivot(index='cellid', columns='state_sig', values=['r', 'r_se']) dr = dp['r'].copy() dr['b_unique'] = dr[state_list[3]]**2 - dr[state_list[2]]**2 dr['p_unique'] = dr[state_list[3]]**2 - dr[state_list[1]]**2 dr['bp_common'] = dr[state_list[3]]**2 - dr[ state_list[0]]**2 - dr['b_unique'] - dr['p_unique'] dr['bp_full'] = dr['b_unique'] + dr['p_unique'] + dr['bp_common'] dr['null'] = dr[state_list[0]]**2 * np.sign(dr[state_list[0]]) dr['full'] = dr[state_list[3]]**2 * np.sign(dr[state_list[3]]) dr['sig']=((dp['r'][state_list[3]]-dp['r'][state_list[0]]) > \ (dp['r_se'][state_list[3]]+dp['r_se'][state_list[0]])) #dm = dr.loc[dr['sig'].values,['null','full','bp_common','p_unique','b_unique']] dm = dr.loc[:, ['null', 'full', 'bp_common', 'p_unique', 'b_unique', 'sig']] dm = dm.sort_values(['null']) mfull = dm[['null', 'full', 'bp_common', 'p_unique', 'b_unique', 'sig']].values if nb > 0: stepsize = mfull.shape[0] / nb mm = np.zeros((nb, mfull.shape[1])) for i in range(nb): #x0=int(np.floor(i*stepsize)) #x1=int(np.floor((i+1)*stepsize)) #mm[i,:]=np.mean(m[x0:x1,:],axis=0) x01 = (mfull[:, 0] > i / nb) & (mfull[:, 0] <= (i + 1) / nb) mm[i, :] = np.nanmean(mfull[x01, :], axis=0) print(np.round(mm, 3)) m = mm.copy() else: # alt to look at each cell individually: m = mfull.copy() mb = m[:, 2:] ax1 = plt.subplot(3, 1, 1) stateplots.beta_comp(mfull[:, 0], mfull[:, 1], n1='State independent', n2='Full state-dep', ax=ax1, highlight=dm['sig'], hist_range=[-0.1, 1]) ax2 = plt.subplot(3, 1, 2) ind = np.arange(mb.shape[0]) width = 0.8 #ind = m[:,0] p1 = plt.bar(ind, mb[:, 0], width=width) p2 = plt.bar(ind, mb[:, 1], width=width, bottom=mb[:, 0]) p3 = plt.bar(ind, mb[:, 2], width=width, bottom=mb[:, 0] + mb[:, 1]) plt.legend(('common', 'p_unique', 'b-unique')) if title is not None: plt.title(title) plt.xlabel('behavior-independent quintile') plt.ylabel('mean r2') ax3 = plt.subplot(3, 1, 3) ind = np.arange(mb.shape[0]) #ind = m[:,0] p1 = plt.plot(ind, mb[:, 0]) p2 = plt.plot(ind, mb[:, 1] + mb[:, 0]) p3 = plt.plot(ind, mb[:, 2] + mb[:, 0] + mb[:, 1]) plt.legend(('common', 'p_unique', 'b-unique')) plt.xlabel('behavior-independent quintile') plt.ylabel('mean r2') plt.tight_layout() return ax1, ax2, ax3
def aud_vs_state(df, nb=5, title=None, state_list=None, colors=['r','g','b','k'], norm_by_null=False): """ d = dataframe output by get_model_results_per_state_model() nb = number of bins """ if state_list is None: state_list = ['st.pup0.beh0','st.pup0.beh','st.pup.beh0','st.pup.beh'] f = plt.figure(figsize=(5.0,5.0)) dr = df.copy() if len(state_list)==4: dr['bp_common'] = dr['r_full'] - df['r_task_unique'] - df['r_pupil_unique'] - dr['r_shuff'] dr = dr.sort_values('r_shuff') mfull = dr[['r_shuff', 'r_full', 'bp_common', 'r_task_unique', 'r_pupil_unique', 'sig_state']].values elif len(state_list)==2: dr['bp_common'] = dr['r_full'] - dr['r_shuff'] dr = dr.sort_values('r_shuff') dr['b_unique'] = dr['bp_common']*0 dr['p_unique'] = dr['bp_common']*0 mfull=dr[['r_shuff', 'r_full', 'bp_common', 'b_unique', 'p_unique', 'sig_state']].values mfull=mfull.astype(float) if nb > 0: mm=np.zeros((nb,mfull.shape[1])) for i in range(nb): x01=(mfull[:,0]>i/nb) & (mfull[:,0]<=(i+1)/nb) if np.sum(x01): mm[i,:]=np.nanmean(mfull[x01,:],axis=0) print(np.round(mm,3)) m = mm.copy() else: # alt to look at each cell individually: m = mfull.copy() mall = np.nanmean(mfull, axis=0, keepdims=True) # remove sensory component, which swamps everything else mall = mall[:, 2:] mb=m[:,2:] ax1 = plt.subplot(2,2,1) stateplots.beta_comp(mfull[:,0],mfull[:,1],n1='State independent',n2='Full state-dep', ax=ax1, highlight=mfull[:, -1], hist_range=[-0.1, 1]) plt.subplot(2,2,3) width=0.8 mplots=np.concatenate((mall, mb), axis=0) ind = np.arange(mplots.shape[0]) plt.bar(ind, mplots[:,0], width=width, color=colors[1]) plt.bar(ind, mplots[:,1], width=width, bottom=mplots[:,0], color=colors[2]) plt.bar(ind, mplots[:,2], width=width, bottom=mplots[:,0]+mplots[:,1], color=colors[3]) plt.legend(('common','b-unique','p_unique')) if title is not None: plt.title(title) plt.xlabel('behavior-independent quintile') plt.ylabel('mean r2') ax3 = plt.subplot(2,2,2) if norm_by_null: d=(mfull[:,1]-mfull[:,0]) / (1-np.abs(mfull[:,0])) ylabel = "dep-indep normed" else: d=(mfull[:,1]-mfull[:,0]) ylabel = "dep-indep" stateplots.beta_comp(mfull[:,0], d, n1='State independent',n2=ylabel, ax=ax3, highlight=mfull[:,-1], hist_range=[-0.1, 1], markersize=4) if not norm_by_null: ax3.plot([1,0], [0,1], 'k--', linewidth=0.5) slope, intercept, r, p, std_err = st.linregress(mfull[:,0],d) dr['site'] = [c[:7] for c in dr.index.get_level_values(0)] x = get_bootstrapped_sample({s: mfull[(dr.site==s).values, 0] for s in dr.site.unique()}, {s: d[(dr.site==s).values] for s in dr.site.unique()}, metric='corrcoef', nboot=10000) pboot, _ = get_direct_prob(x, np.zeros(x.shape[0])) mm = np.array([np.min(mfull[:,0]), np.max(mfull[:,0])]) ax3.plot(mm,intercept+slope*mm,'k--', linewidth=0.5) plt.title('n={} cc={:.3} p={:.4}, pboot={:.5f}'.format(len(d),r,p,1-pboot),fontsize=7) ax4 = plt.subplot(2,2,4) if norm_by_null: d=(mfull[:,1]-mfull[:,0]) / (1-np.abs(mfull[:,0])) ylabel = "dep-indep normed" else: d=(mfull[:,1]-mfull[:,0]) ylabel = "dep-indep" snr = np.log(dr['SNR'].values) _ok = np.isfinite(d) & np.isfinite(snr) ax4.plot(snr[_ok], d[_ok], 'k.', markersize=4) #stateplots.beta_comp(snr[_ok], d[_ok], n1='SNR',n2='dep - indep', # ax=ax4, highlight=mfull[_ok,-1], hist_range=[-0.1, 1], markersize=4) slope, intercept, r, p, std_err = st.linregress(snr[_ok], d[_ok]) x = get_bootstrapped_sample({s: snr[(dr.site==s).values & _ok] for s in dr.site.unique()}, {s: d[(dr.site==s).values & _ok] for s in dr.site.unique()}, metric='corrcoef', nboot=10000) pboot, _ = get_direct_prob(x, np.zeros(x.shape[0])) mm = np.array([np.min(snr[_ok]), np.max(snr[_ok])]) ax4.plot(mm,intercept+slope*mm,'k--', linewidth=0.5) ax4.set_xlabel('log(SNR)') ax4.set_ylabel(ylabel) ax4.set_title('n={} cc={:.3} p={:.4}, pboot={:.5f}'.format(len(d),r,p, 1-pboot),fontsize=7) nplt.ax_remove_box(ax4) f.tight_layout() return f
def aud_vs_state(df, nb=5, title=None, state_list=None, colors=['r', 'g', 'b', 'k']): """ d = dataframe output by get_model_results_per_state_model() nb = number of bins """ if state_list is None: state_list = [ 'st.pup0.beh0', 'st.pup0.beh', 'st.pup.beh0', 'st.pup.beh' ] f = plt.figure(figsize=(4, 6)) da = df[df['state_chan'] == 'active'] dp = da.pivot(index='cellid', columns='state_sig', values=['r', 'r_se']) dr = dp['r'].copy() if len(state_list) == 4: dr['b_unique'] = dr[state_list[3]]**2 - dr[state_list[2]]**2 dr['p_unique'] = dr[state_list[3]]**2 - dr[state_list[1]]**2 dr['bp_common'] = dr[state_list[3]]**2 - dr[ state_list[0]]**2 - dr['b_unique'] - dr['p_unique'] dr['bp_full'] = dr['b_unique'] + dr['p_unique'] + dr['bp_common'] dr['null'] = dr[state_list[0]]**2 * np.sign(dr[state_list[0]]) dr['full'] = dr[state_list[3]]**2 * np.sign(dr[state_list[3]]) dr['sig']=((dp['r'][state_list[3]]-dp['r'][state_list[0]]) > \ (dp['r_se'][state_list[3]]+ dp['r_se'][state_list[0]])) #dm = dr.loc[dr['sig'].values,['null','full','bp_common','p_unique','b_unique']] dm = dr.loc[:, [ 'null', 'full', 'bp_common', 'b_unique', 'p_unique', 'sig' ]] dm = dm.sort_values(['null']) mfull = dm[[ 'null', 'full', 'bp_common', 'b_unique', 'p_unique', 'sig' ]].values elif len(state_list) == 2: dr['bp_common'] = dr[state_list[1]]**2 - dr[state_list[0]]**2 dr['b_unique'] = dr['bp_common'] * 0 dr['p_unique'] = dr['bp_common'] * 0 dr['bp_full'] = dr['b_unique'] + dr['p_unique'] + dr['bp_common'] dr['null'] = dr[state_list[0]]**2 * np.sign(dr[state_list[0]]) dr['full'] = dr[state_list[1]]**2 * np.sign(dr[state_list[1]]) dr['sig']=((dp['r'][state_list[1]]-dp['r'][state_list[0]]) > \ (dp['r_se'][state_list[1]]+ dp['r_se'][state_list[0]])) dr['cellid'] = dp['r'][state_list[1]].index #dm = dr.loc[dr['sig'].values,['null','full','bp_common','p_unique','b_unique']] dm = dr.loc[:, [ 'cellid', 'null', 'full', 'bp_common', 'b_unique', 'p_unique', 'sig' ]] dm = dm.sort_values(['null']) mfull = dm[[ 'null', 'full', 'bp_common', 'b_unique', 'p_unique', 'sig' ]].values cellids = dm['cellid'].to_list() big_idx = mfull[:, 1] - mfull[:, 0] > 0.2 for i, b in enumerate(big_idx): if b: print('{} : {:.3f} - {:.3f}'.format(cellids[i], mfull[i, 0], mfull[i, 1])) if nb > 0: stepsize = mfull.shape[0] / nb mm = np.zeros((nb, mfull.shape[1])) for i in range(nb): #x0=int(np.floor(i*stepsize)) #x1=int(np.floor((i+1)*stepsize)) #mm[i,:]=np.mean(m[x0:x1,:],axis=0) x01 = (mfull[:, 0] > i / nb) & (mfull[:, 0] <= (i + 1) / nb) if np.sum(x01): mm[i, :] = np.nanmean(mfull[x01, :], axis=0) print(np.round(mm, 3)) m = mm.copy() else: # alt to look at each cell individually: m = mfull.copy() mall = np.nanmean(mfull, axis=0, keepdims=True) # remove sensory component, which swamps everything else mall = mall[:, 2:] mb = m[:, 2:] ax1 = plt.subplot(3, 1, 1) stateplots.beta_comp(mfull[:, 0], mfull[:, 1], n1='State independent', n2='Full state-dep', ax=ax1, highlight=dm['sig'], hist_range=[-0.1, 1]) ax2 = plt.subplot(3, 1, 2) width = 0.8 #ind = m[:,0] mplots = np.concatenate((mall, mb), axis=0) ind = np.arange(mplots.shape[0]) p1 = plt.bar(ind, mplots[:, 0], width=width, color=colors[1]) p2 = plt.bar(ind, mplots[:, 1], width=width, bottom=mplots[:, 0], color=colors[2]) p3 = plt.bar(ind, mplots[:, 2], width=width, bottom=mplots[:, 0] + mplots[:, 1], color=colors[3]) plt.legend(('common', 'b-unique', 'p_unique')) if title is not None: plt.title(title) plt.xlabel('behavior-independent quintile') plt.ylabel('mean r2') ax3 = plt.subplot(3, 1, 3) d = (mfull[:, 1] - mfull[:, 0]) #/(1-np.abs(mfull[:,0])) stateplots.beta_comp(mfull[:, 0], d, n1='State independent', n2='dep - indep', ax=ax3, highlight=dm['sig'], hist_range=[-0.1, 1], markersize=4) ax3.plot([1, 0], [0, 1], 'k--', linewidth=0.5) r, p = st.pearsonr(mfull[:, 0], d) plt.title('cc={:.3} p={:.4}'.format(r, p)) #ind = np.arange(mb.shape[0]) ##ind = m[:,0] #p1 = plt.plot(ind, mb[:,0]) #p2 = plt.plot(ind, mb[:,1]+mb[:,0]) #p3 = plt.plot(ind, mb[:,2]+mb[:,0]+mb[:,1]) #plt.legend(('common','p_unique','b-unique')) #plt.xlabel('behavior-independent quintile') #plt.ylabel('mean r2') plt.tight_layout() return f
state_mod = np.stack(state_mod) r_test = np.stack(r_test) se_test = np.stack(se_test) u_state_mod = state_mod[[sv_len],:] - state_mod[:sv_len, :] u_r_test = r_test[[sv_len],:] - r_test[:sv_len,:] u_r_good = u_r_test > se_test[:sv_len,:] plt.close('all') plt.figure() for i,p in enumerate(sv_pairs): ax = plt.subplot(len(sv_pairs),3,i*3+1) stateplots.beta_comp(u_r_test[p[0],:], u_r_test[p[1],:], n1=statevars[p[0]], n2=statevars[p[1]], title='u r_test', hist_range=[-0.05, 0.15], ax=ax, highlight=(u_r_good[p[0],:] | u_r_good[p[1],:])) ax = plt.subplot(len(sv_pairs),3,i*3+2) stateplots.beta_comp(u_state_mod[p[0],:,p[0]+1], u_state_mod[p[1],:,p[1]+1], n1=statevars[p[0]], n2=statevars[p[1]], title='u state_mod', hist_range=[-0.6, 0.6], ax=ax, highlight=(u_r_good[p[0],:] | u_r_good[p[1],:])) # plt.plot(tvr[np.logical_not(u_r_good[p[0],:])], # state_mod[-1,np.logical_not(u_r_good[p[0],:]),p[0]+1], # '.', color='LightGray') # plt.plot(tvr[u_r_good[p[0],:]], state_mod[-1,u_r_good[p[0],:],p[0]+1], 'k.') # plt.xlabel('tvr') # plt.ylabel("raw state_mod " + statevars[p[0]]) ax = plt.subplot(len(sv_pairs),3,i*3+3)