def add_fa_dict_to_decoder(decoder_training_te, dec_ix, fa_te, return_dec=False): # First make sure we're training from the correct task entry: spike counts n_units == BMI units from db import dbfunctions as dbfn te = dbfn.TaskEntry(fa_te) hdf = te.hdf sc_n_units = hdf.root.task[0]["spike_counts"].shape[0] from db.tracker import models te_arr = models.Decoder.objects.filter(entry=decoder_training_te) search_flag = 1 for te in te_arr: ix = te.path.find("_") if search_flag: if int(te.path[ix + 1 : ix + 3]) == dec_ix: decoder_old = te search_flag = 0 if search_flag: raise Exception("No decoder from ", str(decoder_training_te), " and matching index: ", str(dec_ix)) from tasks.factor_analysis_tasks import FactorBMIBase FA_dict = FactorBMIBase.generate_FA_matrices(fa_te) import pickle dec = pickle.load(open(decoder_old.filename)) dec.trained_fa_dict = FA_dict dec_n_units = dec.n_units if dec_n_units != sc_n_units: raise Exception("Cant use TE for BMI training and FA training -- n_units mismatch") if return_dec: return dec else: from db import trainbmi trainbmi.save_new_decoder_from_existing(dec, decoder_old, suffix="_w_fa_dict_from_" + str(fa_te))
def main(): vfb_file = os.path.expandvars('$FA_GROM_DATA/sims/input_output_analysis/enc052016_0140vfb60_15_15_10_.hdf') vfb = tables.openFile(vfb_file) vfb_ix = np.arange(0, vfb.root.task.shape[0], 6) from riglib.bmi import state_space_models ssm = state_space_models.StateSpaceEndptVel2D() A, _, W = ssm.get_ssm_matrices() mean = np.zeros(A.shape[0]) mins = 5. n_samples = np.min([int(mins*60*10), len(vfb_ix)]) print 'n samples: ', n_samples number_of_reps = 1 dat_dict = {} for i_m, master_encoder in enumerate(master_encoder_fnames): encoder = pickle.load(open(master_encoder)) units = encoder.get_units() n_units = len(units) for qri, input_device in enumerate(input_devices): for shar_unt in shar_unt_arr: for priv_unt in priv_unt_arr: for p2s in priv_to_shar: priv_inp = [] shar_inp = [] all_inp = [] for n in range(number_of_reps): tun = 1. - priv_unt - shar_unt w = np.array([priv_unt, p2s*tun, shar_unt, (1.-p2s)*tun]) encoder.wt_sources = w print encoder.wt_sources spike_counts = np.zeros([n_units, n_samples]) priv_counts_tun = np.zeros_like(spike_counts) priv_counts_unt = np.zeros_like(spike_counts) shar_counts_tun = np.zeros_like(spike_counts) shar_counts_unt = np.zeros_like(spike_counts) shar_counts = np.zeros_like(spike_counts) priv_counts = np.zeros_like(spike_counts) ctl = np.zeros((7, n_samples )) encoder.call_ds_rate = 1 state_samples = np.random.multivariate_normal(mean, W, n_samples) for k in range(n_samples): current_state = vfb.root.task[vfb_ix[k]]['internal_decoder_state'] target_state = vfb.root.task[vfb_ix[k]]['target'] targ = np.vstack((target_state[:, np.newaxis], np.zeros((4, 1)), )) targ[-1, 0] = 1 ctrl = input_device.calc_next_state(current_state, targ)/3. ctrl[[0, 1, 2, 4]] = 0 #ctrl = np.mat(state_samples[k, :]).T z = np.array(encoder(ctrl, mode='counts')) spike_counts[:, k] = np.squeeze(z) priv_counts_tun[:, k] = (w[1]*encoder.priv_tun) shar_counts_tun[:, k] = (w[3]*encoder.shar_tun) priv_counts_unt[:, k] = (w[0]*encoder.priv_unt) shar_counts_unt[:, k] = (w[2]*encoder.shar_unt) shar_counts[:, k] = (w[3]*encoder.shar_tun) + (w[2]*encoder.shar_unt) priv_counts[:, k] = (w[1]*encoder.priv_tun) + (w[0]*encoder.priv_unt) ctl[:, k] = np.squeeze(np.array(ctrl)) #SOT IN: sot_in = np.sum(np.var(shar_counts, axis=1))/(np.sum(np.var(shar_counts, axis=1))+ np.sum(np.var(priv_counts, axis=1))) #FA on spike counts: print spike_counts[0, 0] from tasks.factor_analysis_tasks import FactorBMIBase FA_dict = FactorBMIBase.generate_FA_matrices(None, bin_spk=spike_counts.T) sot_out = np.trace(FA_dict['U']*FA_dict['U'].T)/(np.trace(FA_dict['U']*FA_dict['U'].T) + np.trace(FA_dict['Psi'])) if n==0: #dat_dict[p2s, priv_unt, shar_unt, 'nf'] = [num_factors] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'nf_orth'] = [FA_dict['fa_main_shar_n_dim']] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'priv_inp_tun_var'] = [np.sum(np.var(priv_counts_tun, axis=1))] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'shar_inp_tun_var'] = [np.sum(np.var(shar_counts_tun, axis=1))] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'priv_inp_var'] = [np.sum(np.var(priv_counts, axis=1))] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'priv_inp_unt_var'] = [np.sum(np.var(priv_counts_unt, axis=1))] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'shar_inp_unt_var'] = [np.sum(np.var(shar_counts_unt, axis=1))] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'shar_inp_var'] = [np.sum(np.var(shar_counts, axis=1))] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'priv_out_var'] = [np.trace(FA_dict['Psi'])] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'shar_out_var'] = [np.trace(FA_dict['U']*FA_dict['U'].T)] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'sot_out'] = [sot_out] dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'sot_in'] = [sot_in] #dat_dict[p2s, priv_unt, shar_unt, 'priv_out_calc_var'] = [np.sum(np.var(priv_proj, axis=1))] #dat_dict[p2s, priv_unt, shar_unt, 'shar_out_calc_var'] = [np.sum(np.var(shar_proj, axis=1))] else: #dat_dict[p2s, priv_unt, shar_unt, 'nf'].append(num_factors) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'nf_orth'].append(FA_dict['fa_main_shar_n_dim']) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'priv_inp_tun_var'].append(np.sum(np.var(priv_counts_tun, axis=1))) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'shar_inp_tun_var'].append(np.sum(np.var(shar_counts_tun, axis=1))) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'priv_inp_var'].append(np.sum(np.var(priv_counts, axis=1))) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'priv_inp_unt_var'].append(np.sum(np.var(priv_counts_unt, axis=1))) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'shar_inp_unt_var'].append(np.sum(np.var(shar_counts_unt, axis=1))) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'shar_inp_var'].append(np.sum(np.var(shar_counts, axis=1))) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'priv_out_var'].append(np.sum(FA.noise_variance_)) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'shar_out_var'].append(np.trace(A)) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'sot_out'].append(sot_out) dat_dict[p2s, priv_unt, shar_unt, i_m, qri, 'sot_in'].append(sot_in) #dat_dict[p2s, priv_unt, shar_unt, 'priv_out_calc_var'].append(np.sum(np.var(priv_proj, axis=1))) #dat_dict[p2s, priv_unt, shar_unt, 'shar_out_calc_var'].append(np.sum(np.var(shar_proj, axis=1))) ct = datetime.datetime.now() p_fname = os.path.expandvars('$FA_GROM_DATA/sims/FR_val/dat_dict_'+ct.strftime("%m%d%y_%H%M")+'.pkl') pickle.dump(dat_dict, open(p_fname,'wb'))
def parse_task_entry_halves(te_num, hdf, decoder, epoch_1_end=10., epoch_2_end = 20.): drives_neurons_ix0 = 3 #Get FA dict: rew_ix = pa.get_trials_per_min(hdf) half_rew_ix = np.floor(len(rew_ix)/2.) bin_spk, targ_pos, targ_ix, trial_ix, reach_time, hdf_ix = pa.extract_trials_all(hdf, rew_ix[:half_rew_ix], hdf_ix=True, drives_neurons_ix0=3) from tasks.factor_analysis_tasks import FactorBMIBase FA_dict = FactorBMIBase.generate_FA_matrices(None, bin_spk=bin_spk) #Get BMI update IX: internal_state = hdf.root.task[:]['internal_decoder_state'] update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1 epoch1_ix = int(np.nonzero(update_bmi_ix > int(epoch_1_end*60*60))[0][0]) epoch2_ix = int(np.nonzero(update_bmi_ix > int(epoch_2_end*60*60))[0][0]) #Get spike coutns and bin them: spike_counts = hdf.root.task[:]['spike_counts'][:,:,0] bin_spk_cnts = np.zeros((epoch1_ix, spike_counts.shape[1])) bin_spk_cnts2 = np.zeros((epoch2_ix, spike_counts.shape[1])) for ib, i_ix in enumerate(update_bmi_ix[:epoch1_ix]): #Inclusive of EndBin bin_spk_cnts[ib,:] = np.sum(spike_counts[i_ix-5:i_ix+1,:], axis=0) for ib, i_ix in enumerate(update_bmi_ix[:epoch2_ix]): #Inclusive of EndBin bin_spk_cnts2[ib,:] = np.sum(spike_counts[i_ix-5:i_ix+1,:], axis=0) kin = hdf.root.task[update_bmi_ix[:epoch1_ix]]['cursor'] binlen = decoder.binlen velocity = np.diff(kin, axis=0) * 1./binlen velocity = np.vstack([np.zeros(kin.shape[1]), velocity]) kin = np.hstack([kin, velocity]) ssm = decoder.ssm units = decoder.units #Shared and Scaled Shared Decoders: T = bin_spk_cnts.shape[0] demean = bin_spk_cnts.T - np.tile(FA_dict['fa_mu'], [1, T]) decoder_demn = train.train_KFDecoder_abstract(ssm, kin.T, demean, units, 0.1) decoder_demn.kin = kin.T decoder_demn.neur = demean decoder_demn.target = hdf.root.task[update_bmi_ix[:epoch1_ix]]['target'] main_shar = FA_dict['fa_main_shared'] * demean #main_priv = demean - main_shar main_sc_shar = np.multiply(main_shar, np.tile(FA_dict['fa_main_shared_sc'], [1, T])) #full_sc = np.multiply(demean, np.tile(FA_dict['fa_main_shared_sc'], [1,T])) #main_sc_shar_pls_priv = main_sc_shar + main_priv decoder_shar = train.train_KFDecoder_abstract(ssm, kin.T, main_shar, units, 0.1) decoder_shar.kin = kin.T decoder_shar.neur = main_shar decoder_sc_shar = train.train_KFDecoder_abstract(ssm, kin.T, main_sc_shar, units, 0.1) decoder_sc_shar.kin = kin.T decoder_sc_shar.neur = main_sc_shar decs_all = dict(dmn=decoder_demn, shar = decoder_shar, sc_shar = decoder_sc_shar) return decoder_full, decoder_shar, decoder_sc_shar, bin_spk_cnts2, epoch1_ix, epoch2_ix, update_bmi_ix, FA_dict
def train_FADecoder_from_KF(FA_nfactors, FA_te_id, decoder, use_scaled=True, use_main=True): from tasks.factor_analysis_tasks import FactorBMIBase FA_dict = FactorBMIBase.generate_FA_matrices(FA_nfactors, FA_te_id) # #Now, retrain: binlen = decoder.binlen from db import dbfunctions as dbfn te_id = dbfn.TaskEntry(decoder.te_id) files = dict(plexon=te_id.plx_filename, hdf=te_id.hdf_filename) extractor_cls = decoder.extractor_cls extractor_kwargs = decoder.extractor_kwargs kin_extractor = get_plant_pos_vel ssm = decoder.ssm update_rate = decoder.binlen units = decoder.units tslice = (0.0, te_id.length) ## get kinematic data kin_source = "task" tmask, rows = _get_tmask(files, tslice, sys_name=kin_source) kin = kin_extractor(files, binlen, tmask, pos_key="cursor", vel_key=None) ## get neural features neural_features, units, extractor_kwargs = get_neural_features( files, binlen, extractor_cls.extract_from_file, extractor_kwargs, tslice=tslice, units=units, source=kin_source ) # Get shared input: T = neural_features.shape[0] demean = neural_features.T - np.tile(FA_dict["fa_mu"], [1, T]) if use_main: main_shar = FA_dict["fa_main_shared"] * demean main_priv = demean - main_shar FA = FA_dict["FA_model"] else: shar = FA_dict["fa_sharL"] * demean shar_sc = np.multiply(shar, np.tile(FA_dict["fa_shar_var_sc"], [1, T])) + np.tile(FA_dict["fa_mu"], [1, T]) shar_unsc = shar + np.tile(FA_dict["fa_mu"], [1, T]) if use_scaled: neural_features = shar_sc[:, :-1] else: neural_features = shar_unsc[:, :-1] # Remove 1st kinematic sample and last neural features sample to align the # velocity with the neural features kin = kin[1:].T decoder2 = train_KFDecoder_abstract(ssm, kin, neural_features, units, update_rate, tslice=tslice) decoder2.extractor_cls = extractor_cls decoder2.extractor_kwargs = extractor_kwargs decoder2.te_id = decoder.te_id decoder2.trained_fa_dict = FA_dict import datetime now = datetime.datetime.now() tp = now.isoformat() import pickle fname = os.path.expandvars("$FA_GROM_DATA/decoder_") + tp + ".pkl" f = open(fname, "w") pickle.dump(decoder2, f) f.close() return decoder2, fname