def plt_traj(hdf, save=False): rad = hdf.root.task.attrs.target_radius cmap_list = ['maroon', 'orangered', 'darkgoldenrod', 'olivedrab', 'teal', 'steelblue', 'midnightblue', 'darkmagenta', 'darkgray'] f, ax =plt.subplots() tg = hdf.root.task[:]['target'] tg_ix = pa.get_target_ix(tg[:, [0, 2]]) for it, t in enumerate(np.unique(tg_ix)): ix = np.nonzero(tg_ix==t)[0][0] tmp_circ = plt.Circle(tg[ix,[0,2]], rad, color=cmap_list[it], alpha=.5) ax.add_artist(tmp_circ) curs = hdf.root.task[:]['cursor'] epoch_ix = np.array([ [hdf.root.task_msgs[j-3]['time'], i['time']] for j, i in enumerate(hdf.root.task_msgs[:]) if i['msg']=='reward']) for i, (g, r) in enumerate(epoch_ix): ax.plot(curs[g:r,0], curs[g:r,2], color='lightgrey') ax.set_ylim([-14, 14]) ax.set_xlim([-14, 14]) if save: f.savefig(hdf.filename[:-4]+'_traj.png', format='png')
def jerk_squared(R, input1, input2): hdf = R.hdf #Go indices go_ix = np.array([hdf.root.task_msgs[it-3][1] for it, t in enumerate(hdf.root.task_msgs[:]) if t[0] == 'reward']) sub_ix = R.update_bmi_ix go_ix = np.hstack((go_ix, len(R.cursor_vel) )) jerk_sq = dict(priv=np.zeros((len(go_ix)-1, )), shar=np.zeros((len(go_ix)-1, )), norm=np.zeros((len(go_ix)-1, ))) jerk_sq_traj = dict(priv=dict(), shar=dict(), norm=dict()) targ_ix = pa.get_target_ix(R.target[:, [0,2]]) for ig, go in enumerate(go_ix[:-1]): if go < len(R.cursor_vel): ix_corr = set(R.update_bmi_ix) ix_ = np.sort(np.array(list(ix_corr.intersection(np.arange(go,go_ix[ig+1]))))) priv_traj = np.squeeze(R.decoded_vel[input1][ix_,:]).copy() shar_traj = np.squeeze(R.decoded_vel[input2][ix_,:]).copy() norm_traj = np.squeeze(R.cursor_vel[ix_,:]).copy() jerk_sq['priv'][ig], jerk_sq_traj['priv'][ig] = _get_jk_squ(priv_traj) jerk_sq['shar'][ig], jerk_sq_traj['shar'][ig] = _get_jk_squ(shar_traj) jerk_sq['norm'][ig], jerk_sq_traj['norm'][ig] = _get_jk_squ(norm_traj) return jerk_sq, jerk_sq_traj, targ_ix[go_ix[:-1]]
def get_epoch(hdf_list, hold_min=0., set_window=None, epoch='hold'): #Set window and hold_min are in seconds binned_counts = [] trl_ix = [] trl_ix_start = 0 t2t_arr = [] target_info = [] cursor = [] for hdf_nm in hdf_list: hdf = tables.openFile(os.path.expandvars('$FA_GROM_DATA/'+hdf_nm)) spike_counts = hdf.root.task[:]['spike_counts'] curs_vel = hdf.root.task[:]['internal_decoder_state'][:, [3, 5], 0] if epoch=='hold': start_tab = 2 end_tab = 1 elif epoch=='reach': start_tab = 3 end_tab = 2 start_hold = np.array([hdf.root.task_msgs[i-start_tab]['time'] for i, j in enumerate(hdf.root.task_msgs[:]) if j['msg'] == 'reward']) act_end_hold = np.array([hdf.root.task_msgs[i-end_tab]['time'] for i, j in enumerate(hdf.root.task_msgs[:]) if j['msg'] == 'reward']) if set_window is None: end_hold = act_end_hold.copy() else: end_hold = start_hold + (set_window*60) t2t = (act_end_hold - start_hold)/60. t2t_ix = t2t>=hold_min t2t_arr.append(t2t[t2t_ix]) start_hold = start_hold[t2t_ix] end_hold = end_hold[t2t_ix] internal_state = hdf.root.task[:]['internal_decoder_state'] update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, 3, 0])))[0]+1 for i, (s, e) in enumerate(zip(start_hold, end_hold)): sc = hdf.root.task[s:e]['spike_counts'] c = [] s_e = np.nonzero(update_bmi_ix>=s)[0] s_e = update_bmi_ix[s_e[0]] trl = [] for j, bin_edge in enumerate(np.arange(s_e, e, 6)): trl.append(np.sum(spike_counts[bin_edge:bin_edge+6, :, 0], axis=0)) trl_ix.append(i+trl_ix_start) c.append(curs_vel[bin_edge, :]) target_info.append(hdf.root.task[s+10]['target'][[0, 2]]) binned_counts.append(np.vstack((trl))) cursor.append(np.vstack((c))) trl_ix_start += i + 1 target_index = pa.get_target_ix(np.vstack(target_info)) return np.array(trl_ix), np.vstack((binned_counts)), np.hstack((t2t_arr)), np.vstack((cursor)), target_index
def perf_mets(hdf, save=False, hdf_fname=None): epoch_ix = np.array([ [hdf.root.task_msgs[j-3]['time'], i['time']] for j, i in enumerate(hdf.root.task_msgs[:]) if i['msg']=='reward']) #Get targ ix: tg = hdf.root.task[:]['target'] targ_ix = pa.get_target_ix(tg[:, [0, 2]]) targ_ix = targ_ix[epoch_ix[:,0]+3] #Get time2targ: time2targ = (epoch_ix[:,1] - epoch_ix[:,0])/60. cursor = hdf.root.task[:]['cursor'][:, [0, 2]] metrics = dict() metrics['path_length'] = [] metrics['path_error'] = [] metrics['avg_speed'] = [] metrics['targ_ix'] = [] metrics['time2targ'] = [] #Path error: for i, (g, r) in enumerate(epoch_ix): path = cursor[g:r, :] targ_loc = tg[g+5,[0, 2]] path_length, path_error, avg_speed = bha.path_metrics(path, targ_loc) metrics['path_length'].append(path_length) metrics['path_error'].append(path_error) metrics['avg_speed'].append(avg_speed) metrics['targ_ix'].append(targ_ix[i]) metrics['time2targ'].append(time2targ[i]) #Add output SOT to metrics: if hdf_fname is None: print 'Need hdf_fname for sot calc' else: metrics['output_sot'] = get_output_sot(hdf, hdf_fname) if save: sio.savemat(hdf.filename[:-4]+'_metrics.mat', metrics)
def decompose_inputs(fa_dict, bin_spk_i, hdf, dec, task='bmi_resetting', process_to_ix=None, use_main = True): #Main shared: T = bin_spk_i.shape[0] mn = np.tile(fa_dict['fa_mu'], [1, T]) dmn = bin_spk_i.T - mn ReSim = trbt.RerunDecoding(hdf, dec, task='bmi_resetting') if use_main: shar = (fa_dict['fa_main_shared'] * dmn) priv = (dmn - shar) ReSim.main= True else: shar = (fa_dict['fa_sharL'] * dmn) priv = (dmn - shar) ReSim.main = False main_shar_spks = sdf.rebin_spks(shar + mn) main_priv_spks = sdf.rebin_spks(priv + mn) if process_to_ix is None: proc_ix = main_shar_spks.shape[0] else: proc_ix = process_to_ix ReSim.add_input(ReSim.spike_counts[:proc_ix, :, :], 'all') ReSim.add_input(main_shar_spks[:proc_ix, :, :], 'main_shar') ReSim.add_input(main_priv_spks[:proc_ix, :, :], 'main_priv') #Now plot trials: gr_ix = np.array([(hdf.root.task_msgs[j-3]['time'], i['time']) for j, i in enumerate(hdf.root.task_msgs) if i['msg'] == 'reward']) targ_pos = hdf.root.task[gr_ix[:50,1]]['target'] targ_ix = pa.get_target_ix(targ_pos[:,[0, 2]]) ReSim.target_ix = targ_ix ReSim.targ_pos_trunc = targ_pos return ReSim, gr_ix, proc_ix
def psth(Recoder,bins_after_go=60*4, input_type='all'): hdf = Recoder.hdf #Go indices go_ix = np.array([hdf.root.task_msgs[it-3][1] for it, t in enumerate(hdf.root.task_msgs[:]) if t[0] == 'reward']) go_ix = np.hstack((go_ix, len(Recoder.decoded_pos[input_type]) )) spks = sdf.rebin_spks(Recoder.dec_spk_cnt_bin[input_type]) targ_ix = pa.get_target_ix(Recoder.target[:, [0,2]]) psth = dict() for ig, go in enumerate(go_ix[:-1]): ti = targ_ix[go] if go+bins_after_go < len(Recoder.decoded_pos[input_type]): sp = np.squeeze(spks[go:go+bins_after_go, : , 0]) if int(ti) in psth.keys(): psth[int(ti)] = np.dstack((psth[int(ti)], sp)) else: psth[int(ti)] = sp return psth
def process_targets(te_list, new_hdf_name): ''' Summary: Yields an array of trial information in pytables with columns seen above in Trial_Metrics table class Input param: hdf: hdf for either fa_bmi or bmi_resetting Output param: trial dictionary ''' h5file = tables.openFile(new_hdf_name, mode="w", title='FA Trial Analysis') trial_mets_table = h5file.createTable("/", 'trial_metrics', Trial_Metrics, "Trial Mets") meta_table = h5file.createTable("/", 'meta_metrics', Meta_Metrics, "Meta Mets") row_cnt = -1 for te in te_list: te_obj = dbfn.TaskEntry(te) task_entry = te hdf = te_obj.hdf #Extract trials ignoring assist: rew_ix, rew_per_min = pa.get_trials_per_min(hdf, nmin=2, rew_per_min_cutoff=0, ignore_assist=True, return_rpm=True) #Go backwards 3 steps (rew --> targ trans --> hold --> target (onset)) go_ix = np.array([hdf.root.task_msgs[it-3][1] for it, t in enumerate(hdf.root.task_msgs[:]) if scipy.logical_and(t[0] == 'reward', t[1] in rew_ix)]) assert len(rew_ix) == len(go_ix)#: raise Exception("Rew ix and Go ix are unequal") #Time to target is diff b/w target onset and reward time time2targs = (rew_ix - go_ix)/60. #Add buffer of 5 task steps (80 ms) target_locs = hdf.root.task[go_ix.astype(int) + 5]['target'][:, [0, 2]] try: obstacle_sz = hdf.root.task[go_ix.astype(int) + 5]['obstacle_size'][:, 0] except: obstacle_sz = np.zeros_like(go_ix) #Targ ix -- should be 8 targ_ixs = pa.get_target_ix(target_locs) assert len(np.unique(targ_ixs)) == 8#: raise Exception("Target Ix has more or less than 8 targets :/ ") #Input type: try: input_types = hdf.root.task[go_ix.astype(int)+5]['fa_input'] except: input_types = ['all']*len(go_ix) rew_time_2_trial = {} for ri, r_ix in enumerate(rew_ix): trl = trial_mets_table.row row_cnt += 1 trl['trial_number'] = ri trl['task_entry'] = task_entry trl['rew_per_min'] = rew_per_min[ri] trl['target_index'] = targ_ixs[ri] trl['target_loc'] = target_locs[ri,:] trl['start_time'] = go_ix[ri]/60. # In seconds trl['input_type'] = input_types[ri] trl['time2targ'] = time2targs[ri] rew_time_2_trial[r_ix] = time2targs[ri] trl['obstacle_size'] = obstacle_sz[ri] path = hdf.root.task[int(go_ix[ri]):int(r_ix)]['cursor'][:, [0, 2]] path_length, path_error, avg_speed = path_metrics(path, target_locs[ri, :]) trl['path_length'] = path_length trl['path_error'] = path_error trl['avg_speed'] = avg_speed trl['timeout_time'] = hdf.root.task.attrs.timeout_time trl['timeout_penalty_time'] = hdf.root.task.attrs.timeout_penalty_time trl.append() trial_mets_table.flush() #Meta metrics for hdf: block_len = hdf.root.task.shape[0] wind_len = 5*60*60 wind_step = 2.5*60*60 meta_ix = np.arange(0, block_len - wind_len, wind_step) trial_ix = np.array([i for i in hdf.root.task_msgs[:] if i['msg'] in ['reward','timeout_penalty','hold_penalty','obstacle_penalty'] ], dtype=hdf.root.task_msgs.dtype) for m in meta_ix: meta_trl = meta_table.row meta_trl['task_entry'] = te end_meta_ix = m + wind_len trial_ix_time = trial_ix[:,] msg_ix = np.nonzero(np.logical_and(trial_ix['time']<=end_meta_ix, trial_ix['time']>m))[0] targ = hdf.root.task[trial_ix[msg_ix]['time'].astype(int)]['target'][:,[0,2]] targ_ix = pa.get_target_ix(targ) #Only mark as correct if it's the first correct trial targ_ix_ix = [0] tg_prev = targ_ix[0] for ii, tg_ix in enumerate(targ_ix[1:]): if tg_prev != tg_ix: targ_ix_ix.append(ii+1) tg_prev = tg_ix targ_ix_mod = targ_ix[targ_ix_ix] msg_ix_mod = msg_ix[targ_ix_ix] targ_percent_success = np.zeros((8, )) - 1 all_perc_succ_lte10sec = np.zeros((8, )) - 1 #time2targs for rew_ix: rew_ix_meta = np.array([i['time'] for i in trial_ix[msg_ix] if i['msg']=='reward']) for t in range(8): t_ix = np.nonzero(targ_ix_mod==t)[0] msgs = trial_ix[msg_ix_mod[t_ix]] if len(msgs) > 0: targ_percent_success[t] = len(np.nonzero(msgs['msg']=='reward')[0]) / float(len(msgs)) msgs_copy = msgs.copy() for mc in msgs_copy: if mc['msg'] == 'reward': r_tm = mc['time'] if r_tm in rew_time_2_trial.keys(): if rew_time_2_trial[r_tm] > 10.: msgs_copy['msg'] = 'not_reward' print 'not reward: ', m, t, mc else: print 'asissted: ', m, t, mc msgs_copy['msg'] = 'not_rewarded' all_perc_succ_lte10sec[t] = len(np.nonzero(msgs_copy['msg'] == 'reward')[0]) / float(len(msgs_copy)) if all_perc_succ_lte10sec[t] < 0: print 'error: ', m, t, mc else: print 'no msgs for this epoch: ', m meta_trl['targ_percent_success'] = targ_percent_success meta_trl['all_perc_succ_lte10sec'] = all_perc_succ_lte10sec all_msg = trial_ix[msg_ix_mod] meta_trl['all_percent_success'] = len(np.nonzero(all_msg['msg']=='reward')[0]) / float(len(all_msg)) meta_trl.append() meta_table.flush() #How many 5 min segments in 2.5 min steps h5file.close() return new_hdf_name
def shared_vs_priv_vel_activ(Recoder, input_types=['private', 'shared']): tiny = 1e-12 hdf = Recoder.hdf #Go indices go_ix = np.array([hdf.root.task_msgs[it-3][1] for it, t in enumerate(hdf.root.task_msgs[:]) if t[0] == 'reward']) #Rew indices rew_ix = np.array([t[1] for it, t in enumerate(hdf.root.task_msgs[:]) if t[0]=='reward']) go_ix = np.hstack((go_ix, len(Recoder.decoded_pos) )) targ_ix = pa.get_target_ix(Recoder.target[go_ix+5, [0,2]]) sot_start = dict() sot_end = dict() f, ax_start = plt.subplots(nrows=4, ncols=2) f, ax_end = plt.subplots(nrows=4, ncols=2) tg_key = dict() for i_t, t in enumerate(np.unique(targ_ix[go_ix])): sot_start[str(int(t))] = [] sot_end[str(int(t))] = [] tg_key[t] = i_t for ig, go in enumerate(go_ix[:-1]): if go < len(R_priv.decoded_pos): #Speed calc: traj_priv = np.sqrt(np.sum((R_priv.decoded_vel[go:go_ix[ig+1], [0, 2]].copy())**2, axis=1)) traj_shar = np.sqrt(np.sum((R_shar.decoded_vel[go:go_ix[ig+1], [0, 2]].copy())**2, axis=1)) tg_ix = targ_ix[go] tmp = traj_shar / (tiny + traj_shar + traj_priv) sot_start[str(int(tg_ix))].append(tmp[:120]) sot_end[str(int(tg_ix))].append(tmp[-120:]) for i_t, t in enumerate(np.unique(targ_ix[go_ix[:-1]])): ky = str(int(t)) axi = ax_start[i_t%4, i_t/4] axj = ax_end[i_t%4, i_t/4] tmp_start = np.vstack(( sot_start[ky] )) tmp_end = np.vstack(( sot_end[ky] )) x = np.arange(120)/60. sem_start = np.std(tmp_start, axis=0)/np.sqrt(tmp_start.shape[0]) sem_end = np.std(tmp_end, axis=0)/np.sqrt(tmp_end.shape[0]) axi.plot(x, np.mean(tmp_start, axis=0), 'b-') axi.fill_between(x, np.mean(tmp_start, axis=0) - sem_start, np.mean(tmp_start, axis=0) + sem_start, alpha=.5, color='b') axj.plot(x, np.mean(tmp_end, axis=0), 'r-') axj.fill_between(x, np.mean(tmp_end, axis=0) - sem_end, np.mean(tmp_end, axis=0) + sem_end, alpha=.5, color='r') axi.set_ylim([.5, .75]) axj.set_ylim([.5, .75])
def PSTH(fa_hdf,task_tbl_met = 'spike_counts', plot_by_neur=False, save_f=False): epoch_ix = np.array([ [fa_hdf.root.task_msgs[j-3]['time'], i['time']] for j, i in enumerate(fa_hdf.root.task_msgs[:]) if i['msg']=='reward']) cursor_pos = fa_hdf.root.task[:]['cursor'][:, [0, 2]] try: spike_counts = fa_hdf.root.task[:][task_tbl_met][:,:,0] #Time x units except: spike_counts = fa_hdf.root.task[:]['spike_counts'][:,:,0] print 'USING SPIKE COUNTS from HDF FILE' st = np.nonzero(np.sum(spike_counts, axis=1))[0][0] st = st % 6 update_ix = np.arange(st, spike_counts.shape[0], 6) ixx = [] trl = [] tg = [] for i, (g,r) in enumerate(epoch_ix): u_ix = np.nonzero(np.logical_and(update_ix<=r, update_ix>g))[0] ixx.append(update_ix[u_ix]) if u_ix[0] >=4: trl.append(spike_counts[update_ix[u_ix[0]-4:u_ix[0]+12], :]) tg.append(fa_hdf.root.task[update_ix[u_ix[0]+3]]['target'][[0, 2]]) IX = np.hstack((ixx)) if len(trl) > 0: TRL = np.dstack((trl)) TG_IX = pa.get_target_ix(np.vstack((tg))) x_ax = np.arange(-.4, 1.2, .1) n_neurons = TRL.shape[1] cnt = -1 if plot_by_neur: u_col = get_cmap(len(np.unique(TG_IX))) for n in range(n_neurons): print n if not (n % 9): try: plt.tight_layout() if save_f: f.savefig(fa_hdf.filename[:-4]+'_PSTH_by_neur_'+str(cnt)+'.png', format='png') except: pass print n f, ax = plt.subplots(nrows=3, ncols=3) cnt += 1 axi = ax[(n%9)/3, (n%9)%3] if not ((n%9)%3): axi.set_ylabel('FR') if (n%9)/3 == 2: axi.set_xlabel('Sec w.r.t. Target Onset') for ii, it in enumerate(np.unique(TG_IX)): ix = np.nonzero(TG_IX==it)[0] fr = TRL[:, n, ix].T mn = np.mean(fr, axis=0) axi.plot(x_ax, mn, color=u_col[ii]) axi.set_title('Unit '+str(n)) else: u_col = get_cmap(n_neurons) f, ax = plt.subplots(nrows = 3, ncols = 3) for ii, it in enumerate(np.unique(TG_IX)): axix = bha.targ_ix_to_3x3_subplot(it) axi = ax[axix[0], axix[1]] ix = np.nonzero(TG_IX==it)[0] fr = TRL[:, :, ix].T #Trials x units x time for u in range(n_neurons): mn = np.mean(fr[:, u, :], axis=0) axi.plot(x_ax, mn, color=u_col[ii])
def get_files(te_num=None, get_psth=False): if te_num is None or te_num == 4048: #Good centerout: te = 4048 hdf = tables.openFile(os.path.expandvars('$FA_GROM_DATA/grom20160201_04_te4048.hdf')) dec = pickle.load(open(os.path.expandvars('$FA_GROM_DATA/grom20160201_01_RMLC02011515.pkl'))) elif te_num == 4526: #Learning obstacles: hdf = tables.openFile(os.path.expandvars('$FA_GROM_DATA/grom20160317_11_te4526.hdf')) dec = pickle.load(open(os.path.expandvars('$FA_GROM_DATA/grom20160317_03_test03171536.pkl'))) elif te_num == 2991: #Suraj's BMI holding hdf = tables.openFile(os.path.expandvars('$FA_GROM_DATA/grom20150414_04.hdf')) dec = pickle.load(open(os.path.expandvars('$FA_GROM_DATA/grom20150414_02_RMLC04141528.pkl'))) elif te_num == 4549: #Better obstacles hdf = tables.openFile(os.path.expandvars('$FA_GROM_DATA/grom20160319_11_te4549.hdf')) dec = pickle.load(open(os.path.expandvars('$FA_GROM_DATA/grom20160319_09_test03191611.pkl'))) else: from db import dbfunctions te = dbfunctions.TaskEntry(te_num) hdf = te.hdf dec = te.decoder #Use reward to calc fa_dict: fa_dict = factor_analysis_tasks.FactorBMIBase.generate_FA_matrices(None, hdf=hdf, dec=dec) #Calculate spikes to transform: drives_neurons_ix0 = 3 internal_state = hdf.root.task[:]['internal_decoder_state'] update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1 spike_i = hdf.root.task[:]['spike_counts'][:,:,0] bin_spk_i = np.zeros((len(update_bmi_ix), spike_i.shape[1])) if get_psth: n_units = dec.n_units go_ix = np.array([hdf.root.task_msgs[i-3]['time'] for i, j in enumerate(hdf.root.task_msgs) if j['msg'] == 'reward']) rew_ix = np.array([j['time'] for i, j in enumerate(hdf.root.task_msgs) if j['msg'] == 'reward']) targ = hdf.root.task[go_ix+10]['target'][:,[0, 2]] targ_ix = pa.get_target_ix(targ) d = {} psth = {} psth_max = {} psth_shar = {} psth_priv = {} for t in np.unique(targ_ix): t_ix = np.nonzero(targ_ix==t)[0] min_len = 10e10 max_len = 0 for i_t, trl in enumerate(t_ix): bin = [] g = go_ix[trl] r = rew_ix[trl] start_bin = np.nonzero(update_bmi_ix>=g)[0][0] end_bin = np.nonzero(update_bmi_ix<=r)[0][-1] for ib, i_ix in enumerate(update_bmi_ix[start_bin:end_bin+1]): bin.append(np.sum(spike_i[i_ix-5:i_ix+1, :], axis=0)) min_len = np.min([min_len, len(bin)]) max_len = np.max([max_len, len(bin)]) if i_t == 0: d[t] = [np.vstack(bin)] else: d[t].append(np.vstack(bin)) tmp = [] tmp_max = [] for b in d[t]: #Share b: tmp.append(b[:min_len, :]) mx = np.zeros((max_len, n_units)) mx[:,:] = np.nan mx[:b.shape[0], :] = b tmp_max.append(mx) #Bins x neurons x trials spk = np.dstack((tmp)) dmn = spk - np.tile(fa_dict['fa_mu'][np.newaxis, :, :], [spk.shape[0], 1, spk.shape[2]]) dmn_resh = dmn.reshape(dmn.shape[1], dmn.shape[0]*dmn.shape[2]) shar_resh = np.array((fa_dict['fa_main_shared'] * dmn_resh)) shar = shar_resh.reshape(dmn.shape[0], dmn.shape[1], dmn.shape[2]) priv = dmn - shar spk_max = np.dstack((tmp_max)) psth[t] = np.mean(spk, axis=2)[:, :, np.newaxis] psth_max[t] = np.nanmean(np.dstack((tmp_max)), axis=2)[:, :, np.newaxis] psth_shar[t] = np.mean(shar, axis=2)[:, :, np.newaxis] psth_priv[t] = np.mean(priv, axis=2)[:, :, np.newaxis] return psth, psth_max, psth_shar, psth_priv, hdf, dec, targ, targ_ix else: #binned spikes: for ib, i_ix in enumerate(update_bmi_ix): bin_spk_i[ib, :]= np.sum(spike_i[i_ix-5:i_ix+1, :], axis=0) return fa_dict, bin_spk_i, hdf, dec
def plot_traj(R, plot_pos=1, plot_vel=1, plot_force=0, it_cutoff=20000, min_it_cutoff=0, input_type='all'): rew_ix = pa.get_trials_per_min(R.hdf,nmin=2) go_ix = np.array([R.hdf.root.task_msgs[it-3][1] for it, t in enumerate(R.hdf.root.task_msgs[:]) if t[0] == 'reward']) #Make sure no assist is used try: zero_assist_start = np.nonzero(R.hdf.root.task[:]['assist_level']==0)[0][0] except: print 'ignoring assist' zero_assist_start = 0 keep_ix = scipy.logical_and(go_ix>np.max([min_it_cutoff, zero_assist_start+(60*60)]), go_ix<it_cutoff) go_ix = go_ix[keep_ix] rew_ix = rew_ix[keep_ix] targ_pos = R.hdf.root.task[go_ix.astype(int)+5]['target'] targ_ix = pa.get_target_ix(targ_pos[:,[0, 2]]) if plot_pos == 1: f, ax = plt.subplots(nrows=4, ncols=2) f2, ax2 = plt.subplots(nrows=4, ncols=2) f0, ax0 = plt.subplots() if plot_vel == 1: f3, ax3 = plt.subplots(nrows=4, ncols=2) f4, ax4 = plt.subplots(nrows=4, ncols=2) if plot_force ==1: f5, ax5 = plt.subplots(nrows=4, ncols=2) f6, ax6 = plt.subplots(nrows=4, ncols=2) #R = pickle.load(open(pkl_name)) for i, (g,r) in enumerate(zip(go_ix, rew_ix)): #Choose axis: targ = targ_ix[i] if plot_pos == 1: axi = ax[targ%4, targ/4] axi2 = ax2[targ%4, targ/4] axi.plot(R.cursor_pos[g:r, 0], 'k-') axi2.plot(R.cursor_pos[g:r, 2], 'k-') axi.plot(R.decoded_pos[input_type][g:r, 0], 'b-') axi2.plot(R.decoded_pos[input_type][g:r, 2], 'b-') ax0.plot(R.cursor_pos[g:r,0], R.cursor_pos[g:r, 2], '-', color='k') ax0.plot(R.decoded_pos[input_type][g:r,0], R.decoded_pos[input_type][g:r, 2], '--', color=cmap_list[int(targ)]) if i==0: axi.set_title('Position X') axi2.set_title('Position Z') if plot_vel == 1: axi3 = ax3[targ%4, targ/4] axi4 = ax4[targ%4, targ/4] axi3.plot(R.cursor_vel[g:r, 0], 'k-') axi4.plot(R.cursor_vel[g:r, 2], 'k-') axi3.plot(R.decoded_vel[input_type][g:r, 0], 'b-') axi4.plot(R.decoded_vel[input_type][g:r, 2], 'b-') if i==0: axi3.set_title('Vel X') axi4.set_title('Vel Z') if plot_force ==1: axi5 = ax5[targ%4, targ/4] axi6 = ax6[targ%4, targ/4] axi5.plot(R.dec_state_mn[input_type][g-1:r-1, 9], 'b-') axi6.plot(R.dec_state_mn[input_type][g-1:r-1, 11], 'b-') axi5.plot(hdf.root.task[g:r]['internal_decoder_state'][:,9], 'k-') axi6.plot(hdf.root.task[g:r]['internal_decoder_state'][:,11], 'k-') if i==0: axi5.set_title('Acc X') axi6.set_title('Acc Z') plt.tight_layout()
def traj_sem_plot(R, input_type_root, fa_by_targ=False, it_start=0, it_cutoff=20000, ax=None, rm_assist=True, extra_title_text=''): rew_ix = pa.get_trials_per_min(R.hdf,nmin=2) go_ix = np.array([R.hdf.root.task_msgs[it-3][1] for it, t in enumerate(R.hdf.root.task_msgs[:]) if t[0] == 'reward']) if rm_assist: zero_assist_start = np.nonzero(R.hdf.root.task[:]['assist_level']==0)[0] if len(zero_assist_start) > 0: zero_assist_start = zero_assist_start[0] else: Exception('Assist is never == 0!') else: zero_assist_start = 0 start_ix = np.max([zero_assist_start, it_start]) keep_ix = scipy.logical_and(go_ix>start_ix, go_ix<it_cutoff) go_ix = go_ix[keep_ix] rew_ix = rew_ix[keep_ix] t2targ = rew_ix - go_ix targ_pos = R.hdf.root.task[go_ix.astype(int)+5]['target'] targ_ix = pa.get_target_ix(targ_pos[:,[0, 2]]) #For each target, make array of trials x time (min length for reward trial) trial_arr = dict() if ax is None: f, ax = plt.subplots() #For each target, concat trials: for it, t in enumerate(np.unique(targ_ix)): if fa_by_targ: input_type = input_type_root + '_'+str(t) else: input_type = input_type_root print input_type, R.decoded_pos[input_type].shape t_ix = np.nonzero(targ_ix==t)[0] trial_arr[str(t)] = np.zeros((len(t_ix), np.max(t2targ[t_ix]), 2)) trial_arr[str(t)][:,:] = np.nan cix = 0 for i, ti in enumerate(t_ix): if rew_ix[ti]<= R.decoded_pos[input_type].shape[0]: cix = go_ix[ti].copy() trial_arr[str(t)][i,:t2targ[ti], :] = R.decoded_pos[input_type][go_ix[ti]:rew_ix[ti],[0,2]] ax.plot(trial_arr[str(t)][i,:,0], trial_arr[str(t)][i,:,1], '-', color='lightgrey') #Plot Mean, sem: mean = np.nanmean(trial_arr[str(t)], axis=0) sem = np.nanstd(trial_arr[str(t)], axis=0)/np.sqrt(trial_arr[str(t)].shape[0]) ax.plot(mean[:,0], mean[:,1], '-',color=cmap_list[it]) ax.plot(mean[:,0] - sem[:,0], mean[:,1] - sem[:,1], '--',color=cmap_list[it]) ax.plot(mean[:,0] + sem[:,0], mean[:,1] + sem[:,1], '--',color=cmap_list[it]) tmp_circ = plt.Circle(targ_pos[t_ix[0],[0,2]], R.target_rad, color=cmap_list[it], alpha=.5) ax.add_artist(tmp_circ) ax.set_title(input_type_root + extra_title_text) ax.set_xlim([-14, 14]) ax.set_ylim([-12, 12]) return ax