예제 #1
0
def add_fa_dict_to_decoder(decoder_training_te, dec_ix, fa_te, return_dec=False):
    # First make sure we're training from the correct task entry: spike counts n_units == BMI units
    from db import dbfunctions as dbfn

    te = dbfn.TaskEntry(fa_te)
    hdf = te.hdf
    sc_n_units = hdf.root.task[0]["spike_counts"].shape[0]

    from db.tracker import models

    te_arr = models.Decoder.objects.filter(entry=decoder_training_te)
    search_flag = 1
    for te in te_arr:
        ix = te.path.find("_")
        if search_flag:
            if int(te.path[ix + 1 : ix + 3]) == dec_ix:
                decoder_old = te
                search_flag = 0

    if search_flag:
        raise Exception("No decoder from ", str(decoder_training_te), " and matching index: ", str(dec_ix))

    from tasks.factor_analysis_tasks import FactorBMIBase

    FA_dict = FactorBMIBase.generate_FA_matrices(fa_te)

    import pickle

    dec = pickle.load(open(decoder_old.filename))
    dec.trained_fa_dict = FA_dict
    dec_n_units = dec.n_units

    if dec_n_units != sc_n_units:
        raise Exception("Cant use TE for BMI training and FA training -- n_units mismatch")
    if return_dec:
        return dec
    else:
        from db import trainbmi

        trainbmi.save_new_decoder_from_existing(dec, decoder_old, suffix="_w_fa_dict_from_" + str(fa_te))
예제 #2
0
def main():

    vfb_file = os.path.expandvars('$FA_GROM_DATA/sims/input_output_analysis/enc052016_0140vfb60_15_15_10_.hdf')
    vfb = tables.openFile(vfb_file)
    vfb_ix = np.arange(0, vfb.root.task.shape[0], 6)

    from riglib.bmi import state_space_models 
    ssm = state_space_models.StateSpaceEndptVel2D()
    A, _, W = ssm.get_ssm_matrices()
    mean = np.zeros(A.shape[0])

    mins = 5.
    n_samples = np.min([int(mins*60*10), len(vfb_ix)])
    print 'n samples: ', n_samples

    number_of_reps = 1
    dat_dict = {}

    for i_m, master_encoder in enumerate(master_encoder_fnames):
        encoder = pickle.load(open(master_encoder))
        units = encoder.get_units()
        n_units = len(units)

        for qri, input_device in enumerate(input_devices):
            for shar_unt in shar_unt_arr:
                for priv_unt in priv_unt_arr:
                    for p2s in priv_to_shar:
                        priv_inp = []
                        shar_inp = []
                        all_inp = []

                        for n in range(number_of_reps):
                            tun = 1. - priv_unt - shar_unt
                            w = np.array([priv_unt, p2s*tun, shar_unt, (1.-p2s)*tun])
                            encoder.wt_sources = w
                            print encoder.wt_sources

                            spike_counts = np.zeros([n_units, n_samples])
                            priv_counts_tun = np.zeros_like(spike_counts)
                            priv_counts_unt = np.zeros_like(spike_counts)
                            shar_counts_tun = np.zeros_like(spike_counts)
                            shar_counts_unt = np.zeros_like(spike_counts)
                            shar_counts = np.zeros_like(spike_counts)
                            priv_counts = np.zeros_like(spike_counts)
                            ctl = np.zeros((7, n_samples ))

                            encoder.call_ds_rate = 1
                            state_samples = np.random.multivariate_normal(mean, W, n_samples)

                            for k in range(n_samples):
                                current_state = vfb.root.task[vfb_ix[k]]['internal_decoder_state']
                                target_state = vfb.root.task[vfb_ix[k]]['target']
                                targ = np.vstack((target_state[:, np.newaxis], np.zeros((4, 1)), ))
                                targ[-1, 0] = 1
                                ctrl = input_device.calc_next_state(current_state, targ)/3.
                                ctrl[[0, 1, 2, 4]] = 0
                                #ctrl = np.mat(state_samples[k, :]).T
                                z = np.array(encoder(ctrl, mode='counts'))
                                spike_counts[:, k] = np.squeeze(z)
                                priv_counts_tun[:, k] = (w[1]*encoder.priv_tun)
                                shar_counts_tun[:, k] = (w[3]*encoder.shar_tun)
                                priv_counts_unt[:, k] = (w[0]*encoder.priv_unt)
                                shar_counts_unt[:, k] = (w[2]*encoder.shar_unt)
                                
                                shar_counts[:, k] = (w[3]*encoder.shar_tun) + (w[2]*encoder.shar_unt)
                                priv_counts[:, k] = (w[1]*encoder.priv_tun) + (w[0]*encoder.priv_unt)
                                ctl[:, k] = np.squeeze(np.array(ctrl))
                            #SOT IN: 
                            sot_in = np.sum(np.var(shar_counts, axis=1))/(np.sum(np.var(shar_counts, axis=1))+
                                np.sum(np.var(priv_counts, axis=1)))

                            #FA on spike counts: 
                            print spike_counts[0, 0]
                            from tasks.factor_analysis_tasks import FactorBMIBase
                            FA_dict = FactorBMIBase.generate_FA_matrices(None, bin_spk=spike_counts.T)

                            sot_out = np.trace(FA_dict['U']*FA_dict['U'].T)/(np.trace(FA_dict['U']*FA_dict['U'].T) + np.trace(FA_dict['Psi']))

                            

                            if n==0:
                                #dat_dict[p2s, priv_unt, shar_unt, 'nf'] = [num_factors]
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'nf_orth'] = [FA_dict['fa_main_shar_n_dim']]
                                
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'priv_inp_tun_var'] = [np.sum(np.var(priv_counts_tun, axis=1))]
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'shar_inp_tun_var'] = [np.sum(np.var(shar_counts_tun, axis=1))]
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'priv_inp_var'] = [np.sum(np.var(priv_counts, axis=1))]
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'priv_inp_unt_var'] = [np.sum(np.var(priv_counts_unt, axis=1))]
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'shar_inp_unt_var'] = [np.sum(np.var(shar_counts_unt, axis=1))]
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'shar_inp_var'] = [np.sum(np.var(shar_counts, axis=1))]
                                
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'priv_out_var'] = [np.trace(FA_dict['Psi'])]
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'shar_out_var'] = [np.trace(FA_dict['U']*FA_dict['U'].T)]

                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'sot_out'] = [sot_out]
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'sot_in'] = [sot_in]
                                #dat_dict[p2s, priv_unt, shar_unt, 'priv_out_calc_var'] = [np.sum(np.var(priv_proj, axis=1))]
                                #dat_dict[p2s, priv_unt, shar_unt, 'shar_out_calc_var'] = [np.sum(np.var(shar_proj, axis=1))]

                            else:
                                #dat_dict[p2s, priv_unt, shar_unt, 'nf'].append(num_factors)
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'nf_orth'].append(FA_dict['fa_main_shar_n_dim'])
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'priv_inp_tun_var'].append(np.sum(np.var(priv_counts_tun, axis=1)))
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'shar_inp_tun_var'].append(np.sum(np.var(shar_counts_tun, axis=1)))
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'priv_inp_var'].append(np.sum(np.var(priv_counts, axis=1)))
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'priv_inp_unt_var'].append(np.sum(np.var(priv_counts_unt, axis=1)))
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'shar_inp_unt_var'].append(np.sum(np.var(shar_counts_unt, axis=1)))
                                
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'shar_inp_var'].append(np.sum(np.var(shar_counts, axis=1)))

                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'priv_out_var'].append(np.sum(FA.noise_variance_))
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'shar_out_var'].append(np.trace(A))
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'sot_out'].append(sot_out)
                                dat_dict[p2s, priv_unt, shar_unt, i_m, qri,  'sot_in'].append(sot_in)
                                #dat_dict[p2s, priv_unt, shar_unt, 'priv_out_calc_var'].append(np.sum(np.var(priv_proj, axis=1)))
                                #dat_dict[p2s, priv_unt, shar_unt, 'shar_out_calc_var'].append(np.sum(np.var(shar_proj, axis=1)))

    ct = datetime.datetime.now()
    p_fname = os.path.expandvars('$FA_GROM_DATA/sims/FR_val/dat_dict_'+ct.strftime("%m%d%y_%H%M")+'.pkl')
    pickle.dump(dat_dict, open(p_fname,'wb'))
