Ejemplo n.º 1
0
def fit_FA_w_init(factors=5,ntrials=100, hdf=None):
    rew_ix = pa.get_trials_per_min(hdf)
    bin_spk, targ_pos, targ_ix, trial_ix, reach_time, hdf_ix = pa.extract_trials_all(hdf, 
        rew_ix, hdf_ix=True)

    if ntrials=='max':
        ix = np.arange(len(trial_ix))
    else:
        ix = np.nonzero(trial_ix<=ntrials)[0]
    FA = skdecomp.FactorAnalysis(n_components=factors)
    zsc, mu = pa.zscore_spks(bin_spk[ix,:])
    FA.fit(bin_spk[ix,:])
    return FA, mu, trial_ix, hdf_ix, bin_spk
Ejemplo n.º 2
0
def fit_FA_by_targ(factors=5, ntrials='max', hdf=None):
    rew_ix = pa.get_trials_per_min(hdf)
    bin_spk, targ_pos, targ_ix, trial_ix, reach_time, hdf_ix = pa.extract_trials_all(hdf, 
        rew_ix, hdf_ix=True)

    if ntrials=='max':
        ix_cutoff = len(trial_ix)
    else:
        ix_cutoff = len(np.nonzero(trial_ix<=ntrials)[0])

    FA_dict = dict()
    mu_dict = dict()
    for tg in np.unique(targ_ix):
        tg_ix = np.nonzero(targ_ix==tg)[0]
        tg_ix = tg_ix[tg_ix<ix_cutoff]
        FA = skdecomp.FactorAnalysis(n_components=factors)
        zsc, mu = pa.zscore_spks(bin_spk[tg_ix,:])
        FA.fit(bin_spk[tg_ix,:])

        FA_dict[str(tg)] = FA
        mu_dict[str(tg)] = mu
        print 'done with target: ', str(tg)

    return FA_dict, mu_dict, trial_ix, hdf_ix, bin_spk, targ_ix
