示例#1
0
class OptitrackPlayback(Optitrack):
    '''
    Read a csv file back into BMI3D as if it were live data
    '''

    filepath = traits.String("", desc="path to optitrack csv file for playback")
    
    def init(self):
        '''
        Secondary init function. See riglib.experiment.Experiment.init()
        Prior to starting the task, this 'init' sets up the DataSource for interacting with the 
        motion tracker system and registers the source with the SinkRegister so that the data gets saved to file as it is collected.
        '''

        # Start the fake natnet client
        self.client = optitrack.PlaybackClient(self.filepath)

        # Create a source to buffer the motion tracking data
        from riglib import source
        self.motiondata = source.DataSource(optitrack.make(optitrack.System, self.client, self.optitrack_feature, 1))

        # Save to the sink
        from riglib import sink
        sink_manager = sink.SinkManager.get_instance()
        sink_manager.register(self.motiondata)
        super(Optitrack, self).init()
class CLDAControlMulti(BMIControlMulti, LinearlyDecreasingHalfLife):
    '''
    BMI task that periodically refits the decoder parameters based on intended
    movements toward the targets. Inherits directly from BMIControl. Can be made
    to automatically linearly decrease assist level over set time period, or
    to provide constant assistance by setting assist_level and assist_min equal.
    '''

    batch_time = traits.Float(80.0, desc='The length of the batch in seconds')
    decoder_sequence = traits.String(
        'test', desc='signifier to group together sequences of decoders')

    ordered_traits = [
        'session_length', 'assist_level', 'assist_level_time', 'batch_time',
        'half_life', 'half_life_time'
    ]

    def __init__(self, *args, **kwargs):
        super(CLDAControlMulti, self).__init__(*args, **kwargs)
        self.learn_flag = True

    def create_learner(self):
        self.batch_size = int(self.batch_time / self.decoder.binlen)
        self.learner = CursorGoalLearner2(self.batch_size)
        self.learn_flag = True

    def create_updater(self):
        half_life_start, half_life_end = self.half_life
        self.updater = clda.KFSmoothbatch(self.batch_time, half_life_start)

    def call_decoder(self, *args, **kwargs):
        kwargs['half_life'] = self.current_half_life
        return super(CLDAControlMulti, self).call_decoder(*args, **kwargs)
示例#3
0
class ButtonTask(LogExperiment):
    side = traits.String("left",
                         desc='Use "left" for one side, "right" for the other')
    reward_time = traits.Float(5, desc='Amount of reward (in seconds)')
    penalty_time = traits.Float(5, desc='Amount of penalty (in seconds)')

    status = dict(
        left=dict(left_correct="reward", left_incorrect="penalty", stop=None),
        right=dict(right_correct="reward",
                   right_incorrect="penalty",
                   stop=None),
        reward=dict(post_reward="picktrial"),
        penalty=dict(post_penalty="picktrial"),
    )

    state = "picktrial"

    def __init__(self, **kwargs):
        from riglib import button
        super(ButtonTask, self).__init__(**kwargs)
        self.button = button.Button()

    def _start_picktrial(self):
        self.set_state(self.side)

    def _get_event(self):
        if self.button is not None:
            return self.button.pressed()
        return None

    def _while_left(self):
        self.event = self._get_event()

    def _while_right(self):
        self.event = self._get_event()

    def _test_left_correct(self, ts):
        return self.event is not None and self.event in [1, 2]

    def _test_left_incorrect(self, ts):
        return self.event is not None and self.event in [8, 4]

    def _test_right_correct(self, ts):
        return self.event is not None and self.event in [8, 4]

    def _test_right_incorrect(self, ts):
        return self.event is not None and self.event in [1, 2]

    def _test_post_reward(self, ts):
        return ts > self.reward_time

    def _test_post_penalty(self, ts):
        return ts > self.penalty_time

    def _test_both_correct(self, ts):
        return self.event is not None

    def _start_None(self):
        pass
