def plt_traj(hdf, save=False):
    rad = hdf.root.task.attrs.target_radius

    cmap_list = ['maroon', 'orangered', 'darkgoldenrod', 'olivedrab', 'teal', 'steelblue', 'midnightblue', 'darkmagenta', 'darkgray']
    f, ax =plt.subplots()

    tg = hdf.root.task[:]['target']
    tg_ix = pa.get_target_ix(tg[:, [0, 2]])

    for it, t in enumerate(np.unique(tg_ix)):
        ix = np.nonzero(tg_ix==t)[0][0]
        tmp_circ = plt.Circle(tg[ix,[0,2]], rad, color=cmap_list[it], alpha=.5)
        ax.add_artist(tmp_circ)

    curs = hdf.root.task[:]['cursor']
    epoch_ix = np.array([ [hdf.root.task_msgs[j-3]['time'], i['time']]
        for j, i in enumerate(hdf.root.task_msgs[:]) if i['msg']=='reward'])

    for i, (g, r) in enumerate(epoch_ix):
        ax.plot(curs[g:r,0], curs[g:r,2], color='lightgrey')

    ax.set_ylim([-14, 14])
    ax.set_xlim([-14, 14])

    if save:
        f.savefig(hdf.filename[:-4]+'_traj.png', format='png')
def jerk_squared(R, input1, input2):

    hdf = R.hdf

    #Go indices
    go_ix = np.array([hdf.root.task_msgs[it-3][1] 
        for it, t in enumerate(hdf.root.task_msgs[:]) 
        if t[0] == 'reward'])

    sub_ix =  R.update_bmi_ix

    go_ix = np.hstack((go_ix, len(R.cursor_vel) ))
    
    jerk_sq = dict(priv=np.zeros((len(go_ix)-1, )), shar=np.zeros((len(go_ix)-1, )), norm=np.zeros((len(go_ix)-1, )))
    jerk_sq_traj = dict(priv=dict(), shar=dict(), norm=dict())
    targ_ix = pa.get_target_ix(R.target[:, [0,2]])

    for ig, go in enumerate(go_ix[:-1]):
        if go < len(R.cursor_vel):
            

            ix_corr = set(R.update_bmi_ix)
            ix_ = np.sort(np.array(list(ix_corr.intersection(np.arange(go,go_ix[ig+1])))))

            priv_traj = np.squeeze(R.decoded_vel[input1][ix_,:]).copy()
            shar_traj = np.squeeze(R.decoded_vel[input2][ix_,:]).copy()
            norm_traj = np.squeeze(R.cursor_vel[ix_,:]).copy()
            
            jerk_sq['priv'][ig], jerk_sq_traj['priv'][ig] = _get_jk_squ(priv_traj)
            jerk_sq['shar'][ig], jerk_sq_traj['shar'][ig] = _get_jk_squ(shar_traj)
            jerk_sq['norm'][ig], jerk_sq_traj['norm'][ig] = _get_jk_squ(norm_traj)
        
    return jerk_sq, jerk_sq_traj, targ_ix[go_ix[:-1]]
def get_epoch(hdf_list, hold_min=0., set_window=None, epoch='hold'):
    #Set window and hold_min are in seconds

    binned_counts = []
    trl_ix = []
    trl_ix_start = 0
    t2t_arr = []
    target_info = []
    cursor = []
    for hdf_nm in hdf_list:
        hdf = tables.openFile(os.path.expandvars('$FA_GROM_DATA/'+hdf_nm))

        spike_counts = hdf.root.task[:]['spike_counts']
        curs_vel = hdf.root.task[:]['internal_decoder_state'][:, [3, 5], 0]

        if epoch=='hold':
            start_tab = 2
            end_tab = 1
        elif epoch=='reach':
            start_tab = 3
            end_tab = 2

        
        start_hold = np.array([hdf.root.task_msgs[i-start_tab]['time'] for i, j in enumerate(hdf.root.task_msgs[:]) if j['msg'] == 'reward'])
        act_end_hold = np.array([hdf.root.task_msgs[i-end_tab]['time'] for i, j in enumerate(hdf.root.task_msgs[:]) if j['msg'] == 'reward'])

        if set_window is None:
            end_hold = act_end_hold.copy()
        else:
            end_hold = start_hold + (set_window*60)
    
        
        t2t = (act_end_hold - start_hold)/60.
        t2t_ix = t2t>=hold_min
        t2t_arr.append(t2t[t2t_ix])

        start_hold = start_hold[t2t_ix]
        end_hold = end_hold[t2t_ix]

        internal_state = hdf.root.task[:]['internal_decoder_state']
        update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, 3, 0])))[0]+1

        for i, (s, e) in enumerate(zip(start_hold, end_hold)):
            sc = hdf.root.task[s:e]['spike_counts']
            c = []
            s_e = np.nonzero(update_bmi_ix>=s)[0]
            s_e = update_bmi_ix[s_e[0]]
            trl = []
            for j, bin_edge in enumerate(np.arange(s_e, e, 6)):
                trl.append(np.sum(spike_counts[bin_edge:bin_edge+6, :, 0], axis=0))
                trl_ix.append(i+trl_ix_start)
                c.append(curs_vel[bin_edge, :])

            target_info.append(hdf.root.task[s+10]['target'][[0, 2]])
            binned_counts.append(np.vstack((trl)))
            cursor.append(np.vstack((c)))
        trl_ix_start += i + 1
        target_index = pa.get_target_ix(np.vstack(target_info))
    return np.array(trl_ix), np.vstack((binned_counts)), np.hstack((t2t_arr)), np.vstack((cursor)), target_index
