コード例 #1
0
    def test_unit_conv_mm_to_cm(self):

        # block with fixed decoder operating in mm
        te = dbfunctions.get_task_entry(1762)
        
        hdf = dbfunctions.get_hdf(te)
        tslice = slice(5, None, 6)
        cursor = hdf.root.task[tslice]['cursor']
        spike_counts = hdf.root.task[tslice]['bins']
        spike_counts = np.array(spike_counts, dtype=np.float64)
        #spike_counts = spike_counts[5::6] # weird indexing feature of the way the old BMI was running
        
        def run_decoder(dec, spike_counts):
            T = spike_counts.shape[0]
            decoded_state = []
            for t in range(0, T):
                decoded_state.append(dec.predict(spike_counts[t,:]))
            return np.array(np.vstack(decoded_state))
        
        dec = dbfunctions.get_decoder(te)
        dec_state_mm = 0.1*run_decoder(dec, spike_counts)
        diff_mm = cursor - np.float32(dec_state_mm[:,0:3])
        self.assertEqual(np.max(np.abs(diff_mm)), 0)
        
        dec = dbfunctions.get_decoder(te)
        dec_cm = train.rescale_KFDecoder_units(dec)
        dec_state_cm = run_decoder(dec_cm, spike_counts)
        diff_cm = cursor - np.float32(dec_state_cm[:,0:3])
        #print np.max(np.abs(diff_cm))
        self.assertEqual(np.max(np.abs(diff_cm)), 0)
コード例 #2
0
    def test_unit_conv_mm_to_cm(self):

        # block with fixed decoder operating in mm
        te = dbfunctions.get_task_entry(1762)

        hdf = dbfunctions.get_hdf(te)
        tslice = slice(5, None, 6)
        cursor = hdf.root.task[tslice]['cursor']
        spike_counts = hdf.root.task[tslice]['bins']
        spike_counts = np.array(spike_counts, dtype=np.float64)

        #spike_counts = spike_counts[5::6] # weird indexing feature of the way the old BMI was running

        def run_decoder(dec, spike_counts):
            T = spike_counts.shape[0]
            decoded_state = []
            for t in range(0, T):
                decoded_state.append(dec.predict(spike_counts[t, :]))
            return np.array(np.vstack(decoded_state))

        dec = dbfunctions.get_decoder(te)
        dec_state_mm = 0.1 * run_decoder(dec, spike_counts)
        diff_mm = cursor - np.float32(dec_state_mm[:, 0:3])
        self.assertEqual(np.max(np.abs(diff_mm)), 0)

        dec = dbfunctions.get_decoder(te)
        dec_cm = train.rescale_KFDecoder_units(dec)
        dec_state_cm = run_decoder(dec_cm, spike_counts)
        diff_cm = cursor - np.float32(dec_state_cm[:, 0:3])
        #print np.max(np.abs(diff_cm))
        self.assertEqual(np.max(np.abs(diff_cm)), 0)
コード例 #3
0
#!/usr/bin/python
from db import dbfunctions
from tasks import bmimultitasks
import numpy as np
from riglib.bmi import state_space_models, kfdecoder, train
reload(kfdecoder)

te = dbfunctions.get_task_entry(1883) # Block with predict and update of kf running at 10Hz
hdf = dbfunctions.get_hdf(te)
dec = dbfunctions.get_decoder(te)

assist_level = hdf.root.task[:]['assist_level'].ravel()
spike_counts = hdf.root.task[:]['spike_counts']
target = hdf.root.task[:]['target']
cursor = hdf.root.task[:]['cursor']

assert np.all(assist_level == 0)

task_msgs = hdf.root.task_msgs[:]
update_bmi_msgs = filter(lambda x: x['msg'] == 'update_bmi', task_msgs)
inds = [x[1] for x in update_bmi_msgs]

assert len(inds) == 0

T = spike_counts.shape[0]
error = np.zeros(T)
for k in range(spike_counts.shape[0]):
    if k - 1 in inds:
        dec.update_params(bmi_params[inds.index(k-1)])
    st = dec(spike_counts[k], target=target[k], target_radius=1.8, assist_level=assist_level[k])
コード例 #4
0
# coding: utf-8
from db import dbfunctions as dbfn
import numpy as np

task_entry = 2023

dec = dbfn.get_decoder(task_entry)

try:
    dec.bminum
except:
    dec.bminum = 1

bmi_params = np.load(dbfn.get_bmiparams_file(task_entry))
hdf = dbfn.get_hdf(task_entry)
task_msgs = hdf.root.task_msgs[:]
update_bmi_msgs = [msg for msg in task_msgs if msg['msg'] == 'update_bmi']
spike_counts = hdf.root.task[:]['spike_counts']

for k, msg in enumerate(update_bmi_msgs):
    time = msg['time']
    hdf_sc = np.sum(spike_counts[time - dec.bminum + 1:time + 1], axis=0)
    if not np.array_equal(hdf_sc, bmi_params[k]['spike_counts_batch']):
        print(k)
