Пример #1
0
class CircularTarget(object):
    def __init__(self,
                 target_radius=2,
                 target_color=(1, 0, 0, .5),
                 starting_pos=np.zeros(3)):
        self.target_color = target_color
        self.default_target_color = tuple(self.target_color)
        self.target_radius = target_radius
        self.target_color = target_color
        self.position = starting_pos
        self.int_position = starting_pos
        self._pickle_init()

    def _pickle_init(self):
        self.sphere = Sphere(radius=self.target_radius,
                             color=self.target_color)
        self.graphics_models = [self.sphere]
        self.sphere.translate(*self.position)

    def move_to_position(self, new_pos):
        self.int_position = new_pos
        self.drive_to_new_pos()

    def drive_to_new_pos(self):
        raise NotImplementedError
Пример #2
0
class TestGraphics(Sequence, Window):
    status = dict(wait=dict(stop=None), )

    #initial state
    state = "wait"
    target_radius = 2.

    #create targets, cursor objects, initialize
    def __init__(self, *args, **kwargs):
        # Add the target and cursor locations to the task data to be saved to
        # file
        #super(TestGraphics, self).__init__(*args, **kwargs)
        super().__init__(*args, **kwargs)
        self.dtype = [('target', 'f', (3, )), ('cursor', 'f', (3, )),
                      (('target_index', 'i', (1, )))]
        self.target1 = Sphere(radius=self.target_radius, color=(1, 0, 0, .5))
        self.add_model(self.target1)
        self.target2 = Sphere(radius=self.target_radius, color=(1, 0, 0, 0.5))
        self.add_model(self.target2)

        # Initialize target location variable
        self.target_location = np.array([0.0, 0.0, 0.0])

    ##### HELPER AND UPDATE FUNCTIONS ####


#<<<<<<< HEAD

    def _get_renderer(self):
        return stereo.MirrorDisplay(self.window_size, self.fov, 1, 1024,
                                    self.screen_dist, self.iod)

    def _cycle(self):

        super()._cycle()

    #### STATE FUNCTIONS ####
    def _while_wait(self):
        #print("_while_wait")

        delta_movement = np.array([0, 0, 0.01])
        self.target_location += delta_movement

        self.target1.translate(self.target_location[0],
                               self.target_location[1],
                               self.target_location[2],
                               reset=True)
        self.requeue()
        self.draw_world()
        print('current target 1 position ' +
              np.array2string(self.target_location))
class BMIControlManipulatedFB(bmimultitasks.BMIControlMulti):

    feedback_rate = traits.Float(60, desc="Rate in hz that cursor position is updated on screen (best if factor of 60)")
    task_update_rate = traits.Float(60, desc="Rate in hz that decoded cursor position is updated within task (best if factor of 60)")
    ordered_traits = ['session_length', 'assist_level', 'assist_time', 'feedback_rate', 'task_update_rate']

    def __init__(self, *args, **kwargs):
        super(BMIControlManipulatedFB, self).__init__(*args, **kwargs)
        self.visible_cursor = Sphere(radius=self.cursor_radius, color=(1,1,1,1))
        self.add_model(self.visible_cursor)
        self.cursor_visible = True

    def init(self):
        self.dtype.append(('visible_cursor','f8',3))
        super(BMIControlManipulatedFB, self).init()
        
        self.feedback_num = int(60.0/self.feedback_rate)
        self.task_update_num = int(60.0/self.task_update_rate)
        self.loopcount = 0

    def update_cursor(self):
        ''' Update the cursor's location and visibility status.'''
        pt = self.get_cursor_location()
        prev = self.cursor_visible
        self.cursor_visible = False
        if prev != self.cursor_visible:
            self.show_object(self.cursor, show=False) #self.cursor.detach()
            self.requeue()
        #update the "real" cursor location only according to specified task update rate
        if self.loopcount%self.task_update_num==0:
            if pt is not None:
                self.move_cursor(pt)
        #update the visible cursor location only according to specified feedback rate
        if self.loopcount%self.feedback_num==0:
            loc = self.cursor.xfm.move
            self.visible_cursor.translate(*loc,reset=True)

    def _cycle(self):
        ''' Overwriting parent methods since this one works differently'''
        self.update_assist_level()
        self.task_data['assist_level'] = self.current_assist_level
        self.update_cursor()
        self.task_data['cursor'] = self.cursor.xfm.move.copy()
        self.task_data['target'] = self.target_location.copy()
        self.task_data['target_index'] = self.target_index
        self.task_data['visible_cursor'] = self.visible_cursor.xfm.move.copy()
        self.loopcount += 1
        #write to screen
        self.draw_world()
Пример #4
0
class VirtualKinChainWithToolLink(RobotArmGen2D):
    def _pickle_init(self):
        super(VirtualKinChainWithToolLink, self)._pickle_init()

        self.tool_tip_cursor = Sphere(radius=self.link_radii[-1] / 2,
                                      color=RED)
        self.tool_base_cursor = Sphere(radius=self.link_radii[-1] / 2,
                                       color=BLUE)

        self.graphics_models = [
            self.link_groups[0], self.tool_tip_cursor, self.tool_base_cursor
        ]

    def _update_link_graphics(self):
        super(VirtualKinChainWithToolLink, self)._update_link_graphics()

        joint_angles = self.calc_joint_angles()
        spatial_joint_pos = self.kin_chain.spatial_positions_of_joints(
            joint_angles)
        self.tool_tip_cursor.translate(*spatial_joint_pos[:, -1], reset=True)
        self.tool_base_cursor.translate(*spatial_joint_pos[:, -2], reset=True)
Пример #5
0
class ArmPlant(Window):
    '''
    This task creates a RobotArm object and allows it to move around the screen based on either joint or endpoint
    positions. There is a spherical cursor at the end of the arm. The links of the arm can be visible or hidden.
    '''

    background = (0, 0, 0, 1)

    arm_visible = traits.Bool(
        True,
        desc='Specifies whether entire arm is displayed or just endpoint')

    cursor_radius = traits.Float(.5, desc="Radius of cursor")
    cursor_color = (.5, 0, .5, 1)

    arm_class = traits.OptionsList(*plantlist,
                                   bmi3d_input_options=plantlist.keys())
    starting_pos = (5, 0, 5)

    def __init__(self, *args, **kwargs):
        super(ArmPlant, self).__init__(*args, **kwargs)
        self.cursor_visible = True

        # Initialize the arm
        self.arm = ik.test_3d
        self.arm_vis_prev = True

        if self.arm_class == 'CursorPlant':
            pass
        else:
            self.dtype.append(('joint_angles', 'f8', (self.arm.num_joints, )))
            self.dtype.append(('arm_visible', 'f8', (1, )))
            self.add_model(self.arm)

        ## Declare cursor
        self.dtype.append(('cursor', 'f8', (3, )))
        self.cursor = Sphere(radius=self.cursor_radius,
                             color=self.cursor_color)
        self.add_model(self.cursor)
        self.cursor.translate(*self.arm.get_endpoint_pos(), reset=True)

    def _cycle(self):
        '''
        Calls any update functions necessary and redraws screen. Runs 60x per second by default.
        '''
        ## Run graphics commands to show/hide the arm if the visibility has changed
        if self.arm_class != 'CursorPlant':
            if self.arm_visible != self.arm_vis_prev:
                self.arm_vis_prev = self.arm_visible
                self.show_object(self.arm, show=self.arm_visible)

        self.move_arm()
        self.update_cursor()
        if self.cursor_visible:
            self.task_data['cursor'] = self.cursor.xfm.move.copy()
        else:
            #if the cursor is not visible, write NaNs into cursor location saved in file
            self.task_data['cursor'] = np.array([np.nan, np.nan, np.nan])

        if self.arm_class != 'CursorPlant':
            if self.arm_visible:
                self.task_data['arm_visible'] = 1
            else:
                self.task_data['arm_visible'] = 0

        super(ArmPlant, self)._cycle()

    ## Functions to move the cursor using keyboard/mouse input
    def get_mouse_events(self):
        import pygame
        events = []
        for btn in pygame.event.get(
            (pygame.MOUSEBUTTONDOWN, pygame.MOUSEBUTTONUP)):
            events = events + [btn.button]
        return events

    def get_key_events(self):
        import pygame
        return pygame.key.get_pressed()

    def move_arm(self):
        '''
        allows use of keyboard keys to test movement of arm. Use QW/OP for joint movements, arrow keys for endpoint movements
        '''
        import pygame

        keys = self.get_key_events()
        joint_speed = (np.pi / 6) / 60
        hand_speed = .2

        x, y, z = self.arm.get_endpoint_pos()

        if keys[pygame.K_RIGHT]:
            x = x - hand_speed
            self.arm.set_endpoint_pos(np.array([x, 0, z]))
        if keys[pygame.K_LEFT]:
            x = x + hand_speed
            self.arm.set_endpoint_pos(np.array([x, 0, z]))
        if keys[pygame.K_DOWN]:
            z = z - hand_speed
            self.arm.set_endpoint_pos(np.array([x, 0, z]))
        if keys[pygame.K_UP]:
            z = z + hand_speed
            self.arm.set_endpoint_pos(np.array([x, 0, z]))

        if self.arm.num_joints == 2:
            xz, xy = self.get_arm_joints()
            e = np.array([xz[0], xy[0]])
            s = np.array([xz[1], xy[1]])

            if keys[pygame.K_q]:
                s = s - joint_speed
                self.set_arm_joints([e[0], s[0]], [e[1], s[1]])
            if keys[pygame.K_w]:
                s = s + joint_speed
                self.set_arm_joints([e[0], s[0]], [e[1], s[1]])
            if keys[pygame.K_o]:
                e = e - joint_speed
                self.set_arm_joints([e[0], s[0]], [e[1], s[1]])
            if keys[pygame.K_p]:
                e = e + joint_speed
                self.set_arm_joints([e[0], s[0]], [e[1], s[1]])

        if self.arm.num_joints == 4:
            jts = self.get_arm_joints()
            keyspressed = [
                keys[pygame.K_q], keys[pygame.K_w], keys[pygame.K_e],
                keys[pygame.K_r]
            ]
            for i in range(self.arm.num_joints):
                if keyspressed[i]:
                    jts[i] = jts[i] + joint_speed
                    self.set_arm_joints(jts)

    def get_cursor_location(self):
        return self.arm.get_endpoint_pos()

    def set_arm_endpoint(self, pt, **kwargs):
        self.arm.set_endpoint_pos(pt, **kwargs)

    def set_arm_joints(self, angle_xz, angle_xy):
        self.arm.set_intrinsic_coordinates(angle_xz, angle_xy)

    def get_arm_joints(self):
        return self.arm.get_intrinsic_coordinates()

    def update_cursor(self):
        '''
        Update the cursor's location and visibility status.
        '''
        pt = self.get_cursor_location()
        if pt is not None:
            self.move_cursor(pt)

    def move_cursor(self, pt):
        ''' Move the cursor object to the specified 3D location. '''
        if not hasattr(self.arm, 'endpt_cursor'):
            self.cursor.translate(*pt[:3], reset=True)
