class ManualControl2(ManualControl):
    status = dict(wait=dict(start_trial="origin", stop=None),
                  origin=dict(enter_target="origin_hold", stop=None),
                  origin_hold=dict(leave_early="hold_penalty",
                                   hold="terminus"),
                  terminus=dict(timeout="timeout_penalty",
                                enter_target="terminus_hold",
                                stop=None),
                  timeout_penalty=dict(penalty_end="pre_target_change"),
                  terminus_hold=dict(leave_early="hold_penalty",
                                     hold="terminus2"),
                  terminus2=dict(timeout="timeout_penalty",
                                 enter_target="terminus2_hold",
                                 stop=None),
                  terminus2_hold=dict(leave_early="hold_penalty",
                                      hold="reward"),
                  reward=dict(reward_end="target_change"),
                  hold_penalty=dict(penalty_end="pre_target_change"),
                  pre_target_change=dict(tried_enough='target_change',
                                         not_tried_enough='wait'),
                  target_change=dict(target_change_end='wait'))

    scale_factor = 2
    cursor_radius = .4

    def __init__(self, *args, **kwargs):
        # Add the 2nd terminus target
        super(ManualControl2, self).__init__(*args, **kwargs)
        self.terminus2_target = Sphere(radius=self.terminus_size,
                                       color=(1, 0, 0, .5))
        self.add_model(self.terminus2_target)

    def _start_wait(self):
        #set target colors
        self.terminus2_target.color = (1, 0, 0, .5)
        #hide targets from previous trial
        self.terminus2_target.detach()
        super(ManualControl2, self)._start_wait()

    def _test_enter_target(self, ts):
        #get the current cursor location and target location, return true if center of cursor is inside target (has to be close enough to center to be fully inside)
        if self.state == "terminus2":
            c = self.cursor.xfm.move
            t = self.terminus2_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            return d <= self.terminus_target.radius - self.cursor.radius
        else:
            return super(ManualControl2, self)._test_enter_target(ts)

    def _start_origin(self):
        if self.tries == 0:
            #retrieve location of next terminus target
            t2 = self.next_trial.T[2]
            #move target to correct location
            self.terminus2_target.translate(*t2, reset=True)
        super(ManualControl2, self)._start_origin()

    def _start_terminus_hold(self):
        self.terminus2_target.color = (1, 0, 0, 0.5)
        self.terminus2_target.attach()
        self.requeue()

    def _start_timeout_penalty(self):
        #hide targets and fixation point
        self.terminus2_target.detach()
        super(ManualControl2, self)._start_timeout_penalty()

    def _start_terminus2(self):
        self.terminus_target.detach()
        self.requeue()

    def _test_leave_early(self, ts):
        if self.state == "terminus2_hold":
            c = self.cursor.xfm.move
            t = self.terminus2_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            rad = self.terminus_target.radius - self.cursor.radius
            return d > rad * self.exit_radius
        else:
            return super(ManualControl2, self)._test_leave_early(ts)

    def _while_terminus2(self):
        self.update_cursor()

    def _while_terminus2_hold(self):
        self.update_cursor()

    def _end_terminus2_hold(self):
        self.terminus2_target.color = (0, 1, 0, 0.5)

    def _start_hold_penalty(self):
        self.terminus2_target.detach()
        super(ManualControl2, self)._start_hold_penalty()

    def _start_timeout_penalty(self):
        self.terminus2_target.detach()
        super(ManualControl2, self)._start_timeout_penalty()

    def update_target_location(self):
        # Determine the task target for assist/decoder adaptation purposes (convert
        # units from cm to mm for decoder)
        # TODO - decide what to do with y location, target_xz ignores it!
        if self.state == 'terminus2' or self.state == 'terminus2_hold':
            self.location = 10 * self.terminus2_target.xfm.move
            self.target_xz = np.array([self.location[0], self.location[2]])
        super(ManualControl2, self).update_target_location()
