def load_decoder(self):
     '''
     Create a 'seed' decoder for the simulation which is simply randomly initialized
     '''
     from riglib.bmi import state_space_models
     ssm = state_space_models.StateSpaceEndptVel2D()
     self.decoder = train.rand_KFDecoder(ssm, self.encoder.get_units())
示例#2
0
    def test_zero_velocity_goal(self):
        ssm = state_space_models.StateSpaceEndptVel2D()
        goal_calc = goal_calculators.ZeroVelocityGoal(ssm)

        target_pos = np.array([0, 0, 0], dtype=np.float64)
        (goal_state, error), _ = goal_calc(target_pos)

        self.assertTrue(np.array_equal(goal_state.ravel(), np.array([0, 0, 0, 0, 0, 0, 1])))
示例#3
0
def make_FACosEnc(num):
    from riglib.bmi import state_space_models as ssm
    import pickle
    num_neurons = 20;
    SSM = ssm.StateSpaceEndptVel2D()

    for n in range(num):
        kwargs = {}
        kwargs['n_neurons'] = num_neurons
        C = np.random.rand(num_neurons, SSM.n_states)
        enc = FACosEnc(C, SSM, return_ts=True, **kwargs)
        enc.psi_unt_std[:4] /= 40
        pickle.dump(enc, open('/storage/preeya/grom_data/sims/test_obs_vs_co_overlap/encoder_param_matched_'+str(n)+'.pkl', 'wb'))
示例#4
0
def from_file_to_FACosEnc(plot=False):
    from riglib.bmi import state_space_models as ssm
    import pickle
    import os
    import matplotlib.pyplot as plt

    dat = pickle.load(
        open(
            os.path.expandvars(
                '/home/lab/preeya/fa_analysis/grom_data/co_obs_SNR_w_coefficients.pkl'
            )))
    SSM = ssm.StateSpaceEndptVel2D()

    snr = {}
    eps = 10**-10

    if plot:
        f, ax = plt.subplots(nrows=3, ncols=3)

    for j, i in enumerate(np.sort(list(dat.keys()))):

        snr[i] = []
        d = dat[i]
        kwargs = {}
        kwargs['n_neurons'] = len(list(d.keys()))
        C = np.random.rand(kwargs['n_neurons'], SSM.n_states)
        kwargs['wt_sources'] = [1, 1, 0, 0]
        enc = FACosEnc(C, SSM, return_ts=True, **kwargs)

        for n in range(len(list(d.keys()))):
            #For individual units:
            enc.psi_tun[n, [3, 5, 6]] = d[n][3][0, :]  #Terrible construction.
            enc.mu[n] = 0
        #Now set the standard deviation: Draw from VFB distribution of commands

        data, enc = sim_enc(enc)
        U = np.vstack((data['unt']))
        T = np.vstack((data['tun']))
        spk = U + T
        vel = np.hstack((data['ctl']))[[3, 5], :].T
        vel = np.hstack((np.array(vel), np.ones((len(vel), 1))))

        # Fit encoder:
        n_units = spk.shape[1]
        snr_act = []
        for n in range(n_units):
            snr_des = d[n][2]
            if np.isnan(snr_des):
                snr_des = .3
                print('sucdess')
            snr_act.append(snr_des)
            s2 = spk[:, n]  #Spikes:
            x = np.linalg.lstsq(
                vel, s2[:, np.newaxis])  #Regress Spikes against Velocities
            qnoise = np.var(s2[:, np.newaxis] - vel * np.mat(x[0]))  #Residuals
            #Explained Variance vs. Residual Variance:
            qsig = np.var(vel * np.mat(x[0]))
            k = qsig / (snr_des)
            enc.psi_unt_std[n] = np.sqrt(k)  # + eps

        #Fit simulation:
        data, enc = sim_enc(enc)
        U = np.vstack((data['unt']))
        T = np.vstack((data['tun']))
        spk = U + T
        vel = np.hstack((data['ctl']))[[3, 5], :].T
        vel = np.hstack((np.array(vel), np.ones((len(vel), 1))))
        snr_sim = []
        for n in range(n_units):
            s2 = spk[:, n]  #Spikes:
            x = np.linalg.lstsq(
                vel, s2[:, np.newaxis])  #Regress Spikes against Velocities
            qnoise = np.var(s2[:, np.newaxis] - vel * np.mat(x[0]))  #Residuals
            #Explained Variance vs. Residual Variance:
            qsig = np.var(vel * np.mat(x[0]))
            snr_sim.append(qsig / qnoise)

        if plot:
            axi = ax[j / 3, j % 3]
            axi.plot(snr_sim, snr_act, '.')
            axi.set_title(list(dat.keys())[j])

        #kwargs['psi_unt_std'] = psi_unt_std
        #kwargs['psi_tun'] = psi_tun
        pickle.dump(
            enc,
            open(
                os.path.expandvars(
                    '$FA_GROM_DATA/sims/test_obs_vs_co_overlap/encoder_param_matched_'
                    + str(i) + '.pkl'), 'wb'))