class ApproachAvoidanceTask(Sequence, Window):
    '''
    This is for a free-choice task with two targets (left and right) presented to choose from.  
    The position of the targets may change along the x-axis, according to the target generator, 
    and each target has a varying probability of reward, also according to the target generator.
    The code as it is written is for a joystick.  

    Notes: want target_index to only write once per trial.  if so, can make instructed trials random.  else, make new state for instructed trial.
    '''

    background = (0,0,0,1)
    shoulder_anchor = np.array([2., 0., -15.]) # Coordinates of shoulder anchor point on screen
    
    arm_visible = traits.Bool(True, desc='Specifies whether entire arm is displayed or just endpoint')
    
    cursor_radius = traits.Float(.5, desc="Radius of cursor")
    cursor_color = (.5,0,.5,1)

    joystick_method = traits.Float(1,desc="1: Normal velocity, 0: Position control")
    joystick_speed = traits.Float(20, desc="Speed of cursor")

    plant_type_options = plantlist.keys()
    plant_type = traits.OptionsList(*plantlist, bmi3d_input_options=plantlist.keys())
    starting_pos = (5, 0, 5)
    # window_size = (1280*2, 1024)
    window_size = traits.Tuple((1366*2, 768), desc='window size')
    

    status = dict(
        #wait = dict(start_trial="target", stop=None),
        wait = dict(start_trial="center", stop=None),
        center = dict(enter_center="hold_center", timeout="timeout_penalty", stop=None),
        hold_center = dict(leave_early_center = "hold_penalty",hold_center_complete="target", timeout="timeout_penalty", stop=None),
        target = dict(enter_targetL="hold_targetL", enter_targetH = "hold_targetH", timeout="timeout_penalty", stop=None),
        hold_targetR = dict(leave_early_R="hold_penalty", hold_complete="targ_transition"),
        hold_targetL = dict(leave_early_L="hold_penalty", hold_complete="targ_transition"),
        targ_transition = dict(trial_complete="check_reward",trial_abort="wait", trial_incomplete="center"),
        check_reward = dict(avoid="reward",approach="reward_and_airpuff"),
        timeout_penalty = dict(timeout_penalty_end="targ_transition"),
        hold_penalty = dict(hold_penalty_end="targ_transition"),
        reward = dict(reward_end="wait"),
        reward_and_airpuff = dict(reward_and_airpuff_end="wait"),
    )
    #
    target_color = (.5,1,.5,0)

    #initial state
    state = "wait"

    #create settable traits
    reward_time_avoid = traits.Float(.2, desc="Length of juice reward for avoid decision")
    reward_time_approach_min = traits.Float(.2, desc="Min length of juice for approach decision")
    reward_time_approach_max = traits.Float(.8, desc="Max length of juice for approach decision")
    target_radius = traits.Float(1.5, desc="Radius of targets in cm")
    block_length = traits.Float(100, desc="Number of trials per block")  
    
    hold_time = traits.Float(.5, desc="Length of hold required at targets")
    hold_penalty_time = traits.Float(1, desc="Length of penalty time for target hold error")
    timeout_time = traits.Float(10, desc="Time allowed to go between targets")
    timeout_penalty_time = traits.Float(1, desc="Length of penalty time for timeout error")
    max_attempts = traits.Int(10, desc='The number of attempts at a target before\
        skipping to the next one')
    session_length = traits.Float(0, desc="Time until task automatically stops. Length of 0 means no auto stop.")
    marker_num = traits.Int(14, desc="The index of the motiontracker marker to use for cursor position")
   
    arm_hide_rate = traits.Float(0.0, desc='If the arm is visible, specifies a percentage of trials where it will be hidden')
    target_index = 0 # Helper variable to keep track of whether trial is instructed (1 = 1 choice) or free-choice (2 = 2 choices)
    target_selected = 'L'   # Helper variable to indicate which target was selected
    tries = 0 # Helper variable to keep track of the number of failed attempts at a given trial.
    timedout = False    # Helper variable to keep track if transitioning from timeout_penalty
    reward_counter = 0.0
    cursor_visible = False # Determines when to hide the cursor.
    no_data_count = 0 # Counter for number of missing data frames in a row
    scale_factor = 3.0 #scale factor for converting hand movement to screen movement (1cm hand movement = 3.5cm cursor movement)
    starting_dist = 10.0 # starting distance from center target
    #color_targets = np.random.randint(2)
    color_targets = 1   # 0: yellow low, blue high; 1: blue low, yellow high
    stopped_center_hold = False   #keep track if center hold was released early
    
    limit2d = 1

    color1 = target_colors['purple']  			# approach color
    color2 = target_colors['lightsteelblue']  	# avoid color
    reward_color = target_colors['green'] 		# color of reward bar
    airpuff_color = target_colors['red']		# color of airpuff bar

    sequence_generators = ['colored_targets_with_probabilistic_reward','block_probabilistic_reward','colored_targets_with_randomwalk_reward','randomwalk_probabilistic_reward']
    
    def __init__(self, *args, **kwargs):
        super(ApproachAvoidanceTask, self).__init__(*args, **kwargs)
        self.cursor_visible = True

        # Add graphics models for the plant and targets to the window

        self.plant = plantlist[self.plant_type]
        self.plant_vis_prev = True

        # Add graphics models for the plant and targets to the window
        if hasattr(self.plant, 'graphics_models'):
            for model in self.plant.graphics_models:
                self.add_model(model)

        self.current_pt=np.zeros([3]) #keep track of current pt
        self.last_pt=np.zeros([3]) #kee
        ## Declare cursor
        #self.dtype.append(('cursor', 'f8', (3,)))
        if 0: #hasattr(self.arm, 'endpt_cursor'):
            self.cursor = self.arm.endpt_cursor
        else:
            self.cursor = Sphere(radius=self.cursor_radius, color=self.cursor_color)
            self.add_model(self.cursor)
            self.cursor.translate(*self.get_arm_endpoint(), reset=True) 

        ## Instantiate the targets. Target 1 is center target, Target H is target with high probability of reward, Target L is target with low probability of reward.
        self.target1 = Sphere(radius=self.target_radius, color=self.target_color)           # center target
        self.add_model(self.target1)
        self.targetR = Sphere(radius=self.target_radius, color=self.target_color)           # left target
        self.add_model(self.targetH)
        self.targetL = Sphere(radius=self.target_radius, color=self.target_color)           # right target
        self.add_model(self.targetL)

        ###STOPPED HERE: should define Rect target here and then adapt length during task. Also, 
        ### be sure to change all targetH instantiations to targetR.

        # Initialize target location variable. 
        self.target_location1 = np.array([0,0,0])
        self.target_locationR = np.array([-self.starting_dist,0,0])
        self.target_locationL = np.array([self.starting_dist,0,0])

        self.target1.translate(*self.target_location1, reset=True)
        self.targetH.translate(*self.target_locationR, reset=True)
        self.targetL.translate(*self.target_locationL, reset=True)

        # Initialize colors for high probability and low probability target.  Color will not change.
        self.targetH.color = self.color_targets*self.color1 + (1 - self.color_targets)*self.color2 # high is magenta if color_targets = 1, juicyorange otherwise
        self.targetL.color = (1 - self.color_targets)*self.color1 + self.color_targets*self.color2

        #set target colors 
        self.target1.color = (1,0,0,.5)      # center target red
        
        
        # Initialize target location variable
        self.target_location = np.array([0, 0, 0])

        # Declare any plant attributes which must be saved to the HDF file at the _cycle rate
        for attr in self.plant.hdf_attrs:
            self.add_dtype(*attr)  


    def init(self):
        self.add_dtype('targetR', 'f8', (3,))
        self.add_dtype('targetL','f8', (3,))
        self.add_dtype('reward_scheduleR','f8', (1,))
        self.add_dtype('reward_scheduleL','f8', (1,)) 
        self.add_dtype('target_index', 'i', (1,))
        super(ApproachAvoidanceTask, self).init()
        self.trial_allocation = np.zeros(1000)

    def _cycle(self):
        ''' Calls any update functions necessary and redraws screen. Runs 60x per second. '''

        ## Run graphics commands to show/hide the arm if the visibility has changed
        if self.plant_type != 'cursor_14x14':
            if self.arm_visible != self.arm_vis_prev:
                self.arm_vis_prev = self.arm_visible
                self.show_object(self.arm, show=self.arm_visible)

        self.move_arm()
        #self.move_plant()

        ## Save plant status to HDF file
        plant_data = self.plant.get_data_to_save()
        for key in plant_data:
            self.task_data[key] = plant_data[key]

        self.update_cursor()

        if self.plant_type != 'cursor_14x14':
            self.task_data['joint_angles'] = self.get_arm_joints()

        super(ApproachAvoidanceTask, self)._cycle()
        
    ## Plant functions
    def get_cursor_location(self):
        # arm returns it's position as if it was anchored at the origin, so have to translate it to the correct place
        return self.get_arm_endpoint()

    def get_arm_endpoint(self):
        return self.plant.get_endpoint_pos() 

    def set_arm_endpoint(self, pt, **kwargs):
        self.plant.set_endpoint_pos(pt, **kwargs) 

    def set_arm_joints(self, angles):
        self.arm.set_intrinsic_coordinates(angles)

    def get_arm_joints(self):
        return self.arm.get_intrinsic_coordinates()

    def update_cursor(self):
        '''
        Update the cursor's location and visibility status.
        '''
        pt = self.get_cursor_location()
        self.update_cursor_visibility()
        if pt is not None:
            self.move_cursor(pt)

    def move_cursor(self, pt):
        ''' Move the cursor object to the specified 3D location. '''
        # if not hasattr(self.arm, 'endpt_cursor'):
        self.cursor.translate(*pt[:3],reset=True)

    ##    


    ##### HELPER AND UPDATE FUNCTIONS ####

    def move_arm(self):
        ''' Returns the 3D coordinates of the cursor. For manual control, uses
        joystick data. If no joystick data available, returns None'''

        pt = self.joystick.get()
        if len(pt) > 0:
            pt = pt[-1][0]
            pt[0]=1-pt[0]; #Switch L / R axes
            calib = [0.497,0.517] #Sometimes zero point is subject to drift this is the value of the incoming joystick when at 'rest' 

            if self.joystick_method==0:                
                pos = np.array([(pt[0]-calib[0]), 0, calib[1]-pt[1]])
                pos[0] = pos[0]*36
                pos[2] = pos[2]*24
                self.current_pt = pos

            elif self.joystick_method==1:
                vel=np.array([(pt[0]-calib[0]), 0, calib[1]-pt[1]])
                epsilon = 2*(10**-2) #Define epsilon to stabilize cursor movement
                if sum((vel)**2) > epsilon:
                    self.current_pt=self.last_pt+self.joystick_speed*vel*(1/60) #60 Hz update rate, dt = 1/60
                else:
                    self.current_pt = self.last_pt

                if self.current_pt[0] < -25: self.current_pt[0] = -25
                if self.current_pt[0] > 25: self.current_pt[0] = 25
                if self.current_pt[-1] < -14: self.current_pt[-1] = -14
                if self.current_pt[-1] > 14: self.current_pt[-1] = 14

            self.set_arm_endpoint(self.current_pt)
            self.last_pt = self.current_pt.copy()

    def convert_to_cm(self, val):
        return val/10.0

    def update_cursor_visibility(self):
        ''' Update cursor visible flag to hide cursor if there has been no good data for more than 3 frames in a row'''
        prev = self.cursor_visible
        if self.no_data_count < 3:
            self.cursor_visible = True
            if prev != self.cursor_visible:
            	self.show_object(self.cursor, show=True)
            	self.requeue()
        else:
            self.cursor_visible = False
            if prev != self.cursor_visible:
            	self.show_object(self.cursor, show=False)
            	self.requeue()

    def calc_n_successfultrials(self):
        trialendtimes = np.array([state[1] for state in self.state_log if state[0]=='check_reward'])
        return len(trialendtimes)

    def calc_n_rewards(self):
        rewardtimes = np.array([state[1] for state in self.state_log if state[0]=='reward'])
        return len(rewardtimes)

    def calc_trial_num(self):
        '''Calculates the current trial count: completed + aborted trials'''
        trialtimes = [state[1] for state in self.state_log if state[0] in ['wait']]
        return len(trialtimes)-1

    def calc_targetH_num(self):
        '''Calculates the total number of times the high-value target was selected'''
        trialtimes = [state[1] for state in self.state_log if state[0] in ['hold_targetH']]
        return len(trialtimes) - 1

    def calc_rewards_per_min(self, window):
        '''Calculates the Rewards/min for the most recent window of specified number of seconds in the past'''
        rewardtimes = np.array([state[1] for state in self.state_log if state[0]=='reward'])
        if (self.get_time() - self.task_start_time) < window:
            divideby = (self.get_time() - self.task_start_time)/sec_per_min
        else:
            divideby = window/sec_per_min
        return np.sum(rewardtimes >= (self.get_time() - window))/divideby

    def calc_success_rate(self, window):
        '''Calculates the rewarded trials/initiated trials for the most recent window of specified length in sec'''
        trialtimes = np.array([state[1] for state in self.state_log if state[0] in ['reward', 'timeout_penalty', 'hold_penalty']])
        rewardtimes = np.array([state[1] for state in self.state_log if state[0]=='reward'])
        if len(trialtimes) == 0:
            return 0.0
        else:
            return float(np.sum(rewardtimes >= (self.get_time() - window)))/np.sum(trialtimes >= (self.get_time() - window))

    def update_report_stats(self):
        '''Function to update any relevant report stats for the task. Values are saved in self.reportstats,
        an ordered dictionary. Keys are strings that will be displayed as the label for the stat in the web interface,
        values can be numbers or strings. Called every time task state changes.'''
        super(ApproachAvoidanceTask, self).update_report_stats()
        self.reportstats['Trial #'] = self.calc_trial_num()
        self.reportstats['Reward/min'] = np.round(self.calc_rewards_per_min(120),decimals=2)
        self.reportstats['High-value target selections'] = self.calc_targetH_num()
        #self.reportstats['Success rate'] = str(np.round(self.calc_success_rate(120)*100.0,decimals=2)) + '%'
        start_time = self.state_log[0][1]
        rewardtimes=np.array([state[1] for state in self.state_log if state[0]=='reward'])
        if len(rewardtimes):
            rt = rewardtimes[-1]-start_time
        else:
            rt= np.float64("0.0")

        sec = str(np.int(np.mod(rt,60)))
        if len(sec) < 2:
            sec = '0'+sec
        self.reportstats['Time Of Last Reward'] = str(np.int(np.floor(rt/60))) + ':' + sec



    #### TEST FUNCTIONS ####
    def _test_enter_center(self, ts):
        #return true if the distance between center of cursor and target is smaller than the cursor radius

        d = np.sqrt((self.cursor.xfm.move[0]-self.target_location1[0])**2 + (self.cursor.xfm.move[1]-self.target_location1[1])**2 + (self.cursor.xfm.move[2]-self.target_location1[2])**2)
        #print 'TARGET SELECTED', self.target_selected
        return d <= self.target_radius - self.cursor_radius

    def _test_enter_targetL(self, ts):
        if self.target_index == 1 and self.LH_target_on[0]==0:
            #return false if instructed trial and this target is not on
            return False
        else:
            #return true if the distance between center of cursor and target is smaller than the cursor radius
            d = np.sqrt((self.cursor.xfm.move[0]-self.target_locationL[0])**2 + (self.cursor.xfm.move[1]-self.target_locationL[1])**2 + (self.cursor.xfm.move[2]-self.target_locationL[2])**2)
            self.target_selected = 'L'
            #print 'TARGET SELECTED', self.target_selected
            return d <= self.target_radius - self.cursor_radius

    def _test_enter_targetH(self, ts):
        if self.target_index ==1 and self.LH_target_on[1]==0:
            return False
        else:
            #return true if the distance between center of cursor and target is smaller than the cursor radius
            d = np.sqrt((self.cursor.xfm.move[0]-self.target_locationH[0])**2 + (self.cursor.xfm.move[1]-self.target_locationH[1])**2 + (self.cursor.xfm.move[2]-self.target_locationH[2])**2)
            self.target_selected = 'H'
            #print 'TARGET SELECTED', self.target_selected
            return d <= self.target_radius - self.cursor_radius
    def _test_leave_early_center(self, ts):
        # return true if cursor moves outside the exit radius (gives a bit of slack around the edge of target once cursor is inside)
        d = np.sqrt((self.cursor.xfm.move[0]-self.target_location1[0])**2 + (self.cursor.xfm.move[1]-self.target_location1[1])**2 + (self.cursor.xfm.move[2]-self.target_location1[2])**2)
        rad = self.target_radius - self.cursor_radius
        return d > rad

    def _test_leave_early_L(self, ts):
        # return true if cursor moves outside the exit radius (gives a bit of slack around the edge of target once cursor is inside)
        d = np.sqrt((self.cursor.xfm.move[0]-self.target_locationL[0])**2 + (self.cursor.xfm.move[1]-self.target_locationL[1])**2 + (self.cursor.xfm.move[2]-self.target_locationL[2])**2)
        rad = self.target_radius - self.cursor_radius
        return d > rad

    def _test_leave_early_H(self, ts):
        # return true if cursor moves outside the exit radius (gives a bit of slack around the edge of target once cursor is inside)
        d = np.sqrt((self.cursor.xfm.move[0]-self.target_locationH[0])**2 + (self.cursor.xfm.move[1]-self.target_locationH[1])**2 + (self.cursor.xfm.move[2]-self.target_locationH[2])**2)
        rad = self.target_radius - self.cursor_radius
        return d > rad

    def _test_hold_center_complete(self, ts):
        return ts>=self.hold_time
    
    def _test_hold_complete(self, ts):
        return ts>=self.hold_time

    def _test_timeout(self, ts):
        return ts>self.timeout_time

    def _test_timeout_penalty_end(self, ts):
        return ts>self.timeout_penalty_time

    def _test_hold_penalty_end(self, ts):
        return ts>self.hold_penalty_time

    def _test_trial_complete(self, ts):
        #return self.target_index==self.chain_length-1
        return not self.timedout

    def _test_trial_incomplete(self, ts):
        return (not self._test_trial_complete(ts)) and (self.tries<self.max_attempts)

    def _test_trial_abort(self, ts):
        return (not self._test_trial_complete(ts)) and (self.tries==self.max_attempts)

    def _test_yes_reward(self,ts):
        if self.target_selected == 'H':
            #reward_assigned = self.targs[0,1]
            reward_assigned = self.rewardH
        else:
            #reward_assigned = self.targs[1,1]
            reward_assigned = self.rewardL
        if self.reward_SmallLarge==1:
            self.reward_time = reward_assigned*self.reward_time_large + (1 - reward_assigned)*self.reward_time_small   # update reward time if using Small/large schedule
            reward_assigned = 1    # always rewarded
        return bool(reward_assigned)

    def _test_no_reward(self,ts):
        if self.target_selected == 'H':
            #reward_assigned = self.targs[0,1]
            reward_assigned = self.rewardH
        else:
            #reward_assigned = self.targs[1,1]
            reward_assigned = self.rewardL
        if self.reward_SmallLarge==True:
            self.reward_time = reward_assigned*self.reward_time_large + (1 - reward_assigned)*self.reward_time_small   # update reward time if using Small/large schedule
            reward_assigned = 1    # always rewarded
        return bool(not reward_assigned)

    def _test_reward_end(self, ts):
        time_ended = (ts > self.reward_time)
        self.reward_counter = self.reward_counter + 1
        return time_ended

    def _test_stop(self, ts):
        if self.session_length > 0 and (time.time() - self.task_start_time) > self.session_length:
            self.end_task()
        return self.stop

    #### STATE FUNCTIONS ####

    def show_object(self, obj, show=False):
        '''
        Show or hide an object
        '''
        if show:
            obj.attach()
        else:
            obj.detach()
        self.requeue()


    def _start_wait(self):
        super(ApproachAvoidanceTask, self)._start_wait()
        self.tries = 0
        self.target_index = 0     # indicator for instructed or free-choice trial
        #hide targets
        self.show_object(self.target1, False)
        self.show_object(self.targetL, False)
        self.show_object(self.targetH, False)


        #get target positions and reward assignments for this trial
        self.targs = self.next_trial
        if self.plant_type != 'cursor_14x14' and np.random.rand() < self.arm_hide_rate:
            self.arm_visible = False
        else:
            self.arm_visible = True
        #self.chain_length = self.targs.shape[0] #Number of sequential targets in a single trial

        #self.task_data['target'] = self.target_locationH.copy()
        assign_reward = np.random.randint(0,100,size=2)
        self.rewardH = np.greater(self.targs[0,1],assign_reward[0])
        #print 'high value target reward prob', self.targs[0,1]
        self.rewardL = np.greater(self.targs[1,1],assign_reward[1])

        
        #print 'TARGET GENERATOR', self.targs[0,]
        self.task_data['targetH'] = self.targs[0,].copy()
        self.task_data['reward_scheduleH'] = self.rewardH.copy()
        self.task_data['targetL'] = self.targs[1,].copy()
        self.task_data['reward_scheduleL'] = self.rewardL.copy()
        
        self.requeue()

    def _start_center(self):

        #self.target_index += 1

        
        self.show_object(self.target1, True)
        self.show_object(self.cursor, True)
        
        # Third argument in self.targs determines if target is on left or right
        # First argument in self.targs determines if location is offset to farther distances
        offsetH = (2*self.targs[0,2] - 1)*(self.starting_dist + self.location_offset_allowed*self.targs[0,0]*4.0)
        moveH = np.array([offsetH,0,0]) 
        offsetL = (2*self.targs[1,2] - 1)*(self.starting_dist + self.location_offset_allowed*self.targs[1,0]*4.0)
        moveL = np.array([offsetL,0,0])

        self.targetL.translate(*moveL, reset=True) 
        #self.targetL.move_to_position(*moveL, reset=True)           
        ##self.targetL.translate(*self.targs[self.target_index], reset=True)
        self.show_object(self.targetL, True)
        self.target_locationL = self.targetL.xfm.move

        self.targetH.translate(*moveH, reset=True)
        #self.targetR.move_to_position(*moveR, reset=True)
        ##self.targetR.translate(*self.targs[self.target_index], reset=True)
        self.show_object(self.targetH, True)
        self.target_locationH = self.targetH.xfm.move


        # Insert instructed trials within free choice trials
        if self.trial_allocation[self.calc_trial_num()] == 1:
        #if (self.calc_trial_num() % 10) < (self.percentage_instructed_trials/10):
            self.target_index = 1    # instructed trial
            leftright_coinflip = np.random.randint(0,2)
            if leftright_coinflip == 0:
                self.show_object(self.targetL, False)
                self.LH_target_on = (0, 1)
            else:
                self.show_object(self.targetH, False)
                self.LR_coinflip = 0
                self.LH_target_on = (1, 0)
        else:
            self.target_index = 2   # free-choice trial

        self.cursor_visible = True
        self.task_data['target_index'] = self.target_index
        self.requeue()

    def _start_target(self):

    	#self.target_index += 1

        #move targets to current location and set location attribute.  Target1 (center target) position does not change.                    
        
        self.show_object(self.target1, False)
        #self.target_location1 = self.target1.xfm.move
        self.show_object(self.cursor, True)
       
        self.update_cursor()
        self.requeue()

    def _start_hold_center(self):
        self.show_object(self.target1, True)
        self.timedout = False
        self.requeue()

    def _start_hold_targetL(self):
        #make next target visible unless this is the final target in the trial
        #if 1 < self.chain_length:
            #self.targetL.translate(*self.targs[self.target_index+1], reset=True)
         #   self.show_object(self.targetL, True)
         #   self.requeue()
        self.show_object(self.targetL, True)
        self.timedout = False
        self.requeue()

    def _start_hold_targetH(self):
        #make next target visible unless this is the final target in the trial
        #if 1 < self.chain_length:
            #self.targetR.translate(*self.targs[self.target_index+1], reset=True)
         #   self.show_object(self.targetR, True)
          #  self.requeue()
        self.show_object(self.targetH, True)
        self.timedout = False
        self.requeue()

    def _end_hold_center(self):
        self.target1.radius = 0.7*self.target_radius # color target green
    
    def _end_hold_targetL(self):
        self.targetL.color = (0,1,0,0.5)    # color target green

    def _end_hold_targetH(self):
        self.targetH.color = (0,1,0,0.5)    # color target green

    def _start_hold_penalty(self):
    	#hide targets
        self.show_object(self.target1, False)
        self.show_object(self.targetL, False)
        self.show_object(self.targetH, False)
        self.timedout = True
        self.requeue()
        self.tries += 1
        #self.target_index = -1
    
    def _start_timeout_penalty(self):
    	#hide targets
        self.show_object(self.target1, False)
        self.show_object(self.targetL, False)
        self.show_object(self.targetH, False)
        self.timedout = True
        self.requeue()
        self.tries += 1
        #self.target_index = -1


    def _start_targ_transition(self):
        #hide targets

        self.show_object(self.target1, False)
        self.show_object(self.targetL, False)
        self.show_object(self.targetH, False)
        self.requeue()

    def _start_check_reward(self):
        #hide targets
        self.show_object(self.target1, False)
        self.show_object(self.targetL, False)
        self.show_object(self.targetH, False)
        self.requeue()

    def _start_reward(self):
        #super(ApproachAvoidanceTask, self)._start_reward()
        if self.target_selected == 'L':
            self.show_object(self.targetL, True)  
            #reward_assigned = self.targs[1,1]
        else:
            self.show_object(self.targetH, True)
            #reward_assigned = self.targs[0,1]
        #self.reward_counter = self.reward_counter + float(reward_assigned)
        self.requeue()

    @staticmethod
    def colored_targets_with_probabilistic_reward(length=1000, boundaries=(-18,18,-10,10,-15,15),reward_high_prob=80,reward_low_prob=40):

        """
        Generator should return array of ntrials x 2 x 3. The second dimension is for each target.
        For example, first is the target with high probability of reward, and the second 
        entry is for the target with low probability of reward.  The third dimension holds three variables indicating 
        position offset (yes/no), reward probability (fixed in this case), and location (binary returned where the
        ouput indicates either left or right).

        UPDATE: CHANGED SO THAT THE SECOND DIMENSION CARRIES THE REWARD PROBABILITY RATHER THAN THE REWARD SCHEDULE
        """

        position_offsetH = np.random.randint(2,size=(1,length))
        position_offsetL = np.random.randint(2,size=(1,length))
        location_int = np.random.randint(2,size=(1,length))

        # coin flips for reward schedules, want this to be elementwise comparison
        #assign_rewardH = np.random.randint(0,100,size=(1,length))
        #assign_rewardL = np.random.randint(0,100,size=(1,length))
        high_prob = reward_high_prob*np.ones((1,length))
        low_prob = reward_low_prob*np.ones((1,length))
        
        #reward_high = np.greater(high_prob,assign_rewardH)
        #reward_low = np.greater(low_prob,assign_rewardL)

        pairs = np.zeros([length,2,3])
        pairs[:,0,0] = position_offsetH
        #pairs[:,0,1] = reward_high
        pairs[:,0,1] = high_prob
        pairs[:,0,2] = location_int

        pairs[:,1,0] = position_offsetL
        #pairs[:,1,1] = reward_low
        pairs[:,1,1] = low_prob
        pairs[:,1,2] = 1 - location_int

        return pairs

    @staticmethod
    def block_probabilistic_reward(length=1000, boundaries=(-18,18,-10,10,-15,15),reward_high_prob=80,reward_low_prob=40):
        pairs = colored_targets_with_probabilistic_reward(length=length, boundaries=boundaries,reward_high_prob=reward_high_prob,reward_low_prob=reward_low_prob)
        return pairs

    @staticmethod
    def colored_targets_with_randomwalk_reward(length=1000,reward_high_prob=80,reward_low_prob=40,reward_high_span = 20, reward_low_span = 20,step_size_mean = 0, step_size_var = 1):

        """
        Generator should return array of ntrials x 2 x 3. The second dimension is for each target.
        For example, first is the target with high probability of reward, and the second 
        entry is for the target with low probability of reward.  The third dimension holds three variables indicating 
        position offset (yes/no), reward probability, and location (binary returned where the
        ouput indicates either left or right).  The variables reward_high_span and reward_low_span indicate the width
        of the range that the high or low reward probability are allowed to span respectively, e.g. if reward_high_prob
        is 80 and reward_high_span is 20, then the reward probability for the high value target will be bounded
        between 60 and 100 percent.
        """

        position_offsetH = np.random.randint(2,size=(1,length))
        position_offsetL = np.random.randint(2,size=(1,length))
        location_int = np.random.randint(2,size=(1,length))

        # define variables for increments: amount of increment and in which direction (i.e. increasing or decreasing)
        assign_rewardH = np.random.randn(1,length)
        assign_rewardL = np.random.randn(1,length)
        assign_rewardH_direction = np.random.randn(1,length)
        assign_rewardL_direction = np.random.randn(1,length)

        r_0_high = reward_high_prob
        r_0_low = reward_low_prob
        r_lowerbound_high = r_0_high - (reward_high_span/2)
        r_upperbound_high = r_0_high + (reward_high_span/2)
        r_lowerbound_low = r_0_low - (reward_low_span/2)
        r_upperbound_low = r_0_low + (reward_low_span/2)
        
        reward_high = np.zeros(length)
        reward_low = np.zeros(length)
        reward_high[0] = r_0_high
        reward_low[0] = r_0_low

        eps_high = assign_rewardH*step_size_mean + [2*(assign_rewardH_direction > 0) - 1]*step_size_var
        eps_low = assign_rewardL*step_size_mean + [2*(assign_rewardL_direction > 0) - 1]*step_size_var

        eps_high = eps_high.ravel()
        eps_low = eps_low.ravel()

        for i in range(1,length):
            '''
            assign_rewardH_direction = np.random.randn(1)
            assign_rewardL_direction = np.random.randn(1)
            assign_rewardH = np.random.randn(1)
            if assign_rewardH_direction[i-1,] < 0:
                eps_high = step_size_mean*assign_rewardH[i-1] - step_size_var
            else:
                eps_high = step_size_mean*assign_rewardH[i-1] + step_size_var

            if assign_rewardL_direction[i] < 0:
                eps_low = step_size_mean*assign_rewardL[i] - step_size_var
            else:
                eps_low = step_size_mean*assign_rewardL[i] + step_size_var
            '''
            reward_high[i] = reward_high[i-1] + eps_high[i-1]
            reward_low[i] = reward_low[i-1] + eps_low[i-1]

            reward_high[i] = (r_lowerbound_high < reward_high[i] < r_upperbound_high)*reward_high[i] + (r_lowerbound_high > reward_high[i])*(r_lowerbound_high+ eps_high[i-1]) + (r_upperbound_high < reward_high[i])*(r_upperbound_high - eps_high[i-1])
            reward_low[i] = (r_lowerbound_low < reward_low[i] < r_upperbound_low)*reward_low[i] + (r_lowerbound_low > reward_low[i])*(r_lowerbound_low+ eps_low[i-1]) + (r_upperbound_low < reward_low[i])*(r_upperbound_low - eps_low[i-1])

        pairs = np.zeros([length,2,3])
        pairs[:,0,0] = position_offsetH
        pairs[:,0,1] = reward_high
        pairs[:,0,2] = location_int

        pairs[:,1,0] = position_offsetL
        pairs[:,1,1] = reward_low
        pairs[:,1,2] = 1 - location_int

        return pairs

    @staticmethod
    def randomwalk_probabilistic_reward(length=1000,reward_high_prob=80,reward_low_prob=40,reward_high_span = 20, reward_low_span = 20,step_size_mean = 0, step_size_var = 1):
        pairs = colored_targets_with_randomwalk_reward(length=length,reward_high_prob=reward_high_prob,reward_low_prob=reward_low_prob,reward_high_span = reward_high_span, reward_low_span = reward_low_span,step_size_mean = step_size_mean, step_size_var = step_size_var)
        return pairs
