Esempio n. 1
0
def fit_FA_w_init(factors=5,ntrials=100, hdf=None):
    rew_ix = pa.get_trials_per_min(hdf)
    bin_spk, targ_pos, targ_ix, trial_ix, reach_time, hdf_ix = pa.extract_trials_all(hdf, 
        rew_ix, hdf_ix=True)

    if ntrials=='max':
        ix = np.arange(len(trial_ix))
    else:
        ix = np.nonzero(trial_ix<=ntrials)[0]
    FA = skdecomp.FactorAnalysis(n_components=factors)
    zsc, mu = pa.zscore_spks(bin_spk[ix,:])
    FA.fit(bin_spk[ix,:])
    return FA, mu, trial_ix, hdf_ix, bin_spk
Esempio n. 2
0
def fit_FA_by_targ(factors=5, ntrials='max', hdf=None):
    rew_ix = pa.get_trials_per_min(hdf)
    bin_spk, targ_pos, targ_ix, trial_ix, reach_time, hdf_ix = pa.extract_trials_all(hdf, 
        rew_ix, hdf_ix=True)

    if ntrials=='max':
        ix_cutoff = len(trial_ix)
    else:
        ix_cutoff = len(np.nonzero(trial_ix<=ntrials)[0])

    FA_dict = dict()
    mu_dict = dict()
    for tg in np.unique(targ_ix):
        tg_ix = np.nonzero(targ_ix==tg)[0]
        tg_ix = tg_ix[tg_ix<ix_cutoff]
        FA = skdecomp.FactorAnalysis(n_components=factors)
        zsc, mu = pa.zscore_spks(bin_spk[tg_ix,:])
        FA.fit(bin_spk[tg_ix,:])

        FA_dict[str(tg)] = FA
        mu_dict[str(tg)] = mu
        print 'done with target: ', str(tg)

    return FA_dict, mu_dict, trial_ix, hdf_ix, bin_spk, targ_ix
    targ_ix_all = []
    out_all = []
    LR = []
    trl_cnt = 0

    for t in te:
        task_entry = dbfn.TaskEntry(t)
        hdf = task_entry.hdf

        drives_neurons_ix0 = 3
        internal_state = hdf.root.task[:]['internal_decoder_state']
        update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1

        step_dict = dict(reward=3, hold_penalty=2, timeout_penalty=1, obstacle_penalty=1)

        bin_spk, targ_pos, targ_ix, trial_ix, reach_time, hdf_ix = pa.extract_trials_all(hdf, None, 
            update_bmi_ix=update_bmi_ix, hdf_ix=True, rew_pls=True, step_dict=step_dict, time_cutoff=100000)

        outcome_msg= np.array([hdf.root.task_msgs[it]['msg'] for it, m in enumerate(hdf.root.task_msgs[:]) 
            if m['msg'] in step_dict.keys()])

        if len(np.unique(trial_ix))!= len(outcome_msg):
            raise Exception

        cursor = hdf.root.task[hdf_ix]['cursor']

        lr = []
        for tr in np.unique(trial_ix):
            tr_ix = np.nonzero(trial_ix==tr)[0]
            lr.append(get_l_or_r(cursor[tr_ix,:], targ_pos[tr_ix[0], :]))