Ejemplo n.º 3
0
def process_targets(te_list, new_hdf_name):
    '''
    Summary: Yields an array of trial information in 
        pytables with columns seen above in Trial_Metrics table class

    Input param: hdf: hdf for either fa_bmi or bmi_resetting
    Output param: trial dictionary
    '''
    h5file = tables.openFile(new_hdf_name, mode="w", title='FA Trial Analysis')
    trial_mets_table = h5file.createTable("/", 'trial_metrics', Trial_Metrics, "Trial Mets")
    meta_table = h5file.createTable("/", 'meta_metrics', Meta_Metrics, "Meta Mets")

    row_cnt = -1

    for te in te_list:

        te_obj = dbfn.TaskEntry(te)
        task_entry = te

        hdf = te_obj.hdf

        #Extract trials ignoring assist: 
        rew_ix, rew_per_min = pa.get_trials_per_min(hdf, nmin=2, rew_per_min_cutoff=0, 
            ignore_assist=True, return_rpm=True)

        #Go backwards 3 steps (rew --> targ trans --> hold --> target (onset))
        go_ix = np.array([hdf.root.task_msgs[it-3][1] for it, t in enumerate(hdf.root.task_msgs[:]) if 
            scipy.logical_and(t[0] == 'reward', t[1] in rew_ix)])

        assert len(rew_ix) == len(go_ix)#: raise Exception("Rew ix and Go ix are unequal")

        #Time to target is diff b/w target onset and reward time
        time2targs = (rew_ix - go_ix)/60.

        #Add buffer of 5 task steps (80 ms)
        target_locs = hdf.root.task[go_ix.astype(int) + 5]['target'][:, [0, 2]]

        try:
            obstacle_sz = hdf.root.task[go_ix.astype(int) + 5]['obstacle_size'][:, 0]
        except:
            obstacle_sz = np.zeros_like(go_ix) 

        #Targ ix -- should be 8
        targ_ixs = pa.get_target_ix(target_locs)
        assert len(np.unique(targ_ixs)) == 8#: raise Exception("Target Ix has more or less than 8 targets :/ ")

        #Input type: 
        try:
            input_types = hdf.root.task[go_ix.astype(int)+5]['fa_input']
        except:
            input_types = ['all']*len(go_ix)

        rew_time_2_trial = {}

        for ri, r_ix in enumerate(rew_ix):
            trl = trial_mets_table.row
            row_cnt += 1

            trl['trial_number'] = ri
            trl['task_entry'] = task_entry
            trl['rew_per_min'] = rew_per_min[ri]
            trl['target_index'] = targ_ixs[ri]
            trl['target_loc'] = target_locs[ri,:]
            trl['start_time'] = go_ix[ri]/60. # In seconds
            trl['input_type'] = input_types[ri]
            trl['time2targ'] = time2targs[ri]
            rew_time_2_trial[r_ix] = time2targs[ri]

            trl['obstacle_size'] = obstacle_sz[ri]


            path = hdf.root.task[int(go_ix[ri]):int(r_ix)]['cursor'][:, [0, 2]]
            path_length, path_error, avg_speed = path_metrics(path, target_locs[ri, :])

            trl['path_length'] = path_length
            trl['path_error'] = path_error
            trl['avg_speed'] = avg_speed

            trl['timeout_time'] = hdf.root.task.attrs.timeout_time
            trl['timeout_penalty_time'] = hdf.root.task.attrs.timeout_penalty_time

            trl.append()

        trial_mets_table.flush()

        #Meta metrics for hdf: 
        block_len = hdf.root.task.shape[0]
        wind_len = 5*60*60
        wind_step = 2.5*60*60

        meta_ix = np.arange(0, block_len - wind_len, wind_step)
        trial_ix = np.array([i for i in hdf.root.task_msgs[:] if 
            i['msg'] in ['reward','timeout_penalty','hold_penalty','obstacle_penalty'] ], dtype=hdf.root.task_msgs.dtype)

        for m in meta_ix:
            meta_trl = meta_table.row
            meta_trl['task_entry'] = te
            end_meta_ix = m + wind_len

            trial_ix_time = trial_ix[:,]
            msg_ix = np.nonzero(np.logical_and(trial_ix['time']<=end_meta_ix, trial_ix['time']>m))[0]

            targ = hdf.root.task[trial_ix[msg_ix]['time'].astype(int)]['target'][:,[0,2]]
            targ_ix = pa.get_target_ix(targ)

            #Only mark as correct if it's the first correct trial
            targ_ix_ix = [0]
            tg_prev = targ_ix[0]
            for ii, tg_ix in enumerate(targ_ix[1:]):
                if tg_prev != tg_ix:
                    targ_ix_ix.append(ii+1)
                    tg_prev = tg_ix

            targ_ix_mod = targ_ix[targ_ix_ix]
            msg_ix_mod = msg_ix[targ_ix_ix]

            targ_percent_success = np.zeros((8, )) - 1
            all_perc_succ_lte10sec = np.zeros((8, )) - 1

            #time2targs for rew_ix:
            rew_ix_meta = np.array([i['time'] for i in trial_ix[msg_ix] if i['msg']=='reward'])

            for t in range(8):
                t_ix = np.nonzero(targ_ix_mod==t)[0]
                msgs = trial_ix[msg_ix_mod[t_ix]]

                if len(msgs) > 0:
                    targ_percent_success[t] = len(np.nonzero(msgs['msg']=='reward')[0]) / float(len(msgs))
                    msgs_copy = msgs.copy()
                    for mc in msgs_copy:
                        if mc['msg'] == 'reward':
                            r_tm = mc['time']
                            if r_tm in rew_time_2_trial.keys():
                                if rew_time_2_trial[r_tm] > 10.:
                                    msgs_copy['msg'] = 'not_reward'
                                    print 'not reward: ', m, t, mc
                            else:
                                print 'asissted: ', m, t, mc
                                msgs_copy['msg'] = 'not_rewarded'
                    all_perc_succ_lte10sec[t] = len(np.nonzero(msgs_copy['msg'] == 'reward')[0]) / float(len(msgs_copy))
                    if all_perc_succ_lte10sec[t] < 0:
                        print 'error: ', m, t, mc
                else: print 'no msgs for this epoch: ', m
            meta_trl['targ_percent_success'] = targ_percent_success
            meta_trl['all_perc_succ_lte10sec'] = all_perc_succ_lte10sec

            all_msg = trial_ix[msg_ix_mod]
            meta_trl['all_percent_success'] = len(np.nonzero(all_msg['msg']=='reward')[0]) / float(len(all_msg))
            meta_trl.append()

        meta_table.flush()

        #How many 5 min segments in 2.5 min steps


    h5file.close()

    return new_hdf_name