示例#4
0
class CLDAControlExoEndpt(BMIControlExoEndpt, LinearlyDecreasingHalfLife):
    batch_time = traits.Float(0.1, desc='The length of the batch in seconds')
    decoder_sequence = traits.String('exo', desc='signifier to group together sequences of decoders')

    def create_updater(self):
        self.updater = clda.KFRML(self.batch_time, self.half_life[0])

    def init(self):
        super(CLDAControlExoEndpt, self).init()
        self.batch_time = self.decoder.binlen
        self.updater.init(self.decoder)        

    def create_learner(self):
        self.batch_size = int(self.batch_time/self.decoder.binlen)
        self.learner = clda.FeedbackControllerLearner(self.batch_size, joint_vel_fb_ctrl, reset_states=['go_to_origin', 'wait', 'init_exo', 'move_target', 'pause', 'reward' ])
        self.learn_flag = True
示例#5
0
class MAMCLDA(MAMBMI, LinearlyDecreasingHalfLife):
    batch_time = traits.Float(0.1, desc='The length of the batch in seconds')
    decoder_sequence = traits.String('MAM', desc='signifier to group together sequences of decoders')

    def create_learner(self):
        self.batch_size = int(self.batch_time/self.decoder.binlen)
        fb_ctrl = MSKController()
        self.learner = clda.FeedbackControllerLearner(self.batch_size, fb_ctrl, style='mixing')
        self.learn_flag = True

    def create_updater(self):
        self.updater = clda.KFRML(self.batch_time, self.half_life[0])

    def _cycle(self):
        super(MAMCLDA, self)._cycle()
        if self.calc_state_occurrences('reward') > 16:
            self.learner.batch_size = np.inf

    def call_decoder(self, *args, **kwargs):
        kwargs['half_life'] = self.current_half_life
        return super(MAMCLDA, self).call_decoder(*args, **kwargs)
class PointMassCLDAReconstruction(PointMassBMIReconstruction):
    batch_time = traits.Float(0.1, desc='The length of the batch in seconds')
    decoder_sequence = traits.String('test', desc='signifier to group together sequences of decoders')

    def __init__(self, *args, **kwargs):
        super(PointMassCLDAReconstruction, self).__init__(*args, **kwargs)
        self.half_life = self.te.half_life

    def create_learner(self):
        from tasks.point_mass_cursor import PointMassFBController
        self.batch_size = int(self.batch_time/self.decoder.binlen)
        fb_ctrl = PointMassFBController()
        self.learner = clda.FeedbackControllerLearner(self.batch_size, fb_ctrl)
        self.learn_flag = True

    def create_updater(self):
        self.updater = clda.KFRML(self.batch_time, self.half_life[0])

    def call_decoder(self, *args, **kwargs):
        kwargs['half_life'] = self.current_half_life
        return super(PointMassCLDAReconstruction, self).call_decoder(*args, **kwargs)
class CLDA_BMIResettingObstacles(BMIResettingObstacles,
                                 LinearlyDecreasingHalfLife):
    sequence_generators = [
        'centerout_2D_discrete', 'centerout_2D_discrete_w_obstacle'
    ]
    batch_time = traits.Float(0.1, desc='The length of the batch in seconds')
    decoder_sequence = traits.String(
        'test', desc='signifier to group together sequences of decoders')
    memory_decay_rate = traits.Float(0.5, desc="")
    ordered_traits = [
        'session_length', 'assist_level', 'assist_level_time', 'batch_time',
        'half_life', 'half_life_time'
    ]

    def __init__(self, *args, **kwargs):
        super(CLDA_BMIResettingObstacles, self).__init__(*args, **kwargs)
        self.learn_flag = True

    def call_decoder(self, *args, **kwargs):
        kwargs['half_life'] = self.current_half_life
        return super(CLDA_BMIResettingObstacles,
                     self).call_decoder(*args, **kwargs)

    def create_updater(self):
        self.updater = clda.KFRML(self.batch_time, self.half_life[0])
        self.updater.default_gain = self.memory_decay_rate

    def create_learner(self):
        # self.batch_size = int(self.batch_time/self.decoder.binlen)
        # self.learner = ObstacleLearner(self.batch_size)
        # self.learn_flag = True

        self.batch_size = int(self.batch_time / self.decoder.binlen)
        A, B, _ = self.decoder.ssm.get_ssm_matrices()
        Q = np.mat(np.diag([10, 10, 10, 5, 5, 5, 0]))
        R = 10**6 * np.mat(np.eye(B.shape[1]))
        from tentaclebmitasks import OFCLearnerTentacle
        self.learner = OFCLearnerTentacle(self.batch_size, A, B, Q, R)
        self.learn_flag = True