def perf_mets(hdf, save=False, hdf_fname=None):

    epoch_ix = np.array([ [hdf.root.task_msgs[j-3]['time'], i['time']]
        for j, i in enumerate(hdf.root.task_msgs[:]) if i['msg']=='reward'])
    
    #Get targ ix:
    tg = hdf.root.task[:]['target']
    targ_ix = pa.get_target_ix(tg[:, [0, 2]])
    targ_ix = targ_ix[epoch_ix[:,0]+3]

    #Get time2targ:
    time2targ = (epoch_ix[:,1] - epoch_ix[:,0])/60.

    cursor = hdf.root.task[:]['cursor'][:, [0, 2]]

    metrics = dict()
    metrics['path_length'] = []
    metrics['path_error'] = []
    metrics['avg_speed'] = []
    metrics['targ_ix'] = []
    metrics['time2targ'] = []

    #Path error: 
    for i, (g, r) in enumerate(epoch_ix):
        path = cursor[g:r, :]
        targ_loc = tg[g+5,[0, 2]]

        path_length, path_error, avg_speed = bha.path_metrics(path, targ_loc)
        metrics['path_length'].append(path_length)
        metrics['path_error'].append(path_error)
        metrics['avg_speed'].append(avg_speed)
        metrics['targ_ix'].append(targ_ix[i])
        metrics['time2targ'].append(time2targ[i])

    #Add output SOT to metrics: 
    if hdf_fname is None: 
        print 'Need hdf_fname for sot calc'
    else:
        metrics['output_sot'] = get_output_sot(hdf, hdf_fname)

    if save:
        sio.savemat(hdf.filename[:-4]+'_metrics.mat', metrics)
def decompose_inputs(fa_dict, bin_spk_i, hdf, dec, task='bmi_resetting', process_to_ix=None, use_main = True):

    #Main shared: 

    T = bin_spk_i.shape[0]
    mn = np.tile(fa_dict['fa_mu'], [1, T])
    dmn = bin_spk_i.T - mn

    ReSim = trbt.RerunDecoding(hdf, dec, task='bmi_resetting')

    if use_main:
        shar = (fa_dict['fa_main_shared'] * dmn) 
        priv = (dmn - shar)
        ReSim.main= True
    else:
        shar = (fa_dict['fa_sharL'] * dmn)
        priv = (dmn - shar)
        ReSim.main = False

    main_shar_spks = sdf.rebin_spks(shar + mn)
    main_priv_spks = sdf.rebin_spks(priv + mn)

    if process_to_ix is None:
        proc_ix = main_shar_spks.shape[0]
    else:
        proc_ix = process_to_ix

    
    ReSim.add_input(ReSim.spike_counts[:proc_ix, :, :], 'all')
    ReSim.add_input(main_shar_spks[:proc_ix, :, :], 'main_shar')
    ReSim.add_input(main_priv_spks[:proc_ix, :, :], 'main_priv')

    #Now plot trials: 
    gr_ix = np.array([(hdf.root.task_msgs[j-3]['time'], i['time']) for j, i 
        in enumerate(hdf.root.task_msgs) if i['msg'] == 'reward'])

    targ_pos = hdf.root.task[gr_ix[:50,1]]['target']
    targ_ix = pa.get_target_ix(targ_pos[:,[0, 2]])
    ReSim.target_ix = targ_ix
    ReSim.targ_pos_trunc = targ_pos
    return ReSim, gr_ix, proc_ix