class ManualControl2(ManualControl):
    status = dict(wait=dict(start_trial="origin", stop=None),
                  origin=dict(enter_target="origin_hold", stop=None),
                  origin_hold=dict(leave_early="hold_penalty",
                                   hold="terminus"),
                  terminus=dict(timeout="timeout_penalty",
                                enter_target="terminus_hold",
                                stop=None),
                  timeout_penalty=dict(penalty_end="pre_target_change"),
                  terminus_hold=dict(leave_early="hold_penalty",
                                     hold="terminus2"),
                  terminus2=dict(timeout="timeout_penalty",
                                 enter_target="terminus2_hold",
                                 stop=None),
                  terminus2_hold=dict(leave_early="hold_penalty",
                                      hold="reward"),
                  reward=dict(reward_end="target_change"),
                  hold_penalty=dict(penalty_end="pre_target_change"),
                  pre_target_change=dict(tried_enough='target_change',
                                         not_tried_enough='wait'),
                  target_change=dict(target_change_end='wait'))

    scale_factor = 2
    cursor_radius = .4

    def __init__(self, *args, **kwargs):
        # Add the 2nd terminus target
        super(ManualControl2, self).__init__(*args, **kwargs)
        self.terminus2_target = Sphere(radius=self.terminus_size,
                                       color=(1, 0, 0, .5))
        self.add_model(self.terminus2_target)

    def _start_wait(self):
        #set target colors
        self.terminus2_target.color = (1, 0, 0, .5)
        #hide targets from previous trial
        self.terminus2_target.detach()
        super(ManualControl2, self)._start_wait()

    def _test_enter_target(self, ts):
        #get the current cursor location and target location, return true if center of cursor is inside target (has to be close enough to center to be fully inside)
        if self.state == "terminus2":
            c = self.cursor.xfm.move
            t = self.terminus2_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            return d <= self.terminus_target.radius - self.cursor.radius
        else:
            return super(ManualControl2, self)._test_enter_target(ts)

    def _start_origin(self):
        if self.tries == 0:
            #retrieve location of next terminus target
            t2 = self.next_trial.T[2]
            #move target to correct location
            self.terminus2_target.translate(*t2, reset=True)
        super(ManualControl2, self)._start_origin()

    def _start_terminus_hold(self):
        self.terminus2_target.color = (1, 0, 0, 0.5)
        self.terminus2_target.attach()
        self.requeue()

    def _start_timeout_penalty(self):
        #hide targets and fixation point
        self.terminus2_target.detach()
        super(ManualControl2, self)._start_timeout_penalty()

    def _start_terminus2(self):
        self.terminus_target.detach()
        self.requeue()

    def _test_leave_early(self, ts):
        if self.state == "terminus2_hold":
            c = self.cursor.xfm.move
            t = self.terminus2_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            rad = self.terminus_target.radius - self.cursor.radius
            return d > rad * self.exit_radius
        else:
            return super(ManualControl2, self)._test_leave_early(ts)

    def _while_terminus2(self):
        self.update_cursor()

    def _while_terminus2_hold(self):
        self.update_cursor()

    def _end_terminus2_hold(self):
        self.terminus2_target.color = (0, 1, 0, 0.5)

    def _start_hold_penalty(self):
        self.terminus2_target.detach()
        super(ManualControl2, self)._start_hold_penalty()

    def _start_timeout_penalty(self):
        self.terminus2_target.detach()
        super(ManualControl2, self)._start_timeout_penalty()

    def update_target_location(self):
        # Determine the task target for assist/decoder adaptation purposes (convert
        # units from cm to mm for decoder)
        # TODO - decide what to do with y location, target_xz ignores it!
        if self.state == 'terminus2' or self.state == 'terminus2_hold':
            self.location = 10 * self.terminus2_target.xfm.move
            self.target_xz = np.array([self.location[0], self.location[2]])
        super(ManualControl2, self).update_target_location()