def parse_task_entry_halves(te_num, hdf, decoder, epoch_1_end=10., epoch_2_end = 20.):
    drives_neurons_ix0 = 3
    #Get FA dict:
    rew_ix = pa.get_trials_per_min(hdf)
    half_rew_ix = np.floor(len(rew_ix)/2.)
    bin_spk, targ_pos, targ_ix, trial_ix, reach_time, hdf_ix = pa.extract_trials_all(hdf, 
        rew_ix[:half_rew_ix], hdf_ix=True, drives_neurons_ix0=3)

    from tasks.factor_analysis_tasks import FactorBMIBase
    FA_dict = FactorBMIBase.generate_FA_matrices(None, bin_spk=bin_spk)


    #Get BMI update IX: 
    internal_state = hdf.root.task[:]['internal_decoder_state']
    update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1
    epoch1_ix = int(np.nonzero(update_bmi_ix > int(epoch_1_end*60*60))[0][0])
    epoch2_ix = int(np.nonzero(update_bmi_ix > int(epoch_2_end*60*60))[0][0])

    #Get spike coutns and bin them: 
    spike_counts = hdf.root.task[:]['spike_counts'][:,:,0]
    bin_spk_cnts = np.zeros((epoch1_ix, spike_counts.shape[1]))
    bin_spk_cnts2 = np.zeros((epoch2_ix, spike_counts.shape[1]))

    for ib, i_ix in enumerate(update_bmi_ix[:epoch1_ix]):
        #Inclusive of EndBin
        bin_spk_cnts[ib,:] = np.sum(spike_counts[i_ix-5:i_ix+1,:], axis=0)

    for ib, i_ix in enumerate(update_bmi_ix[:epoch2_ix]):
        #Inclusive of EndBin
        bin_spk_cnts2[ib,:] = np.sum(spike_counts[i_ix-5:i_ix+1,:], axis=0)

    kin = hdf.root.task[update_bmi_ix[:epoch1_ix]]['cursor']
    binlen = decoder.binlen
    velocity = np.diff(kin, axis=0) * 1./binlen
    velocity = np.vstack([np.zeros(kin.shape[1]), velocity])
    kin = np.hstack([kin, velocity])

    ssm = decoder.ssm
    units = decoder.units

    #Shared and Scaled Shared Decoders: 
    T = bin_spk_cnts.shape[0]
    demean = bin_spk_cnts.T - np.tile(FA_dict['fa_mu'], [1, T])
    decoder_demn = train.train_KFDecoder_abstract(ssm, kin.T, demean, units, 0.1)
    decoder_demn.kin = kin.T
    decoder_demn.neur = demean
    decoder_demn.target = hdf.root.task[update_bmi_ix[:epoch1_ix]]['target']

    main_shar = FA_dict['fa_main_shared'] * demean
    #main_priv = demean - main_shar
    main_sc_shar = np.multiply(main_shar, np.tile(FA_dict['fa_main_shared_sc'], [1, T]))
    #full_sc = np.multiply(demean, np.tile(FA_dict['fa_main_shared_sc'], [1,T]))
    #main_sc_shar_pls_priv = main_sc_shar + main_priv

    decoder_shar = train.train_KFDecoder_abstract(ssm, kin.T, main_shar, units, 0.1)
    decoder_shar.kin = kin.T
    decoder_shar.neur = main_shar

    decoder_sc_shar = train.train_KFDecoder_abstract(ssm, kin.T, main_sc_shar, units, 0.1)
    decoder_sc_shar.kin = kin.T
    decoder_sc_shar.neur = main_sc_shar
    decs_all = dict(dmn=decoder_demn, shar = decoder_shar, sc_shar = decoder_sc_shar)

    return decoder_full, decoder_shar, decoder_sc_shar, bin_spk_cnts2, epoch1_ix, epoch2_ix, update_bmi_ix, FA_dict