def psth(Recoder,bins_after_go=60*4, input_type='all'):
    hdf = Recoder.hdf

    #Go indices
    go_ix = np.array([hdf.root.task_msgs[it-3][1] for it, t in enumerate(hdf.root.task_msgs[:]) if t[0] == 'reward'])
    go_ix = np.hstack((go_ix, len(Recoder.decoded_pos[input_type]) ))
    spks = sdf.rebin_spks(Recoder.dec_spk_cnt_bin[input_type])

    targ_ix = pa.get_target_ix(Recoder.target[:, [0,2]])
    psth = dict()

    for ig, go in enumerate(go_ix[:-1]):
        ti = targ_ix[go]
        if go+bins_after_go < len(Recoder.decoded_pos[input_type]):
            sp = np.squeeze(spks[go:go+bins_after_go, : , 0])

            if int(ti) in psth.keys():
                psth[int(ti)] = np.dstack((psth[int(ti)], sp))

            else:
                psth[int(ti)] = sp
    return psth
def process_targets(te_list, new_hdf_name):
    '''
    Summary: Yields an array of trial information in 
        pytables with columns seen above in Trial_Metrics table class

    Input param: hdf: hdf for either fa_bmi or bmi_resetting
    Output param: trial dictionary
    '''
    h5file = tables.openFile(new_hdf_name, mode="w", title='FA Trial Analysis')
    trial_mets_table = h5file.createTable("/", 'trial_metrics', Trial_Metrics, "Trial Mets")
    meta_table = h5file.createTable("/", 'meta_metrics', Meta_Metrics, "Meta Mets")

    row_cnt = -1

    for te in te_list:

        te_obj = dbfn.TaskEntry(te)
        task_entry = te

        hdf = te_obj.hdf

        #Extract trials ignoring assist: 
        rew_ix, rew_per_min = pa.get_trials_per_min(hdf, nmin=2, rew_per_min_cutoff=0, 
            ignore_assist=True, return_rpm=True)

        #Go backwards 3 steps (rew --> targ trans --> hold --> target (onset))
        go_ix = np.array([hdf.root.task_msgs[it-3][1] for it, t in enumerate(hdf.root.task_msgs[:]) if 
            scipy.logical_and(t[0] == 'reward', t[1] in rew_ix)])

        assert len(rew_ix) == len(go_ix)#: raise Exception("Rew ix and Go ix are unequal")

        #Time to target is diff b/w target onset and reward time
        time2targs = (rew_ix - go_ix)/60.

        #Add buffer of 5 task steps (80 ms)
        target_locs = hdf.root.task[go_ix.astype(int) + 5]['target'][:, [0, 2]]

        try:
            obstacle_sz = hdf.root.task[go_ix.astype(int) + 5]['obstacle_size'][:, 0]
        except:
            obstacle_sz = np.zeros_like(go_ix) 

        #Targ ix -- should be 8
        targ_ixs = pa.get_target_ix(target_locs)
        assert len(np.unique(targ_ixs)) == 8#: raise Exception("Target Ix has more or less than 8 targets :/ ")

        #Input type: 
        try:
            input_types = hdf.root.task[go_ix.astype(int)+5]['fa_input']
        except:
            input_types = ['all']*len(go_ix)

        rew_time_2_trial = {}

        for ri, r_ix in enumerate(rew_ix):
            trl = trial_mets_table.row
            row_cnt += 1

            trl['trial_number'] = ri
            trl['task_entry'] = task_entry
            trl['rew_per_min'] = rew_per_min[ri]
            trl['target_index'] = targ_ixs[ri]
            trl['target_loc'] = target_locs[ri,:]
            trl['start_time'] = go_ix[ri]/60. # In seconds
            trl['input_type'] = input_types[ri]
            trl['time2targ'] = time2targs[ri]
            rew_time_2_trial[r_ix] = time2targs[ri]

            trl['obstacle_size'] = obstacle_sz[ri]


            path = hdf.root.task[int(go_ix[ri]):int(r_ix)]['cursor'][:, [0, 2]]
            path_length, path_error, avg_speed = path_metrics(path, target_locs[ri, :])

            trl['path_length'] = path_length
            trl['path_error'] = path_error
            trl['avg_speed'] = avg_speed

            trl['timeout_time'] = hdf.root.task.attrs.timeout_time
            trl['timeout_penalty_time'] = hdf.root.task.attrs.timeout_penalty_time

            trl.append()

        trial_mets_table.flush()

        #Meta metrics for hdf: 
        block_len = hdf.root.task.shape[0]
        wind_len = 5*60*60
        wind_step = 2.5*60*60

        meta_ix = np.arange(0, block_len - wind_len, wind_step)
        trial_ix = np.array([i for i in hdf.root.task_msgs[:] if 
            i['msg'] in ['reward','timeout_penalty','hold_penalty','obstacle_penalty'] ], dtype=hdf.root.task_msgs.dtype)

        for m in meta_ix:
            meta_trl = meta_table.row
            meta_trl['task_entry'] = te
            end_meta_ix = m + wind_len

            trial_ix_time = trial_ix[:,]
            msg_ix = np.nonzero(np.logical_and(trial_ix['time']<=end_meta_ix, trial_ix['time']>m))[0]

            targ = hdf.root.task[trial_ix[msg_ix]['time'].astype(int)]['target'][:,[0,2]]
            targ_ix = pa.get_target_ix(targ)

            #Only mark as correct if it's the first correct trial
            targ_ix_ix = [0]
            tg_prev = targ_ix[0]
            for ii, tg_ix in enumerate(targ_ix[1:]):
                if tg_prev != tg_ix:
                    targ_ix_ix.append(ii+1)
                    tg_prev = tg_ix

            targ_ix_mod = targ_ix[targ_ix_ix]
            msg_ix_mod = msg_ix[targ_ix_ix]

            targ_percent_success = np.zeros((8, )) - 1
            all_perc_succ_lte10sec = np.zeros((8, )) - 1

            #time2targs for rew_ix:
            rew_ix_meta = np.array([i['time'] for i in trial_ix[msg_ix] if i['msg']=='reward'])

            for t in range(8):
                t_ix = np.nonzero(targ_ix_mod==t)[0]
                msgs = trial_ix[msg_ix_mod[t_ix]]

                if len(msgs) > 0:
                    targ_percent_success[t] = len(np.nonzero(msgs['msg']=='reward')[0]) / float(len(msgs))
                    msgs_copy = msgs.copy()
                    for mc in msgs_copy:
                        if mc['msg'] == 'reward':
                            r_tm = mc['time']
                            if r_tm in rew_time_2_trial.keys():
                                if rew_time_2_trial[r_tm] > 10.:
                                    msgs_copy['msg'] = 'not_reward'
                                    print 'not reward: ', m, t, mc
                            else:
                                print 'asissted: ', m, t, mc
                                msgs_copy['msg'] = 'not_rewarded'
                    all_perc_succ_lte10sec[t] = len(np.nonzero(msgs_copy['msg'] == 'reward')[0]) / float(len(msgs_copy))
                    if all_perc_succ_lte10sec[t] < 0:
                        print 'error: ', m, t, mc
                else: print 'no msgs for this epoch: ', m
            meta_trl['targ_percent_success'] = targ_percent_success
            meta_trl['all_perc_succ_lte10sec'] = all_perc_succ_lte10sec

            all_msg = trial_ix[msg_ix_mod]
            meta_trl['all_percent_success'] = len(np.nonzero(all_msg['msg']=='reward')[0]) / float(len(all_msg))
            meta_trl.append()

        meta_table.flush()

        #How many 5 min segments in 2.5 min steps


    h5file.close()

    return new_hdf_name
