def load_decoder(self):
        '''
        Instantiate the neural encoder and "train" the decoder
        '''
        from riglib.bmi import train

        encoder = self.encoder
        n_samples = 20000
        units = self.encoder.get_units()
        n_units = len(units)

        # draw samples from the W distribution
        ssm = PointClickSSM()
        A, _, W = ssm.get_ssm_matrices()
        mean = np.zeros(A.shape[0])
        mean[-1] = 1
        state_samples = np.random.multivariate_normal(mean, 100*W, n_samples)

        spike_counts = np.zeros([n_units, n_samples])
        self.encoder.call_ds_rate = 1
        for k in range(n_samples):
            spike_counts[:,k] = np.array(self.encoder(state_samples[k])).ravel()

        unit_inds = np.arange(n_units)
        np.random.shuffle(unit_inds)
        spike_counts = spike_counts[unit_inds, :]

        kin = state_samples.T

        self.decoder = train.train_KFDecoder_abstract(ssm, kin, spike_counts, units, 0.1)
        self.encoder.call_ds_rate = 6
    def load_decoder(self):
        '''
        Instantiate the neural encoder and "train" the decoder
        '''
        print "Creating simulation decoder.."
        encoder = self.encoder
        n_samples = 20000
        units = self.encoder.get_units()
        n_units = len(units)

        # draw samples from the W distribution
        ssm = self.ssm
        A, _, W = ssm.get_ssm_matrices()
        mean = np.zeros(A.shape[0])
        mean[-1] = 1
        state_samples = np.random.multivariate_normal(mean, 100*W, n_samples)

        spike_counts = np.zeros([n_units, n_samples])
        self.encoder.call_ds_rate = 1
        for k in range(n_samples):
            spike_counts[:,k] = np.array(self.encoder(state_samples[k])).ravel()

        inds = np.arange(spike_counts.shape[0])
        np.random.shuffle(inds)
        spike_counts = spike_counts[inds,:]

        kin = state_samples.T

        self.init_neural_features = spike_counts
        self.init_kin_features = kin

        self.decoder = train.train_KFDecoder_abstract(ssm, kin, spike_counts, units, 0.1)
        self.encoder.call_ds_rate = 6
        super(SimKFDecoderShuffled, self).load_decoder()
def analyze_OFC_gain(decs_all, dec_name='dmn'):
    qr_rat = np.array([1., 10e2, 10e4, 10e6, 10e8, 10e10])
    decoder = decs_all[dec_name]
    kin = decoder.kin
    dec_adj = dict()
    kin_adj = np.zeros_like(kin)
    kin_adj[[0, 2],:] = kin[[0, 2], :]
    for i, qr in enumerate(qr_rat):
        input_device = sim_fa_decoding.EndPtAssisterOFC(cost_err_ratio=qr)
        for k in range(kin.shape[1]-1):
            cs = kin[:, k][:, np.newaxis]
            ts = np.vstack((decoder.target[k,:][:, np.newaxis], np.zeros((3, 1)), np.ones((1,1))))

            ctrl = input_device.calc_next_state(np.vstack((cs, np.ones((1,1)))), ts)
            kin_adj[[3, 5], k+1] = np.squeeze(np.array(ctrl[[3, 5], 0]))

        dec_adj[qr] = train.train_KFDecoder_abstract(ssm, kin_adj, demean, units, 0.1)

    trbt_dec = {}
    for d, qr_ in enumerate(dec_adj.keys()):
        dec = trbt.RerunDecoding(hdf, dec_adj[qr_], task='bmi_resetting')
        dec.add_input(sdf.rebin_spks(demean), 'demean')
        trbt_dec[qr_] = dec

    for iq, q in enumerate(trbt_dec.keys()):
        d = trbt_dec[q]
        plot_sim_traj.traj_sem_plot(d, 'demean', it_start=epoch1_ix, extra_title_text=str(q))
Beispiel #4
0
def train_xz_KF_from_encoder(pnm, encoder_wts, n_samples, split_dec=False):
    encoder = pickle.load(open(pnm))
    encoder.wt_sources = encoder_wts
    units = encoder.get_units()
    n_units = len(units)

    # draw samples from the W distribution
    ssm = StateSpaceEndptVel2D()
    A, _, W = ssm.get_ssm_matrices()
    mean = np.zeros(A.shape[0])
    mean[-1] = 1
    state_samples = np.random.multivariate_normal(mean, W, n_samples)

    spike_counts = np.zeros([n_units, n_samples])
    encoder.call_ds_rate = 1
    for k in range(n_samples):
        z = np.array(encoder(state_samples[k], mode='counts'))
        #print k, state_samples[k], state_samples[k].shape, z.shape, z.ravel().shape
        spike_counts[:,k] = z.ravel()
    
    kin = state_samples.T
    decoder = train.train_KFDecoder_abstract(ssm, kin, spike_counts, units, 0.1)

    hdf_nm = pnm[:-4]+'.hdf'
    hdf = tables.openFile(hdf_nm)

    dim_red_dict = factor_analysis_tasks.FactorBMIBase.generate_FA_matrices(None,
        hdf=hdf, dec=decoder)
    
    return decoder, dim_red_dict, hdf