class ManualControl(TargetCapture):
    status = dict(wait=dict(start_trial="origin", stop=None),
                  origin=dict(enter_target="origin_hold", stop=None),
                  origin_hold=dict(leave_early="hold_penalty",
                                   hold="terminus"),
                  terminus=dict(timeout="timeout_penalty",
                                enter_target="terminus_hold",
                                stop=None),
                  timeout_penalty=dict(penalty_end="pre_target_change"),
                  terminus_hold=dict(leave_early="hold_penalty",
                                     hold="reward"),
                  reward=dict(reward_end="target_change"),
                  hold_penalty=dict(penalty_end="pre_target_change"),
                  pre_target_change=dict(tried_enough='target_change',
                                         not_tried_enough='wait'),
                  target_change=dict(target_change_end='wait'))

    #create settable traits
    terminus_size = traits.Float(1, desc="Radius of terminus targets")
    terminus_hold_time = traits.Float(
        2, desc="Length of hold required at terminus")
    timeout_time = traits.Float(
        10, desc="Time allowed to go between origin and terminus")
    timeout_penalty_time = traits.Float(
        3, desc="Length of penalty time for timeout error")

    #create fixation point, targets, cursor objects, initialize
    def __init__(self, *args, **kwargs):
        # Add the target and cursor locations to the task data to be saved to
        # file
        self.dtype = [('target', 'f', (3, )), ('cursor', 'f', (3, ))]
        super(ManualControl, self).__init__(*args, **kwargs)
        self.terminus_target = Sphere(radius=self.terminus_size,
                                      color=(1, 0, 0, .5))
        self.add_model(self.terminus_target)
        # Initialize target location variables
        self.location = np.array([0, 0, 0])
        self.target_xz = np.array([0, 0])

    def _start_wait(self):
        super(ManualControl, self)._start_wait()
        #set target colors
        self.terminus_target.color = (1, 0, 0, .5)
        #hide targets from previous trial
        self.show_terminus(False)

    def _start_origin(self):
        if self.tries == 0:
            #retrieve location of next terminus target
            t = self.next_trial.T[1]
            #move target to correct location
            self.terminus_target.translate(*t, reset=True)
        super(ManualControl, self)._start_origin()

    def _start_origin_hold(self):
        #make terminus target visible
        self.show_terminus(True)

    def show_terminus(self, show=False):
        if show:
            self.terminus_target.attach()
        else:
            self.terminus_target.detach()
        self.requeue()

    def _start_terminus(self):
        self.show_origin(False)

    def _end_terminus_hold(self):
        self.terminus_target.color = (0, 1, 0, 0.5)

    def _start_hold_penalty(self):
        #hide targets
        super(ManualControl, self)._start_hold_penalty()
        self.show_terminus(False)

    def _start_timeout_penalty(self):
        #hide targets and fixation point
        self.tries += 1
        self.show_terminus(False)

    def _start_reward(self):
        pass

    def _test_enter_target(self, ts):
        #get the current cursor location and target location, return true if center of cursor is inside target (has to be close enough to center to be fully inside)
        if self.state == "origin":
            c = self.cursor.xfm.move
            t = self.origin_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            return d <= self.origin_target.radius - self.cursor.radius
        if self.state == "terminus":
            c = self.cursor.xfm.move
            t = self.terminus_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            return d <= self.terminus_target.radius - self.cursor.radius

    def _test_leave_early(self, ts):
        if self.state == "origin_hold":
            c = self.cursor.xfm.move
            t = self.origin_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            rad = self.origin_target.radius - self.cursor.radius
            return d > rad * self.exit_radius
        if self.state == "terminus_hold":
            c = self.cursor.xfm.move
            t = self.terminus_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            rad = self.terminus_target.radius - self.cursor.radius
            return d > rad * self.exit_radius

    def _test_hold(self, ts):
        if self.state == "origin_hold":
            return ts >= self.origin_hold_time
        else:
            return ts >= self.terminus_hold_time

    def _test_timeout(self, ts):
        return ts > self.timeout_time

    def _test_penalty_end(self, ts):
        if self.state == "timeout_penalty":
            return ts > self.timeout_penalty_time
        if self.state == "fixation_penalty":
            return ts > self.fixation_penalty_time
        else:
            return ts > self.hold_penalty_time

    def _while_terminus(self):
        self.update_cursor()

    def _while_terminus_hold(self):
        self.update_cursor()

    def _while_timeout_penalty(self):
        self.update_cursor()

    def update_target_location(self):
        # Determine the task target for assist/decoder adaptation purposes (convert
        # units from cm to mm for decoder)
        # TODO - decide what to do with y location, target_xz ignores it!
        if self.state == 'origin' or self.state == 'origin_hold':
            self.location = 10 * self.origin_target.xfm.move
            self.target_xz = np.array([self.location[0], self.location[2]])
        elif self.state == 'terminus' or self.state == 'terminus_hold':
            self.location = 10 * self.terminus_target.xfm.move
            self.target_xz = np.array([self.location[0], self.location[2]])
        self.task_data['target'] = self.location[:3]

    def update_cursor(self):
        self.update_target_location()
        super(ManualControl, self).update_cursor()
        self.task_data['cursor'] = self.cursor.xfm.move.copy()