def parse_task_entry_halves(te_num, hdf, decoder, epoch_1_end=10., epoch_2_end = 20.):
    drives_neurons_ix0 = 3
    #Get FA dict:
    rew_ix = pa.get_trials_per_min(hdf)
    half_rew_ix = np.floor(len(rew_ix)/2.)
    bin_spk, targ_pos, targ_ix, trial_ix, reach_time, hdf_ix = pa.extract_trials_all(hdf, 
        rew_ix[:half_rew_ix], hdf_ix=True, drives_neurons_ix0=3)

    from tasks.factor_analysis_tasks import FactorBMIBase
    FA_dict = FactorBMIBase.generate_FA_matrices(None, bin_spk=bin_spk)


    #Get BMI update IX: 
    internal_state = hdf.root.task[:]['internal_decoder_state']
    update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1
    epoch1_ix = int(np.nonzero(update_bmi_ix > int(epoch_1_end*60*60))[0][0])
    epoch2_ix = int(np.nonzero(update_bmi_ix > int(epoch_2_end*60*60))[0][0])

    #Get spike coutns and bin them: 
    spike_counts = hdf.root.task[:]['spike_counts'][:,:,0]
    bin_spk_cnts = np.zeros((epoch1_ix, spike_counts.shape[1]))
    bin_spk_cnts2 = np.zeros((epoch2_ix, spike_counts.shape[1]))

    for ib, i_ix in enumerate(update_bmi_ix[:epoch1_ix]):
        #Inclusive of EndBin
        bin_spk_cnts[ib,:] = np.sum(spike_counts[i_ix-5:i_ix+1,:], axis=0)

    for ib, i_ix in enumerate(update_bmi_ix[:epoch2_ix]):
        #Inclusive of EndBin
        bin_spk_cnts2[ib,:] = np.sum(spike_counts[i_ix-5:i_ix+1,:], axis=0)

    kin = hdf.root.task[update_bmi_ix[:epoch1_ix]]['cursor']
    binlen = decoder.binlen
    velocity = np.diff(kin, axis=0) * 1./binlen
    velocity = np.vstack([np.zeros(kin.shape[1]), velocity])
    kin = np.hstack([kin, velocity])

    ssm = decoder.ssm
    units = decoder.units

    #Shared and Scaled Shared Decoders: 
    T = bin_spk_cnts.shape[0]
    demean = bin_spk_cnts.T - np.tile(FA_dict['fa_mu'], [1, T])
    decoder_demn = train.train_KFDecoder_abstract(ssm, kin.T, demean, units, 0.1)
    decoder_demn.kin = kin.T
    decoder_demn.neur = demean
    decoder_demn.target = hdf.root.task[update_bmi_ix[:epoch1_ix]]['target']

    main_shar = FA_dict['fa_main_shared'] * demean
    #main_priv = demean - main_shar
    main_sc_shar = np.multiply(main_shar, np.tile(FA_dict['fa_main_shared_sc'], [1, T]))
    #full_sc = np.multiply(demean, np.tile(FA_dict['fa_main_shared_sc'], [1,T]))
    #main_sc_shar_pls_priv = main_sc_shar + main_priv

    decoder_shar = train.train_KFDecoder_abstract(ssm, kin.T, main_shar, units, 0.1)
    decoder_shar.kin = kin.T
    decoder_shar.neur = main_shar

    decoder_sc_shar = train.train_KFDecoder_abstract(ssm, kin.T, main_sc_shar, units, 0.1)
    decoder_sc_shar.kin = kin.T
    decoder_sc_shar.neur = main_sc_shar
    decs_all = dict(dmn=decoder_demn, shar = decoder_shar, sc_shar = decoder_sc_shar)

    return decoder_full, decoder_shar, decoder_sc_shar, bin_spk_cnts2, epoch1_ix, epoch2_ix, update_bmi_ix, FA_dict