Beispiel #5
0
def train_xz_KF_from_vfb(pnm, split_dec=False):
    
    hdf_nm = pnm[:-4]+'.hdf'
    hdf = tables.openFile(hdf_nm)

    enc = pickle.load(open(pnm))
    decoder = enc.corresp_dec

    units = decoder.units
    update_rate = 0.1

    kin, neural_features = get_kin_and_binnedspks(hdf, pos_key='int_kin')
    tslice = (0, len(hdf.root.task)/60.)
    ssm = StateSpaceEndptVel2D()

    kf_decoder2 = train.train_KFDecoder_abstract(ssm, kin.T, neural_features.T, 
        units, update_rate, tslice=tslice)

    #Decoder only used ehr for 'drives neurons'
    dim_red_dict = factor_analysis_tasks.FactorBMIBase.generate_FA_matrices(None,
        hdf=hdf, dec=decoder)

    if split_dec:
        print 'training KF decoder for split input'
        kf_decoder2 = train.conv_KF_to_splitFA_dec(hdf, kf_decoder2, neural_features, kin.T, dim_red_dict)

    return kf_decoder2, dim_red_dict, hdf
Beispiel #6
0
def train_xz_KF_fit_best_int_kin(pnm):
    hdf_nm = pnm[:-4]+'.hdf'
    hdf = tables.openFile(hdf_nm)

    enc = pickle.load(open(pnm))
    decoder = enc.corresp_dec

    units = decoder.units
    update_rate = 0.1
    tslice = (0, len(hdf.root.task)/60.)
    ssm = StateSpaceEndptVel2D()

    kin, neural_features = get_kin_and_binnedspks(hdf, pos_key='int_kin')
    
    #With current kinematics fit best intention signal
    int_qr, kin, neural_features = fit_optimal_int_kin(hdf) 
    
    kf_decoder2 = train.train_KFDecoder_abstract(ssm, kin.T, neural_features.T, 
        units, update_rate, tslice=tslice)

    #Decoder only used ehr for 'drives neurons'
    dim_red_dict = factor_analysis_tasks.FactorBMIBase.generate_FA_matrices(None,
        hdf=hdf, dec=decoder)

    dim_red_dict['fit_qr'] = int_qr

    return kf_decoder2, dim_red_dict, hdf
Beispiel #7
0
    def load_decoder(self):
        '''
        Instantiate the neural encoder and "train" the decoder
        '''

        if hasattr(self, 'decoder'):
            print('Already have a decoder!')
        else:
            print("Creating simulation decoder..")
            print(self.encoder, type(self.encoder))
            n_samples = 2000
            units = self.encoder.get_units()
            n_units = len(units)
            print('units: ', n_units)

            # draw samples from the W distribution
            ssm = self.ssm
            A, _, W = ssm.get_ssm_matrices()
            mean = np.zeros(A.shape[0])
            mean[-1] = 1
            state_samples = np.random.multivariate_normal(mean, W, n_samples)

            spike_counts = np.zeros([n_units, n_samples])
            self.encoder.call_ds_rate = 1
            for k in range(n_samples):
                spike_counts[:, k] = np.array(
                    self.encoder(state_samples[k], mode='counts')).ravel()

            kin = state_samples.T
            zscore = False
            if hasattr(self, 'clda_adapt_mFR_stats'):
                if self.clda_adapt_mFR_stats:
                    zscore = True
            print(' zscore decoder ? : ', zscore)
            self.decoder = train.train_KFDecoder_abstract(ssm,
                                                          kin,
                                                          spike_counts,
                                                          units,
                                                          0.1,
                                                          zscore=zscore)
            self.encoder.call_ds_rate = 6

            self.init_neural_features = spike_counts
            self.init_kin_features = kin

            super(SimKFDecoderSup, self).load_decoder()
    def load_decoder(self):
        '''
        Instantiate the neural encoder and "train" the decoder
        '''
        
        if hasattr(self, 'decoder'):
            print 'Already have a decoder!'
        else:
            print "Creating simulation decoder.."
            print self.encoder, type(self.encoder)
            n_samples = 2000
            units = self.encoder.get_units()
            n_units = len(units)
            print 'units: ', n_units

            # draw samples from the W distribution
            ssm = self.ssm
            A, _, W = ssm.get_ssm_matrices()
            mean = np.zeros(A.shape[0])
            mean[-1] = 1
            state_samples = np.random.multivariate_normal(mean, W, n_samples)

            spike_counts = np.zeros([n_units, n_samples])
            self.encoder.call_ds_rate = 1
            for k in range(n_samples):
                spike_counts[:,k] = np.array(self.encoder(state_samples[k], mode='counts')).ravel()

            kin = state_samples.T
            zscore = False
            if hasattr(self, 'clda_adapt_mFR_stats'):
                if self.clda_adapt_mFR_stats:
                    zscore = True
            print ' zscore decoder ? : ', zscore
            self.decoder = train.train_KFDecoder_abstract(ssm, kin, spike_counts, units, 0.1, zscore=zscore)
            self.encoder.call_ds_rate = 6

            self.init_neural_features = spike_counts
            self.init_kin_features = kin

            super(SimKFDecoderSup, self).load_decoder()