Beispiel #3
0
class TentacleMultiConfigObstacleAvoidance(BMIJointPerturb):
    status = dict(
        wait=dict(start_trial="premove", stop=None),
        premove=dict(premove_complete="target"),
        target=dict(enter_target="hold",
                    timeout="timeout_penalty",
                    stop=None,
                    hit_obstacle="obstacle_penalty"),
        hold=dict(leave_early="hold_penalty", hold_complete="targ_transition"),
        targ_transition=dict(trial_complete="reward",
                             trial_abort="wait",
                             trial_incomplete="target",
                             trial_restart="premove"),
        timeout_penalty=dict(timeout_penalty_end="targ_transition"),
        hold_penalty=dict(hold_penalty_end="targ_transition"),
        obstacle_penalty=dict(obstacle_penalty_end="targ_transition"),
        reward=dict(reward_end="wait"))

    obstacle_radius = traits.Float(2.0, desc='Radius of cylindrical obstacle')
    obstacle_penalty = traits.Float(
        0.0, desc='Penalty time if the chain hits the obstacle(s)')

    def __init__(self, *args, **kwargs):
        super(TentacleMultiConfigObstacleAvoidance,
              self).__init__(*args, **kwargs)

        ## Create an obstacle object, hidden by default
        self.obstacle = Sphere(radius=self.obstacle_radius + 0.6,
                               color=(0, 0, 1, .5))
        self.obstacle_on = False
        self.obstacle_pos = np.ones(3) * np.nan
        self.hit_obstacle = False

        self.add_model(self.obstacle)

    def init(self):
        self.add_dtype('obstacle_on', 'f8', (1, ))
        self.add_dtype('obstacle_pos', 'f8', (3, ))
        super(TentacleMultiConfigObstacleAvoidance, self).init()

    def _cycle(self):
        self.task_data['obstacle_on'] = self.obstacle_on
        self.task_data['obstacle_pos'] = self.obstacle_pos
        super(TentacleMultiConfigObstacleAvoidance, self)._cycle()

    def _start_target(self):
        super(TentacleMultiConfigObstacleAvoidance, self)._start_target()
        if self.target_index == 1:
            self.obstacle_pos = (self.targs[0] / 2)
            self.obstacle.translate(*self.obstacle_pos, reset=True)
            self.obstacle.attach()
            self.obstacle_on = True

    def _test_obstacle_penalty_end(self, ts):
        return ts > self.obstacle_penalty

    def _start_obstacle_penalty(self):
        #hide targets
        for target in self.targets:
            target.hide()

        self.tries += 1
        self.target_index = -1

    def _end_target(self):
        self.obstacle.detach()
        self.obstacle_on = False

    def _test_hit_obstacle(self, ts):
        if self.target_index == 1:
            joint_angles = self.plant.get_intrinsic_coordinates()
            distances_to_links = self.plant.kin_chain.detect_collision(
                joint_angles, self.obstacle_pos)

            hit = np.min(distances_to_links) < (self.obstacle_radius +
                                                self.plant.link_radii[0])
            if hit:
                self.hit_obstacle = True
                return True
        else:
            return False

    @staticmethod
    def tentacle_multi_start_config(nblocks=100,
                                    ntargets=4,
                                    distance=8,
                                    startangle=45):
        elbow_angles = np.array([
            135, 180, 225
        ]) * np.pi / 180  # TODO make this a function argument!
        startangle = 45 * np.pi / 180
        n_configs_per_target = len(elbow_angles)
        target_angles = np.arange(startangle, startangle + (2 * np.pi),
                                  2 * np.pi / ntargets)
        targets = distance * np.vstack(
            [np.cos(target_angles), 0 * target_angles,
             np.sin(target_angles)])

        seq = []
        from itertools import izip
        import random
        for i in range(nblocks):
            target_inds = np.tile(np.arange(ntargets),
                                  (n_configs_per_target, 1)).T.ravel()
            config_inds = np.tile(np.arange(n_configs_per_target), ntargets)

            sub_seq = []
            inds = np.arange(n_configs_per_target * ntargets)
            random.shuffle(inds)
            for k in inds:
                targ_ind = target_inds[k]
                config_ind = config_inds[k]

                seq_item = (np.vstack([targets[:, targ_ind],
                                       np.zeros(3)]), elbow_angles[config_ind])
                seq.append(seq_item)

        return seq