class ManualControl(TargetCapture):
    status = dict(wait=dict(start_trial="origin", stop=None),
                  origin=dict(enter_target="origin_hold", stop=None),
                  origin_hold=dict(leave_early="hold_penalty",
                                   hold="terminus"),
                  terminus=dict(timeout="timeout_penalty",
                                enter_target="terminus_hold",
                                stop=None),
                  timeout_penalty=dict(penalty_end="pre_target_change"),
                  terminus_hold=dict(leave_early="hold_penalty",
                                     hold="reward"),
                  reward=dict(reward_end="target_change"),
                  hold_penalty=dict(penalty_end="pre_target_change"),
                  pre_target_change=dict(tried_enough='target_change',
                                         not_tried_enough='wait'),
                  target_change=dict(target_change_end='wait'))

    #create settable traits
    terminus_size = traits.Float(1, desc="Radius of terminus targets")
    terminus_hold_time = traits.Float(
        2, desc="Length of hold required at terminus")
    timeout_time = traits.Float(
        10, desc="Time allowed to go between origin and terminus")
    timeout_penalty_time = traits.Float(
        3, desc="Length of penalty time for timeout error")

    #create fixation point, targets, cursor objects, initialize
    def __init__(self, *args, **kwargs):
        # Add the target and cursor locations to the task data to be saved to
        # file
        self.dtype = [('target', 'f', (3, )), ('cursor', 'f', (3, ))]
        super(ManualControl, self).__init__(*args, **kwargs)
        self.terminus_target = Sphere(radius=self.terminus_size,
                                      color=(1, 0, 0, .5))
        self.add_model(self.terminus_target)
        # Initialize target location variables
        self.location = np.array([0, 0, 0])
        self.target_xz = np.array([0, 0])

    def _start_wait(self):
        super(ManualControl, self)._start_wait()
        #set target colors
        self.terminus_target.color = (1, 0, 0, .5)
        #hide targets from previous trial
        self.show_terminus(False)

    def _start_origin(self):
        if self.tries == 0:
            #retrieve location of next terminus target
            t = self.next_trial.T[1]
            #move target to correct location
            self.terminus_target.translate(*t, reset=True)
        super(ManualControl, self)._start_origin()

    def _start_origin_hold(self):
        #make terminus target visible
        self.show_terminus(True)

    def show_terminus(self, show=False):
        if show:
            self.terminus_target.attach()
        else:
            self.terminus_target.detach()
        self.requeue()

    def _start_terminus(self):
        self.show_origin(False)

    def _end_terminus_hold(self):
        self.terminus_target.color = (0, 1, 0, 0.5)

    def _start_hold_penalty(self):
        #hide targets
        super(ManualControl, self)._start_hold_penalty()
        self.show_terminus(False)

    def _start_timeout_penalty(self):
        #hide targets and fixation point
        self.tries += 1
        self.show_terminus(False)

    def _start_reward(self):
        pass

    def _test_enter_target(self, ts):
        #get the current cursor location and target location, return true if center of cursor is inside target (has to be close enough to center to be fully inside)
        if self.state == "origin":
            c = self.cursor.xfm.move
            t = self.origin_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            return d <= self.origin_target.radius - self.cursor.radius
        if self.state == "terminus":
            c = self.cursor.xfm.move
            t = self.terminus_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            return d <= self.terminus_target.radius - self.cursor.radius

    def _test_leave_early(self, ts):
        if self.state == "origin_hold":
            c = self.cursor.xfm.move
            t = self.origin_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            rad = self.origin_target.radius - self.cursor.radius
            return d > rad * self.exit_radius
        if self.state == "terminus_hold":
            c = self.cursor.xfm.move
            t = self.terminus_target.xfm.move
            d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
            rad = self.terminus_target.radius - self.cursor.radius
            return d > rad * self.exit_radius

    def _test_hold(self, ts):
        if self.state == "origin_hold":
            return ts >= self.origin_hold_time
        else:
            return ts >= self.terminus_hold_time

    def _test_timeout(self, ts):
        return ts > self.timeout_time

    def _test_penalty_end(self, ts):
        if self.state == "timeout_penalty":
            return ts > self.timeout_penalty_time
        if self.state == "fixation_penalty":
            return ts > self.fixation_penalty_time
        else:
            return ts > self.hold_penalty_time

    def _while_terminus(self):
        self.update_cursor()

    def _while_terminus_hold(self):
        self.update_cursor()

    def _while_timeout_penalty(self):
        self.update_cursor()

    def update_target_location(self):
        # Determine the task target for assist/decoder adaptation purposes (convert
        # units from cm to mm for decoder)
        # TODO - decide what to do with y location, target_xz ignores it!
        if self.state == 'origin' or self.state == 'origin_hold':
            self.location = 10 * self.origin_target.xfm.move
            self.target_xz = np.array([self.location[0], self.location[2]])
        elif self.state == 'terminus' or self.state == 'terminus_hold':
            self.location = 10 * self.terminus_target.xfm.move
            self.target_xz = np.array([self.location[0], self.location[2]])
        self.task_data['target'] = self.location[:3]

    def update_cursor(self):
        self.update_target_location()
        super(ManualControl, self).update_cursor()
        self.task_data['cursor'] = self.cursor.xfm.move.copy()