def shared_vs_priv_vel_activ(Recoder, input_types=['private', 'shared']):

    tiny = 1e-12
    hdf = Recoder.hdf

    #Go indices
    go_ix = np.array([hdf.root.task_msgs[it-3][1] 
        for it, t in enumerate(hdf.root.task_msgs[:]) 
        if t[0] == 'reward'])

    #Rew indices
    rew_ix = np.array([t[1]
        for it, t in enumerate(hdf.root.task_msgs[:])
        if t[0]=='reward'])


    go_ix = np.hstack((go_ix, len(Recoder.decoded_pos) ))
    targ_ix = pa.get_target_ix(Recoder.target[go_ix+5, [0,2]])
    
    sot_start = dict()
    sot_end = dict()

    f, ax_start = plt.subplots(nrows=4, ncols=2)
    f, ax_end = plt.subplots(nrows=4, ncols=2)

    tg_key = dict()

    for i_t, t in enumerate(np.unique(targ_ix[go_ix])):
        sot_start[str(int(t))] = []
        sot_end[str(int(t))] = []
        tg_key[t] = i_t

    for ig, go in enumerate(go_ix[:-1]):
        if go < len(R_priv.decoded_pos):

            #Speed calc: 
            traj_priv = np.sqrt(np.sum((R_priv.decoded_vel[go:go_ix[ig+1], [0, 2]].copy())**2, axis=1))
            traj_shar = np.sqrt(np.sum((R_shar.decoded_vel[go:go_ix[ig+1], [0, 2]].copy())**2, axis=1))


            tg_ix = targ_ix[go]
            
            tmp = traj_shar / (tiny + traj_shar + traj_priv)
            
            sot_start[str(int(tg_ix))].append(tmp[:120])
            sot_end[str(int(tg_ix))].append(tmp[-120:])

    for i_t, t in enumerate(np.unique(targ_ix[go_ix[:-1]])):
        ky = str(int(t))
        axi = ax_start[i_t%4, i_t/4]
        axj = ax_end[i_t%4, i_t/4]
        
        tmp_start = np.vstack(( sot_start[ky] ))
        tmp_end = np.vstack(( sot_end[ky] ))

        x = np.arange(120)/60.

        sem_start = np.std(tmp_start, axis=0)/np.sqrt(tmp_start.shape[0])
        sem_end = np.std(tmp_end, axis=0)/np.sqrt(tmp_end.shape[0])

        axi.plot(x, np.mean(tmp_start, axis=0), 'b-')
        axi.fill_between(x, np.mean(tmp_start, axis=0) - sem_start, 
            np.mean(tmp_start, axis=0) + sem_start, alpha=.5, color='b')

        axj.plot(x, np.mean(tmp_end, axis=0), 'r-')
        axj.fill_between(x, np.mean(tmp_end, axis=0) - sem_end, 
            np.mean(tmp_end, axis=0) + sem_end, alpha=.5, color='r')

        axi.set_ylim([.5, .75])
        axj.set_ylim([.5, .75])