예제 #4
0
def train_FADecoder_from_KF(FA_nfactors, FA_te_id, decoder, use_scaled=True, use_main=True):

    from tasks.factor_analysis_tasks import FactorBMIBase

    FA_dict = FactorBMIBase.generate_FA_matrices(FA_nfactors, FA_te_id)

    # #Now, retrain:
    binlen = decoder.binlen

    from db import dbfunctions as dbfn

    te_id = dbfn.TaskEntry(decoder.te_id)
    files = dict(plexon=te_id.plx_filename, hdf=te_id.hdf_filename)
    extractor_cls = decoder.extractor_cls
    extractor_kwargs = decoder.extractor_kwargs
    kin_extractor = get_plant_pos_vel
    ssm = decoder.ssm
    update_rate = decoder.binlen
    units = decoder.units
    tslice = (0.0, te_id.length)

    ## get kinematic data
    kin_source = "task"
    tmask, rows = _get_tmask(files, tslice, sys_name=kin_source)
    kin = kin_extractor(files, binlen, tmask, pos_key="cursor", vel_key=None)

    ## get neural features
    neural_features, units, extractor_kwargs = get_neural_features(
        files, binlen, extractor_cls.extract_from_file, extractor_kwargs, tslice=tslice, units=units, source=kin_source
    )

    # Get shared input:
    T = neural_features.shape[0]
    demean = neural_features.T - np.tile(FA_dict["fa_mu"], [1, T])

    if use_main:
        main_shar = FA_dict["fa_main_shared"] * demean
        main_priv = demean - main_shar
        FA = FA_dict["FA_model"]

    else:
        shar = FA_dict["fa_sharL"] * demean
        shar_sc = np.multiply(shar, np.tile(FA_dict["fa_shar_var_sc"], [1, T])) + np.tile(FA_dict["fa_mu"], [1, T])
        shar_unsc = shar + np.tile(FA_dict["fa_mu"], [1, T])
        if use_scaled:
            neural_features = shar_sc[:, :-1]
        else:
            neural_features = shar_unsc[:, :-1]

    # Remove 1st kinematic sample and last neural features sample to align the
    # velocity with the neural features
    kin = kin[1:].T

    decoder2 = train_KFDecoder_abstract(ssm, kin, neural_features, units, update_rate, tslice=tslice)
    decoder2.extractor_cls = extractor_cls
    decoder2.extractor_kwargs = extractor_kwargs
    decoder2.te_id = decoder.te_id
    decoder2.trained_fa_dict = FA_dict

    import datetime

    now = datetime.datetime.now()
    tp = now.isoformat()
    import pickle

    fname = os.path.expandvars("$FA_GROM_DATA/decoder_") + tp + ".pkl"
    f = open(fname, "w")
    pickle.dump(decoder2, f)
    f.close()
    return decoder2, fname