Ejemplo n.º 1
0
    def test_unit_conv_mm_to_cm(self):

        # block with fixed decoder operating in mm
        te = dbfunctions.get_task_entry(1762)

        hdf = dbfunctions.get_hdf(te)
        tslice = slice(5, None, 6)
        cursor = hdf.root.task[tslice]['cursor']
        spike_counts = hdf.root.task[tslice]['bins']
        spike_counts = np.array(spike_counts, dtype=np.float64)

        #spike_counts = spike_counts[5::6] # weird indexing feature of the way the old BMI was running

        def run_decoder(dec, spike_counts):
            T = spike_counts.shape[0]
            decoded_state = []
            for t in range(0, T):
                decoded_state.append(dec.predict(spike_counts[t, :]))
            return np.array(np.vstack(decoded_state))

        dec = dbfunctions.get_decoder(te)
        dec_state_mm = 0.1 * run_decoder(dec, spike_counts)
        diff_mm = cursor - np.float32(dec_state_mm[:, 0:3])
        self.assertEqual(np.max(np.abs(diff_mm)), 0)

        dec = dbfunctions.get_decoder(te)
        dec_cm = train.rescale_KFDecoder_units(dec)
        dec_state_cm = run_decoder(dec_cm, spike_counts)
        diff_cm = cursor - np.float32(dec_state_cm[:, 0:3])
        #print np.max(np.abs(diff_cm))
        self.assertEqual(np.max(np.abs(diff_cm)), 0)
    def test_unit_conv_mm_to_cm(self):

        # block with fixed decoder operating in mm
        te = dbfunctions.get_task_entry(1762)
        
        hdf = dbfunctions.get_hdf(te)
        tslice = slice(5, None, 6)
        cursor = hdf.root.task[tslice]['cursor']
        spike_counts = hdf.root.task[tslice]['bins']
        spike_counts = np.array(spike_counts, dtype=np.float64)
        #spike_counts = spike_counts[5::6] # weird indexing feature of the way the old BMI was running
        
        def run_decoder(dec, spike_counts):
            T = spike_counts.shape[0]
            decoded_state = []
            for t in range(0, T):
                decoded_state.append(dec.predict(spike_counts[t,:]))
            return np.array(np.vstack(decoded_state))
        
        dec = dbfunctions.get_decoder(te)
        dec_state_mm = 0.1*run_decoder(dec, spike_counts)
        diff_mm = cursor - np.float32(dec_state_mm[:,0:3])
        self.assertEqual(np.max(np.abs(diff_mm)), 0)
        
        dec = dbfunctions.get_decoder(te)
        dec_cm = train.rescale_KFDecoder_units(dec)
        dec_state_cm = run_decoder(dec_cm, spike_counts)
        diff_cm = cursor - np.float32(dec_state_cm[:,0:3])
        #print np.max(np.abs(diff_cm))
        self.assertEqual(np.max(np.abs(diff_cm)), 0)
 def load_decoder(self):
     '''
     Create the object for the initial decoder
     '''
     self.decoder = dbfn.get_decoder(self.te.record)
     self.n_subbins = self.decoder.n_subbins
     self.decoder_state = np.zeros([self.n_iter, self.decoder.n_states, self.n_subbins])    
#!/usr/bin/python
from db import dbfunctions
from tasks import bmimultitasks
import numpy as np
from riglib.bmi import state_space_models, kfdecoder, train
reload(kfdecoder)

te = dbfunctions.get_task_entry(1883) # Block with predict and update of kf running at 10Hz
hdf = dbfunctions.get_hdf(te)
dec = dbfunctions.get_decoder(te)

assist_level = hdf.root.task[:]['assist_level'].ravel()
spike_counts = hdf.root.task[:]['spike_counts']
target = hdf.root.task[:]['target']
cursor = hdf.root.task[:]['cursor']

assert np.all(assist_level == 0)

task_msgs = hdf.root.task_msgs[:]
update_bmi_msgs = filter(lambda x: x['msg'] == 'update_bmi', task_msgs)
inds = [x[1] for x in update_bmi_msgs]

assert len(inds) == 0

T = spike_counts.shape[0]
error = np.zeros(T)
for k in range(spike_counts.shape[0]):
    if k - 1 in inds:
        dec.update_params(bmi_params[inds.index(k-1)])
    st = dec(spike_counts[k], target=target[k], target_radius=1.8, assist_level=assist_level[k])