def PSTH(fa_hdf,task_tbl_met = 'spike_counts', plot_by_neur=False, save_f=False):
    epoch_ix = np.array([ [fa_hdf.root.task_msgs[j-3]['time'], i['time']]
        for j, i in enumerate(fa_hdf.root.task_msgs[:]) if i['msg']=='reward'])

    cursor_pos = fa_hdf.root.task[:]['cursor'][:, [0, 2]]
    try:
        spike_counts = fa_hdf.root.task[:][task_tbl_met][:,:,0] #Time x units
    except:
        spike_counts = fa_hdf.root.task[:]['spike_counts'][:,:,0]
        print 'USING SPIKE COUNTS from HDF FILE'
    st = np.nonzero(np.sum(spike_counts, axis=1))[0][0]
    st = st % 6
    update_ix = np.arange(st, spike_counts.shape[0], 6)

    ixx = []
    trl = []
    tg = []
    for i, (g,r) in enumerate(epoch_ix):
        u_ix = np.nonzero(np.logical_and(update_ix<=r, update_ix>g))[0]
        ixx.append(update_ix[u_ix])
        if u_ix[0] >=4:
            trl.append(spike_counts[update_ix[u_ix[0]-4:u_ix[0]+12], :])
            tg.append(fa_hdf.root.task[update_ix[u_ix[0]+3]]['target'][[0, 2]])
    
    IX = np.hstack((ixx))
    if len(trl) > 0:
        TRL = np.dstack((trl))
        TG_IX = pa.get_target_ix(np.vstack((tg)))
        x_ax = np.arange(-.4, 1.2, .1)
        n_neurons = TRL.shape[1]
        cnt = -1
        if plot_by_neur:
            u_col = get_cmap(len(np.unique(TG_IX)))
            for n in range(n_neurons):
                print n
                if not (n % 9):
                    try:
                        plt.tight_layout()
                        if save_f: 
                            f.savefig(fa_hdf.filename[:-4]+'_PSTH_by_neur_'+str(cnt)+'.png', format='png')
                    except:
                        pass
                    print n
                    f, ax = plt.subplots(nrows=3, ncols=3)
                    cnt += 1
                axi = ax[(n%9)/3, (n%9)%3]
                if not ((n%9)%3):
                    axi.set_ylabel('FR')
                if (n%9)/3 == 2:
                    axi.set_xlabel('Sec w.r.t. Target Onset')

                for ii, it in enumerate(np.unique(TG_IX)):
                    ix = np.nonzero(TG_IX==it)[0]
                    fr = TRL[:, n, ix].T
                    mn = np.mean(fr, axis=0)
                    axi.plot(x_ax, mn, color=u_col[ii])
                    axi.set_title('Unit '+str(n))

        else:
            u_col = get_cmap(n_neurons)
            f, ax = plt.subplots(nrows = 3, ncols = 3)
            for ii, it in enumerate(np.unique(TG_IX)):
                axix = bha.targ_ix_to_3x3_subplot(it)
                axi = ax[axix[0], axix[1]]
                ix = np.nonzero(TG_IX==it)[0]

                fr = TRL[:, :, ix].T #Trials x units x time

                for u in range(n_neurons):
                    mn = np.mean(fr[:, u, :], axis=0)
                    axi.plot(x_ax, mn, color=u_col[ii])