class TargetCapture(Sequence, FixationTraining):
    status = dict(wait=dict(start_trial="origin", stop=None),
                  origin=dict(enter_target="origin_hold", stop=None),
                  origin_hold=dict(leave_early="hold_penalty", hold="reward"),
                  reward=dict(reward_end="target_change"),
                  hold_penalty=dict(penalty_end="pre_target_change"),
                  pre_target_change=dict(tried_enough='target_change',
                                         not_tried_enough='wait'),
                  target_change=dict(target_change_end='wait'))

    #create settable traits
    origin_size = traits.Float(1, desc="Radius of origin targets"
                               )  #add error if target is smaller than cursor
    origin_hold_time = traits.Float(2,
                                    desc="Length of hold required at origin")
    hold_penalty_time = traits.Float(
        3, desc="Length of penalty time for target hold error")
    exit_radius = 1.5  #Multiplier for the actual radius which is considered 'exiting' the target

    no_data_count = 0
    tries = 0
    scale_factor = 3.5  #scale factor for converting hand movement to screen movement (1cm hand movement = 3.5cm cursor movement)
    cursor_radius = .5

    def __init__(self, *args, **kwargs):
        super(TargetCapture, self).__init__(*args, **kwargs)
        self.origin_target = Sphere(radius=self.origin_size,
                                    color=(1, 0, 0, .5))
        self.add_model(self.origin_target)
        self.cursor = Sphere(radius=self.cursor_radius, color=(.5, 0, .5, 1))
        self.add_model(self.cursor)

    def _start_wait(self):
        super(TargetCapture, self)._start_wait()
        #set target color
        self.origin_target.color = (1, 0, 0, .5)
        #hide target from previous trial
        self.show_origin(False)

    def show_origin(self, show=False):
        if show:
            self.origin_target.attach()
        else:
            self.origin_target.detach()
        self.requeue()

    def _start_origin(self):
        if self.tries == 0:
            #retrieve location of next origin target
            o = self.next_trial.T[0]
            #move target to correct location
            self.origin_target.translate(*o, reset=True)
        #make visible
        self.show_origin(True)

    def _end_origin_hold(self):
        #change target color
        self.origin_target.color = (0, 1, 0, 0.5)

    def _start_hold_penalty(self):
        self.tries += 1
        #hide target
        self.show_origin(False)

    def _start_target_change(self):
        self.tries = 0

    def _test_target_change_end(self, ts):
        return True

    def _test_enter_target(self, ts):
        #get the current cursor location and target location, return true if center of cursor is inside target (has to be close enough to center to be fully inside)
        c = self.cursor.xfm.move
        t = self.origin_target.xfm.move
        d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
        return d <= self.origin_target.radius - self.cursor.radius

    def _test_leave_early(self, ts):
        c = self.cursor.xfm.move
        t = self.origin_target.xfm.move
        d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
        rad = self.origin_target.radius - self.cursor.radius
        return d > rad * self.exit_radius

    def _test_hold(self, ts):
        return ts >= self.origin_hold_time

    def _test_penalty_end(self, ts):
        if self.state == "fixation_penalty":
            return ts > self.fixation_penalty_time
        else:
            return ts > self.hold_penalty_time

    def _test_tried_enough(self, ts):
        return self.tries == 3

    def _test_not_tried_enough(self, ts):
        return self.tries != 3

    def _update(self, pt):
        if len(pt) > 0:
            self.cursor.translate(*pt[:3], reset=True)
        #if no data has come in for at least 3 frames, hide cursor
        elif self.no_data_count > 2:
            self.no_data_count += 1
            self.cursor.detach()
            self.requeue()
        else:
            self.no_data_count += 1

    def update_cursor(self):
        #get data from 1st marker on motion tracker- take average of all data points since last poll
        pt = self.motiondata.get()
        if len(pt) > 0:
            pt = pt[:, 14, :]
            # NOTE!!! The marker on the hand was changed from #0 to #14 on
            # 5/19/13 after LED #0 broke. All data files saved before this date
            # have LED #0 controlling the cursor.
            conds = pt[:, 3]
            inds = np.nonzero((conds >= 0) & (conds != 4))
            if len(inds[0]) > 0:
                pt = pt[inds[0], :3]
                #convert units from mm to cm and scale to desired amount
                pt = pt.mean(0) * .1 * self.scale_factor
                #ignore y direction
                pt[1] = 0
                #move cursor to marker location
                self._update(pt)
            else:
                self.no_data_count += 1
        else:
            self.no_data_count += 1
        #write to screen
        self.draw_world()

    def calc_trial_num(self):
        '''Calculates the current trial count'''
        trialtimes = [
            state[1] for state in self.state_log
            if state[0] in ['reward', 'timeout_penalty', 'hold_penalty']
        ]
        return len(trialtimes)

    def calc_rewards_per_min(self, window):
        '''Calculates the Rewards/min for the most recent window of specified number of seconds in the past'''
        rewardtimes = np.array(
            [state[1] for state in self.state_log if state[0] == 'reward'])
        if (self.get_time() - self.task_start_time) < window:
            divideby = (self.get_time() - self.task_start_time) / 60.0
        else:
            divideby = window / 60.0
        return np.sum(rewardtimes >= (self.get_time() - window)) / divideby

    def calc_success_rate(self, window):
        '''Calculates the rewarded trials/initiated trials for the most recent window of specified length in sec'''
        trialtimes = np.array([
            state[1] for state in self.state_log
            if state[0] in ['reward', 'timeout_penalty', 'hold_penalty']
        ])
        rewardtimes = np.array(
            [state[1] for state in self.state_log if state[0] == 'reward'])
        if len(trialtimes) == 0:
            return 0.0
        else:
            return float(
                np.sum(rewardtimes >= (self.get_time() - window))) / np.sum(
                    trialtimes >= (self.get_time() - window))

    def update_report_stats(self):
        '''Function to update any relevant report stats for the task. Values are saved in self.reportstats,
        an ordered dictionary. Keys are strings that will be displayed as the label for the stat in the web interface,
        values can be numbers or strings. Called every time task state changes.'''
        super(TargetCapture, self).update_report_stats()
        self.reportstats['Trial #'] = self.calc_trial_num()
        self.reportstats['Reward/min'] = np.round(
            self.calc_rewards_per_min(120), decimals=2)
        self.reportstats['Success rate'] = str(
            np.round(self.calc_success_rate(120) * 100.0, decimals=2)) + '%'

    def _while_wait(self):
        self.update_cursor()

    def _while_origin(self):
        self.update_cursor()

    def _while_origin_hold(self):
        self.update_cursor()

    def _while_fixation_penalty(self):
        self.update_cursor()

    def _while_hold_penalty(self):
        self.update_cursor()

    def _while_reward(self):
        self.update_cursor()

    def _while_pre_target_change(self):
        self.update_cursor()

    def _while_target_change(self):
        self.update_cursor()