# coding: utf-8
from db import dbfunctions as dbfn
import numpy as np

task_entry = 2023

dec = dbfn.get_decoder(task_entry)

try:
    dec.bminum
except:
    dec.bminum = 1

bmi_params = np.load(dbfn.get_bmiparams_file(task_entry))
hdf = dbfn.get_hdf(task_entry)
task_msgs = hdf.root.task_msgs[:]
update_bmi_msgs = [msg for msg in task_msgs if msg['msg'] == 'update_bmi']
spike_counts = hdf.root.task[:]['spike_counts']

for k, msg in enumerate(update_bmi_msgs):
    time = msg['time']
    hdf_sc = np.sum(spike_counts[time - dec.bminum + 1:time + 1], axis=0)
    if not np.array_equal(hdf_sc, bmi_params[k]['spike_counts_batch']):
        print(k)
Ejemplo n.º 6
0
#!/usr/bin/python 
'''
Script to map between new and old assists
infinite horizon LQR feedback controllers are hackishly fit based on arrival times
'''
import numpy as np
from riglib.bmi import feedback_controllers, kfdecoder, ppfdecoder
import matplotlib.pyplot as plt
from db import dbfunctions as dbfn

dec = dbfn.get_decoder(2326)

if isinstance(dec, kfdecoder.KFDecoder):
    dt = 1./10
    dec_type = 'kf'
elif isinstance(dec, ppfdecoder.PPFDecoder):
    dt = 1./180
    dec_type = 'ppf'

eff_targ_radius = 1.2

I = np.eye(3)
B = np.bmat([[0*I], 
              [dt/1e-3 * I],
              [np.zeros([1, 3])]])
A = dec.filt.A

## F = []
## F.append(np.zeros([3, 7]))
## for k in range(num_assist_levels):
## 
#!/usr/bin/python
'''
Equivalence test for changes to vfb training procedure (riglib.bmi.train._train_KFDecoder_visual_feedback)
'''
import numpy as np
from db import dbfunctions as dbfn
from db.tracker import models
from riglib.bmi import train
reload(train)

te = dbfn.get_task_entry(1844)
dec_record = dbfn.get_decoder_entry(te)
dec = dbfn.get_decoder(te)

training_block = dbfn.get_task_entry(dec_record.entry_id)
datafiles = models.DataFile.objects.filter(entry_id=training_block.id)
files = dict((d.system.name, d.get_path()) for d in datafiles)

dec_new = train._train_KFDecoder_visual_feedback(cells=dec.units, binlen=dec.binlen, tslice=dec.tslice, **files)

print("C error: %g" % np.max(np.abs(dec_new.kf.C - dec_new.kf.C)))
print("Q error: %g" % np.max(np.abs(dec_new.kf.Q - dec_new.kf.Q)))
print("R error: %g" % np.max(np.abs(dec_new.kf.R - dec_new.kf.R)))
print("S error: %g" % np.max(np.abs(dec_new.kf.S - dec_new.kf.S)))
print("T error: %g" % np.max(np.abs(dec_new.kf.T - dec_new.kf.T)))
print("mFR error: %g" % np.max(np.abs(dec_new.mFR - dec.mFR)))

dec_new = train._train_PPFDecoder_visual_feedback(cells=dec.units, tslice=dec.tslice, **files)
Ejemplo n.º 8
0
#!/usr/bin/python
'''
A set of tests to ensure that all the dbfunctions are still functional
'''
from db import dbfunctions as dbfn
reload(dbfn)

id = 4150
te = dbfn.TaskEntry(id)

print dbfn.get_plx_file(id)
print dbfn.get_decoder_name(id)
print dbfn.get_decoder_name_full(id)
print dbfn.get_decoder(id)
print dbfn.get_params(id)
print dbfn.get_param(id, 'decoder')
print dbfn.get_date(id)
print dbfn.get_notes(id)
print dbfn.get_subject(id)
print dbfn.get_length(id)
print dbfn.get_success_rate(id)

id = 1956
print dbfn.get_bmiparams_file(id)

# TODO check blackrock file fns