def get_files(te_num=None, get_psth=False):
    if te_num is None or te_num == 4048:
        #Good centerout:
        te = 4048
        hdf = tables.openFile(os.path.expandvars('$FA_GROM_DATA/grom20160201_04_te4048.hdf'))
        dec = pickle.load(open(os.path.expandvars('$FA_GROM_DATA/grom20160201_01_RMLC02011515.pkl')))
    elif te_num == 4526:
        #Learning obstacles:
        hdf = tables.openFile(os.path.expandvars('$FA_GROM_DATA/grom20160317_11_te4526.hdf'))
        dec = pickle.load(open(os.path.expandvars('$FA_GROM_DATA/grom20160317_03_test03171536.pkl')))        
    elif te_num == 2991:
        #Suraj's BMI holding
        hdf = tables.openFile(os.path.expandvars('$FA_GROM_DATA/grom20150414_04.hdf'))
        dec = pickle.load(open(os.path.expandvars('$FA_GROM_DATA/grom20150414_02_RMLC04141528.pkl'))) 
    
    elif te_num == 4549:
        #Better obstacles
        hdf = tables.openFile(os.path.expandvars('$FA_GROM_DATA/grom20160319_11_te4549.hdf'))
        dec = pickle.load(open(os.path.expandvars('$FA_GROM_DATA/grom20160319_09_test03191611.pkl')))           
    
    else:
        from db import dbfunctions
        te = dbfunctions.TaskEntry(te_num)
        hdf = te.hdf
        dec = te.decoder

    #Use reward to calc fa_dict:
    fa_dict = factor_analysis_tasks.FactorBMIBase.generate_FA_matrices(None, hdf=hdf, dec=dec)

    #Calculate spikes to transform:
    drives_neurons_ix0 = 3
    internal_state = hdf.root.task[:]['internal_decoder_state']
    update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1
    spike_i = hdf.root.task[:]['spike_counts'][:,:,0]
    bin_spk_i = np.zeros((len(update_bmi_ix), spike_i.shape[1]))

    if get_psth:
        n_units = dec.n_units
        go_ix = np.array([hdf.root.task_msgs[i-3]['time'] for i, j in enumerate(hdf.root.task_msgs) if j['msg'] == 'reward'])
        rew_ix = np.array([j['time'] for i, j in enumerate(hdf.root.task_msgs) if j['msg'] == 'reward'])

        targ = hdf.root.task[go_ix+10]['target'][:,[0, 2]]
        targ_ix = pa.get_target_ix(targ)

        d = {}
        psth = {}
        psth_max = {}
        psth_shar = {}
        psth_priv = {}

        for t in np.unique(targ_ix):
            t_ix = np.nonzero(targ_ix==t)[0]
            min_len = 10e10
            max_len = 0
            for i_t, trl in enumerate(t_ix):
                bin = []
                g = go_ix[trl]
                r = rew_ix[trl]

                start_bin = np.nonzero(update_bmi_ix>=g)[0][0]
                end_bin = np.nonzero(update_bmi_ix<=r)[0][-1]
                for ib, i_ix in enumerate(update_bmi_ix[start_bin:end_bin+1]):
                    bin.append(np.sum(spike_i[i_ix-5:i_ix+1, :], axis=0))

                min_len = np.min([min_len, len(bin)])
                max_len = np.max([max_len, len(bin)])

                if i_t == 0:
                    d[t] = [np.vstack(bin)]
                else:
                    d[t].append(np.vstack(bin))
            tmp = []
            tmp_max = []

            for b in d[t]:

                #Share b: 


                tmp.append(b[:min_len, :])

                mx = np.zeros((max_len, n_units))
                mx[:,:] = np.nan
                mx[:b.shape[0], :] = b
                tmp_max.append(mx)

            #Bins x neurons x trials
            spk = np.dstack((tmp))
            dmn = spk - np.tile(fa_dict['fa_mu'][np.newaxis, :, :], [spk.shape[0], 1, spk.shape[2]])
            dmn_resh = dmn.reshape(dmn.shape[1], dmn.shape[0]*dmn.shape[2])
            shar_resh = np.array((fa_dict['fa_main_shared'] * dmn_resh))
            shar = shar_resh.reshape(dmn.shape[0], dmn.shape[1], dmn.shape[2])
            priv = dmn - shar

            spk_max = np.dstack((tmp_max))


            psth[t] = np.mean(spk, axis=2)[:, :, np.newaxis]
            psth_max[t] = np.nanmean(np.dstack((tmp_max)), axis=2)[:, :, np.newaxis]


            psth_shar[t] = np.mean(shar, axis=2)[:, :, np.newaxis]
            psth_priv[t] = np.mean(priv, axis=2)[:, :, np.newaxis]
            
        return psth, psth_max, psth_shar, psth_priv, hdf, dec, targ, targ_ix


    else:
        #binned spikes: 
        for ib, i_ix in enumerate(update_bmi_ix):
            bin_spk_i[ib, :]= np.sum(spike_i[i_ix-5:i_ix+1, :], axis=0)

        return fa_dict, bin_spk_i, hdf, dec