class CLDAManipulatedFB(BMIControlManipulatedFB):
    '''
    BMI task that periodically refits the decoder parameters based on intended
    movements toward the targets. Inherits directly from BMIControl. Can be made
    to automatically linearly decrease assist level over set time period, or
    to provide constant assistance by setting assist_level and assist_min equal.
    '''

    batch_time = traits.Float(80.0, desc='The length of the batch in seconds')
    half_life = traits.Tuple((120., 120.0), desc='Half life of the adaptation in seconds')
    decoder_sequence = traits.String('test', desc='signifier to group together sequences of decoders')
    #assist_min = traits.Float(0, desc="Assist level to end task at")
    #half_life_final = traits.Float(120.0, desc='Half life of the adaptation in seconds')
    half_life_decay_time = traits.Float(900.0, desc="Time to go from initial half life to final")


    def __init__(self, *args, **kwargs):
        super(CLDAManipulatedFB, self).__init__(*args, **kwargs)
        #self.assist_start = self.assist_level
        self.learn_flag = True

    def init(self):
        '''
        Secondary init function. Decoder has already been created by inclusion
        of the 'bmi' feature in the task. Create the 'learner' and 'updater'
        components of the CLDA algorithm
        '''
        # Add CLDA-specific data to save to HDF file 
        self.dtype.append(('half_life', 'f8', (1,)))

        super(CLDAManipulatedFB, self).init()

        self.batch_size = int(self.batch_time/self.decoder.binlen)
        self.create_learner()

        # Create the updater second b/c the update algorithm might need to force
        # a particular batch size for the learner
        self.create_updater()

        # Create the BMI system which combines the decoder, learner, and updater
        self.bmi_system = riglib.bmi.BMISystem(self.decoder, self.learner,
            self.updater)

        

    def create_learner(self):
        self.learner = clda.CursorGoalLearner(self.batch_size)

        # Start "learn flag" at True
        self.learn_flag = True
        homedir = os.getenv('HOME')
        f = open(os.path.join(homedir, 'learn_flag_file'), 'w')
        f.write('1')
        f.close()

    def create_updater(self):
        clda_input_queue = mp.Queue()
        clda_output_queue = mp.Queue()
        half_life_start, half_life_end = self.half_life
        self.updater = clda.KFSmoothbatch(clda_input_queue, clda_output_queue,self.batch_time, half_life_start)

    def update_learn_flag(self):
        # Tell the adaptive BMI when to learn (skip parts of the task where we
        # assume the subject is not trying to move toward the target)
        prev_learn_flag = self.learn_flag

        # Open file to read learn flag
        try:
            homedir = os.getenv('HOME')
            f = open(os.path.join(homedir, 'learn_flag_file'))
            new_learn_flag = bool(int(f.readline().rstrip('\n')))
        except:
            new_learn_flag = True

        if new_learn_flag and not prev_learn_flag:
            print "CLDA enabled"
        elif prev_learn_flag and not new_learn_flag:
            try:
                print "CLDA disabled after %d successful trials" % self.calc_n_rewards()
            except:
                print "CLDA disabled"
        self.learn_flag = new_learn_flag

    def call_decoder(self, spike_counts):
        half_life_start, half_life_end = self.half_life
        current_half_life = self._linear_change(half_life_start, half_life_end, self.half_life_decay_time)
        self.task_data['half_life'] = current_half_life

        # Get the decoder output
        decoder_output, uf =  self.bmi_system(spike_counts, self.target_location,
            self.state, task_data=self.task_data, assist_level=self.current_assist_level,
            target_radius=self.target_radius, speed=self.assist_speed*self.decoder.binlen, 
            learn_flag=self.learn_flag, half_life=current_half_life)
        if uf:
            #send msg to hdf file to indicate decoder update
            self.hdf.sendMsg("update_bmi")
        return decoder_output #self.decoder['hand_px', 'hand_py', 'hand_pz']

    def _cycle(self):
        self.update_learn_flag()
        super(CLDAManipulatedFB, self)._cycle()

    def cleanup(self, database, saveid, **kwargs):
        super(CLDAManipulatedFB, self).cleanup(database, saveid, **kwargs)
        import tempfile, cPickle, traceback, datetime

        # Open a log file in case of error b/c errors not visible to console
        # at this point
        f = open(os.path.join(os.getenv('HOME'), 'Desktop/log'), 'a')
        f.write('Opening log file\n')
        
        # save out the parameter history and new decoder unless task was stopped
        # before 1st update
        try:
            f.write('# of paramter updates: %d\n' % len(self.bmi_system.param_hist))
            if len(self.bmi_system.param_hist) > 0:
                f.write('Starting to save parameter hist\n')
                tf = tempfile.NamedTemporaryFile()
                # Get the update history for C and Q matrices and save them
                #C, Q, m, sd, intended_kin, spike_counts = zip(*self.bmi_system.param_hist)
                #np.savez(tf, C=C, Q=Q, mean=m, std=sd, intended_kin=intended_kin, spike_counts=spike_counts)
                pickle.dump(self.bmi_system.param_hist, tf)
                tf.flush()
                # Add the parameter history file to the database entry for this
                # session
                database.save_data(tf.name, "bmi_params", saveid)
                f.write('Finished saving parameter hist\n')

                # Save the final state of the decoder as a new decoder
                tf2 = tempfile.NamedTemporaryFile(delete=False) 
                cPickle.dump(self.decoder, tf2)
                tf2.flush()
                # create suffix for new decoder that has the sequence and the current day
                # and time. This suffix will be appended to the name of the
                # decoder that we started with and saved as a new decoder.
                now = datetime.datetime.now()
                decoder_name = self.decoder_sequence + '%02d%02d%02d%02d' % (now.month, now.day, now.hour, now.minute)
                database.save_bmi(decoder_name, saveid, tf2.name)
        except:
            traceback.print_exc(file=f)
        f.close()