def parse_task_entry_halves(te_num, hdf, decoder, epoch_1_end=10., epoch_2_end = 20.):
    drives_neurons_ix0 = 3
    #Get FA dict:
    rew_ix = pa.get_trials_per_min(hdf)
    half_rew_ix = np.floor(len(rew_ix)/2.)
    bin_spk, targ_pos, targ_ix, trial_ix, reach_time, hdf_ix = pa.extract_trials_all(hdf, 
        rew_ix[:half_rew_ix], hdf_ix=True, drives_neurons_ix0=3)

    from tasks.factor_analysis_tasks import FactorBMIBase
    FA_dict = FactorBMIBase.generate_FA_matrices(None, bin_spk=bin_spk)


    #Get BMI update IX: 
    internal_state = hdf.root.task[:]['internal_decoder_state']
    update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1
    epoch1_ix = int(np.nonzero(update_bmi_ix > int(epoch_1_end*60*60))[0][0])
    epoch2_ix = int(np.nonzero(update_bmi_ix > int(epoch_2_end*60*60))[0][0])

    #Get spike coutns and bin them: 
    spike_counts = hdf.root.task[:]['spike_counts'][:,:,0]
    bin_spk_cnts = np.zeros((epoch1_ix, spike_counts.shape[1]))
    bin_spk_cnts2 = np.zeros((epoch2_ix, spike_counts.shape[1]))

    for ib, i_ix in enumerate(update_bmi_ix[:epoch1_ix]):
        #Inclusive of EndBin
        bin_spk_cnts[ib,:] = np.sum(spike_counts[i_ix-5:i_ix+1,:], axis=0)

    for ib, i_ix in enumerate(update_bmi_ix[:epoch2_ix]):
        #Inclusive of EndBin
        bin_spk_cnts2[ib,:] = np.sum(spike_counts[i_ix-5:i_ix+1,:], axis=0)

    kin = hdf.root.task[update_bmi_ix[:epoch1_ix]]['cursor']
    binlen = decoder.binlen
    velocity = np.diff(kin, axis=0) * 1./binlen
    velocity = np.vstack([np.zeros(kin.shape[1]), velocity])
    kin = np.hstack([kin, velocity])

    ssm = decoder.ssm
    units = decoder.units

    #Shared and Scaled Shared Decoders: 
    T = bin_spk_cnts.shape[0]
    demean = bin_spk_cnts.T - np.tile(FA_dict['fa_mu'], [1, T])
    decoder_demn = train.train_KFDecoder_abstract(ssm, kin.T, demean, units, 0.1)
    decoder_demn.kin = kin.T
    decoder_demn.neur = demean
    decoder_demn.target = hdf.root.task[update_bmi_ix[:epoch1_ix]]['target']

    main_shar = FA_dict['fa_main_shared'] * demean
    #main_priv = demean - main_shar
    main_sc_shar = np.multiply(main_shar, np.tile(FA_dict['fa_main_shared_sc'], [1, T]))
    #full_sc = np.multiply(demean, np.tile(FA_dict['fa_main_shared_sc'], [1,T]))
    #main_sc_shar_pls_priv = main_sc_shar + main_priv

    decoder_shar = train.train_KFDecoder_abstract(ssm, kin.T, main_shar, units, 0.1)
    decoder_shar.kin = kin.T
    decoder_shar.neur = main_shar

    decoder_sc_shar = train.train_KFDecoder_abstract(ssm, kin.T, main_sc_shar, units, 0.1)
    decoder_sc_shar.kin = kin.T
    decoder_sc_shar.neur = main_sc_shar
    decs_all = dict(dmn=decoder_demn, shar = decoder_shar, sc_shar = decoder_sc_shar)

    return decoder_full, decoder_shar, decoder_sc_shar, bin_spk_cnts2, epoch1_ix, epoch2_ix, update_bmi_ix, FA_dict