Ejemplo n.º 5
0
def plot_traj(R, plot_pos=1, plot_vel=1, plot_force=0, it_cutoff=20000, 
    min_it_cutoff=0, input_type='all'):
    
    rew_ix = pa.get_trials_per_min(R.hdf,nmin=2)
    go_ix = np.array([R.hdf.root.task_msgs[it-3][1] for it, t in enumerate(R.hdf.root.task_msgs[:]) if t[0] == 'reward'])
    
    #Make sure no assist is used 
    try:
        zero_assist_start = np.nonzero(R.hdf.root.task[:]['assist_level']==0)[0][0]
    except:
        print 'ignoring assist'
        zero_assist_start = 0
    keep_ix = scipy.logical_and(go_ix>np.max([min_it_cutoff, zero_assist_start+(60*60)]), go_ix<it_cutoff)
    go_ix = go_ix[keep_ix]

    rew_ix = rew_ix[keep_ix]

    targ_pos = R.hdf.root.task[go_ix.astype(int)+5]['target']
    targ_ix = pa.get_target_ix(targ_pos[:,[0, 2]])

    if plot_pos == 1:
        f, ax = plt.subplots(nrows=4, ncols=2)
        f2, ax2 = plt.subplots(nrows=4, ncols=2)
        f0, ax0 = plt.subplots()
    
    if plot_vel == 1:
        f3, ax3 = plt.subplots(nrows=4, ncols=2)
        f4, ax4 = plt.subplots(nrows=4, ncols=2)        

    if plot_force ==1:
        f5, ax5 = plt.subplots(nrows=4, ncols=2)
        f6, ax6 = plt.subplots(nrows=4, ncols=2)
    #R = pickle.load(open(pkl_name))

    for i, (g,r) in enumerate(zip(go_ix, rew_ix)):
        #Choose axis: 
        targ = targ_ix[i]
        
        if plot_pos == 1:
            axi = ax[targ%4, targ/4]
            axi2 = ax2[targ%4, targ/4]

            axi.plot(R.cursor_pos[g:r, 0], 'k-')
            axi2.plot(R.cursor_pos[g:r, 2], 'k-')

            axi.plot(R.decoded_pos[input_type][g:r, 0], 'b-')
            axi2.plot(R.decoded_pos[input_type][g:r, 2], 'b-')
            
            ax0.plot(R.cursor_pos[g:r,0], R.cursor_pos[g:r, 2], '-', color='k')
            ax0.plot(R.decoded_pos[input_type][g:r,0], R.decoded_pos[input_type][g:r, 2], '--', color=cmap_list[int(targ)])


            if i==0:
                axi.set_title('Position X')
                axi2.set_title('Position Z')

        if plot_vel == 1:
            axi3 = ax3[targ%4, targ/4]
            axi4 = ax4[targ%4, targ/4]
            axi3.plot(R.cursor_vel[g:r, 0], 'k-')
            axi4.plot(R.cursor_vel[g:r, 2], 'k-')

            axi3.plot(R.decoded_vel[input_type][g:r, 0], 'b-')
            axi4.plot(R.decoded_vel[input_type][g:r, 2], 'b-')
            if i==0:
                axi3.set_title('Vel X')
                axi4.set_title('Vel Z')

        if plot_force ==1:
            axi5 = ax5[targ%4, targ/4]
            axi6 = ax6[targ%4, targ/4]
            axi5.plot(R.dec_state_mn[input_type][g-1:r-1, 9], 'b-')
            axi6.plot(R.dec_state_mn[input_type][g-1:r-1, 11], 'b-')

            axi5.plot(hdf.root.task[g:r]['internal_decoder_state'][:,9], 'k-')
            axi6.plot(hdf.root.task[g:r]['internal_decoder_state'][:,11], 'k-')
            if i==0:
                axi5.set_title('Acc X')
                axi6.set_title('Acc Z')

    plt.tight_layout()