示例#9
0
class TargetCaptureReplay(ScreenTargetCapture):
    '''
    Reads the frame-by-frame cursor and trial-by-trial target positions from a saved
    HDF file to display an exact copy of a previous experiment. 
    Doesn't really work, do not recommend using this.
    '''

    hdf_filepath = traits.String("", desc="Filepath of hdf file to replay")

    exclude_parent_traits = list(set(ScreenTargetCapture.class_traits().keys()) - \
        set(['window_size', 'fullscreen']))

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.t0 = time.perf_counter()
        with tables.open_file(self.hdf_filepath, 'r') as f:
            task = f.root.task.read()
            state = f.root.task_msgs.read()
            trial = f.root.trials.read()
            params = f.root.task.attrs._f_list("user")
            self.task_meta = {k : getattr(f.root.task.attrs, k) for k in params}
        self.replay_state = state
        self.replay_task = task
        self.replay_trial = trial
        for k, v in self.task_meta.items():
            if k in self.exclude_parent_traits:
                print("setting {} to {}".format(k, v))
                setattr(self, k, v)

        # Have to additionally reset the targets since they are created in super().__init__()
        for target in self.targets:
            target.sphere.radius = self.task_meta['target_radius']
            target.sphere.color = target_colors[self.task_meta['target_color']]

    def _test_start_trial(self, time_in_state):
        '''Wait for the state change in the HDF file in case there is autostart enabled'''
        trials = self.replay_state[self.replay_state['msg'] == b'target']
        upcoming_trials = [t['time'] for t in trials if self.replay_task[t['time']]['trial'] > self.calc_trial_num()]
        return (np.array(upcoming_trials) <= self.cycle_count).any()

    def _parse_next_trial(self):
        '''Ignore the generator'''
        self.targs = []
        self.gen_indices = []
        trial_num = self.calc_trial_num()
        for trial in self.replay_trial:
            if trial['trial'] == trial_num:
                self.targs.append(trial['target'])
                self.gen_indices.append(trial['index'])

    def _cycle(self):
        '''Have to fudge the cycle_count a bit in case the fps isn't exactly the same'''
        super()._cycle()
        t1 = time.perf_counter() - self.t0
        self.cycle_count = int(t1*self.fps)

    def move_effector(self):
        current_pt = self.replay_task['cursor'][self.cycle_count]
        self.plant.set_endpoint_pos(current_pt)

    def _test_stop(self, ts):
        return super()._test_stop(ts) or self.cycle_count == len(self.replay_task)