def __init__(self, hdf, decoder, task='point_mass', drives_neurons = 0):
        try:
            self.cursor_pos = hdf.root.task[:]['cursor_pos']
        except:
            self.cursor_pos = hdf.root.task[:]['cursor']

        try:
            self.cursor_vel = hdf.root.task[:]['cursor_vel']
        except:
            self.cursor_vel = hdf.root.task[:]['internal_decoder_state'][:,[3,4,5]]

        self.target = hdf.root.task[:]['target']

        self.target_rad = hdf.root.task.attrs.target_radius
        self.cursor_rad = hdf.root.task.attrs.cursor_radius
        try:
            spike_counts = hdf.root.task[:]['spike_counts']
        except:
            spike_counts = hdf.root.task[:]['all']
            
        self.spike_counts = np.array(spike_counts, dtype=np.float64)
        
        self.internal_state = hdf.root.task[:]['internal_decoder_state']

        self.dec = decoder
        self.drives_neurons = self.dec.drives_neurons;
        self.drives_neurons_ix0 = np.nonzero(self.drives_neurons)[0][0]
        self.update_bmi_ix = np.nonzero(np.diff(np.squeeze(self.internal_state[:, self.drives_neurons_ix0, 0])))[0]+1
            
        if task=='point_mass':
            self.plant = CursorPlantWithMass(endpt_bounds=(-14, 14, 0., 0., -14, 14))
            self.move_plant = self.move_mass_plant
            
        elif task == 'bmi_multi':
            self.plant = CursorPlant(endpt_bounds=(-25, 25, 0., 0., -14, 14))
            self.move_plant = self.move_vel_plant
            self.task_msgs = hdf.root.task_msgs

        elif task == 'bmi_resetting':
            self.plant = CursorPlant(endpt_bounds=(-25, 25, 0., 0., -14, 14))
            self.move_plant = self.move_vel_plant
            self.task_msgs = hdf.root.task_msgs

        self.task = task
        self.assist = hdf.root.task[:]['assist_level']
        self.hdf = hdf
Пример #2
0
import unittest
import numpy as np
import plantlist
from tasks import bmi_recon_tasks
import dbfunctions as dbfn

idx = 849
te = dbfn.TaskEntry(idx, dbname='testing')
n_iter = len(te.hdf.root.task)

cls = bmi_recon_tasks.LFPBMIReconstruction
gen = []
task = cls(te, n_iter)

from riglib.plants import CursorPlant
task.plant = CursorPlant(endpt_bounds=[-10, 10, -10, 10, -10, 10],
                         vel_wall=False)
task.init()

error = task.calc_recon_error(verbose=False, n_iter_betw_fb=1000)
abs_max_error = np.max(np.abs(error))