#!/usr/bin/python
'''
Equivalence test for changes to vfb training procedure (riglib.bmi.train._train_KFDecoder_visual_feedback)
'''
import numpy as np
from db import dbfunctions as dbfn
from db.tracker import models
from riglib.bmi import train
reload(train)

te = dbfn.get_task_entry(1844)
dec_record = dbfn.get_decoder_entry(te)
dec = dbfn.get_decoder(te)

training_block = dbfn.get_task_entry(dec_record.entry_id)
datafiles = models.DataFile.objects.filter(entry_id=training_block.id)
files = dict((d.system.name, d.get_path()) for d in datafiles)

dec_new = train._train_KFDecoder_visual_feedback(cells=dec.units, binlen=dec.binlen, tslice=dec.tslice, **files)

print "C error: %g" % np.max(np.abs(dec_new.kf.C - dec_new.kf.C))
print "Q error: %g" % np.max(np.abs(dec_new.kf.Q - dec_new.kf.Q))
print "R error: %g" % np.max(np.abs(dec_new.kf.R - dec_new.kf.R))
print "S error: %g" % np.max(np.abs(dec_new.kf.S - dec_new.kf.S))
print "T error: %g" % np.max(np.abs(dec_new.kf.T - dec_new.kf.T))
print "mFR error: %g" % np.max(np.abs(dec_new.mFR - dec.mFR))

dec_new = train._train_PPFDecoder_visual_feedback(cells=dec.units, tslice=dec.tslice, **files)
 def seed_decoder(self):
     return dbfn.get_decoder(self.record)
Ejemplo n.º 11
0
#!/usr/bin/python
'''
A set of tests to ensure that all the dbfunctions are still functional
'''
from db import dbfunctions as dbfn
import imp
imp.reload(dbfn)

id = 4150
te = dbfn.TaskEntry(id)

print(dbfn.get_plx_file(id))
print(dbfn.get_decoder_name(id))
print(dbfn.get_decoder_name_full(id))
print(dbfn.get_decoder(id))
print(dbfn.get_params(id))
print(dbfn.get_param(id, 'decoder'))
print(dbfn.get_date(id))
print(dbfn.get_notes(id))
print(dbfn.get_subject(id))
print(dbfn.get_length(id))
print(dbfn.get_success_rate(id))

id = 1956
print(dbfn.get_bmiparams_file(id))

# TODO check blackrock file fns


Ejemplo n.º 12
0
#!/usr/bin/python
'''
Script to map between new and old assists
infinite horizon LQR feedback controllers are hackishly fit based on arrival times
'''
import numpy as np
from riglib.bmi import feedback_controllers, kfdecoder, ppfdecoder
import matplotlib.pyplot as plt
from db import dbfunctions as dbfn

dec = dbfn.get_decoder(2326)

if isinstance(dec, kfdecoder.KFDecoder):
    dt = 1. / 10
    dec_type = 'kf'
elif isinstance(dec, ppfdecoder.PPFDecoder):
    dt = 1. / 180
    dec_type = 'ppf'

eff_targ_radius = 1.2

I = np.eye(3)
B = np.bmat([[0 * I], [dt / 1e-3 * I], [np.zeros([1, 3])]])
A = dec.filt.A

## F = []
## F.append(np.zeros([3, 7]))
## for k in range(num_assist_levels):
##
##     F_k = np.array(feedback_controllers.LQRController.dlqr(A, B, Q, R, eps=1e-15))
##     F.append(F_k)
# coding: utf-8
from db import dbfunctions as dbfn
import numpy as np

task_entry = 2023

dec = dbfn.get_decoder(task_entry)

try:
    dec.bminum
except:
    dec.bminum = 1


bmi_params = np.load(dbfn.get_bmiparams_file(task_entry))
hdf = dbfn.get_hdf(task_entry)
task_msgs = hdf.root.task_msgs[:]
update_bmi_msgs = filter(lambda msg: msg['msg'] == 'update_bmi', task_msgs)
spike_counts = hdf.root.task[:]['spike_counts']

for k, msg in enumerate(update_bmi_msgs):
    time = msg['time']
    hdf_sc = np.sum(spike_counts[time-dec.bminum+1:time+1], axis=0)
    if not np.array_equal(hdf_sc, bmi_params[k]['spike_counts_batch']):
        print k