def parse_task_entry_halves(te_num, hdf, decoder, epoch_1_end=10., epoch_2_end = 20.):
    drives_neurons_ix0 = 3
    #Get FA dict:
    rew_ix = pa.get_trials_per_min(hdf)
    half_rew_ix = np.floor(len(rew_ix)/2.)
    bin_spk, targ_pos, targ_ix, trial_ix, reach_time, hdf_ix = pa.extract_trials_all(hdf, 
        rew_ix[:half_rew_ix], hdf_ix=True, drives_neurons_ix0=3)

    from tasks.factor_analysis_tasks import FactorBMIBase
    FA_dict = FactorBMIBase.generate_FA_matrices(None, bin_spk=bin_spk)


    #Get BMI update IX: 
    internal_state = hdf.root.task[:]['internal_decoder_state']
    update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1
    epoch1_ix = int(np.nonzero(update_bmi_ix > int(epoch_1_end*60*60))[0][0])
    epoch2_ix = int(np.nonzero(update_bmi_ix > int(epoch_2_end*60*60))[0][0])

    #Get spike coutns and bin them: 
    spike_counts = hdf.root.task[:]['spike_counts'][:,:,0]
    bin_spk_cnts = np.zeros((epoch1_ix, spike_counts.shape[1]))
    bin_spk_cnts2 = np.zeros((epoch2_ix, spike_counts.shape[1]))

    for ib, i_ix in enumerate(update_bmi_ix[:epoch1_ix]):
        #Inclusive of EndBin
        bin_spk_cnts[ib,:] = np.sum(spike_counts[i_ix-5:i_ix+1,:], axis=0)

    for ib, i_ix in enumerate(update_bmi_ix[:epoch2_ix]):
        #Inclusive of EndBin
        bin_spk_cnts2[ib,:] = np.sum(spike_counts[i_ix-5:i_ix+1,:], axis=0)

    kin = hdf.root.task[update_bmi_ix[:epoch1_ix]]['cursor']
    binlen = decoder.binlen
    velocity = np.diff(kin, axis=0) * 1./binlen
    velocity = np.vstack([np.zeros(kin.shape[1]), velocity])
    kin = np.hstack([kin, velocity])

    ssm = decoder.ssm
    units = decoder.units

    #Shared and Scaled Shared Decoders: 
    T = bin_spk_cnts.shape[0]
    demean = bin_spk_cnts.T - np.tile(FA_dict['fa_mu'], [1, T])
    decoder_demn = train.train_KFDecoder_abstract(ssm, kin.T, demean, units, 0.1)
    decoder_demn.kin = kin.T
    decoder_demn.neur = demean
    decoder_demn.target = hdf.root.task[update_bmi_ix[:epoch1_ix]]['target']

    main_shar = FA_dict['fa_main_shared'] * demean
    #main_priv = demean - main_shar
    main_sc_shar = np.multiply(main_shar, np.tile(FA_dict['fa_main_shared_sc'], [1, T]))
    #full_sc = np.multiply(demean, np.tile(FA_dict['fa_main_shared_sc'], [1,T]))
    #main_sc_shar_pls_priv = main_sc_shar + main_priv

    decoder_shar = train.train_KFDecoder_abstract(ssm, kin.T, main_shar, units, 0.1)
    decoder_shar.kin = kin.T
    decoder_shar.neur = main_shar

    decoder_sc_shar = train.train_KFDecoder_abstract(ssm, kin.T, main_sc_shar, units, 0.1)
    decoder_sc_shar.kin = kin.T
    decoder_sc_shar.neur = main_sc_shar
    decs_all = dict(dmn=decoder_demn, shar = decoder_shar, sc_shar = decoder_sc_shar)

    return decoder_full, decoder_shar, decoder_sc_shar, bin_spk_cnts2, epoch1_ix, epoch2_ix, update_bmi_ix, FA_dict