Beispiel #11
0
def plot_traj(R, plot_pos=1, plot_vel=1, plot_force=0, it_cutoff=20000, 
    min_it_cutoff=0, input_type='all'):
    
    rew_ix = pa.get_trials_per_min(R.hdf,nmin=2)
    go_ix = np.array([R.hdf.root.task_msgs[it-3][1] for it, t in enumerate(R.hdf.root.task_msgs[:]) if t[0] == 'reward'])
    
    #Make sure no assist is used 
    try:
        zero_assist_start = np.nonzero(R.hdf.root.task[:]['assist_level']==0)[0][0]
    except:
        print 'ignoring assist'
        zero_assist_start = 0
    keep_ix = scipy.logical_and(go_ix>np.max([min_it_cutoff, zero_assist_start+(60*60)]), go_ix<it_cutoff)
    go_ix = go_ix[keep_ix]

    rew_ix = rew_ix[keep_ix]

    targ_pos = R.hdf.root.task[go_ix.astype(int)+5]['target']
    targ_ix = pa.get_target_ix(targ_pos[:,[0, 2]])

    if plot_pos == 1:
        f, ax = plt.subplots(nrows=4, ncols=2)
        f2, ax2 = plt.subplots(nrows=4, ncols=2)
        f0, ax0 = plt.subplots()
    
    if plot_vel == 1:
        f3, ax3 = plt.subplots(nrows=4, ncols=2)
        f4, ax4 = plt.subplots(nrows=4, ncols=2)        

    if plot_force ==1:
        f5, ax5 = plt.subplots(nrows=4, ncols=2)
        f6, ax6 = plt.subplots(nrows=4, ncols=2)
    #R = pickle.load(open(pkl_name))

    for i, (g,r) in enumerate(zip(go_ix, rew_ix)):
        #Choose axis: 
        targ = targ_ix[i]
        
        if plot_pos == 1:
            axi = ax[targ%4, targ/4]
            axi2 = ax2[targ%4, targ/4]

            axi.plot(R.cursor_pos[g:r, 0], 'k-')
            axi2.plot(R.cursor_pos[g:r, 2], 'k-')

            axi.plot(R.decoded_pos[input_type][g:r, 0], 'b-')
            axi2.plot(R.decoded_pos[input_type][g:r, 2], 'b-')
            
            ax0.plot(R.cursor_pos[g:r,0], R.cursor_pos[g:r, 2], '-', color='k')
            ax0.plot(R.decoded_pos[input_type][g:r,0], R.decoded_pos[input_type][g:r, 2], '--', color=cmap_list[int(targ)])


            if i==0:
                axi.set_title('Position X')
                axi2.set_title('Position Z')

        if plot_vel == 1:
            axi3 = ax3[targ%4, targ/4]
            axi4 = ax4[targ%4, targ/4]
            axi3.plot(R.cursor_vel[g:r, 0], 'k-')
            axi4.plot(R.cursor_vel[g:r, 2], 'k-')

            axi3.plot(R.decoded_vel[input_type][g:r, 0], 'b-')
            axi4.plot(R.decoded_vel[input_type][g:r, 2], 'b-')
            if i==0:
                axi3.set_title('Vel X')
                axi4.set_title('Vel Z')

        if plot_force ==1:
            axi5 = ax5[targ%4, targ/4]
            axi6 = ax6[targ%4, targ/4]
            axi5.plot(R.dec_state_mn[input_type][g-1:r-1, 9], 'b-')
            axi6.plot(R.dec_state_mn[input_type][g-1:r-1, 11], 'b-')

            axi5.plot(hdf.root.task[g:r]['internal_decoder_state'][:,9], 'k-')
            axi6.plot(hdf.root.task[g:r]['internal_decoder_state'][:,11], 'k-')
            if i==0:
                axi5.set_title('Acc X')
                axi6.set_title('Acc Z')

    plt.tight_layout()