class TargetCapture(Sequence, FixationTraining):
    status = dict(wait=dict(start_trial="origin", stop=None),
                  origin=dict(enter_target="origin_hold", stop=None),
                  origin_hold=dict(leave_early="hold_penalty", hold="reward"),
                  reward=dict(reward_end="target_change"),
                  hold_penalty=dict(penalty_end="pre_target_change"),
                  pre_target_change=dict(tried_enough='target_change',
                                         not_tried_enough='wait'),
                  target_change=dict(target_change_end='wait'))

    #create settable traits
    origin_size = traits.Float(1, desc="Radius of origin targets"
                               )  #add error if target is smaller than cursor
    origin_hold_time = traits.Float(2,
                                    desc="Length of hold required at origin")
    hold_penalty_time = traits.Float(
        3, desc="Length of penalty time for target hold error")
    exit_radius = 1.5  #Multiplier for the actual radius which is considered 'exiting' the target

    no_data_count = 0
    tries = 0
    scale_factor = 3.5  #scale factor for converting hand movement to screen movement (1cm hand movement = 3.5cm cursor movement)
    cursor_radius = .5

    def __init__(self, *args, **kwargs):
        super(TargetCapture, self).__init__(*args, **kwargs)
        self.origin_target = Sphere(radius=self.origin_size,
                                    color=(1, 0, 0, .5))
        self.add_model(self.origin_target)
        self.cursor = Sphere(radius=self.cursor_radius, color=(.5, 0, .5, 1))
        self.add_model(self.cursor)

    def _start_wait(self):
        super(TargetCapture, self)._start_wait()
        #set target color
        self.origin_target.color = (1, 0, 0, .5)
        #hide target from previous trial
        self.show_origin(False)

    def show_origin(self, show=False):
        if show:
            self.origin_target.attach()
        else:
            self.origin_target.detach()
        self.requeue()

    def _start_origin(self):
        if self.tries == 0:
            #retrieve location of next origin target
            o = self.next_trial.T[0]
            #move target to correct location
            self.origin_target.translate(*o, reset=True)
        #make visible
        self.show_origin(True)

    def _end_origin_hold(self):
        #change target color
        self.origin_target.color = (0, 1, 0, 0.5)

    def _start_hold_penalty(self):
        self.tries += 1
        #hide target
        self.show_origin(False)

    def _start_target_change(self):
        self.tries = 0

    def _test_target_change_end(self, ts):
        return True

    def _test_enter_target(self, ts):
        #get the current cursor location and target location, return true if center of cursor is inside target (has to be close enough to center to be fully inside)
        c = self.cursor.xfm.move
        t = self.origin_target.xfm.move
        d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
        return d <= self.origin_target.radius - self.cursor.radius

    def _test_leave_early(self, ts):
        c = self.cursor.xfm.move
        t = self.origin_target.xfm.move
        d = np.sqrt((c[0] - t[0])**2 + (c[1] - t[1])**2 + (c[2] - t[2])**2)
        rad = self.origin_target.radius - self.cursor.radius
        return d > rad * self.exit_radius

    def _test_hold(self, ts):
        return ts >= self.origin_hold_time

    def _test_penalty_end(self, ts):
        if self.state == "fixation_penalty":
            return ts > self.fixation_penalty_time
        else:
            return ts > self.hold_penalty_time

    def _test_tried_enough(self, ts):
        return self.tries == 3

    def _test_not_tried_enough(self, ts):
        return self.tries != 3

    def _update(self, pt):
        if len(pt) > 0:
            self.cursor.translate(*pt[:3], reset=True)
        #if no data has come in for at least 3 frames, hide cursor
        elif self.no_data_count > 2:
            self.no_data_count += 1
            self.cursor.detach()
            self.requeue()
        else:
            self.no_data_count += 1

    def update_cursor(self):
        #get data from 1st marker on motion tracker- take average of all data points since last poll
        pt = self.motiondata.get()
        if len(pt) > 0:
            pt = pt[:, 14, :]
            # NOTE!!! The marker on the hand was changed from #0 to #14 on
            # 5/19/13 after LED #0 broke. All data files saved before this date
            # have LED #0 controlling the cursor.
            conds = pt[:, 3]
            inds = np.nonzero((conds >= 0) & (conds != 4))
            if len(inds[0]) > 0:
                pt = pt[inds[0], :3]
                #convert units from mm to cm and scale to desired amount
                pt = pt.mean(0) * .1 * self.scale_factor
                #ignore y direction
                pt[1] = 0
                #move cursor to marker location
                self._update(pt)
            else:
                self.no_data_count += 1
        else:
            self.no_data_count += 1
        #write to screen
        self.draw_world()

    def calc_trial_num(self):
        '''Calculates the current trial count'''
        trialtimes = [
            state[1] for state in self.state_log
            if state[0] in ['reward', 'timeout_penalty', 'hold_penalty']
        ]
        return len(trialtimes)

    def calc_rewards_per_min(self, window):
        '''Calculates the Rewards/min for the most recent window of specified number of seconds in the past'''
        rewardtimes = np.array(
            [state[1] for state in self.state_log if state[0] == 'reward'])
        if (self.get_time() - self.task_start_time) < window:
            divideby = (self.get_time() - self.task_start_time) / 60.0
        else:
            divideby = window / 60.0
        return np.sum(rewardtimes >= (self.get_time() - window)) / divideby

    def calc_success_rate(self, window):
        '''Calculates the rewarded trials/initiated trials for the most recent window of specified length in sec'''
        trialtimes = np.array([
            state[1] for state in self.state_log
            if state[0] in ['reward', 'timeout_penalty', 'hold_penalty']
        ])
        rewardtimes = np.array(
            [state[1] for state in self.state_log if state[0] == 'reward'])
        if len(trialtimes) == 0:
            return 0.0
        else:
            return float(
                np.sum(rewardtimes >= (self.get_time() - window))) / np.sum(
                    trialtimes >= (self.get_time() - window))

    def update_report_stats(self):
        '''Function to update any relevant report stats for the task. Values are saved in self.reportstats,
        an ordered dictionary. Keys are strings that will be displayed as the label for the stat in the web interface,
        values can be numbers or strings. Called every time task state changes.'''
        super(TargetCapture, self).update_report_stats()
        self.reportstats['Trial #'] = self.calc_trial_num()
        self.reportstats['Reward/min'] = np.round(
            self.calc_rewards_per_min(120), decimals=2)
        self.reportstats['Success rate'] = str(
            np.round(self.calc_success_rate(120) * 100.0, decimals=2)) + '%'

    def _while_wait(self):
        self.update_cursor()

    def _while_origin(self):
        self.update_cursor()

    def _while_origin_hold(self):
        self.update_cursor()

    def _while_fixation_penalty(self):
        self.update_cursor()

    def _while_hold_penalty(self):
        self.update_cursor()

    def _while_reward(self):
        self.update_cursor()

    def _while_pre_target_change(self):
        self.update_cursor()

    def _while_target_change(self):
        self.update_cursor()