Beispiel #5
0
class TentacleObstacleAvoidance(BMIControlMultiTentacleAttractor):
    status = dict(
        wait=dict(start_trial="target", stop=None),
        target=dict(enter_target="hold",
                    timeout="timeout_penalty",
                    stop=None,
                    hit_obstacle="obstacle_penalty"),
        hold=dict(leave_early="hold_penalty", hold_complete="targ_transition"),
        targ_transition=dict(trial_complete="reward",
                             trial_abort="wait",
                             trial_incomplete="target"),
        timeout_penalty=dict(timeout_penalty_end="targ_transition"),
        hold_penalty=dict(hold_penalty_end="targ_transition"),
        obstacle_penalty=dict(obstacle_penalty_end="targ_transition"),
        reward=dict(reward_end="wait"))
    obstacle_radius = traits.Float(2.0, desc='Radius of cylindrical obstacle')
    obstacle_penalty = traits.Float(
        0.0, desc='Penalty time if the chain hits the obstacle(s)')

    def __init__(self, *args, **kwargs):
        super(TentacleObstacleAvoidance, self).__init__(*args, **kwargs)

        ## Create an obstacle object, hidden by default
        self.obstacle = Sphere(
            radius=self.obstacle_radius + 0.6, color=(0, 0, 1, .5)
        )  ##Cylinder(radius=self.obstacle_radius, height=1, color=(0,0,1,1))
        self.obstacle_on = False
        self.obstacle_pos = np.ones(3) * np.nan
        self.hit_obstacle = False

        self.add_model(self.obstacle)

    def init(self):
        self.add_dtype('obstacle_on', 'f8', (1, ))
        self.add_dtype('obstacle_pos', 'f8', (3, ))
        super(TentacleObstacleAvoidance, self).init()

    def _cycle(self):
        self.task_data['obstacle_on'] = self.obstacle_on
        self.task_data['obstacle_pos'] = self.obstacle_pos
        super(TentacleObstacleAvoidance, self)._cycle()

    def _start_target(self):
        super(TentacleObstacleAvoidance, self)._start_target()
        if self.target_index == 1:
            target_angle = np.round(
                np.rad2deg(
                    np.arctan2(self.target_location[-1],
                               self.target_location[0])))
            try:
                obstacle_data = pickle.load(
                    open(
                        '/storage/task_data/TentacleObstacleAvoidance/center_out_obstacle_pos.pkl'
                    ))
                self.obstacle_pos = obstacle_data[target_angle]
            except:
                self.obstacle_pos = (self.target_location / 2)
            self.obstacle.translate(*self.obstacle_pos, reset=True)
            self.obstacle.attach()
            self.obstacle_on = True

    def _test_obstacle_penalty_end(self, ts):
        return ts > self.obstacle_penalty

    def _start_obstacle_penalty(self):
        #hide targets
        for target in self.targets:
            target.hide()

        self.tries += 1
        self.target_index = -1

    def _end_target(self):
        self.obstacle.detach()
        self.obstacle_on = False

    def _test_hit_obstacle(self, ts):
        if self.target_index == 1:
            joint_angles = self.plant.get_intrinsic_coordinates()
            distances_to_links = self.plant.kin_chain.detect_collision(
                joint_angles, self.obstacle_pos)

            hit = np.min(distances_to_links) < (self.obstacle_radius +
                                                self.plant.link_radii[0])
            if hit:
                self.hit_obstacle = True
                return True
        else:
            return False