Ejemplo n.º 6
0
def traj_sem_plot(R, input_type_root, fa_by_targ=False, it_start=0, it_cutoff=20000, ax=None, 
    rm_assist=True, extra_title_text=''):
    rew_ix = pa.get_trials_per_min(R.hdf,nmin=2)
    go_ix = np.array([R.hdf.root.task_msgs[it-3][1] for it, t in enumerate(R.hdf.root.task_msgs[:]) if t[0] == 'reward'])

    if rm_assist:
        zero_assist_start = np.nonzero(R.hdf.root.task[:]['assist_level']==0)[0]
        if len(zero_assist_start) > 0:
            zero_assist_start = zero_assist_start[0]
        else:
            Exception('Assist is never == 0!')
    else:
        zero_assist_start = 0

    start_ix = np.max([zero_assist_start, it_start])
    keep_ix = scipy.logical_and(go_ix>start_ix, go_ix<it_cutoff)
    go_ix = go_ix[keep_ix]
    rew_ix = rew_ix[keep_ix]

    t2targ = rew_ix - go_ix

    targ_pos = R.hdf.root.task[go_ix.astype(int)+5]['target']
    targ_ix = pa.get_target_ix(targ_pos[:,[0, 2]])

    #For each target, make array of trials x time (min length for reward trial)
    trial_arr = dict()

    if ax is None:
        f, ax = plt.subplots()
    
    #For each target, concat trials: 
    for it, t in enumerate(np.unique(targ_ix)):
        if fa_by_targ:
            input_type = input_type_root + '_'+str(t)
        else:
            input_type = input_type_root
        print input_type, R.decoded_pos[input_type].shape

        t_ix = np.nonzero(targ_ix==t)[0]
        trial_arr[str(t)] = np.zeros((len(t_ix), np.max(t2targ[t_ix]), 2))
        trial_arr[str(t)][:,:] = np.nan
        cix = 0

        for i, ti in enumerate(t_ix):
            if rew_ix[ti]<= R.decoded_pos[input_type].shape[0]:
                cix = go_ix[ti].copy()
                trial_arr[str(t)][i,:t2targ[ti], :] = R.decoded_pos[input_type][go_ix[ti]:rew_ix[ti],[0,2]]
                ax.plot(trial_arr[str(t)][i,:,0], trial_arr[str(t)][i,:,1], '-', color='lightgrey')


        #Plot Mean, sem: 
        mean = np.nanmean(trial_arr[str(t)], axis=0)
        sem = np.nanstd(trial_arr[str(t)], axis=0)/np.sqrt(trial_arr[str(t)].shape[0])

        ax.plot(mean[:,0], mean[:,1], '-',color=cmap_list[it])
        ax.plot(mean[:,0] - sem[:,0], mean[:,1] - sem[:,1], '--',color=cmap_list[it])
        ax.plot(mean[:,0] + sem[:,0], mean[:,1] + sem[:,1], '--',color=cmap_list[it])

        tmp_circ = plt.Circle(targ_pos[t_ix[0],[0,2]], R.target_rad, color=cmap_list[it], alpha=.5)
        ax.add_artist(tmp_circ)
        ax.set_title(input_type_root + extra_title_text)
        ax.set_xlim([-14, 14])
        ax.set_ylim([-12, 12])
    return ax