print abs_max_error
class RerunDecoding(object):
    
    def __init__(self, hdf, decoder, task='point_mass', drives_neurons = 0):
        try:
            self.cursor_pos = hdf.root.task[:]['cursor_pos']
        except:
            self.cursor_pos = hdf.root.task[:]['cursor']

        try:
            self.cursor_vel = hdf.root.task[:]['cursor_vel']
        except:
            self.cursor_vel = hdf.root.task[:]['internal_decoder_state'][:,[3,4,5]]

        self.target = hdf.root.task[:]['target']

        self.target_rad = hdf.root.task.attrs.target_radius
        self.cursor_rad = hdf.root.task.attrs.cursor_radius
        try:
            spike_counts = hdf.root.task[:]['spike_counts']
        except:
            spike_counts = hdf.root.task[:]['all']
            
        self.spike_counts = np.array(spike_counts, dtype=np.float64)
        
        self.internal_state = hdf.root.task[:]['internal_decoder_state']

        self.dec = decoder
        self.drives_neurons = self.dec.drives_neurons;
        self.drives_neurons_ix0 = np.nonzero(self.drives_neurons)[0][0]
        self.update_bmi_ix = np.nonzero(np.diff(np.squeeze(self.internal_state[:, self.drives_neurons_ix0, 0])))[0]+1
            
        if task=='point_mass':
            self.plant = CursorPlantWithMass(endpt_bounds=(-14, 14, 0., 0., -14, 14))
            self.move_plant = self.move_mass_plant
            
        elif task == 'bmi_multi':
            self.plant = CursorPlant(endpt_bounds=(-25, 25, 0., 0., -14, 14))
            self.move_plant = self.move_vel_plant
            self.task_msgs = hdf.root.task_msgs

        elif task == 'bmi_resetting':
            self.plant = CursorPlant(endpt_bounds=(-25, 25, 0., 0., -14, 14))
            self.move_plant = self.move_vel_plant
            self.task_msgs = hdf.root.task_msgs

        self.task = task
        self.assist = hdf.root.task[:]['assist_level']
        self.hdf = hdf

    def run_decoder(self, spike_counts, input_type = 'all', cutoff=None):
        '''
        Summary: method to use the 'predict' function in the decoder object
        Input param: spike_counts: unbinned spike counts in iter x units x 1
        Input param: cutoff:  cutoff in iterations
        '''

        T = spike_counts.shape[0]
        if not (cutoff is None):
            T = np.min([T, cutoff])
        
        decoded_state = []
        spike_accum = np.zeros_like(spike_counts[0,:])
        
        dec_last = np.zeros_like(self.dec.predict(spike_counts[0,:]))
        
        tot_spike_accum = np.zeros_like(spike_counts[0,:])-1
        if self.task == 'point_mass':
            self.dec.filt.state.mean = np.zeros_like(self.dec.filt.state.mean)
        
        elif (self.task == 'bmi_multi' or self.task == 'bmi_resetting'):
            self.dec.filt._init_state()
            self.state = self.task_msgs[0]['msg']

        for t in range(T):
            spike_accum = spike_accum+spike_counts[t,:]
            
            if t in self.task_msgs[:]['time']:
                ix = np.nonzero(self.task_msgs[:]['time']==t)[0]
                self.state = self.task_msgs[ix[0]]['msg']

            if t in self.update_bmi_ix:
                dec_new = self.dec.predict(spike_accum)

                if self.task == 'bmi_multi':
                    pos = dec_new[[0,1,2]]
                    vel = dec_new[[3,4,5]]
                    pos1, vel1 = self.plant._bound(pos, vel)
                    dec_new[[0,1,2]] = pos1
                    dec_new[[3,4,5]] = vel1
                    self.dec.filt.state.mean = np.array([np.hstack((pos1, vel1, np.array([1.])))]).T

                if self.task == 'bmi_resetting':
                    if self.state == 'premove':
                        self.plant.set_endpoint_pos(np.array([0., 0., 0.]))
                        self.dec['q'] = self.plant.get_intrinsic_coordinates()
                        pos1 = np.array([0., 0., 0.])
                        vel1 = dec_new[[3,4,5]]
                    
                    else:
                        pos = dec_new[[0,1,2]]
                        vel = dec_new[[3,4,5]]
                        pos1, vel1 = self.plant._bound(pos, vel)
                    
                    dec_new[[0,1,2]] = pos1
                    dec_new[[3,4,5]] = vel1

                    self.dec.filt.state.mean = np.array([np.hstack((pos1, vel1, np.array([1.])))]).T                    

                tot_spike_accum = np.hstack((tot_spike_accum, spike_accum))
                decoded_state.append(dec_new)
                dec_last = dec_new
                spike_accum = np.zeros_like(spike_counts[0,:])
            else:
                decoded_state.append(dec_last)
        spk_cnt = np.array(tot_spike_accum)

        if not hasattr(self, 'dec_spk_cnt_bin'):
            self.dec_spk_cnt_bin = dict()
            self.dec_state_mn = dict()
        print 'input_type: ', input_type
        self.dec_spk_cnt_bin[input_type] = spk_cnt[:,1:]
        self.dec_state_mn[input_type] = np.vstack((decoded_state))
        
    def move_vel_plant(self, reset_ix, dt=1/60.):
        pass

    def move_mass_plant(self, reset_ix = [], dt = 1/60., input_type='all'):
        go_res = 0
        p0 = self.cursor_pos[0,:].copy()
        v0 = self.cursor_vel[0,:].copy()
        
        pos_arr = []
        vel_arr = []

        for i in range(1, self.dec_state_mn[input_type].shape[0]):

            force = self.dec_state_mn[input_type][i-1,[9, 10, 11]]
            vel = v0 + dt*force
            pos = p0 + dt*vel + 0.5*dt**2*force
            pos, vel = self.plant._bound(pos,vel)


            #Check if next index is the start of a trial
            if i+1 in reset_ix:
                p0 = self.cursor_pos[i,:]
                v0 = self.cursor_vel[i,:]
                go_res += 1
                print go_res

            else:
                p0 = pos.copy()
                v0 = vel.copy()

            pos_arr.append(pos)
            vel_arr.append(vel)
        self.decoded_pos[input_type] = np.array(pos_arr)
        self.decoded_vel[input_type] = np.array(vel_arr)

    def add_input(self, spike_counts, input_type):
        self.run_decoder(spike_counts, input_type=input_type)
        self.main_move_plant(input_type=input_type)


    def main_move_plant(self, input_type):
        #For Vel BMI: 
        if not hasattr(self, 'decoded_pos'):
            self.decoded_pos = dict()
            self.decoded_vel = dict()

        if self.task == 'bmi_multi':
            go_ix = np.array([self.hdf.root.task_msgs[it-3][1] for it, t in enumerate(self.hdf.root.task_msgs[:]) if t[0] == 'reward'])

            self.decoded_pos[input_type] = self.dec_state_mn[input_type][:,[0,1,2]]
            self.decoded_vel[input_type] = self.dec_state_mn[input_type][:,[3,4,5]]

            for g in go_ix:
                if g < self.decoded_pos[input_type].shape[0]:
                    p0 = self.cursor_pos[g,:]
                    dp = p0 - self.decoded_pos[input_type][g,:]
                    self.decoded_pos[input_type][g:,:] = self.decoded_pos[input_type][g:,:] + np.tile(np.array([dp]), [ self.decoded_pos[input_type][g:,:].shape[0],1])
        
        elif self.task == 'bmi_resetting':
            self.decoded_pos[input_type] = self.dec_state_mn[input_type][:,[0,1,2]]
            self.decoded_vel[input_type] = self.dec_state_mn[input_type][:, [3,4,5]]
        
        else:
            go_ix = np.array([self.hdf.root.task_msgs[it-3][1] for it, t in enumerate(self.hdf.root.task_msgs[:]) if t[0] == 'reward'])
            self.move_plant(reset_ix = list(go_ix))


#