class MovementTraining(Window):
    status = dict(wait=dict(stop=None, move_start="movement"),
                  movement=dict(move_end="reward", move_stop="wait",
                                stop=None),
                  reward=dict(reward_end="wait"))
    log_exclude = set((("wait", "move_start"), ("movement", "move_stop")))

    #initial state
    state = "wait"

    path = [[0, 0, 0]]
    speed = 0
    frame_offset = 2
    over = 0
    inside = 0

    #settable traits
    movement_distance = traits.Float(
        1, desc="Minimum movement distance to trigger reward")
    speed_range = traits.Tuple(
        20, 30, desc="Range of movement speed in cm/s to trigger reward")
    reward_time = traits.Float(14)

    #initialize
    def __init__(self, **kwargs):
        super(MovementTraining, self).__init__(**kwargs)
        self.cursor = Sphere(radius=.5, color=(.5, 0, .5, 1))
        self.add_model(self.cursor)

    def update_cursor(self):
        #get data from 13th marker on motion tracker- take average of all data points since last poll
        pt = self.motiondata.get()
        if len(pt) > 0:
            pt = pt[:, 14, :]
            # NOTE!!! The marker on the hand was changed from #0 to #14 on
            # 5/19/13 after LED #0 broke. All data files saved before this date
            # have LED #0 controlling the cursor.
            pt = pt[~np.isnan(pt).any(1)]
        if len(pt) > 0:
            pt = pt.mean(0)
            self.path.append(pt)
            #ignore y direction
            t = pt * .25
            t[1] = 0
            #move cursor to marker location
            self.cursor.translate(*t[:3], reset=True)
        else:
            self.path.append(self.path[-1])
        if len(self.path) > self.frame_offset:
            self.path.pop(0)
            d = np.sqrt((self.path[-1][0] - self.path[0][0])**2 +
                        (self.path[-1][1] - self.path[0][1])**2 +
                        (self.path[-1][2] - self.path[0][2])**2)
            self.speed = d / (self.frame_offset / 60)
            if self.speed > self.speed_range[0]:
                self.over += 1
            if self.speed_range[0] < self.speed < self.speed_range[1]:
                self.inside += 1
        #write to screen
        self.draw_world()

    def _start_wait(self):
        self.over = 0
        self.inside = 0

    def _while_wait(self):
        self.update_cursor()

    def _while_movement(self):
        self.update_cursor()

    def _while_reward(self):
        self.update_cursor()

    def _test_move_start(self, ts):
        return self.over > self.frame_offset

    def _test_move_end(self, ts):
        return ts > self.movement_distance / self.speed_range[0]

    def _test_move_stop(self, ts):
        return self.inside > self.frame_offset

    def _test_reward_end(self, ts):
        return ts > self.reward_time