Ejemplo n.º 7
0
def targ_vs_all_subspace_align(te_list, file_name=None, cycle_FAs=None, epoch_size=50,
    compare_to_orig_te = False):
    '''
    Summary: function that takes a task entry list, parses each task entry into epochs, and computes
        a factor analysis model for each epoch using (1:cycle_FAs = number of factors), and then either
        saves or returns the output and a modified task list

    Input param: te_list: list of task entries (not tiered)
    Input param: file_name: name to save overlap data to afterwards, if None, then returns args
    Input param: cycle_FAs: number of factors to cycle through (1:cycle_FAs) -- uses optimal number if None
    Input param: epoch_size: size of epochs to parse task entries into (uses ONLY reward trials)
    Input param: compare_to_orig_te: if True, then also compares all of the compute FAs to the original 
        FA model in the hdf file of the task entry listed in 'compare to original te'

    Output param: 
    '''

    fa_dict = {}
    te_list = np.array(te_list)
    te_mod_list = []

    #assume no more than 10 FA epochs per task entry
    mod_increment = 0.01
    increment_offs = 0.1
    #For each TE get the right FA model: 
    for te in te_list:

        t = dbfn.TaskEntry(te)
        hdf = t.hdf

        #Now get the right FA model: 
        drives_neurons_ix0 = 3
        rew_ix_total = pa.get_trials_per_min(hdf)

        #Epoch rew_ix into epochs of 'epoch_size' 
        internal_state = hdf.root.task[:]['internal_decoder_state']
        update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1
        more_epochs = 1
        cnt = 0

        while more_epochs:

            #If not enough trials, still use TE, or if not epoch size specified
            if (epoch_size is None) or (len(rew_ix_total) < epoch_size):
                bin_spk, targ_pos, targ_ix, z, zz = pa.extract_trials_all(hdf, rew_ix_total, update_bmi_ix=update_bmi_ix)
                more_epochs = 0
                te_mod = float(te)

            else:
                if (cnt+1)*epoch_size <= len(rew_ix_total):
                    rew_ix = rew_ix_total[cnt*epoch_size:(cnt+1)*epoch_size]
                    if len(rew_ix) != epoch_size:
                        print moose
                    #Only use rew indices for that epoch:
                    bin_spk, targ_pos, targ_ix, z, zz = pa.extract_trials_all(hdf, rew_ix, time_cutoff=1000, update_bmi_ix=update_bmi_ix)
                    te_mod = float(te) + (mod_increment*cnt) + increment_offs
                    cnt += 1

                    #Be done if next epoch won't have enought trials
                    if (cnt+1)*epoch_size > len(rew_ix_total):
                        print (cnt+1)*epoch_size, len(rew_ix_total)
                        more_epochs = 0
    
            #Add modified te to dictionary
            te_mod_list.append(te_mod)

            #Use the bin_spk from above
            zscore_X, mu = pa.zscore_spks(bin_spk)
            n_neurons = zscore_X.shape[1]
            #If no assigned number of factors to cycle through, find optimal
            if cycle_FAs is None:
                log_lik, ax = pa.find_k_FA(zscore_X, iters=10, max_k = n_neurons, plot=False)

                mn_log_like = np.mean(log_lik, axis=0)
                ix = np.argmax(mn_log_like)
                num_factors = ix + 1

                FA_full = skdecomp.FactorAnalysis(n_components = num_factors)
                FA_full.fit(zscore_X)
                fa_dict[te, 0] = FA_full
                fa_dict['units', te] = t.decoder.units

            else:
                for i in np.arange(1, cycle_FAs+1):
                    FA = skdecomp.FactorAnalysis(n_components = i)
                    FA.fit(zscore_X)
                    fa_dict[te_mod, i] = FA
    
    #Now FA dict is completed:
    print 'now computing overlaps', te_mod_list
    
    if file_name is None:
        d, te_mod = ss_overlap(cycle_FAs, te_mod_list, fa_dict, file_name=file_name, 
            compare_to_orig_te=compare_to_orig_te)
        return d, te_mod

    else:
        ss_overlap(cycle_FAs, te_mod_list, fa_dict, file_name=file_name, 
            compare_to_orig_te=compare_to_orig_te)
Ejemplo n.º 8
0
vfb_te = 3997
bmi_cal_te = 3999
bmi_fa_te = 4000

n_factors = [2, 3, 5, 7]
n_reps = 5

import dbfunctions as dbfn
import prelim_analysis as pa
import sklearn.decomposition as skdecomp
from online_analysis import spk_from_vfb as sfv

test_data = dbfn.TaskEntry(bmi_fa_te)
rew_ix = pa.get_trials_per_min(test_data.hdf)
test_bin_spk, targ_pos, targ_ix, z, zz = pa.extract_trials_all(test_data.hdf, rew_ix)

#Training paradigms:
train_data = dbfn.TaskEntry(bmi_cal_te)
train_rew_ix = pa.get_trials_per_min(train_data.hdf)

vfb = dbfn.TaskEntry(vfb_te)
rew_ix = pa.get_trials_per_min(vfb.hdf)
vfb_bin_spk = sfv.get_vfb_spk(vfb, rew_ix, train_data.decoder.units)


train_param = ['vfb']+range(8,64,8)

LL = np.zeros((len(train_param), len(n_factors), n_reps))

for it, tr in enumerate(train_param):
    if tr == 'vfb':