Esempio n. 5
0
def targ_vs_all_subspace_align(te_list, file_name=None, cycle_FAs=None, epoch_size=50,
    compare_to_orig_te = False):
    '''
    Summary: function that takes a task entry list, parses each task entry into epochs, and computes
        a factor analysis model for each epoch using (1:cycle_FAs = number of factors), and then either
        saves or returns the output and a modified task list

    Input param: te_list: list of task entries (not tiered)
    Input param: file_name: name to save overlap data to afterwards, if None, then returns args
    Input param: cycle_FAs: number of factors to cycle through (1:cycle_FAs) -- uses optimal number if None
    Input param: epoch_size: size of epochs to parse task entries into (uses ONLY reward trials)
    Input param: compare_to_orig_te: if True, then also compares all of the compute FAs to the original 
        FA model in the hdf file of the task entry listed in 'compare to original te'

    Output param: 
    '''

    fa_dict = {}
    te_list = np.array(te_list)
    te_mod_list = []

    #assume no more than 10 FA epochs per task entry
    mod_increment = 0.01
    increment_offs = 0.1
    #For each TE get the right FA model: 
    for te in te_list:

        t = dbfn.TaskEntry(te)
        hdf = t.hdf

        #Now get the right FA model: 
        drives_neurons_ix0 = 3
        rew_ix_total = pa.get_trials_per_min(hdf)

        #Epoch rew_ix into epochs of 'epoch_size' 
        internal_state = hdf.root.task[:]['internal_decoder_state']
        update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1
        more_epochs = 1
        cnt = 0

        while more_epochs:

            #If not enough trials, still use TE, or if not epoch size specified
            if (epoch_size is None) or (len(rew_ix_total) < epoch_size):
                bin_spk, targ_pos, targ_ix, z, zz = pa.extract_trials_all(hdf, rew_ix_total, update_bmi_ix=update_bmi_ix)
                more_epochs = 0
                te_mod = float(te)

            else:
                if (cnt+1)*epoch_size <= len(rew_ix_total):
                    rew_ix = rew_ix_total[cnt*epoch_size:(cnt+1)*epoch_size]
                    if len(rew_ix) != epoch_size:
                        print moose
                    #Only use rew indices for that epoch:
                    bin_spk, targ_pos, targ_ix, z, zz = pa.extract_trials_all(hdf, rew_ix, time_cutoff=1000, update_bmi_ix=update_bmi_ix)
                    te_mod = float(te) + (mod_increment*cnt) + increment_offs
                    cnt += 1

                    #Be done if next epoch won't have enought trials
                    if (cnt+1)*epoch_size > len(rew_ix_total):
                        print (cnt+1)*epoch_size, len(rew_ix_total)
                        more_epochs = 0
    
            #Add modified te to dictionary
            te_mod_list.append(te_mod)

            #Use the bin_spk from above
            zscore_X, mu = pa.zscore_spks(bin_spk)
            n_neurons = zscore_X.shape[1]
            #If no assigned number of factors to cycle through, find optimal
            if cycle_FAs is None:
                log_lik, ax = pa.find_k_FA(zscore_X, iters=10, max_k = n_neurons, plot=False)

                mn_log_like = np.mean(log_lik, axis=0)
                ix = np.argmax(mn_log_like)
                num_factors = ix + 1

                FA_full = skdecomp.FactorAnalysis(n_components = num_factors)
                FA_full.fit(zscore_X)
                fa_dict[te, 0] = FA_full
                fa_dict['units', te] = t.decoder.units

            else:
                for i in np.arange(1, cycle_FAs+1):
                    FA = skdecomp.FactorAnalysis(n_components = i)
                    FA.fit(zscore_X)
                    fa_dict[te_mod, i] = FA
    
    #Now FA dict is completed:
    print 'now computing overlaps', te_mod_list
    
    if file_name is None:
        d, te_mod = ss_overlap(cycle_FAs, te_mod_list, fa_dict, file_name=file_name, 
            compare_to_orig_te=compare_to_orig_te)
        return d, te_mod

    else:
        ss_overlap(cycle_FAs, te_mod_list, fa_dict, file_name=file_name, 
            compare_to_orig_te=compare_to_orig_te)
vfb_te = 3997
bmi_cal_te = 3999
bmi_fa_te = 4000

n_factors = [2, 3, 5, 7]
n_reps = 5

import dbfunctions as dbfn
import prelim_analysis as pa
import sklearn.decomposition as skdecomp
from online_analysis import spk_from_vfb as sfv

test_data = dbfn.TaskEntry(bmi_fa_te)
rew_ix = pa.get_trials_per_min(test_data.hdf)
test_bin_spk, targ_pos, targ_ix, z, zz = pa.extract_trials_all(test_data.hdf, rew_ix)

#Training paradigms:
train_data = dbfn.TaskEntry(bmi_cal_te)
train_rew_ix = pa.get_trials_per_min(train_data.hdf)

vfb = dbfn.TaskEntry(vfb_te)
rew_ix = pa.get_trials_per_min(vfb.hdf)
vfb_bin_spk = sfv.get_vfb_spk(vfb, rew_ix, train_data.decoder.units)


train_param = ['vfb']+range(8,64,8)

LL = np.zeros((len(train_param), len(n_factors), n_reps))

for it, tr in enumerate(train_param):
    if tr == 'vfb':