class TestBoundary(Window):
    '''
    A very simple task that displays a marker at the specified screen locations.
    Useful for determining reasonable boundary values for targets.
    '''

    status = dict(wait=dict(stop=None))

    state = "wait"

    boundaries = traits.Tuple((-18, 18, -10, 10, -12, 12),
                              desc="x,y,z boundaries to display")

    def __init__(self, **kwargs):
        super(TestBoundary, self).__init__(**kwargs)
        # Create a small sphere for each of the 6 boundary marks
        self.xmin = Sphere(radius=.1, color=(.5, 0, .5, 1))
        self.add_model(self.xmin)
        self.xmax = Sphere(radius=.1, color=(.5, 0, .5, 1))
        self.add_model(self.xmax)
        self.ymin = Sphere(radius=.1, color=(.5, 0, .5, 1))
        self.add_model(self.ymin)
        self.ymax = Sphere(radius=.1, color=(.5, 0, .5, 1))
        self.add_model(self.ymax)
        self.zmin = Sphere(radius=.1, color=(.5, 0, .5, 1))
        self.add_model(self.zmin)
        self.zmax = Sphere(radius=.1, color=(.5, 0, .5, 1))
        self.add_model(self.zmax)

    def _start_wait(self):
        self.xmin.translate(self.boundaries[0], 0, 0, reset=True)
        self.xmin.attach()
        self.xmax.translate(self.boundaries[1], 0, 0, reset=True)
        self.xmax.attach()
        self.ymin.translate(0, self.boundaries[2], 0, reset=True)
        self.ymin.attach()
        self.ymax.translate(0, self.boundaries[3], 0, reset=True)
        self.ymax.attach()
        self.zmin.translate(0, 0, self.boundaries[4], reset=True)
        self.zmin.attach()
        self.zmax.translate(0, 0, self.boundaries[5], reset=True)
        self.zmax.attach()
        self.requeue()

    def _while_wait(self):
        self.draw_world()
Пример #12
0
class ActiveExoPlant(plants.Plant):
    
    n_joints = 5
    force_sensor_offset = 1544.

    def __init__(self, *args, **kwargs):
        if SIM:
            self.rx_port = ('localhost', 60000)
            self.tx_port = ('localhost', 60001)
        else:
            self.rx_port = ('10.0.0.1', 60000)
            self.tx_port = ('10.0.0.14', 60001)            

        self.has_force_sensor = kwargs.pop('has_force_sensor', True)

        self.hdf_attrs = [('joint_angles', 'f8', (5,)), ('joint_velocities', 'f8', (5,)), ('joint_applied_torque', 'f8', (5,)),]
        if self.has_force_sensor and not ('endpt_force' in self.hdf_attrs):
            self.hdf_attrs.append(('endpt_force', 'f8', (1,)))

        # Initialize sockets for transmitting velocity commands / receiving sensor data
        tx_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

        self.tx_sock = tx_sock

        ## kinematic chain
        self.kin_chain = ActiveExoChain()


        # joint limits in radians, based on mechanical limits---some configurations 
        # may still be outside the comfortable range for the subject
        self.kin_chain.joint_limits = [(-1.25, 1.25), (0, 1.7), (-0.95, 0.9), (0, 1.4), (-1.5, 1.5)]

        ## Graphics, for experimenter only
        self.link_lengths = link_lengths
        self.cursor = Sphere(radius=arm_radius/2, color=arm_color)

        self.upperarm_graphics = Cylinder(radius=arm_radius, height=self.link_lengths[0], color=arm_color)
        # self.upperarm_graphics.xfm.translate(*exo_chain_graphics_base_loc)

        self.forearm_graphics = Cone(radius1=arm_radius, radius2=arm_radius/3, height=self.link_lengths[1], color=arm_color)
        self.forearm_graphics.xfm.translate(*exo_chain_graphics_base_loc)

        self.graphics_models = [self.upperarm_graphics, self.forearm_graphics, self.cursor]
        self.enabled = True

        super(ActiveExoPlant, self).__init__(*args, **kwargs)

    def disable(self):
        self.enabled = False

    def enable(self):
        self.enabled = True

    def start(self):
        self.rx_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.rx_sock.bind(self.rx_port)
        super(ActiveExoPlant, self).start()

    def _get_sensor_data(self):
        if not hasattr(self, 'rx_sock'):
            raise Exception("You seem to have forgotten to 'start' the plant!")
        tx_sock, rx_sock = self.tx_sock, self.rx_sock

        tx_sock.sendto('s', self.tx_port)
        self._read_sock()

    def _check_if_ready(self):
        tx_sock, rx_sock = self.tx_sock, self.rx_sock
        socket_list = [self.rx_sock]

        tx_sock.sendto('s', self.tx_port)

        time.sleep(0.5)
        
        # Get the list sockets which are readable
        read_sockets, write_sockets, error_sockets = select.select(socket_list , [], [], 0)

        return self.rx_sock in read_sockets

    def _read_sock(self):
        n_header_bytes = 4
        n_bytes_per_int = 2

        n_rx_bytes = n_header_bytes*3 + 3*self.n_joints*n_bytes_per_double
        fmt = '>IdddddIdddddIddddd'
        if self.has_force_sensor:
            n_rx_bytes += n_bytes_per_int
            fmt += 'H'

        bin_data = self.rx_sock.recvfrom(n_rx_bytes)
        data = np.array(struct.unpack(fmt, bin_data[0]))

        if self.has_force_sensor:
            force_adc_data = data[-1]
            frac_of_max_force = (force_adc_data - self.force_sensor_offset)/(2**14 - self.force_sensor_offset) 
            force_lbs = frac_of_max_force * 10 
            force_N = force_lbs * 4.448221628254617
            self.force_N = max(force_N, 0) # force must be positive for this sensor
            data = data[:-1]
        data = data.reshape(3, self.n_joints + 1)
        data = data[:,1:]
        self.joint_angles, self.joint_velocities, self.joint_applied_torque = data

    def get_data_to_save(self):
        if not hasattr(self, 'joint_angles'):
            print "No data has been acquired yet!"
            return dict()
        data = dict(joint_angles=self.joint_angles, joint_velocities=self.joint_velocities, joint_applied_torque=self.joint_applied_torque)
        if self.has_force_sensor:
            data['endpt_force'] = self.force_N
        return data

    def _set_joint_velocity(self, vel):
        if not len(vel) == self.n_joints:
            raise ValueError("Improper number of joint velocities!")

        if self.enabled:
            vel = vel.ravel()
        else:
            vel = np.zeros(5)
        self.tx_sock.sendto(struct.pack('>I' + 'd'*self.n_joints, self.n_joints, vel[0], vel[1], vel[2], vel[3], vel[4]), self.tx_port)
        self._read_sock()

    def stop_vel(self):
        self._set_joint_velocity(np.zeros(5))

    def stop(self):
        self.rx_sock.close()
        print "RX socket closed!"

    def set_intrinsic_coordinates(self, theta):
        '''
        Set the joint by specifying the angle in radians. Theta is a list of angles. If an element of theta = NaN, angle should remain the same.
        '''
        joint_locations = self.kin_chain.spatial_positions_of_joints(theta)

        vec_sh_to_elbow = joint_locations[:,2]
        vec_elbow_to_endpt = joint_locations[:,4] - joint_locations[:,2]

        self.upperarm_graphics.xfm.rotate = Quaternion.rotate_vecs((0,0,1), vec_sh_to_elbow)
        # self.upperarm_graphics.xfm.translate(*exo_chain_graphics_base_loc)
        self.forearm_graphics.xfm.rotate = Quaternion.rotate_vecs((0,0,1), vec_elbow_to_endpt)
        self.forearm_graphics.xfm.translate(*vec_sh_to_elbow, reset=True)

        self.upperarm_graphics._recache_xfm()
        self.forearm_graphics._recache_xfm()

        self.cursor.translate(*self.get_endpoint_pos(), reset=True)

    def get_intrinsic_coordinates(self, new=False):
        if new or not hasattr(self, 'joint_angles'):
            self._get_sensor_data()
        return self.joint_angles

    def get_endpoint_pos(self):
        if not hasattr(self, 'joint_angles'):
            self._get_sensor_data()        
        return self.kin_chain.endpoint_pos(self.joint_angles)

    def draw_state(self):
        self._get_sensor_data()
        self.set_intrinsic_coordinates(self.joint_angles)

    def drive(self, decoder):
        # import pdb; pdb.set_trace()
        joint_vel = decoder['qdot']
        joint_vel[joint_vel > 0.2] = 0.2
        joint_vel[joint_vel < -0.2] = -0.2

        joint_vel[np.abs(joint_vel) < 0.02] = 0

        # send the command to the robot
        self._set_joint_velocity(joint_vel)

        # set the decoder state to the actual joint angles
        decoder['q'] = self.joint_angles

        self.set_intrinsic_coordinates(self.joint_angles)

    def vel_control_to_joint_config(self, fb_ctrl, target_config, sim=True, control_rate=10, tol=np.deg2rad(10)):
        '''
        Parameters
        ----------
        control_rate : int
            Control rate, in Hz
        '''
        target_state = np.hstack([target_config, np.zeros_like(target_config), np.zeros_like(target_config), 1])
        target_state = np.mat(target_state.reshape(-1,1))

        if not sim:
            self._get_sensor_data()
        else:
            # assume that self.joint_angles has been automagically set
            pass
        current_state = np.hstack([self.joint_angles, np.zeros_like(target_config), np.zeros_like(target_config), 1])
        current_state = np.mat(current_state.reshape(-1,1))

        N = 250
        traj = np.zeros([current_state.shape[0], N]) * np.nan

        for k in range(250):
            print k
            current_config = np.array(current_state[0:5, 0]).ravel()
            # print current_config
            
            if np.all(np.abs(current_config - target_config) < tol):
                print np.abs(current_config - target_config)
                print tol
                break

            current_state = fb_ctrl.calc_next_state(current_state, target_state)
            
            traj[:,k] = np.array(current_state).ravel()

            if sim:
                pass
            else:
                current_vel = np.array(current_state[5:10,0]).ravel()
                self._set_joint_velocity(current_vel)
                
                # update the current state using the joint encoders
                current_state = np.hstack([self.joint_angles, np.zeros_like(target_config), np.zeros_like(target_config), 1])
                current_state = np.mat(current_state.reshape(-1,1))
                
                time.sleep(1./control_rate)

        return traj