コード例 #5
0
def trial_array(sess_list, save=None):
    ''' Works for center out 8 target task only. sess_list input must be a list of 1 or more session IDs. If more than 1 they will
    be concatenated as if they were a single continuous session.
    
    Each row of the output array represents a single initiated trial (trials count as initiated when a center hold is completed).
    The first 3 columns indicate whether that trial ended in a reward, timeout error, or a hold error, and the timestamp in
    sec of the beginning of the trial (counted as the time at the end of the center hold state). The 4th column indicates
    which target was presented (0-8) and the 4th and 5th columns list the reach time and movement error for that trial
    (these are nans for timeout trials since the cursor never made it to the target).'''

    var_dict = {
        'completed': 0,
        'timed_out': 1,
        'hold_error': 2,
        'target_number': 3,
        'reach_time': 4,
        'movement_error': 5
    }
    angle_tolerance = .01
    angle_list = np.arange(-3 * np.pi / 4, 5 * np.pi / 4,
                           np.pi / 4.)  # 8 targets from -135 to 180 degrees

    all_states = []
    all_targstate_inds = []
    all_target_numbers = []
    all_trajectories = []
    all_biases = []

    for j, sess_id in enumerate(sess_list):
        hdf = dbf.get_hdf(sess_id)
        states = hdf.root.task_msgs[:]

        # pull out all target states and corresponding target index values, then remove center targets
        target_states = [(i, s) for i, s in enumerate(states[:-2])
                         if (s[0] == 'target')]
        target_inds = [
            hdf.root.task[s[1][1]]['target_index'][0] for s in target_states
        ]
        target_states = [
            target_states[i] for i, ti in enumerate(target_inds) if ti == 1
        ]
        targstate_inds = [s[0] for s in target_states]

        # get location of target for each trial
        target_locs = [hdf.root.task[s[1][1]]['target'] for s in target_states]
        target_numbers = []
        # Find the angle of the target for each trial and classify it with the correct target index key
        for t in target_locs:
            targ_angle = np.arctan2(t[2], t[0])
            target_num = None
            for i, ang in enumerate(angle_list):
                if targ_angle >= ang - angle_tolerance and targ_angle <= ang + angle_tolerance:
                    target_num = i
            assert (target_num is not None), "Unrecognized target angle! "
            target_numbers.append(target_num)

        # Get cursor trajectory for each trial
        trajectories = [
            hdf.root.task[states[s[0]][1]:states[s[0] + 1][1]]['cursor']
            for s in target_states
        ]

        # Adjust timestamps and indices if concatenating multiple sessions
        if j > 0:
            last_time = all_states[-1][1]
            last_ind = len(all_states)
            states = [(s[0], s[1] + last_time) for s in states]
            targstate_inds = [k + last_ind for k in targstate_inds]

        # Combine data from this file with previous ones
        all_states.extend(states)
        all_targstate_inds.extend(targstate_inds)
        all_target_numbers.extend(target_numbers)
        all_trajectories.extend(trajectories)

        try:
            biases = [hdf.root.task[s[1][1]]['bias'] for s in target_states]
            all_biases.extend(biases)
        except:
            pass

    # initialize array to hold trial data
    output_array = np.zeros([len(all_targstate_inds), 6])
    output_array[:, var_dict['reach_time']] = np.nan
    output_array[:, var_dict['movement_error']] = np.nan

    for i, state_ind in enumerate(all_targstate_inds):
        state_ts = all_states[state_ind][1]
        # if trial ended in timeout
        if all_states[state_ind + 1][0] == 'timeout_penalty':
            output_array[i, var_dict['timed_out']] = state_ts
            output_array[i, var_dict['target_number']] = all_target_numbers[i]
        # if trial ended in hold error
        if all_states[state_ind + 2][0] == 'hold_penalty':
            output_array[i, var_dict['hold_error']] = state_ts
            output_array[i, var_dict['reach_time']] = (
                all_states[state_ind + 1][1] - state_ts) / 60.
            output_array[i, var_dict['target_number']] = all_target_numbers[i]
            # Find the movement error (max perpendicular distance from straight line path to target)
            points = all_trajectories[i][:, [0, 2]]
            dists = np.zeros(len(points))
            ang = angle_list[all_target_numbers[i]]
            m = np.tan(ang)
            if np.abs(m) > 1000000:
                dists = points[:, 0]
            else:
                dists = np.abs(points[:, 1] -
                               m * points[:, 0]) / np.sqrt(m**2 + 1)
            output_array[i, var_dict['movement_error']] = np.max(np.abs(dists))
        # if trial ended in reward
        if all_states[state_ind + 3][0] == 'reward':
            output_array[i, var_dict['completed']] = state_ts
            output_array[i, var_dict['reach_time']] = (
                all_states[state_ind + 1][1] - state_ts) / 60.
            output_array[i, var_dict['target_number']] = all_target_numbers[i]
            # Find the movement error (max perpendicular distance from straight line path to target)
            points = all_trajectories[i][:, [0, 2]]
            dists = np.zeros(len(points))
            ang = angle_list[all_target_numbers[i]]
            m = np.tan(ang)
            if np.abs(m) > 1000000:
                dists = points[:, 0]
            else:
                dists = np.abs(points[:, 1] -
                               m * points[:, 0]) / np.sqrt(m**2 + 1)
            output_array[i, var_dict['movement_error']] = np.max(np.abs(dists))

    if save is not None:
        np.save(save, output_array)

    return output_array, var_dict, all_biases
# coding: utf-8
from db import dbfunctions as dbfn
import numpy as np

task_entry = 2023

dec = dbfn.get_decoder(task_entry)

try:
    dec.bminum
except:
    dec.bminum = 1


bmi_params = np.load(dbfn.get_bmiparams_file(task_entry))
hdf = dbfn.get_hdf(task_entry)
task_msgs = hdf.root.task_msgs[:]
update_bmi_msgs = filter(lambda msg: msg['msg'] == 'update_bmi', task_msgs)
spike_counts = hdf.root.task[:]['spike_counts']

for k, msg in enumerate(update_bmi_msgs):
    time = msg['time']
    hdf_sc = np.sum(spike_counts[time-dec.bminum+1:time+1], axis=0)
    if not np.array_equal(hdf_sc, bmi_params[k]['spike_counts_batch']):
        print k