Beispiel #12
0
def traj_sem_plot(R, input_type_root, fa_by_targ=False, it_start=0, it_cutoff=20000, ax=None, 
    rm_assist=True, extra_title_text=''):
    rew_ix = pa.get_trials_per_min(R.hdf,nmin=2)
    go_ix = np.array([R.hdf.root.task_msgs[it-3][1] for it, t in enumerate(R.hdf.root.task_msgs[:]) if t[0] == 'reward'])

    if rm_assist:
        zero_assist_start = np.nonzero(R.hdf.root.task[:]['assist_level']==0)[0]
        if len(zero_assist_start) > 0:
            zero_assist_start = zero_assist_start[0]
        else:
            Exception('Assist is never == 0!')
    else:
        zero_assist_start = 0

    start_ix = np.max([zero_assist_start, it_start])
    keep_ix = scipy.logical_and(go_ix>start_ix, go_ix<it_cutoff)
    go_ix = go_ix[keep_ix]
    rew_ix = rew_ix[keep_ix]

    t2targ = rew_ix - go_ix

    targ_pos = R.hdf.root.task[go_ix.astype(int)+5]['target']
    targ_ix = pa.get_target_ix(targ_pos[:,[0, 2]])

    #For each target, make array of trials x time (min length for reward trial)
    trial_arr = dict()

    if ax is None:
        f, ax = plt.subplots()
    
    #For each target, concat trials: 
    for it, t in enumerate(np.unique(targ_ix)):
        if fa_by_targ:
            input_type = input_type_root + '_'+str(t)
        else:
            input_type = input_type_root
        print input_type, R.decoded_pos[input_type].shape

        t_ix = np.nonzero(targ_ix==t)[0]
        trial_arr[str(t)] = np.zeros((len(t_ix), np.max(t2targ[t_ix]), 2))
        trial_arr[str(t)][:,:] = np.nan
        cix = 0

        for i, ti in enumerate(t_ix):
            if rew_ix[ti]<= R.decoded_pos[input_type].shape[0]:
                cix = go_ix[ti].copy()
                trial_arr[str(t)][i,:t2targ[ti], :] = R.decoded_pos[input_type][go_ix[ti]:rew_ix[ti],[0,2]]
                ax.plot(trial_arr[str(t)][i,:,0], trial_arr[str(t)][i,:,1], '-', color='lightgrey')


        #Plot Mean, sem: 
        mean = np.nanmean(trial_arr[str(t)], axis=0)
        sem = np.nanstd(trial_arr[str(t)], axis=0)/np.sqrt(trial_arr[str(t)].shape[0])

        ax.plot(mean[:,0], mean[:,1], '-',color=cmap_list[it])
        ax.plot(mean[:,0] - sem[:,0], mean[:,1] - sem[:,1], '--',color=cmap_list[it])
        ax.plot(mean[:,0] + sem[:,0], mean[:,1] + sem[:,1], '--',color=cmap_list[it])

        tmp_circ = plt.Circle(targ_pos[t_ix[0],[0,2]], R.target_rad, color=cmap_list[it], alpha=.5)
        ax.add_artist(tmp_circ)
        ax.set_title(input_type_root + extra_title_text)
        ax.set_xlim([-14, 14])
        ax.set_ylim([-12, 12])
    return ax