Пример #13
0
class TentacleMultiConfigObstacleAvoidance(BMIJointPerturb):
    status = dict(
        wait=dict(start_trial="premove", stop=None),
        premove=dict(premove_complete="target"),
        target=dict(enter_target="hold",
                    timeout="timeout_penalty",
                    stop=None,
                    hit_obstacle="obstacle_penalty"),
        hold=dict(leave_early="hold_penalty", hold_complete="targ_transition"),
        targ_transition=dict(trial_complete="reward",
                             trial_abort="wait",
                             trial_incomplete="target",
                             trial_restart="premove"),
        timeout_penalty=dict(timeout_penalty_end="targ_transition"),
        hold_penalty=dict(hold_penalty_end="targ_transition"),
        obstacle_penalty=dict(obstacle_penalty_end="targ_transition"),
        reward=dict(reward_end="wait"))

    obstacle_radius = traits.Float(2.0, desc='Radius of cylindrical obstacle')
    obstacle_penalty = traits.Float(
        0.0, desc='Penalty time if the chain hits the obstacle(s)')

    def __init__(self, *args, **kwargs):
        super(TentacleMultiConfigObstacleAvoidance,
              self).__init__(*args, **kwargs)

        ## Create an obstacle object, hidden by default
        self.obstacle = Sphere(radius=self.obstacle_radius + 0.6,
                               color=(0, 0, 1, .5))
        self.obstacle_on = False
        self.obstacle_pos = np.ones(3) * np.nan
        self.hit_obstacle = False

        self.add_model(self.obstacle)

    def init(self):
        self.add_dtype('obstacle_on', 'f8', (1, ))
        self.add_dtype('obstacle_pos', 'f8', (3, ))
        super(TentacleMultiConfigObstacleAvoidance, self).init()

    def _cycle(self):
        self.task_data['obstacle_on'] = self.obstacle_on
        self.task_data['obstacle_pos'] = self.obstacle_pos
        super(TentacleMultiConfigObstacleAvoidance, self)._cycle()

    def _start_target(self):
        super(TentacleMultiConfigObstacleAvoidance, self)._start_target()
        if self.target_index == 1:
            self.obstacle_pos = (self.targs[0] / 2)
            self.obstacle.translate(*self.obstacle_pos, reset=True)
            self.obstacle.attach()
            self.obstacle_on = True

    def _test_obstacle_penalty_end(self, ts):
        return ts > self.obstacle_penalty

    def _start_obstacle_penalty(self):
        #hide targets
        for target in self.targets:
            target.hide()

        self.tries += 1
        self.target_index = -1

    def _end_target(self):
        self.obstacle.detach()
        self.obstacle_on = False

    def _test_hit_obstacle(self, ts):
        if self.target_index == 1:
            joint_angles = self.plant.get_intrinsic_coordinates()
            distances_to_links = self.plant.kin_chain.detect_collision(
                joint_angles, self.obstacle_pos)

            hit = np.min(distances_to_links) < (self.obstacle_radius +
                                                self.plant.link_radii[0])
            if hit:
                self.hit_obstacle = True
                return True
        else:
            return False

    @staticmethod
    def tentacle_multi_start_config(nblocks=100,
                                    ntargets=4,
                                    distance=8,
                                    startangle=45):
        elbow_angles = np.array([
            135, 180, 225
        ]) * np.pi / 180  # TODO make this a function argument!
        startangle = 45 * np.pi / 180
        n_configs_per_target = len(elbow_angles)
        target_angles = np.arange(startangle, startangle + (2 * np.pi),
                                  2 * np.pi / ntargets)
        targets = distance * np.vstack(
            [np.cos(target_angles), 0 * target_angles,
             np.sin(target_angles)])

        seq = []
        from itertools import izip
        import random
        for i in range(nblocks):
            target_inds = np.tile(np.arange(ntargets),
                                  (n_configs_per_target, 1)).T.ravel()
            config_inds = np.tile(np.arange(n_configs_per_target), ntargets)

            sub_seq = []
            inds = np.arange(n_configs_per_target * ntargets)
            random.shuffle(inds)
            for k in inds:
                targ_ind = target_inds[k]
                config_ind = config_inds[k]

                seq_item = (np.vstack([targets[:, targ_ind],
                                       np.zeros(3)]), elbow_angles[config_ind])
                seq.append(seq_item)

        return seq
Пример #14
0
class TentacleObstacleAvoidance(BMIControlMultiTentacleAttractor):
    status = dict(
        wait=dict(start_trial="target", stop=None),
        target=dict(enter_target="hold",
                    timeout="timeout_penalty",
                    stop=None,
                    hit_obstacle="obstacle_penalty"),
        hold=dict(leave_early="hold_penalty", hold_complete="targ_transition"),
        targ_transition=dict(trial_complete="reward",
                             trial_abort="wait",
                             trial_incomplete="target"),
        timeout_penalty=dict(timeout_penalty_end="targ_transition"),
        hold_penalty=dict(hold_penalty_end="targ_transition"),
        obstacle_penalty=dict(obstacle_penalty_end="targ_transition"),
        reward=dict(reward_end="wait"))
    obstacle_radius = traits.Float(2.0, desc='Radius of cylindrical obstacle')
    obstacle_penalty = traits.Float(
        0.0, desc='Penalty time if the chain hits the obstacle(s)')

    def __init__(self, *args, **kwargs):
        super(TentacleObstacleAvoidance, self).__init__(*args, **kwargs)

        ## Create an obstacle object, hidden by default
        self.obstacle = Sphere(
            radius=self.obstacle_radius + 0.6, color=(0, 0, 1, .5)
        )  ##Cylinder(radius=self.obstacle_radius, height=1, color=(0,0,1,1))
        self.obstacle_on = False
        self.obstacle_pos = np.ones(3) * np.nan
        self.hit_obstacle = False

        self.add_model(self.obstacle)

    def init(self):
        self.add_dtype('obstacle_on', 'f8', (1, ))
        self.add_dtype('obstacle_pos', 'f8', (3, ))
        super(TentacleObstacleAvoidance, self).init()

    def _cycle(self):
        self.task_data['obstacle_on'] = self.obstacle_on
        self.task_data['obstacle_pos'] = self.obstacle_pos
        super(TentacleObstacleAvoidance, self)._cycle()

    def _start_target(self):
        super(TentacleObstacleAvoidance, self)._start_target()
        if self.target_index == 1:
            target_angle = np.round(
                np.rad2deg(
                    np.arctan2(self.target_location[-1],
                               self.target_location[0])))
            try:
                obstacle_data = pickle.load(
                    open(
                        '/storage/task_data/TentacleObstacleAvoidance/center_out_obstacle_pos.pkl'
                    ))
                self.obstacle_pos = obstacle_data[target_angle]
            except:
                self.obstacle_pos = (self.target_location / 2)
            self.obstacle.translate(*self.obstacle_pos, reset=True)
            self.obstacle.attach()
            self.obstacle_on = True

    def _test_obstacle_penalty_end(self, ts):
        return ts > self.obstacle_penalty

    def _start_obstacle_penalty(self):
        #hide targets
        for target in self.targets:
            target.hide()

        self.tries += 1
        self.target_index = -1

    def _end_target(self):
        self.obstacle.detach()
        self.obstacle_on = False

    def _test_hit_obstacle(self, ts):
        if self.target_index == 1:
            joint_angles = self.plant.get_intrinsic_coordinates()
            distances_to_links = self.plant.kin_chain.detect_collision(
                joint_angles, self.obstacle_pos)

            hit = np.min(distances_to_links) < (self.obstacle_radius +
                                                self.plant.link_radii[0])
            if hit:
                self.hit_obstacle = True
                return True
        else:
            return False