Exemplo n.º 1
0
class PenaltyAudio(traits.HasTraits):
    '''
    Play a sound in any penalty state. Have to define a new _start method for each different
    penalty state that might occur.
    '''
    files = list(reversed([f for f in os.listdir(audio_path) if '.wav' in f]))
    penalty_sound = traits.OptionsList(
        files, desc="File in riglib/audio to play on each penalty")

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.penalty_player = AudioPlayer(self.penalty_sound)

    def _start_hold_penalty(self):
        if hasattr(super(), '_start_hold_penalty'):
            super()._start_hold_penalty()
        self.penalty_player.play()

    def _start_delay_penalty(self):
        if hasattr(super(), '_start_delay_penalty'):
            super()._start_delay_penalty()
        self.penalty_player.play()

    def _start_reach_penalty(self):
        if hasattr(super(), '_start_reach_penalty'):
            super()._start_reach_penalty()
        self.penalty_player.play()

    def _start_timeout_penalty(self):
        if hasattr(super(), '_start_timeout_penalty'):
            super()._start_timeout_penalty()
        self.penalty_player.play()
Exemplo n.º 2
0
class EndPostureFeedbackController(BMILoop, traits.HasTraits):
    ssm_type_options = bmi_ssm_options
    ssm_type = traits.OptionsList(*bmi_ssm_options, bmi3d_input_options=bmi_ssm_options)

    def load_decoder(self):
        self.ssm = StateSpaceEndptVel2D()
        A, B, W = self.ssm.get_ssm_matrices()
        filt = MachineOnlyFilter(A, W)
        units = []
        self.decoder = Decoder(filt, units, self.ssm, binlen=0.1)
        self.decoder.n_features = 1

    def create_feature_extractor(self):
        self.extractor = DummyExtractor()
        self._add_feature_extractor_dtype()
Exemplo n.º 3
0
class RewardAudio(traits.HasTraits):
    '''
    Play a sound in any reward state. Need to add other reward states you want to be included.
    '''

    files = [f for f in os.listdir(audio_path) if '.wav' in f]
    reward_sound = traits.OptionsList(
        files, desc="File in riglib/audio to play on each reward")

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.reward_player = AudioPlayer(self.reward_sound)

    def _start_reward(self):
        if hasattr(super(), '_start_reward'):
            super()._start_reward()
        self.reward_player.play()
class EndPostureFeedbackController(BMILoop, traits.HasTraits):
    ssm_type_options = bmi_ssm_options
    ssm_type = traits.OptionsList(*bmi_ssm_options,
                                  bmi3d_input_options=bmi_ssm_options)

    def load_decoder(self):
        from db.namelist import bmi_state_space_models
        # from config import config
        # with open(os.path.join(config.log_dir, 'EndPostureFeedbackController'), 'w') as fh:
        #     fh.write('%s' % self.ssm_type)
        self.ssm = bmi_state_space_models[self.ssm_type]
        A, B, W = self.ssm.get_ssm_matrices()
        filt = MachineOnlyFilter(A, W)
        units = []
        self.decoder = Decoder(filt, units, self.ssm, binlen=0.1)
        self.decoder.n_features = 1

    def create_feature_extractor(self):
        self.extractor = DummyExtractor()
        self._add_feature_extractor_dtype()
Exemplo n.º 5
0
class ArmPlant(Window):
    '''
    This task creates a RobotArm object and allows it to move around the screen based on either joint or endpoint
    positions. There is a spherical cursor at the end of the arm. The links of the arm can be visible or hidden.
    '''

    background = (0, 0, 0, 1)

    arm_visible = traits.Bool(
        True,
        desc='Specifies whether entire arm is displayed or just endpoint')

    cursor_radius = traits.Float(.5, desc="Radius of cursor")
    cursor_color = (.5, 0, .5, 1)

    arm_class = traits.OptionsList(*plantlist,
                                   bmi3d_input_options=plantlist.keys())
    starting_pos = (5, 0, 5)

    def __init__(self, *args, **kwargs):
        super(ArmPlant, self).__init__(*args, **kwargs)
        self.cursor_visible = True

        # Initialize the arm
        self.arm = ik.test_3d
        self.arm_vis_prev = True

        if self.arm_class == 'CursorPlant':
            pass
        else:
            self.dtype.append(('joint_angles', 'f8', (self.arm.num_joints, )))
            self.dtype.append(('arm_visible', 'f8', (1, )))
            self.add_model(self.arm)

        ## Declare cursor
        self.dtype.append(('cursor', 'f8', (3, )))
        self.cursor = Sphere(radius=self.cursor_radius,
                             color=self.cursor_color)
        self.add_model(self.cursor)
        self.cursor.translate(*self.arm.get_endpoint_pos(), reset=True)

    def _cycle(self):
        '''
        Calls any update functions necessary and redraws screen. Runs 60x per second by default.
        '''
        ## Run graphics commands to show/hide the arm if the visibility has changed
        if self.arm_class != 'CursorPlant':
            if self.arm_visible != self.arm_vis_prev:
                self.arm_vis_prev = self.arm_visible
                self.show_object(self.arm, show=self.arm_visible)

        self.move_arm()
        self.update_cursor()
        if self.cursor_visible:
            self.task_data['cursor'] = self.cursor.xfm.move.copy()
        else:
            #if the cursor is not visible, write NaNs into cursor location saved in file
            self.task_data['cursor'] = np.array([np.nan, np.nan, np.nan])

        if self.arm_class != 'CursorPlant':
            if self.arm_visible:
                self.task_data['arm_visible'] = 1
            else:
                self.task_data['arm_visible'] = 0

        super(ArmPlant, self)._cycle()

    ## Functions to move the cursor using keyboard/mouse input
    def get_mouse_events(self):
        import pygame
        events = []
        for btn in pygame.event.get(
            (pygame.MOUSEBUTTONDOWN, pygame.MOUSEBUTTONUP)):
            events = events + [btn.button]
        return events

    def get_key_events(self):
        import pygame
        return pygame.key.get_pressed()

    def move_arm(self):
        '''
        allows use of keyboard keys to test movement of arm. Use QW/OP for joint movements, arrow keys for endpoint movements
        '''
        import pygame

        keys = self.get_key_events()
        joint_speed = (np.pi / 6) / 60
        hand_speed = .2

        x, y, z = self.arm.get_endpoint_pos()

        if keys[pygame.K_RIGHT]:
            x = x - hand_speed
            self.arm.set_endpoint_pos(np.array([x, 0, z]))
        if keys[pygame.K_LEFT]:
            x = x + hand_speed
            self.arm.set_endpoint_pos(np.array([x, 0, z]))
        if keys[pygame.K_DOWN]:
            z = z - hand_speed
            self.arm.set_endpoint_pos(np.array([x, 0, z]))
        if keys[pygame.K_UP]:
            z = z + hand_speed
            self.arm.set_endpoint_pos(np.array([x, 0, z]))

        if self.arm.num_joints == 2:
            xz, xy = self.get_arm_joints()
            e = np.array([xz[0], xy[0]])
            s = np.array([xz[1], xy[1]])

            if keys[pygame.K_q]:
                s = s - joint_speed
                self.set_arm_joints([e[0], s[0]], [e[1], s[1]])
            if keys[pygame.K_w]:
                s = s + joint_speed
                self.set_arm_joints([e[0], s[0]], [e[1], s[1]])
            if keys[pygame.K_o]:
                e = e - joint_speed
                self.set_arm_joints([e[0], s[0]], [e[1], s[1]])
            if keys[pygame.K_p]:
                e = e + joint_speed
                self.set_arm_joints([e[0], s[0]], [e[1], s[1]])

        if self.arm.num_joints == 4:
            jts = self.get_arm_joints()
            keyspressed = [
                keys[pygame.K_q], keys[pygame.K_w], keys[pygame.K_e],
                keys[pygame.K_r]
            ]
            for i in range(self.arm.num_joints):
                if keyspressed[i]:
                    jts[i] = jts[i] + joint_speed
                    self.set_arm_joints(jts)

    def get_cursor_location(self):
        return self.arm.get_endpoint_pos()

    def set_arm_endpoint(self, pt, **kwargs):
        self.arm.set_endpoint_pos(pt, **kwargs)

    def set_arm_joints(self, angle_xz, angle_xy):
        self.arm.set_intrinsic_coordinates(angle_xz, angle_xy)

    def get_arm_joints(self):
        return self.arm.get_intrinsic_coordinates()

    def update_cursor(self):
        '''
        Update the cursor's location and visibility status.
        '''
        pt = self.get_cursor_location()
        if pt is not None:
            self.move_cursor(pt)

    def move_cursor(self, pt):
        ''' Move the cursor object to the specified 3D location. '''
        if not hasattr(self.arm, 'endpt_cursor'):
            self.cursor.translate(*pt[:3], reset=True)
class FreeChoiceFA(FactorBMIBase):
    '''
    Task where the virtual plant starts in configuration sampled from a discrete set and resets every trial
    '''

    sequence_generators = [
        'centerout_2D_discrete_w_free_choice',
        'centerout_2D_discrete_w_free_choice_v2',
        'centerout_2D_discrete_w_free_choices_evenly_spaced'
    ]
    #sequence_generators = ['centerout_2D_discrete']

    input_type_list = [
        'shared', 'private', 'shared_scaled', 'private_scaled', 'all',
        'all_scaled_by_shar', 'sc_shared+unsc_priv', 'sc_shared+sc_priv',
        'main_shared', 'main_sc_shared', 'main_sc_private',
        'main_sc_shar+unsc_priv', 'main_sc_shar+sc_priv', 'pca', 'split'
    ]

    input_type_0 = traits.OptionsList(*input_type_list,
                                      bmi3d_input_options=input_type_list)
    color_0 = traits.OptionsList(*target_colors.keys(),
                                 bmi3d_input_options=target_colors.keys())

    input_type_1 = traits.OptionsList(*input_type_list,
                                      bmi3d_input_options=input_type_list)
    color_1 = traits.OptionsList(*target_colors.keys(),
                                 bmi3d_input_options=target_colors.keys())

    input_type_2 = traits.OptionsList(*input_type_list,
                                      bmi3d_input_options=input_type_list)
    color_2 = traits.OptionsList(*target_colors.keys(),
                                 bmi3d_input_options=target_colors.keys())

    input_type_3 = traits.OptionsList(*input_type_list,
                                      bmi3d_input_options=input_type_list)
    color_3 = traits.OptionsList(*target_colors.keys(),
                                 bmi3d_input_options=target_colors.keys())

    choice_assist = traits.Float(0.)
    target_assist = traits.Float(0.)
    choice_target_rad = traits.Float(2.)

    status = dict(wait=dict(start_trial="targ_transition", stop=None),
                  pre_choice_orig=dict(enter_orig='choice_target',
                                       timeout='timeout_penalty',
                                       stop=None),
                  choice_target=dict(enter_choice_target='targ_transition',
                                     timeout='timeout_penalty',
                                     stop=None),
                  target=dict(enter_target="hold",
                              timeout="timeout_penalty",
                              stop=None),
                  hold=dict(leave_early="hold_penalty",
                            hold_complete="targ_transition"),
                  targ_transition=dict(trial_complete="reward",
                                       trial_abort="wait",
                                       trial_incomplete="target",
                                       make_choice='pre_choice_orig'),
                  timeout_penalty=dict(timeout_penalty_end="targ_transition"),
                  hold_penalty=dict(hold_penalty_end="targ_transition"),
                  reward=dict(reward_end="wait"))
    hidden_traits = [
        'arm_hide_rate', 'arm_visible', 'hold_penalty_time', 'rand_start',
        'reset', 'window_size', 'assist_level', 'assist_level_time',
        'plant_hide_rate', 'plant_visible', 'show_environment',
        'trials_per_reward'
    ]

    def __init__(self, *args, **kwargs):
        super(FreeChoiceFA, self).__init__(*args, **kwargs)

        seq_params = eval(kwargs.pop('seq_params', '{}'))
        print 'SEQ PARAMS: ', seq_params, type(seq_params)
        self.choice_per_n_blocks = seq_params.pop('blocks_per_free_choice', 1)
        self.n_free_choices = seq_params.pop('n_free_choices', 2)
        self.n_targets = seq_params.pop('ntargets', 8)

        self.input_type_dict = dict()
        self.input_type_dict[0] = self.input_type_0
        self.input_type_dict[0, 'color'] = target_colors[self.color_0]

        self.input_type_dict[1] = self.input_type_1
        self.input_type_dict[1, 'color'] = target_colors[self.color_1]

        self.input_type_dict[2] = self.input_type_2
        self.input_type_dict[2, 'color'] = target_colors[self.color_2]

        self.input_type_dict[3] = self.input_type_3
        self.input_type_dict[3, 'color'] = target_colors[self.color_3]

        # Instantiate the choice targets
        self.choices_targ_list = []
        for c in range(self.n_free_choices):
            self.choices_targ_list.append(
                target_graphics.VirtualCircularTarget(
                    target_radius=self.choice_target_rad,
                    target_color=self.input_type_dict[c, 'color']))

        for c in self.choices_targ_list:
            for model in c.graphics_models:
                self.add_model(model)

        self.subblock_cnt = 0
        self.subblock_end = self.choice_per_n_blocks * self.n_targets
        self.choice_made = 0
        self.choice_ts = 0
        self.chosen_input_ix = -1
        self.choice_locs = np.zeros((self.n_free_choices, 3))

    def init(self):
        self.add_dtype('trial_type', np.str_, 16)
        self.add_dtype('choice_ix', 'f8', (1, ))
        self.add_dtype('choice_targ_loc', 'f8', (self.n_free_choices, 3))

        super(FreeChoiceFA, self).init()

    def _start_pre_choice_orig(self):
        target = self.targets[0]
        target.move_to_position(np.array([0., 0., 0.]))
        target.cue_trial_start()
        self.chosen_input_ix = -1

    def _start_timeout_penalty(self):
        #hide targets
        for target in self.targets:
            target.hide()

        for target in self.choices_targ_list:
            target.hide()

        self.tries += 1
        self.target_index = -1

    def _test_enter_orig(self, ts):
        cursor_pos = self.plant.get_endpoint_pos()
        d = np.linalg.norm(cursor_pos)
        return d <= self.target_radius

    def update_level(self):
        pass

    def create_goal_calculator(self):
        self.goal_calculator = Choice_Goal_Calc(self.decoder.ssm)

    def get_target_BMI_state(self, *args):
        '''
        Run the goal calculator to determine the target state of the task
        '''
        target_state = self.goal_calculator(self.targs, self.choice_locs,
                                            self.choice_asst_ix,
                                            self.target_index, self.state)
        return np.array(target_state).reshape(-1, 1)

    def _parse_next_trial(self):
        print 'parse next: ', self.next_trial[2][0]
        pairs = self.next_trial[0]
        self.targs = pairs[:, :, 1]
        self.choice_locs = pairs[:, :, 0]
        self.choice_asst_ix = self.next_trial[1][0]
        self.choice_instructed = self.next_trial[2][0]

        if self.subblock_cnt >= self.subblock_end:
            self.choice_made = 0
            self.subblock_cnt = 0

    def _test_make_choice(self, ts):
        return not self.choice_made

    def _cycle(self):
        self.task_data['trial_type'] = self.choice_instructed
        self.task_data['choice_ix'] = self.chosen_input_ix
        self.task_data['choice_targ_loc'] = self.choice_locs
        super(FreeChoiceFA, self)._cycle()

    def _start_choice_target(self):
        self.choice_ts = 0
        if self.choice_instructed == 'Free':
            for ic, c in enumerate(self.choices_targ_list):
                #move a target to current location (target1 and target2 alternate moving) and set location attribute
                c.move_to_position(self.choice_locs[ic, :])
                c.sphere.color = self.input_type_dict[ic, 'color']
                c.show()
        elif self.choice_instructed == 'Instructed':
            ic = self.choice_asst_ix
            c = self.choices_targ_list[ic]
            c.move_to_position(self.choice_locs[ic, :])
            c.sphere.color = self.input_type_dict[ic, 'color']
            c.show()

        target = self.targets[0]
        target.hide()
        self.choice_ts = 0

    def _start_target(self):
        super(FreeChoiceFA, self)._start_target()
        self.current_assist_level = self.target_assist
        for ic, c in enumerate(self.choices_targ_list):
            c.hide()

    def _test_enter_choice_target(self, ts):
        cursor_pos = self.plant.get_endpoint_pos()
        enter_targ = 0
        for ic, c in enumerate(self.choice_locs):
            d = np.linalg.norm(cursor_pos - c)
            if d <= self.choice_target_rad:  #NOTE, gets in if CENTER of cursor is in target (not entire cursor)
                enter_targ += 1

                #Set chosen as new input:
                self.chosen_input_ix = ic
                self.decoder.filt.FA_input = self.input_type_dict[ic]
                print 'trial: ', self.decoder.filt.FA_input, self.choice_instructed

                #Declare that choice has been made:
                self.choice_made = 1

                #Change color of cursor:
                sph = self.plant.graphics_models[0]
                sph.color = self.input_type_dict[ic, 'color']

        return enter_targ > 0

    def _test_trial_incomplete(self, ts):
        if self.choice_made == 0:
            return False
        else:
            return (not self._test_trial_complete(ts)) and (self.tries <
                                                            self.max_attempts)

    def _start_reward(self):
        self.subblock_cnt += 1
        super(FreeChoiceFA, self)._start_reward()

    @staticmethod
    def centerout_2D_discrete_w_free_choice(nblocks=100,
                                            ntargets=8,
                                            boundaries=(-18, 18, -12, 12),
                                            distance=10,
                                            n_free_choices=2,
                                            blocks_per_free_choice=1,
                                            percent_instructed=50.):
        return True

    @staticmethod
    def centerout_2D_discrete_w_free_choice_v2(nblocks=100,
                                               ntargets=8,
                                               boundaries=(-18, 18, -12, 12),
                                               distance=10,
                                               n_free_choices=2,
                                               blocks_per_free_choice=1,
                                               percent_instructed=50.):
        '''

        Generates a sequence of 2D (x and z) target pairs with the first target
        always at the origin and a sequence of 2D (x and z) target locations for nblocks 
        of free choices where the location of each choice changes. 

        Parameters
        ----------
        length : int
            The number of target pairs in the sequence.
        boundaries: 6 element Tuple
            The limits of the allowed target locations (-x, x, -z, z)
        distance : float
            The distance in cm between the targets in a pair.

        n_free_choices: number of choices. 

        Returns
        -------
        ([nblocks x ntargets x 2 x 3], [nblocks x n_free_choices x 3]) array of 1) pairs of target locations
        and 2) set of free choices 


        '''

        # Choose a random sequence of points on the edge of a circle of radius
        # "distance"

        theta = []
        theta_choice = []
        ix_choice_assist = []
        ix_choice_instructed = []
        for i in range(nblocks):
            temp_ = []
            for j in range(blocks_per_free_choice):
                temp = np.arange(0, 2 * np.pi, 2 * np.pi / ntargets)
                np.random.shuffle(temp)
                temp_ = temp_ + list(temp)

            theta.append(temp_)
            temp2 = np.arange(0, np.pi / 2., np.pi / 2. / n_free_choices) + (
                np.pi / 4.) + (np.pi / 2.) * np.random.randint(0, 2)
            temp3 = np.random.randint(0, n_free_choices)
            temp4 = np.random.rand()
            if temp4 < percent_instructed / 100.:
                ix_choice_instructed.append('Instructed')
            else:
                ix_choice_instructed.append('Free')

            np.random.shuffle(temp2)
            theta_choice.append(temp2)
            ix_choice_assist.append(temp3)

        theta = np.vstack(theta)
        theta_choice = np.vstack(theta_choice)  #nblocks x n_free_choices
        ix_choice_assist = np.array(ix_choice_assist)
        ix_choice_instructed = np.array(ix_choice_instructed)

        #### calculate targets:
        x = distance * np.cos(theta)
        y = np.zeros((nblocks, ntargets * blocks_per_free_choice))
        z = distance * np.sin(theta)

        pairs = np.zeros([nblocks, ntargets * blocks_per_free_choice, 2, 3])
        pairs[:, :, 1, :] = np.dstack([x, y, z])

        #### calculate free choices:
        x = distance * np.cos(theta_choice)
        y = np.zeros((nblocks, n_free_choices))
        z = distance * np.sin(theta_choice)

        choice = np.zeros((nblocks, n_free_choices, 3))
        choice = np.dstack((x, y, z))

        g = []
        for i in range(nblocks):
            chz = choice[i, :, :]
            chz_assist = ix_choice_assist[i]
            type_chz = ix_choice_instructed[i]
            for j in range(ntargets * blocks_per_free_choice):
                tg = pairs[i, j, :, :]
                g.append((np.dstack((chz, tg)), [chz_assist], [type_chz]))
        return g

    @staticmethod
    def centerout_2D_discrete_w_free_choices_evenly_spaced(
            nblocks=100,
            ntargets=8,
            boundaries=(-18, 18, -12, 12),
            distance=10,
            n_free_choices=2,
            blocks_per_free_choice=1,
            percent_instructed=50.,
            choice_targ_ang=30.):
        '''

        Generates a sequence of 2D (x and z) target pairs with the first target
        always at the origin and a sequence of 2D (x and z) target locations for nblocks 
        of free choices where the location of each choice changes -- specifically the location
        of the free choices are opposite the previous target, and spaced at 30 degree angle offsets

        Parameters
        ----------
        length : int
            The number of target pairs in the sequence.
        boundaries: 6 element Tuple
            The limits of the allowed target locations (-x, x, -z, z)
        distance : float
            The distance in cm between the targets in a pair.

        n_free_choices: number of choices. 

        Returns
        -------
        ([nblocks x ntargets x 2 x 3], [nblocks x n_free_choices x 3]) array of 1) pairs of target locations
        and 2) set of free choices 


        '''

        # Choose a random sequence of points on the edge of a circle of radius
        # "distance"

        theta = []
        theta_choice = []
        ix_choice_assist = []
        ix_choice_instructed = []
        last_targ_ang_ = 0.
        for i in range(nblocks):
            temp_ = []
            for j in range(blocks_per_free_choice):
                temp = np.arange(0, 2 * np.pi, 2 * np.pi / ntargets)
                np.random.shuffle(temp)
                temp_ = temp_ + list(temp)

            theta.append(temp_)

            ang = np.array([
                -choice_targ_ang * (np.pi / 180),
                choice_targ_ang * (np.pi / 180.)
            ])
            temp2 = ang + np.pi + last_targ_ang_
            last_targ_ang_ = temp_[-1]

            temp3 = np.random.randint(0, n_free_choices)
            temp4 = np.random.rand()
            if temp4 < percent_instructed / 100.:
                ix_choice_instructed.append('Instructed')
            else:
                ix_choice_instructed.append('Free')

            np.random.shuffle(temp2)
            theta_choice.append(temp2)
            ix_choice_assist.append(temp3)

        theta = np.vstack(theta)
        theta_choice = np.vstack(theta_choice)  #nblocks x n_free_choices
        ix_choice_assist = np.array(ix_choice_assist)
        ix_choice_instructed = np.array(ix_choice_instructed)

        #### calculate targets:
        x = distance * np.cos(theta)
        y = np.zeros((nblocks, ntargets * blocks_per_free_choice))
        z = distance * np.sin(theta)

        pairs = np.zeros([nblocks, ntargets * blocks_per_free_choice, 2, 3])
        pairs[:, :, 1, :] = np.dstack([x, y, z])

        #### calculate free choices:
        x = distance * np.cos(theta_choice)
        y = np.zeros((nblocks, n_free_choices))
        z = distance * np.sin(theta_choice)

        choice = np.zeros((nblocks, n_free_choices, 3))
        choice = np.dstack((x, y, z))

        g = []
        for i in range(nblocks):
            chz = choice[i, :, :]
            chz_assist = ix_choice_assist[i]
            type_chz = ix_choice_instructed[i]
            for j in range(ntargets * blocks_per_free_choice):
                tg = pairs[i, j, :, :]
                g.append((np.dstack((chz, tg)), [chz_assist], [type_chz]))
        return g
class LFP_Mod(BMILoop, Sequence, Window):

    background = (0,0,0,1)
    
    plant_visible = traits.Bool(True, desc='Specifies whether entire plant is displayed or just endpoint')
    
    lfp_cursor_rad = traits.Float(.5, desc="length of LFP cursor")
    lfp_cursor_color = (.5,0,.5,.75)  
     
    lfp_plant_type_options = plantlist.keys()
    lfp_plant_type = traits.OptionsList(*plantlist, bmi3d_input_options=plantlist.keys())

    window_size = traits.Tuple((1920*2, 1080), desc='window size')

    lfp_frac_lims = traits.Tuple((0., 0.35), desc='fraction limits')
    xlfp_frac_lims = traits.Tuple((-.7, 1.7), desc = 'x dir fraction limits')
    lfp_control_band = traits.Tuple((25, 40), desc='beta power band limits')
    lfp_totalpw_band = traits.Tuple((1, 100), desc='total power band limits')
    xlfp_control_band = traits.Tuple((0, 5), desc = 'x direction band limits')
    n_steps = traits.Int(2, desc='moving average for decoder')


    powercap = traits.Float(1, desc="Timeout for total power above this")

    zboundaries=(-12,12)

    status = dict(
        wait = dict(start_trial="lfp_target", stop=None),
        lfp_target = dict(enter_lfp_target="lfp_hold", powercap_penalty="powercap_penalty", stop=None),
        lfp_hold = dict(leave_early="lfp_target", lfp_hold_complete="reward", powercap_penalty="powercap_penalty"),
        powercap_penalty = dict(powercap_penalty_end="lfp_target"),
        reward = dict(reward_end="wait")
        )

    static_states = [] # states in which the decoder is not run
    trial_end_states = ['reward']
    lfp_cursor_on = ['lfp_target', 'lfp_hold']

    #initial state
    state = "wait"

    #create settable traits
    reward_time = traits.Float(.5, desc="Length of juice reward")

    lfp_target_rad = traits.Float(3.6, desc="Length of targets in cm")
    
    lfp_hold_time = traits.Float(.2, desc="Length of hold required at lfp targets")
    lfp_hold_var = traits.Float(.05, desc="Length of hold variance required at lfp targets")

    hold_penalty_time = traits.Float(1, desc="Length of penalty time for target hold error")
    
    powercap_penalty_time = traits.Float(1, desc="Length of penalty time for timeout error")

    # max_attempts = traits.Int(10, desc='The number of attempts at a target before\
    #     skipping to the next one')

    session_length = traits.Float(0, desc="Time until task automatically stops. Length of 0 means no auto stop.")

    #plant_hide_rate = traits.Float(0.0, desc='If the plant is visible, specifies a percentage of trials where it will be hidden')
    lfp_target_color = (123/256.,22/256.,201/256.,.5)
    mc_target_color = (1,0,0,.5)

    target_index = -1 # Helper variable to keep track of which target to display within a trial
    #tries = 0 # Helper variable to keep track of the number of failed attempts at a given trial.
    
    cursor_visible = False # Determines when to hide the cursor.
    no_data_count = 0 # Counter for number of missing data frames in a row
    
    sequence_generators = ['lfp_mod_4targ']
    
    def __init__(self, *args, **kwargs):
        super(LFP_Mod, self).__init__(*args, **kwargs)
        self.cursor_visible = True

        print 'INIT FRAC LIMS: ', self.lfp_frac_lims
        
        dec_params = dict(lfp_frac_lims = self.lfp_frac_lims,
                          xlfp_frac_lims = self.xlfp_frac_lims,
                          powercap = self.powercap,
                          zboundaries = self.zboundaries,
                          lfp_control_band = self.lfp_control_band,
                          lfp_totalpw_band = self.lfp_totalpw_band,
                          xlfp_control_band = self.xlfp_control_band,
                          n_steps = self.n_steps)

        self.decoder.filt.init_from_task(**dec_params)
        self.decoder.init_from_task(**dec_params)

        self.lfp_plant = plantlist[self.lfp_plant_type]
        if self.lfp_plant_type == 'inv_cursor_onedimLFP':
            print 'MAKE SURE INVERSE GENERATOR IS ON'
            
        self.plant_vis_prev = True

        self.current_assist_level = 0
        self.learn_flag = False

        if hasattr(self.lfp_plant, 'graphics_models'):
            for model in self.lfp_plant.graphics_models:
                self.add_model(model)

        # Instantiate the targets
        ''' 
        height and width on kinarm machine are 2.4. Here we make it 2.4/8*12 = 3.6
        '''
        lfp_target = VirtualSquareTarget(target_radius=self.lfp_target_rad, target_color=self.lfp_target_color)
        self.targets = [lfp_target]
        
        # Initialize target location variable
        self.target_location_lfp = np.array([-100, -100, -100])

        # Declare any plant attributes which must be saved to the HDF file at the _cycle rate
        for attr in self.lfp_plant.hdf_attrs:
            self.add_dtype(*attr) 

    def init(self):
        self.plant = DummyPlant()
        self.add_dtype('lfp_target', 'f8', (3,)) 
        self.add_dtype('target_index', 'i', (1,))
        self.add_dtype('powercap_flag', 'i',(1,))

        for target in self.targets:
            for model in target.graphics_models:
                self.add_model(model)

        super(LFP_Mod, self).init()

    def _cycle(self):
        '''
        Calls any update functions necessary and redraws screen. Runs 60x per second.
        '''
        self.task_data['loop_time'] = self.iter_time()
        self.task_data['lfp_target'] = self.target_location_lfp.copy()
        self.task_data['target_index'] = self.target_index
        #self.task_data['internal_decoder_state'] = self.decoder.filt.current_lfp_pos
        self.task_data['powercap_flag'] = self.decoder.filt.current_powercap_flag

        self.move_plant()

        ## Save plant status to HDF file, ###ADD BACK
        lfp_plant_data = self.lfp_plant.get_data_to_save()
        for key in lfp_plant_data:
            self.task_data[key] = lfp_plant_data[key]

        super(LFP_Mod, self)._cycle()

    def move_plant(self):
        feature_data = self.get_features()

        # Save the "neural features" (e.g. spike counts vector) to HDF file
        for key, val in feature_data.items():
            self.task_data[key] = val
        Bu = None
        assist_weight = 0
        target_state = np.zeros([self.decoder.n_states, self.decoder.n_subbins])

        ## Run the decoder
        if self.state not in self.static_states:
            neural_features = feature_data[self.extractor.feature_type]
            self.call_decoder(neural_features, target_state, Bu=Bu, assist_level=assist_weight, feature_type=self.extractor.feature_type)

        ## Drive the plant to the decoded state, if permitted by the constraints of the plant
        self.lfp_plant.drive(self.decoder)
        self.task_data['decoder_state'] = decoder_state = self.decoder.get_state(shape=(-1,1))
        return decoder_state     

    def run(self):
        '''
        See experiment.Experiment.run for documentation. 
        '''
        # Fire up the plant. For virtual/simulation plants, this does little/nothing.
        self.lfp_plant.start()
        try:
            super(LFP_Mod, self).run()
        finally:
            self.lfp_plant.stop()

    ##### HELPER AND UPDATE FUNCTIONS ####
    def update_cursor_visibility(self):
        ''' Update cursor visible flag to hide cursor if there has been no good data for more than 3 frames in a row'''
        prev = self.cursor_visible
        if self.no_data_count < 3:
            self.cursor_visible = True
            if prev != self.cursor_visible:
                self.show_object(self.cursor, show=True)
        else:
            self.cursor_visible = False
            if prev != self.cursor_visible:
                self.show_object(self.cursor, show=False)

    def update_report_stats(self):
        '''
        see experiment.Experiment.update_report_stats for docs
        '''
        super(LFP_Mod, self).update_report_stats()
        self.reportstats['Trial #'] = self.calc_trial_num()
        self.reportstats['Reward/min'] = np.round(self.calc_events_per_min('reward', 120), decimals=2)

    #### TEST FUNCTIONS ####
    def _test_powercap_penalty(self, ts):
        if self.decoder.filt.current_powercap_flag:
            #Turn off power cap flag:
            self.decoder.filt.current_powercap_flag = 0
            return True
        else:
            return False


    def _test_enter_lfp_target(self, ts):
        '''
        return true if the distance between center of cursor and target is smaller than the cursor radius in the x and z axis only
        '''
        cursor_pos = self.lfp_plant.get_endpoint_pos()
        dx = np.linalg.norm(cursor_pos[0] - self.target_location_lfp[0])
        dz = np.linalg.norm(cursor_pos[2] - self.target_location_lfp[2])
        in_targ = False
        if dx<= (self.lfp_target_rad/2.) and dz<= (self.lfp_target_rad/2.):
            in_targ = True

        return in_targ

        # #return d <= (self.lfp_target_rad - self.lfp_cursor_rad)

        # #If center of cursor enters target at all: 
        # return d <= (self.lfp_target_rad/2.)

        # #New version: 
        # cursor_pos = self.lfp_plant.get_endpoint_pos()
        # d = np.linalg.norm(cursor_pos[2] - self.target_location_lfp[2])
        # d <= (self.lfp_target_rad - self.lfp_cursor_rad)
        
    def _test_leave_early(self, ts):
        '''
        return true if cursor moves outside the exit radius
        '''
        cursor_pos = self.lfp_plant.get_endpoint_pos()
        dx = np.linalg.norm(cursor_pos[0] - self.target_location_lfp[0])
        dz = np.linalg.norm(cursor_pos[2] - self.target_location_lfp[2])
        out_of_targ = False
        if dx > (self.lfp_target_rad/2.) or dz > (self.lfp_target_rad/2.):
            out_of_targ = True
        #rad = self.lfp_target_rad - self.lfp_cursor_rad
        #return d > rad
        return out_of_targ

    def _test_lfp_hold_complete(self, ts):
        return ts>=self.lfp_hold_time_plus_var

    # def _test_lfp_timeout(self, ts):
    #     return ts>self.timeout_time

    def _test_powercap_penalty_end(self, ts):
        if ts>self.powercap_penalty_time:
            self.lfp_plant.turn_on()

        return ts>self.powercap_penalty_time

    def _test_reward_end(self, ts):
        return ts>self.reward_time

    def _test_stop(self, ts):
        if self.session_length > 0 and (self.get_time() - self.task_start_time) > self.session_length:
            self.end_task()
        return self.stop

    #### STATE FUNCTIONS ####
    def _parse_next_trial(self):
        self.targs = self.next_trial
        
    def _start_wait(self):
        super(LFP_Mod, self)._start_wait()
        self.tries = 0
        self.target_index = -1
        #hide targets
        for target in self.targets:
            target.hide()

        #get target locations for this trial
        self._parse_next_trial()
        self.chain_length = 1
        self.lfp_hold_time_plus_var = self.lfp_hold_time + np.random.uniform(low=-1,high=1)*self.lfp_hold_var

    def _start_lfp_target(self):
        self.target_index += 1
        self.target_index = 0

        #only 1 target: 
        target = self.targets[0]
        self.target_location_lfp = self.targs #Just one target. 
        
        target.move_to_position(self.target_location_lfp)
        target.cue_trial_start()

    def _start_lfp_hold(self):
        #make next target visible unless this is the final target in the trial
        idx = (self.target_index + 1)
        if idx < self.chain_length: 
            target = self.targets[idx % 2]
            target.move_to_position(self.targs[idx])
    
    def _end_lfp_hold(self):
        # change current target color to green
        self.targets[self.target_index % 2].cue_trial_end_success()
    
    def _start_timeout_penalty(self):
        #hide targets
        for target in self.targets:
            target.hide()

        self.tries += 1
        self.target_index = -1

    def _start_reward(self):
        super(LFP_Mod, self)._start_reward()
        #self.targets[self.target_index % 2].show()

    def _start_powercap_penalty(self):
        for target in self.targets:
            target.hide()
        self.lfp_plant.turn_off()

    @staticmethod
    def lfp_mod_4targ(nblocks=100, boundaries=(-18,18,-12,12), xaxis=-8):
        '''Mimics beta modulation task from Kinarm Rig:

        In Kinarm rig, the following linear transformations happen: 
            1. LFP cursor is calculated
            2. mapped from fraction limits [0, .35] to [-1, 1] (unit_coordinates)
            3. udp sent to kinarm machine and multiplied by 8
            4. translated upward in the Y direction by + 2.5

        This means, our targets which are at -8, [-0.75, 2.5, 5.75, 9.0]
        must be translated down by 2.5 to: -8, [-3.25,  0.  ,  3.25,  6.5]
        then divided by 8: -1, [-0.40625,  0.     ,  0.40625,  0.8125 ] in unit_coordinates

        The radius is 1.2, which is 0.15 in unit_coordinates

        Now, we map this to a new system: 
        - new_zero: (y1+y2) / 2
        - new_scale: (y2 - y1) / 2

         (([-0.40625,  0.     ,  0.40625,  0.8125 ]) * new_scale ) + new_zero
        
        new_zero = 0
        new_scale = 12

        12 * [-0.40625,  0.     ,  0.40625,  0.8125 ] 

        = array([-4.875,  0.   ,  4.875,  9.75 ])

        '''

        new_zero = (boundaries[3]+boundaries[2]) / 2.
        new_scale = (boundaries[3] - boundaries[2]) / 2.

        kin_targs = np.array([-0.40625,  0.     ,  0.40625,  0.8125 ])

        lfp_targ_y = (new_scale*kin_targs) + new_zero

        for i in range(nblocks):
            temp = lfp_targ_y.copy()
            np.random.shuffle(temp)
            if i==0:
                z = temp.copy()
            else:
                z = np.hstack((z, temp))

        #Fixed X axis: 
        x = np.tile(xaxis,(nblocks*4))
        y = np.zeros(nblocks*4)
        
        pairs = np.vstack([x, y, z]).T
        return pairs
class LFP_Mod_plus_MC_reach(LFP_Mod_plus_MC_hold):
    mc_cursor_radius = traits.Float(.5, desc="Radius of cursor")
    mc_target_radius = traits.Float(3, desc="Radius of MC target")
    mc_cursor_color = (.5,0,.5,1)
    mc_plant_type_options = plantlist.keys()
    mc_plant_type = traits.OptionsList(*plantlist, bmi3d_input_options=plantlist.keys())
    origin_hold_time = traits.Float(.2, desc="Hold time in center")
    mc_periph_holdtime = traits.Float(.2, desc="Hold time in center")
    mc_timeout_time = traits.Float(10, desc="Time allowed to go between targets")
    exclude_parent_traits = ['goal_cache_block'] #redefine this to NOT include marker_num, marker_count
    marker_num = traits.Int(14,desc='Index')
    marker_count = traits.Int(16,desc='Num of markers')

    scale_factor = 3.0 #scale factor for converting hand movement to screen movement (1cm hand movement = 3.5cm cursor movement)
    wait_flag = 1
    # NOTE!!! The marker on the hand was changed from #0 to #14 on
    # 5/19/13 after LED #0 broke. All data files saved before this date
    # have LED #0 controlling the cursor.
    limit2d = 1
    #   state_file = open("/home/helene/preeya/tot_pw.txt","w")
    state_cnt = 0

    status = dict(
        wait = dict(start_trial="origin", stop=None),
        origin = dict(enter_origin="origin_hold", stop=None),
        origin_hold = dict(origin_hold_complete="lfp_target",leave_origin="hold_penalty", stop=None),
        lfp_target = dict(enter_lfp_target="lfp_hold", leave_origin="hold_penalty", powercap_penalty="powercap_penalty", stop=None),
        lfp_hold = dict(leave_early="lfp_target", lfp_hold_complete="mc_target", leave_origin="hold_penalty",powercap_penalty="powercap_penalty"),
        mc_target = dict(enter_mc_target='mc_hold',mc_timeout="timeout_penalty", stop=None),
        mc_hold = dict(leave_periph_early='hold_penalty',mc_hold_complete="reward"),
        powercap_penalty = dict(powercap_penalty_end="origin"),
        timeout_penalty = dict(timeout_penalty_end="wait"),
        hold_penalty = dict(hold_penalty_end="origin"),
        reward = dict(reward_end="wait"),
    )

    
    static_states = ['origin'] # states in which the decoder is not run
    trial_end_states = ['reward', 'timeout_penalty']
    lfp_cursor_on = ['lfp_target', 'lfp_hold', 'reward']

    sequence_generators = ['lfp_mod_plus_MC_reach', 'lfp_mod_plus_MC_reach_INV']

    def __init__(self, *args, **kwargs):
        # import pickle
        # decoder = pickle.load(open('/storage/decoders/cart20141216_03_cart_new2015_2.pkl'))
        # self.decoder = decoder
        super(LFP_Mod_plus_MC_reach, self).__init__(*args, **kwargs)

        mc_origin = VirtualCircularTarget(target_radius=self.mc_target_radius, target_color=RED)
        mc_periph = VirtualCircularTarget(target_radius=self.mc_target_radius, target_color=RED)
        lfp_target = VirtualSquareTarget(target_radius=self.lfp_target_rad, target_color=self.lfp_target_color)

        self.targets = [lfp_target, mc_origin, mc_periph]

        # #Should be unnecessary: 
        # for target in self.targets:
        #     for model in target.graphics_models:
        #         self.add_model(model)

        # self.lfp_plant = plantlist[self.lfp_plant_type] 
        # if hasattr(self.lfp_plant, 'graphics_models'):
        #     for model in self.lfp_plant.graphics_models:
        #         self.add_model(model)

        # self.mc_plant = plantlist[self.mc_plant_type]
        # if hasattr(self.mc_plant, 'graphics_models'):
        #     for model in self.mc_plant.graphics_models:
        #         self.add_model(model)

    def _parse_next_trial(self):
        t = self.next_trial
        self.lfp_targ = t['lfp']
        self.mc_targ_orig = t['origin']
        self.mc_targ_periph = t['periph']

    def _start_mc_target(self):
        #Turn off LFP things
        self.lfp_plant.turn_off()
        self.targets[0].hide()
        self.targets[1].hide()

        target = self.targets[2] #MC target
        self.target_location_mc = self.mc_targ_periph
        
        target.move_to_position(self.target_location_mc)
        target.cue_trial_start()

    def _test_enter_mc_target(self,ts):
        cursor_pos = self.mc_plant.get_endpoint_pos()
        d = np.linalg.norm(cursor_pos - self.target_location_mc)
        return d <= (self.mc_target_radius - self.mc_cursor_radius)

    def _test_mc_timeout(self, ts):
        return ts>self.mc_timeout_time

    def _test_leave_periph_early(self, ts):
        cursor_pos = self.mc_plant.get_endpoint_pos()
        d = np.linalg.norm(cursor_pos - self.target_location_mc)
        rad = self.mc_target_radius - self.mc_cursor_radius
        return d > rad

    def _test_mc_hold_complete(self, ts):
        return ts>self.mc_periph_holdtime

    def _timeout_penalty_end(self, ts):
        print 'timeout', ts
        #return ts > 1.
        return True

    def _end_mc_hold(self):
        self.targets[2].cue_trial_end_success()

    # def _cycle(self):
    #     if self.state_cnt < 3600*3:
    #         self.state_cnt +=1
    #         s = "%s\n" % self.state
    #         self.state_file.write(str(s))

    #     if self.state_cnt == 3600*3:
    #         self.state_file.close()

    #     super(LFP_Mod_plus_MC_reach, self)._cycle()

    def _start_reward(self):
        super(LFP_Mod_plus_MC_reach, self)._start_reward()
        lfp_targ = self.targets[0]
        mc_orig = self.targets[1]
        lfp_targ.hide()
        mc_orig.hide()

    @staticmethod
    def lfp_mod_plus_MC_reach(nblocks=100, boundaries=(-18,18,-12,12), xaxis=-8, target_distance=6, n_mc_targets=4, mc_target_angle_offset=0,**kwargs):
        new_zero = (boundaries[3]+boundaries[2]) / 2.
        new_scale = (boundaries[3] - boundaries[2]) / 2.
        kin_targs = np.array([-0.40625,  0.     ,  0.40625,  0.8125 ])
        lfp_targ_y = (new_scale*kin_targs) + new_zero

        for i in range(nblocks):
            temp = lfp_targ_y.copy()
            np.random.shuffle(temp)
            if i==0:
                z = temp.copy()
            else:
                z = np.hstack((z, temp))

        #Fixed X axis: 
        x = np.tile(xaxis,(nblocks*4))
        y = np.zeros(nblocks*4)
        lfp = np.vstack([x, y, z]).T
        origin = np.zeros(( lfp.shape ))

        theta = []
        for i in range(nblocks*4):
            temp = np.arange(0, 2*np.pi, 2*np.pi/float(n_mc_targets))
            np.random.shuffle(temp)
            theta = theta + [temp]
        theta = np.hstack(theta)
        theta = theta + (mc_target_angle_offset*(np.pi/180.))
        x = target_distance*np.cos(theta)
        y = np.zeros(len(theta))
        z = target_distance*np.sin(theta)
        periph = np.vstack([x, y, z]).T
        it = iter([dict(lfp=lfp[i,:], origin=origin[i,:], periph=periph[i,:]) for i in range(lfp.shape[0])])
        
        if ('return_arrays' in kwargs.keys()) and kwargs['return_arrays']==True:
            return lfp, origin, periph
        else:
            return it

    @staticmethod
    def lfp_mod_plus_MC_reach_INV(nblocks=100, boundaries=(-18,18,-12,12), xaxis=-8, target_distance=6, n_mc_targets=4, mc_target_angle_offset=0):
        kw = dict(return_arrays=True)
        lfp, origin, periph = LFP_Mod_plus_MC_reach.lfp_mod_plus_MC_reach(nblocks=nblocks, boundaries=boundaries, xaxis=xaxis, target_distance=target_distance, 
            n_mc_targets=n_mc_targets, mc_target_angle_offset=mc_target_angle_offset,**kw)

        #Invert LFP:
        lfp[:,2] = -1.0*lfp[:,2] 
        it = iter([dict(lfp=lfp[i,:], origin=origin[i,:], periph=periph[i,:]) for i in range(lfp.shape[0])])
        return it
class LFP_Mod_plus_MC_hold(LFP_Mod):

    mc_cursor_radius = traits.Float(.5, desc="Radius of cursor")
    mc_target_radius = traits.Float(3, desc="Radius of MC target")
    mc_cursor_color = (.5,0,.5,1)
    mc_plant_type_options = plantlist.keys()
    mc_plant_type = traits.OptionsList(*plantlist, bmi3d_input_options=plantlist.keys())
    origin_hold_time = traits.Float(.2, desc="Hold time in center")
    exclude_parent_traits = ['goal_cache_block'] #redefine this to NOT include marker_num, marker_count
    marker_num = traits.Int(14,desc='Index')
    marker_count = traits.Int(16,desc='Num of markers')
    joystick_method = traits.Float(1,desc="1: Normal velocity, 0: Position control")
    joystick_speed = traits.Float(20, desc="Radius of cursor")
    move_while_in_center = traits.Float(1, desc="1 = update plant while in lfp_target, lfp_hold, 0 = don't update in these states")
    scale_factor = 3.0 #scale factor for converting hand movement to screen movement (1cm hand movement = 3.5cm cursor movement)
    wait_flag = 1
    # NOTE!!! The marker on the hand was changed from #0 to #14 on
    # 5/19/13 after LED #0 broke. All data files saved before this date
    # have LED #0 controlling the cursor.
    limit2d = 1

    status = dict(
        wait = dict(start_trial="origin", stop=None),
        origin = dict(enter_origin="origin_hold", stop=None),
        origin_hold = dict(origin_hold_complete="lfp_target",leave_origin="hold_penalty", stop=None),
        lfp_target = dict(enter_lfp_target="lfp_hold", leave_origin="hold_penalty", powercap_penalty="powercap_penalty", stop=None),
        lfp_hold = dict(leave_early="lfp_target", lfp_hold_complete="reward", leave_origin="hold_penalty", powercap_penalty="powercap_penalty",stop=None),
        powercap_penalty = dict(powercap_penalty_end="origin"),
        hold_penalty = dict(hold_penalty_end="origin",stop=None),
        reward = dict(reward_end="wait")
    )

    static_states = ['origin'] # states in which the decoder is not run
    trial_end_states = ['reward']
    lfp_cursor_on = ['lfp_target', 'lfp_hold', 'reward']

    sequence_generators = ['lfp_mod_4targ_plus_mc_orig']


    def __init__(self, *args, **kwargs):
        super(LFP_Mod_plus_MC_hold, self).__init__(*args, **kwargs)
        if self.move_while_in_center>0:
            self.no_plant_update_states = []
        else:
            self.no_plant_update_states = ['lfp_target', 'lfp_hold']

        mc_origin = VirtualCircularTarget(target_radius=self.mc_target_radius, target_color=RED)
        lfp_target = VirtualSquareTarget(target_radius=self.lfp_target_rad, target_color=self.lfp_target_color)

        self.targets = [lfp_target, mc_origin]

        self.mc_plant = plantlist[self.mc_plant_type]
        if hasattr(self.mc_plant, 'graphics_models'):
            for model in self.mc_plant.graphics_models:
                self.add_model(model)

        # Declare any plant attributes which must be saved to the HDF file at the _cycle rate
        for attr in self.mc_plant.hdf_attrs:
            self.add_dtype(*attr) 

        self.target_location_mc = np.array([-100, -100, -100])
        self.manual_control_type = None

        self.current_pt=np.zeros([3]) #keep track of current pt
        self.last_pt=np.zeros([3])

    def init(self):
        self.add_dtype('mc_targ', 'f8', (3,)) ###ADD BACK
        super(LFP_Mod_plus_MC_hold, self).init()


    def _cycle(self):
        '''
        Calls any update functions necessary and redraws screen. Runs 60x per second.
        '''
        self.task_data['mc_targ'] = self.target_location_mc.copy()


        mc_plant_data = self.mc_plant.get_data_to_save()
        for key in mc_plant_data:
            self.task_data[key] = mc_plant_data[key]

        super(LFP_Mod_plus_MC_hold, self)._cycle()


    def _parse_next_trial(self):
        t = self.next_trial
        self.lfp_targ = t['lfp']
        self.mc_targ_orig = t['origin']

    def _start_origin(self):
        if self.wait_flag:
            self.origin_hold_time_store = self.origin_hold_time
            self.origin_hold_time = 3
            self.wait_flag = 0
        else:
            self.origin_hold_time = self.origin_hold_time_store
        #only 1 target: 
        target = self.targets[1] #Origin
        self.target_location_mc = self.mc_targ_orig #Origin 
        
        target.move_to_position(self.target_location_mc)
        target.cue_trial_start()

        #Turn off lfp things
        self.lfp_plant.turn_off()
        self.targets[0].hide()

    def _start_lfp_target(self):
        #only 1 target: 
        target = self.targets[0] #LFP target
        self.target_location_lfp = self.lfp_targ #LFP target
        
        target.move_to_position(self.target_location_lfp)
        target.cue_trial_start()

        self.lfp_plant.turn_on()

    def _start_lfp_hold(self):
        #make next target visible unless this is the final target in the trial
        pass

    def _start_hold_penalty(self):
        #hide targets
        for target in self.targets:
            target.hide()

        self.tries += 1
        self.target_index = -1

        #Turn off lfp things
        self.lfp_plant.turn_off()
        self.targets[0].hide()

    def _end_origin(self):
        self.targets[1].cue_trial_end_success()

    def _test_enter_origin(self, ts):
        cursor_pos = self.mc_plant.get_endpoint_pos()
        d = np.linalg.norm(cursor_pos - self.target_location_mc)
        return d <= (self.mc_target_radius - self.mc_cursor_radius)

    # def _test_origin_timeout(self, ts):
    #     return ts>self.timeout_time

    def _test_leave_origin(self, ts):
        if self.manual_control_type == 'joystick':
            if hasattr(self,'touch'):
                if self.touch <0.5:
                    self.last_touch_zero_event = time.time()
                    return True

        cursor_pos = self.mc_plant.get_endpoint_pos()
        d = np.linalg.norm(cursor_pos - self.target_location_mc)
        return d > (self.mc_target_radius - self.mc_cursor_radius)

    def _test_origin_hold_complete(self,ts):
        return ts>=self.origin_hold_time

    # def _test_enter_lfp_target(self, ts):
    #     '''
    #     return true if the distance between center of cursor and target is smaller than the cursor radius
    #     '''
    #     cursor_pos = self.lfp_plant.get_endpoint_pos()
    #     cursor_pos = [cursor_pos[0], cursor_pos[2]]
    #     targ_loc = np.array([self.target_location_lfp[0], self.target_location_lfp[2]])


    #     d = np.linalg.norm(cursor_pos - targ_loc)
    #     return d <= (self.lfp_target_rad - self.lfp_cursor_rad)

    # def _test_leave_early(self, ts):
    #     '''
    #     return true if cursor moves outside the exit radius
    #     '''
    #     cursor_pos = self.lfp_plant.get_endpoint_pos()
    #     d = np.linalg.norm(cursor_pos - self.target_location_lfp)
    #     rad = self.lfp_target_rad - self.lfp_cursor_rad
    #     return d > rad

    def _test_hold_penalty_end(self, ts):
        return ts>self.hold_penalty_time

    def _end_lfp_hold(self):
        # change current target color to green
        self.targets[0].cue_trial_end_success()


    def move_plant(self):
        if self.state in self.lfp_cursor_on:
            feature_data = self.get_features()


            # Save the "neural features" (e.g. spike counts vector) to HDF file
            for key, val in feature_data.items():
                self.task_data[key] = val
            
            Bu = None
            assist_weight = 0
            target_state = np.zeros([self.decoder.n_states, self.decoder.n_subbins])

            ## Run the decoder
            neural_features = feature_data[self.extractor.feature_type]

            self.call_decoder(neural_features, target_state, Bu=Bu, assist_level=assist_weight, feature_type=self.extractor.feature_type)

           
            ## Drive the plant to the decoded state, if permitted by the constraints of the plant
            self.lfp_plant.drive(self.decoder)
            self.task_data['decoder_state'] = decoder_state = self.decoder.get_state(shape=(-1,1))
            #return decoder_state
           

        #Sets the plant configuration based on motiontracker data. For manual control, uses
        #motiontracker data. If no motiontracker data available, returns None'''
        
        #get data from motion tracker- take average of all data points since last poll
        if self.state in self.no_plant_update_states:
            pt = np.array([0, 0, 0])
            print 'no update'
        else:
            if self.manual_control_type == 'motiondata':
                pt = self.motiondata.get()
                if len(pt) > 0:
                    pt = pt[:, self.marker_num, :]
                    conds = pt[:, 3]
                    inds = np.nonzero((conds>=0) & (conds!=4))[0]
                    if len(inds) > 0:
                        pt = pt[inds,:3]

                        #scale actual movement to desired amount of screen movement
                        pt = pt.mean(0) * self.scale_factor
                        #Set y coordinate to 0 for 2D tasks
                        if self.limit2d: 
                            #pt[1] = 0

                            pt[2] = pt[1].copy()
                            pt[1] = 0


                        pt[1] = pt[1]*2
                        # Return cursor location
                        self.no_data_count = 0
                        pt = pt * mm_per_cm #self.convert_to_cm(pt)
                    else: #if no usable data
                        self.no_data_count += 1
                        pt = None
                else: #if no new data
                    self.no_data_count +=1
                    pt = None
            
            elif self.manual_control_type == 'joystick':
                pt = self.joystick.get()
                #if touch sensor on: 
                try: 
                    self.touch = pt[-1][0][2]
                except:
                    pass

                if len(pt) > 0:
                    pt = pt[-1][0]
                    pt[0]=1-pt[0]; #Switch L / R axes
                    calib = [0.497,0.517] #Sometimes zero point is subject to drift this is the value of the incoming joystick when at 'rest' 
                    if self.joystick_method==0:
                        #pt = pt[-1][0]
                        #pt[0]=1-pt[0]; #Switch L / R axes
                        #calib = [0.497,0.517] #Sometimes zero point is subject to drift this is the value of the incoming joystick when at 'rest' 
                        # calib = [ 0.487,  0.   ]
                        
                        pos = np.array([(pt[0]-calib[0]), 0, calib[1]-pt[1]])
                        pos[0] = pos[0]*36
                        pos[2] = pos[2]*24
                        self.current_pt = pos

                    elif self.joystick_method==1:
                        vel=np.array([(pt[0]-calib[0]), 0, calib[1]-pt[1]])
                        epsilon = 2*(10**-2) #Define epsilon to stabilize cursor movement
                        if sum((vel)**2) > epsilon:
                            self.current_pt=self.last_pt+self.joystick_speed*vel*(1/60) #60 Hz update rate, dt = 1/60
                        else:
                            self.current_pt = self.last_pt

                        if self.current_pt[0] < -25: self.current_pt[0] = -25
                        if self.current_pt[0] > 25: self.current_pt[0] = 25
                        if self.current_pt[-1] < -14: self.current_pt[-1] = -14
                        if self.current_pt[-1] > 14: self.current_pt[-1] = 14
                    pt = self.current_pt

                #self.plant.set_endpoint_pos(self.current_pt)
                self.last_pt = self.current_pt.copy()
            
            elif self.manual_control_type == None:
                pt = None
                try: 
                    pt0 = self.motiondata.get()
                    self.manual_control_type='motiondata'
                except:
                    print 'not motiondata'

                try:
                    pt0 = self.joystick.get()
                    self.manual_control_type = 'joystick'
                
                except:
                    print 'not joystick data'

        # Set the plant's endpoint to the position determined by the motiontracker, unless there is no data available
        if self.manual_control_type is not None:
            if pt is not None and len(pt)>0:
                self.mc_plant.set_endpoint_pos(pt)   

    @staticmethod
    def lfp_mod_4targ_plus_mc_orig(nblocks=100, boundaries=(-18,18,-12,12), xaxis=-8):
        '''
        See lfp_mod_4targ for lfp target explanation 

        '''
        new_zero = (boundaries[3]+boundaries[2]) / 2.
        new_scale = (boundaries[3] - boundaries[2]) / 2.
        kin_targs = np.array([-0.40625,  0.     ,  0.40625,  0.8125 ])
        lfp_targ_y = (new_scale*kin_targs) + new_zero

        for i in range(nblocks):
            temp = lfp_targ_y.copy()
            np.random.shuffle(temp)
            if i==0:
                z = temp.copy()
            else:
                z = np.hstack((z, temp))

        #Fixed X axis: 
        x = np.tile(xaxis,(nblocks*4))
        y = np.zeros(nblocks*4)
                
        lfp = np.vstack([x, y, z]).T
        origin = np.zeros(( lfp.shape ))

        it = iter([dict(lfp=lfp[i,:], origin=origin[i,:]) for i in range(lfp.shape[0])])
        return it
class ApproachAvoidanceTask(Sequence, Window):
    '''
    This is for a free-choice task with two targets (left and right) presented to choose from.  
    The position of the targets may change along the x-axis, according to the target generator, 
    and each target has a varying probability of reward, also according to the target generator.
    The code as it is written is for a joystick.  

    Notes: want target_index to only write once per trial.  if so, can make instructed trials random.  else, make new state for instructed trial.
    '''

    background = (0,0,0,1)
    shoulder_anchor = np.array([2., 0., -15.]) # Coordinates of shoulder anchor point on screen
    
    arm_visible = traits.Bool(True, desc='Specifies whether entire arm is displayed or just endpoint')
    
    cursor_radius = traits.Float(.5, desc="Radius of cursor")
    cursor_color = (.5,0,.5,1)

    joystick_method = traits.Float(1,desc="1: Normal velocity, 0: Position control")
    joystick_speed = traits.Float(20, desc="Speed of cursor")

    plant_type_options = plantlist.keys()
    plant_type = traits.OptionsList(*plantlist, bmi3d_input_options=plantlist.keys())
    starting_pos = (5, 0, 5)
    # window_size = (1280*2, 1024)
    window_size = traits.Tuple((1366*2, 768), desc='window size')
    

    status = dict(
        #wait = dict(start_trial="target", stop=None),
        wait = dict(start_trial="center", stop=None),
        center = dict(enter_center="hold_center", timeout="timeout_penalty", stop=None),
        hold_center = dict(leave_early_center = "hold_penalty",hold_center_complete="target", timeout="timeout_penalty", stop=None),
        target = dict(enter_targetL="hold_targetL", enter_targetH = "hold_targetH", timeout="timeout_penalty", stop=None),
        hold_targetR = dict(leave_early_R="hold_penalty", hold_complete="targ_transition"),
        hold_targetL = dict(leave_early_L="hold_penalty", hold_complete="targ_transition"),
        targ_transition = dict(trial_complete="check_reward",trial_abort="wait", trial_incomplete="center"),
        check_reward = dict(avoid="reward",approach="reward_and_airpuff"),
        timeout_penalty = dict(timeout_penalty_end="targ_transition"),
        hold_penalty = dict(hold_penalty_end="targ_transition"),
        reward = dict(reward_end="wait"),
        reward_and_airpuff = dict(reward_and_airpuff_end="wait"),
    )
    #
    target_color = (.5,1,.5,0)

    #initial state
    state = "wait"

    #create settable traits
    reward_time_avoid = traits.Float(.2, desc="Length of juice reward for avoid decision")
    reward_time_approach_min = traits.Float(.2, desc="Min length of juice for approach decision")
    reward_time_approach_max = traits.Float(.8, desc="Max length of juice for approach decision")
    target_radius = traits.Float(1.5, desc="Radius of targets in cm")
    block_length = traits.Float(100, desc="Number of trials per block")  
    
    hold_time = traits.Float(.5, desc="Length of hold required at targets")
    hold_penalty_time = traits.Float(1, desc="Length of penalty time for target hold error")
    timeout_time = traits.Float(10, desc="Time allowed to go between targets")
    timeout_penalty_time = traits.Float(1, desc="Length of penalty time for timeout error")
    max_attempts = traits.Int(10, desc='The number of attempts at a target before\
        skipping to the next one')
    session_length = traits.Float(0, desc="Time until task automatically stops. Length of 0 means no auto stop.")
    marker_num = traits.Int(14, desc="The index of the motiontracker marker to use for cursor position")
   
    arm_hide_rate = traits.Float(0.0, desc='If the arm is visible, specifies a percentage of trials where it will be hidden')
    target_index = 0 # Helper variable to keep track of whether trial is instructed (1 = 1 choice) or free-choice (2 = 2 choices)
    target_selected = 'L'   # Helper variable to indicate which target was selected
    tries = 0 # Helper variable to keep track of the number of failed attempts at a given trial.
    timedout = False    # Helper variable to keep track if transitioning from timeout_penalty
    reward_counter = 0.0
    cursor_visible = False # Determines when to hide the cursor.
    no_data_count = 0 # Counter for number of missing data frames in a row
    scale_factor = 3.0 #scale factor for converting hand movement to screen movement (1cm hand movement = 3.5cm cursor movement)
    starting_dist = 10.0 # starting distance from center target
    #color_targets = np.random.randint(2)
    color_targets = 1   # 0: yellow low, blue high; 1: blue low, yellow high
    stopped_center_hold = False   #keep track if center hold was released early
    
    limit2d = 1

    color1 = target_colors['purple']  			# approach color
    color2 = target_colors['lightsteelblue']  	# avoid color
    reward_color = target_colors['green'] 		# color of reward bar
    airpuff_color = target_colors['red']		# color of airpuff bar

    sequence_generators = ['colored_targets_with_probabilistic_reward','block_probabilistic_reward','colored_targets_with_randomwalk_reward','randomwalk_probabilistic_reward']
    
    def __init__(self, *args, **kwargs):
        super(ApproachAvoidanceTask, self).__init__(*args, **kwargs)
        self.cursor_visible = True

        # Add graphics models for the plant and targets to the window

        self.plant = plantlist[self.plant_type]
        self.plant_vis_prev = True

        # Add graphics models for the plant and targets to the window
        if hasattr(self.plant, 'graphics_models'):
            for model in self.plant.graphics_models:
                self.add_model(model)

        self.current_pt=np.zeros([3]) #keep track of current pt
        self.last_pt=np.zeros([3]) #kee
        ## Declare cursor
        #self.dtype.append(('cursor', 'f8', (3,)))
        if 0: #hasattr(self.arm, 'endpt_cursor'):
            self.cursor = self.arm.endpt_cursor
        else:
            self.cursor = Sphere(radius=self.cursor_radius, color=self.cursor_color)
            self.add_model(self.cursor)
            self.cursor.translate(*self.get_arm_endpoint(), reset=True) 

        ## Instantiate the targets. Target 1 is center target, Target H is target with high probability of reward, Target L is target with low probability of reward.
        self.target1 = Sphere(radius=self.target_radius, color=self.target_color)           # center target
        self.add_model(self.target1)
        self.targetR = Sphere(radius=self.target_radius, color=self.target_color)           # left target
        self.add_model(self.targetH)
        self.targetL = Sphere(radius=self.target_radius, color=self.target_color)           # right target
        self.add_model(self.targetL)

        ###STOPPED HERE: should define Rect target here and then adapt length during task. Also, 
        ### be sure to change all targetH instantiations to targetR.

        # Initialize target location variable. 
        self.target_location1 = np.array([0,0,0])
        self.target_locationR = np.array([-self.starting_dist,0,0])
        self.target_locationL = np.array([self.starting_dist,0,0])

        self.target1.translate(*self.target_location1, reset=True)
        self.targetH.translate(*self.target_locationR, reset=True)
        self.targetL.translate(*self.target_locationL, reset=True)

        # Initialize colors for high probability and low probability target.  Color will not change.
        self.targetH.color = self.color_targets*self.color1 + (1 - self.color_targets)*self.color2 # high is magenta if color_targets = 1, juicyorange otherwise
        self.targetL.color = (1 - self.color_targets)*self.color1 + self.color_targets*self.color2

        #set target colors 
        self.target1.color = (1,0,0,.5)      # center target red
        
        
        # Initialize target location variable
        self.target_location = np.array([0, 0, 0])

        # Declare any plant attributes which must be saved to the HDF file at the _cycle rate
        for attr in self.plant.hdf_attrs:
            self.add_dtype(*attr)  


    def init(self):
        self.add_dtype('targetR', 'f8', (3,))
        self.add_dtype('targetL','f8', (3,))
        self.add_dtype('reward_scheduleR','f8', (1,))
        self.add_dtype('reward_scheduleL','f8', (1,)) 
        self.add_dtype('target_index', 'i', (1,))
        super(ApproachAvoidanceTask, self).init()
        self.trial_allocation = np.zeros(1000)

    def _cycle(self):
        ''' Calls any update functions necessary and redraws screen. Runs 60x per second. '''

        ## Run graphics commands to show/hide the arm if the visibility has changed
        if self.plant_type != 'cursor_14x14':
            if self.arm_visible != self.arm_vis_prev:
                self.arm_vis_prev = self.arm_visible
                self.show_object(self.arm, show=self.arm_visible)

        self.move_arm()
        #self.move_plant()

        ## Save plant status to HDF file
        plant_data = self.plant.get_data_to_save()
        for key in plant_data:
            self.task_data[key] = plant_data[key]

        self.update_cursor()

        if self.plant_type != 'cursor_14x14':
            self.task_data['joint_angles'] = self.get_arm_joints()

        super(ApproachAvoidanceTask, self)._cycle()
        
    ## Plant functions
    def get_cursor_location(self):
        # arm returns it's position as if it was anchored at the origin, so have to translate it to the correct place
        return self.get_arm_endpoint()

    def get_arm_endpoint(self):
        return self.plant.get_endpoint_pos() 

    def set_arm_endpoint(self, pt, **kwargs):
        self.plant.set_endpoint_pos(pt, **kwargs) 

    def set_arm_joints(self, angles):
        self.arm.set_intrinsic_coordinates(angles)

    def get_arm_joints(self):
        return self.arm.get_intrinsic_coordinates()

    def update_cursor(self):
        '''
        Update the cursor's location and visibility status.
        '''
        pt = self.get_cursor_location()
        self.update_cursor_visibility()
        if pt is not None:
            self.move_cursor(pt)

    def move_cursor(self, pt):
        ''' Move the cursor object to the specified 3D location. '''
        # if not hasattr(self.arm, 'endpt_cursor'):
        self.cursor.translate(*pt[:3],reset=True)

    ##    


    ##### HELPER AND UPDATE FUNCTIONS ####

    def move_arm(self):
        ''' Returns the 3D coordinates of the cursor. For manual control, uses
        joystick data. If no joystick data available, returns None'''

        pt = self.joystick.get()
        if len(pt) > 0:
            pt = pt[-1][0]
            pt[0]=1-pt[0]; #Switch L / R axes
            calib = [0.497,0.517] #Sometimes zero point is subject to drift this is the value of the incoming joystick when at 'rest' 

            if self.joystick_method==0:                
                pos = np.array([(pt[0]-calib[0]), 0, calib[1]-pt[1]])
                pos[0] = pos[0]*36
                pos[2] = pos[2]*24
                self.current_pt = pos

            elif self.joystick_method==1:
                vel=np.array([(pt[0]-calib[0]), 0, calib[1]-pt[1]])
                epsilon = 2*(10**-2) #Define epsilon to stabilize cursor movement
                if sum((vel)**2) > epsilon:
                    self.current_pt=self.last_pt+self.joystick_speed*vel*(1/60) #60 Hz update rate, dt = 1/60
                else:
                    self.current_pt = self.last_pt

                if self.current_pt[0] < -25: self.current_pt[0] = -25
                if self.current_pt[0] > 25: self.current_pt[0] = 25
                if self.current_pt[-1] < -14: self.current_pt[-1] = -14
                if self.current_pt[-1] > 14: self.current_pt[-1] = 14

            self.set_arm_endpoint(self.current_pt)
            self.last_pt = self.current_pt.copy()

    def convert_to_cm(self, val):
        return val/10.0

    def update_cursor_visibility(self):
        ''' Update cursor visible flag to hide cursor if there has been no good data for more than 3 frames in a row'''
        prev = self.cursor_visible
        if self.no_data_count < 3:
            self.cursor_visible = True
            if prev != self.cursor_visible:
            	self.show_object(self.cursor, show=True)
            	self.requeue()
        else:
            self.cursor_visible = False
            if prev != self.cursor_visible:
            	self.show_object(self.cursor, show=False)
            	self.requeue()

    def calc_n_successfultrials(self):
        trialendtimes = np.array([state[1] for state in self.state_log if state[0]=='check_reward'])
        return len(trialendtimes)

    def calc_n_rewards(self):
        rewardtimes = np.array([state[1] for state in self.state_log if state[0]=='reward'])
        return len(rewardtimes)

    def calc_trial_num(self):
        '''Calculates the current trial count: completed + aborted trials'''
        trialtimes = [state[1] for state in self.state_log if state[0] in ['wait']]
        return len(trialtimes)-1

    def calc_targetH_num(self):
        '''Calculates the total number of times the high-value target was selected'''
        trialtimes = [state[1] for state in self.state_log if state[0] in ['hold_targetH']]
        return len(trialtimes) - 1

    def calc_rewards_per_min(self, window):
        '''Calculates the Rewards/min for the most recent window of specified number of seconds in the past'''
        rewardtimes = np.array([state[1] for state in self.state_log if state[0]=='reward'])
        if (self.get_time() - self.task_start_time) < window:
            divideby = (self.get_time() - self.task_start_time)/sec_per_min
        else:
            divideby = window/sec_per_min
        return np.sum(rewardtimes >= (self.get_time() - window))/divideby

    def calc_success_rate(self, window):
        '''Calculates the rewarded trials/initiated trials for the most recent window of specified length in sec'''
        trialtimes = np.array([state[1] for state in self.state_log if state[0] in ['reward', 'timeout_penalty', 'hold_penalty']])
        rewardtimes = np.array([state[1] for state in self.state_log if state[0]=='reward'])
        if len(trialtimes) == 0:
            return 0.0
        else:
            return float(np.sum(rewardtimes >= (self.get_time() - window)))/np.sum(trialtimes >= (self.get_time() - window))

    def update_report_stats(self):
        '''Function to update any relevant report stats for the task. Values are saved in self.reportstats,
        an ordered dictionary. Keys are strings that will be displayed as the label for the stat in the web interface,
        values can be numbers or strings. Called every time task state changes.'''
        super(ApproachAvoidanceTask, self).update_report_stats()
        self.reportstats['Trial #'] = self.calc_trial_num()
        self.reportstats['Reward/min'] = np.round(self.calc_rewards_per_min(120),decimals=2)
        self.reportstats['High-value target selections'] = self.calc_targetH_num()
        #self.reportstats['Success rate'] = str(np.round(self.calc_success_rate(120)*100.0,decimals=2)) + '%'
        start_time = self.state_log[0][1]
        rewardtimes=np.array([state[1] for state in self.state_log if state[0]=='reward'])
        if len(rewardtimes):
            rt = rewardtimes[-1]-start_time
        else:
            rt= np.float64("0.0")

        sec = str(np.int(np.mod(rt,60)))
        if len(sec) < 2:
            sec = '0'+sec
        self.reportstats['Time Of Last Reward'] = str(np.int(np.floor(rt/60))) + ':' + sec



    #### TEST FUNCTIONS ####
    def _test_enter_center(self, ts):
        #return true if the distance between center of cursor and target is smaller than the cursor radius

        d = np.sqrt((self.cursor.xfm.move[0]-self.target_location1[0])**2 + (self.cursor.xfm.move[1]-self.target_location1[1])**2 + (self.cursor.xfm.move[2]-self.target_location1[2])**2)
        #print 'TARGET SELECTED', self.target_selected
        return d <= self.target_radius - self.cursor_radius

    def _test_enter_targetL(self, ts):
        if self.target_index == 1 and self.LH_target_on[0]==0:
            #return false if instructed trial and this target is not on
            return False
        else:
            #return true if the distance between center of cursor and target is smaller than the cursor radius
            d = np.sqrt((self.cursor.xfm.move[0]-self.target_locationL[0])**2 + (self.cursor.xfm.move[1]-self.target_locationL[1])**2 + (self.cursor.xfm.move[2]-self.target_locationL[2])**2)
            self.target_selected = 'L'
            #print 'TARGET SELECTED', self.target_selected
            return d <= self.target_radius - self.cursor_radius

    def _test_enter_targetH(self, ts):
        if self.target_index ==1 and self.LH_target_on[1]==0:
            return False
        else:
            #return true if the distance between center of cursor and target is smaller than the cursor radius
            d = np.sqrt((self.cursor.xfm.move[0]-self.target_locationH[0])**2 + (self.cursor.xfm.move[1]-self.target_locationH[1])**2 + (self.cursor.xfm.move[2]-self.target_locationH[2])**2)
            self.target_selected = 'H'
            #print 'TARGET SELECTED', self.target_selected
            return d <= self.target_radius - self.cursor_radius
    def _test_leave_early_center(self, ts):
        # return true if cursor moves outside the exit radius (gives a bit of slack around the edge of target once cursor is inside)
        d = np.sqrt((self.cursor.xfm.move[0]-self.target_location1[0])**2 + (self.cursor.xfm.move[1]-self.target_location1[1])**2 + (self.cursor.xfm.move[2]-self.target_location1[2])**2)
        rad = self.target_radius - self.cursor_radius
        return d > rad

    def _test_leave_early_L(self, ts):
        # return true if cursor moves outside the exit radius (gives a bit of slack around the edge of target once cursor is inside)
        d = np.sqrt((self.cursor.xfm.move[0]-self.target_locationL[0])**2 + (self.cursor.xfm.move[1]-self.target_locationL[1])**2 + (self.cursor.xfm.move[2]-self.target_locationL[2])**2)
        rad = self.target_radius - self.cursor_radius
        return d > rad

    def _test_leave_early_H(self, ts):
        # return true if cursor moves outside the exit radius (gives a bit of slack around the edge of target once cursor is inside)
        d = np.sqrt((self.cursor.xfm.move[0]-self.target_locationH[0])**2 + (self.cursor.xfm.move[1]-self.target_locationH[1])**2 + (self.cursor.xfm.move[2]-self.target_locationH[2])**2)
        rad = self.target_radius - self.cursor_radius
        return d > rad

    def _test_hold_center_complete(self, ts):
        return ts>=self.hold_time
    
    def _test_hold_complete(self, ts):
        return ts>=self.hold_time

    def _test_timeout(self, ts):
        return ts>self.timeout_time

    def _test_timeout_penalty_end(self, ts):
        return ts>self.timeout_penalty_time

    def _test_hold_penalty_end(self, ts):
        return ts>self.hold_penalty_time

    def _test_trial_complete(self, ts):
        #return self.target_index==self.chain_length-1
        return not self.timedout

    def _test_trial_incomplete(self, ts):
        return (not self._test_trial_complete(ts)) and (self.tries<self.max_attempts)

    def _test_trial_abort(self, ts):
        return (not self._test_trial_complete(ts)) and (self.tries==self.max_attempts)

    def _test_yes_reward(self,ts):
        if self.target_selected == 'H':
            #reward_assigned = self.targs[0,1]
            reward_assigned = self.rewardH
        else:
            #reward_assigned = self.targs[1,1]
            reward_assigned = self.rewardL
        if self.reward_SmallLarge==1:
            self.reward_time = reward_assigned*self.reward_time_large + (1 - reward_assigned)*self.reward_time_small   # update reward time if using Small/large schedule
            reward_assigned = 1    # always rewarded
        return bool(reward_assigned)

    def _test_no_reward(self,ts):
        if self.target_selected == 'H':
            #reward_assigned = self.targs[0,1]
            reward_assigned = self.rewardH
        else:
            #reward_assigned = self.targs[1,1]
            reward_assigned = self.rewardL
        if self.reward_SmallLarge==True:
            self.reward_time = reward_assigned*self.reward_time_large + (1 - reward_assigned)*self.reward_time_small   # update reward time if using Small/large schedule
            reward_assigned = 1    # always rewarded
        return bool(not reward_assigned)

    def _test_reward_end(self, ts):
        time_ended = (ts > self.reward_time)
        self.reward_counter = self.reward_counter + 1
        return time_ended

    def _test_stop(self, ts):
        if self.session_length > 0 and (time.time() - self.task_start_time) > self.session_length:
            self.end_task()
        return self.stop

    #### STATE FUNCTIONS ####

    def show_object(self, obj, show=False):
        '''
        Show or hide an object
        '''
        if show:
            obj.attach()
        else:
            obj.detach()
        self.requeue()


    def _start_wait(self):
        super(ApproachAvoidanceTask, self)._start_wait()
        self.tries = 0
        self.target_index = 0     # indicator for instructed or free-choice trial
        #hide targets
        self.show_object(self.target1, False)
        self.show_object(self.targetL, False)
        self.show_object(self.targetH, False)


        #get target positions and reward assignments for this trial
        self.targs = self.next_trial
        if self.plant_type != 'cursor_14x14' and np.random.rand() < self.arm_hide_rate:
            self.arm_visible = False
        else:
            self.arm_visible = True
        #self.chain_length = self.targs.shape[0] #Number of sequential targets in a single trial

        #self.task_data['target'] = self.target_locationH.copy()
        assign_reward = np.random.randint(0,100,size=2)
        self.rewardH = np.greater(self.targs[0,1],assign_reward[0])
        #print 'high value target reward prob', self.targs[0,1]
        self.rewardL = np.greater(self.targs[1,1],assign_reward[1])

        
        #print 'TARGET GENERATOR', self.targs[0,]
        self.task_data['targetH'] = self.targs[0,].copy()
        self.task_data['reward_scheduleH'] = self.rewardH.copy()
        self.task_data['targetL'] = self.targs[1,].copy()
        self.task_data['reward_scheduleL'] = self.rewardL.copy()
        
        self.requeue()

    def _start_center(self):

        #self.target_index += 1

        
        self.show_object(self.target1, True)
        self.show_object(self.cursor, True)
        
        # Third argument in self.targs determines if target is on left or right
        # First argument in self.targs determines if location is offset to farther distances
        offsetH = (2*self.targs[0,2] - 1)*(self.starting_dist + self.location_offset_allowed*self.targs[0,0]*4.0)
        moveH = np.array([offsetH,0,0]) 
        offsetL = (2*self.targs[1,2] - 1)*(self.starting_dist + self.location_offset_allowed*self.targs[1,0]*4.0)
        moveL = np.array([offsetL,0,0])

        self.targetL.translate(*moveL, reset=True) 
        #self.targetL.move_to_position(*moveL, reset=True)           
        ##self.targetL.translate(*self.targs[self.target_index], reset=True)
        self.show_object(self.targetL, True)
        self.target_locationL = self.targetL.xfm.move

        self.targetH.translate(*moveH, reset=True)
        #self.targetR.move_to_position(*moveR, reset=True)
        ##self.targetR.translate(*self.targs[self.target_index], reset=True)
        self.show_object(self.targetH, True)
        self.target_locationH = self.targetH.xfm.move


        # Insert instructed trials within free choice trials
        if self.trial_allocation[self.calc_trial_num()] == 1:
        #if (self.calc_trial_num() % 10) < (self.percentage_instructed_trials/10):
            self.target_index = 1    # instructed trial
            leftright_coinflip = np.random.randint(0,2)
            if leftright_coinflip == 0:
                self.show_object(self.targetL, False)
                self.LH_target_on = (0, 1)
            else:
                self.show_object(self.targetH, False)
                self.LR_coinflip = 0
                self.LH_target_on = (1, 0)
        else:
            self.target_index = 2   # free-choice trial

        self.cursor_visible = True
        self.task_data['target_index'] = self.target_index
        self.requeue()

    def _start_target(self):

    	#self.target_index += 1

        #move targets to current location and set location attribute.  Target1 (center target) position does not change.                    
        
        self.show_object(self.target1, False)
        #self.target_location1 = self.target1.xfm.move
        self.show_object(self.cursor, True)
       
        self.update_cursor()
        self.requeue()

    def _start_hold_center(self):
        self.show_object(self.target1, True)
        self.timedout = False
        self.requeue()

    def _start_hold_targetL(self):
        #make next target visible unless this is the final target in the trial
        #if 1 < self.chain_length:
            #self.targetL.translate(*self.targs[self.target_index+1], reset=True)
         #   self.show_object(self.targetL, True)
         #   self.requeue()
        self.show_object(self.targetL, True)
        self.timedout = False
        self.requeue()

    def _start_hold_targetH(self):
        #make next target visible unless this is the final target in the trial
        #if 1 < self.chain_length:
            #self.targetR.translate(*self.targs[self.target_index+1], reset=True)
         #   self.show_object(self.targetR, True)
          #  self.requeue()
        self.show_object(self.targetH, True)
        self.timedout = False
        self.requeue()

    def _end_hold_center(self):
        self.target1.radius = 0.7*self.target_radius # color target green
    
    def _end_hold_targetL(self):
        self.targetL.color = (0,1,0,0.5)    # color target green

    def _end_hold_targetH(self):
        self.targetH.color = (0,1,0,0.5)    # color target green

    def _start_hold_penalty(self):
    	#hide targets
        self.show_object(self.target1, False)
        self.show_object(self.targetL, False)
        self.show_object(self.targetH, False)
        self.timedout = True
        self.requeue()
        self.tries += 1
        #self.target_index = -1
    
    def _start_timeout_penalty(self):
    	#hide targets
        self.show_object(self.target1, False)
        self.show_object(self.targetL, False)
        self.show_object(self.targetH, False)
        self.timedout = True
        self.requeue()
        self.tries += 1
        #self.target_index = -1


    def _start_targ_transition(self):
        #hide targets

        self.show_object(self.target1, False)
        self.show_object(self.targetL, False)
        self.show_object(self.targetH, False)
        self.requeue()

    def _start_check_reward(self):
        #hide targets
        self.show_object(self.target1, False)
        self.show_object(self.targetL, False)
        self.show_object(self.targetH, False)
        self.requeue()

    def _start_reward(self):
        #super(ApproachAvoidanceTask, self)._start_reward()
        if self.target_selected == 'L':
            self.show_object(self.targetL, True)  
            #reward_assigned = self.targs[1,1]
        else:
            self.show_object(self.targetH, True)
            #reward_assigned = self.targs[0,1]
        #self.reward_counter = self.reward_counter + float(reward_assigned)
        self.requeue()

    @staticmethod
    def colored_targets_with_probabilistic_reward(length=1000, boundaries=(-18,18,-10,10,-15,15),reward_high_prob=80,reward_low_prob=40):

        """
        Generator should return array of ntrials x 2 x 3. The second dimension is for each target.
        For example, first is the target with high probability of reward, and the second 
        entry is for the target with low probability of reward.  The third dimension holds three variables indicating 
        position offset (yes/no), reward probability (fixed in this case), and location (binary returned where the
        ouput indicates either left or right).

        UPDATE: CHANGED SO THAT THE SECOND DIMENSION CARRIES THE REWARD PROBABILITY RATHER THAN THE REWARD SCHEDULE
        """

        position_offsetH = np.random.randint(2,size=(1,length))
        position_offsetL = np.random.randint(2,size=(1,length))
        location_int = np.random.randint(2,size=(1,length))

        # coin flips for reward schedules, want this to be elementwise comparison
        #assign_rewardH = np.random.randint(0,100,size=(1,length))
        #assign_rewardL = np.random.randint(0,100,size=(1,length))
        high_prob = reward_high_prob*np.ones((1,length))
        low_prob = reward_low_prob*np.ones((1,length))
        
        #reward_high = np.greater(high_prob,assign_rewardH)
        #reward_low = np.greater(low_prob,assign_rewardL)

        pairs = np.zeros([length,2,3])
        pairs[:,0,0] = position_offsetH
        #pairs[:,0,1] = reward_high
        pairs[:,0,1] = high_prob
        pairs[:,0,2] = location_int

        pairs[:,1,0] = position_offsetL
        #pairs[:,1,1] = reward_low
        pairs[:,1,1] = low_prob
        pairs[:,1,2] = 1 - location_int

        return pairs

    @staticmethod
    def block_probabilistic_reward(length=1000, boundaries=(-18,18,-10,10,-15,15),reward_high_prob=80,reward_low_prob=40):
        pairs = colored_targets_with_probabilistic_reward(length=length, boundaries=boundaries,reward_high_prob=reward_high_prob,reward_low_prob=reward_low_prob)
        return pairs

    @staticmethod
    def colored_targets_with_randomwalk_reward(length=1000,reward_high_prob=80,reward_low_prob=40,reward_high_span = 20, reward_low_span = 20,step_size_mean = 0, step_size_var = 1):

        """
        Generator should return array of ntrials x 2 x 3. The second dimension is for each target.
        For example, first is the target with high probability of reward, and the second 
        entry is for the target with low probability of reward.  The third dimension holds three variables indicating 
        position offset (yes/no), reward probability, and location (binary returned where the
        ouput indicates either left or right).  The variables reward_high_span and reward_low_span indicate the width
        of the range that the high or low reward probability are allowed to span respectively, e.g. if reward_high_prob
        is 80 and reward_high_span is 20, then the reward probability for the high value target will be bounded
        between 60 and 100 percent.
        """

        position_offsetH = np.random.randint(2,size=(1,length))
        position_offsetL = np.random.randint(2,size=(1,length))
        location_int = np.random.randint(2,size=(1,length))

        # define variables for increments: amount of increment and in which direction (i.e. increasing or decreasing)
        assign_rewardH = np.random.randn(1,length)
        assign_rewardL = np.random.randn(1,length)
        assign_rewardH_direction = np.random.randn(1,length)
        assign_rewardL_direction = np.random.randn(1,length)

        r_0_high = reward_high_prob
        r_0_low = reward_low_prob
        r_lowerbound_high = r_0_high - (reward_high_span/2)
        r_upperbound_high = r_0_high + (reward_high_span/2)
        r_lowerbound_low = r_0_low - (reward_low_span/2)
        r_upperbound_low = r_0_low + (reward_low_span/2)
        
        reward_high = np.zeros(length)
        reward_low = np.zeros(length)
        reward_high[0] = r_0_high
        reward_low[0] = r_0_low

        eps_high = assign_rewardH*step_size_mean + [2*(assign_rewardH_direction > 0) - 1]*step_size_var
        eps_low = assign_rewardL*step_size_mean + [2*(assign_rewardL_direction > 0) - 1]*step_size_var

        eps_high = eps_high.ravel()
        eps_low = eps_low.ravel()

        for i in range(1,length):
            '''
            assign_rewardH_direction = np.random.randn(1)
            assign_rewardL_direction = np.random.randn(1)
            assign_rewardH = np.random.randn(1)
            if assign_rewardH_direction[i-1,] < 0:
                eps_high = step_size_mean*assign_rewardH[i-1] - step_size_var
            else:
                eps_high = step_size_mean*assign_rewardH[i-1] + step_size_var

            if assign_rewardL_direction[i] < 0:
                eps_low = step_size_mean*assign_rewardL[i] - step_size_var
            else:
                eps_low = step_size_mean*assign_rewardL[i] + step_size_var
            '''
            reward_high[i] = reward_high[i-1] + eps_high[i-1]
            reward_low[i] = reward_low[i-1] + eps_low[i-1]

            reward_high[i] = (r_lowerbound_high < reward_high[i] < r_upperbound_high)*reward_high[i] + (r_lowerbound_high > reward_high[i])*(r_lowerbound_high+ eps_high[i-1]) + (r_upperbound_high < reward_high[i])*(r_upperbound_high - eps_high[i-1])
            reward_low[i] = (r_lowerbound_low < reward_low[i] < r_upperbound_low)*reward_low[i] + (r_lowerbound_low > reward_low[i])*(r_lowerbound_low+ eps_low[i-1]) + (r_upperbound_low < reward_low[i])*(r_upperbound_low - eps_low[i-1])

        pairs = np.zeros([length,2,3])
        pairs[:,0,0] = position_offsetH
        pairs[:,0,1] = reward_high
        pairs[:,0,2] = location_int

        pairs[:,1,0] = position_offsetL
        pairs[:,1,1] = reward_low
        pairs[:,1,2] = 1 - location_int

        return pairs

    @staticmethod
    def randomwalk_probabilistic_reward(length=1000,reward_high_prob=80,reward_low_prob=40,reward_high_span = 20, reward_low_span = 20,step_size_mean = 0, step_size_var = 1):
        pairs = colored_targets_with_randomwalk_reward(length=length,reward_high_prob=reward_high_prob,reward_low_prob=reward_low_prob,reward_high_span = reward_high_span, reward_low_span = reward_low_span,step_size_mean = step_size_mean, step_size_var = step_size_var)
        return pairs
Exemplo n.º 11
0
class ScreenSync(NIDAQSync):
    '''Adds a square in one corner that switches color with every flip.'''

    sync_position = {
        'TopLeft': (-1, 1),
        'TopRight': (1, 1),
        'BottomLeft': (-1, -1),
        'BottomRight': (1, -1)
    }
    sync_position_2D = {
        'TopLeft': (-1, -1),
        'TopRight': (1, -1),
        'BottomLeft': (-1, 1),
        'BottomRight': (1, 1)
    }
    sync_corner = traits.OptionsList(tuple(sync_position.keys()),
                                     desc="Position of sync square")
    sync_size = traits.Float(1, desc="Sync square size (cm)")
    sync_color_off = traits.Tuple((0., 0., 0., 1.),
                                  desc="Sync off color (R,G,B,A)")
    sync_color_on = traits.Tuple((1., 1., 1., 1.),
                                 desc="Sync on color (R,G,B,A)")
    sync_state_duration = 1  # How long to delay the start of the experiment (seconds)

    def __init__(self, *args, **kwargs):

        # Create a new "sync" state at the beginning of the experiment
        if isinstance(self.status, dict):
            self.status["sync"] = dict(start_experiment="wait",
                                       stoppable=False)
        else:
            from riglib.fsm.fsm import StateTransitions
            self.status.states["sync"] = StateTransitions(
                start_experiment="wait", stoppable=False)
        self.state = "sync"

        super().__init__(*args, **kwargs)
        self.sync_state = False
        if hasattr(self, 'is_pygame_display'):
            screen_center = np.divide(self.window_size, 2)
            sync_size_pix = self.sync_size * self.window_size[
                0] / self.screen_cm[0]
            sync_center = [sync_size_pix / 2, sync_size_pix / 2]
            from_center = np.multiply(self.sync_position_2D[self.sync_corner],
                                      np.subtract(screen_center, sync_center))
            top_left = screen_center + from_center - sync_center
            self.sync_rect = pygame.Rect(top_left, np.multiply(sync_center, 2))
        else:
            from_center = np.multiply(
                self.sync_position[self.sync_corner],
                np.subtract(self.screen_cm, self.sync_size))
            pos = np.array(
                [from_center[0] / 2, self.screen_dist, from_center[1] / 2])
            self.sync_square = VirtualRectangularTarget(
                target_width=self.sync_size,
                target_height=self.sync_size,
                target_color=self.sync_color_off,
                starting_pos=pos)
            # self.sync_square = VirtualCircularTarget(target_radius=self.sync_size, target_color=self.sync_color_off, starting_pos=pos)
            for model in self.sync_square.graphics_models:
                self.add_model(model)

    def screen_init(self):
        super().screen_init()
        if hasattr(self, 'is_pygame_display'):
            self.sync = pygame.Surface(self.window_size)
            self.sync.fill(TRANSPARENT)
            self.sync.set_colorkey(TRANSPARENT)

    def _draw_other(self):
        # For pygame display
        color = self.sync_color_on if self.sync_state else self.sync_color_off
        self.sync.fill(255 * np.array(color), rect=self.sync_rect)
        self.screen.blit(self.sync, (0, 0))

    def init(self):
        self.add_dtype('sync_square', bool, (1, ))
        super().init()

    def _while_sync(self):
        '''
        Deliberate "startup sequence":
            1. Send a clock pulse to denote the start of the FSM loop
            2. Turn off the clock and send a single, longer, impulse
                to enable measurement of the screen latency
            3. Turn the clock back on
        '''

        # Turn off the clock after the first cycle is synced
        if self.cycle_count == 1:
            self.sync_every_cycle = False

        # Send an impulse to measure latency halfway through the sync state
        key_cycle = int(self.fps * self.sync_state_duration / 2)
        impulse_duration = 5  # cycles, to make sure it appears on the screen
        if self.cycle_count == key_cycle:
            self.sync_every_cycle = True
        elif self.cycle_count == key_cycle + 1:
            self.sync_every_cycle = False
        elif self.cycle_count == key_cycle + impulse_duration:
            self.sync_every_cycle = True
        elif self.cycle_count == key_cycle + impulse_duration + 1:
            self.sync_every_cycle = False

    def _end_sync(self):
        self.sync_every_cycle = True

    def _test_start_experiment(self, ts):
        return ts > self.sync_state_duration

    def _cycle(self):
        super()._cycle()

        # Update the sync state
        if self.sync_every_cycle:
            self.sync_state = not self.sync_state
        self.task_data['sync_square'] = copy.deepcopy(self.sync_state)

        # For OpenGL display, update the graphics
        if not hasattr(self, 'is_pygame_display'):
            color = self.sync_color_on if self.sync_state else self.sync_color_off
            self.sync_square.cube.color = color
Exemplo n.º 12
0
class Optitrack(traits.HasTraits):
    '''
    Enable reading of raw motiontracker data from Optitrack system
    Requires the natnet library from https://github.com/leoscholl/python_natnet
    To be used as a feature with the ManualControl task for the time being. However,
    ideally this would be implemented as a decoder :)
    '''

    optitrack_feature = traits.OptionsList(("rigid body", "skeleton", "marker"))
    smooth_features = traits.Int(1, desc="How many features to average")
    scale = traits.Float(DEFAULT_SCALE, desc="Control scale factor")
    offset = traits.Array(value=DEFAULT_OFFSET, desc="Control offset")

    hidden_traits = ['optitrack_feature', 'smooth_features']

    def init(self):
        '''
        Secondary init function. See riglib.experiment.Experiment.init()
        Prior to starting the task, this 'init' sets up the DataSource for interacting with the 
        motion tracker system and registers the source with the SinkRegister so that the data gets saved to file as it is collected.
        '''

        # Start the natnet client and recording
        import natnet
        now = datetime.now()
        local_path = "C:/Users/Orsborn Lab/Documents"
        session = "OptiTrack/Session " + now.strftime("%Y-%m-%d")
        take = now.strftime("Take %Y-%m-%d %H:%M:%S")
        logger = Logger(take)
        try:
            client = natnet.Client.connect(logger=logger)
            if self.saveid is not None:
                take += " (%d)" % self.saveid
                client.set_session(os.path.join(local_path, session))
                client.set_take(take)
                self.filename = os.path.join(session, take + '.tak')
                client._send_command_and_wait("LiveMode")
                time.sleep(0.1)
                if client.start_recording():
                    self.optitrack_status = 'recording'
            else:
                self.optitrack_status = 'streaming'
        except natnet.DiscoveryError:
            self.optitrack_status = 'Optitrack couldn\'t be started, make sure Motive is open!'
            client = optitrack.SimulatedClient()
        self.client = client

        # Create a source to buffer the motion tracking data
        from riglib import source
        self.motiondata = source.DataSource(optitrack.make(optitrack.System, self.client, self.optitrack_feature, 1))

        # Save to the sink
        from riglib import sink
        sink_manager = sink.SinkManager.get_instance()
        sink_manager.register(self.motiondata)
        super().init()

    def run(self):
        '''
        Code to execute immediately prior to the beginning of the task FSM executing, or after the FSM has finished running. 
        See riglib.experiment.Experiment.run(). This 'run' method starts the motiondata source and stops it after the FSM has finished running
        '''
        if not self.optitrack_status in ['recording', 'streaming']:
            import io
            self.terminated_in_error = True
            self.termination_err = io.StringIO()
            self.termination_err.write(self.optitrack_status)
            self.termination_err.seek(0)
            self.state = None
            super().run()
        else:
            self.motiondata.start()
            try:
                super().run()
            finally:
                print("Stopping optitrack")
                self.client.stop_recording()
                self.motiondata.stop()

    def _start_None(self):
        '''
        Code to run before the 'None' state starts (i.e., the task stops)
        '''
        #self.client.stop_recording()
        self.motiondata.stop()
        super()._start_None()

    def join(self):
        '''
        See riglib.experiment.Experiment.join(). Re-join the motiondata source process before cleaning up the experiment thread
        '''
        print("Joining optitrack datasource")
        self.motiondata.join()
        super().join()

    def cleanup(self, database, saveid, **kwargs):
        '''
        Save the optitrack recorded file into the database
        '''
        super_result = super().cleanup(database, saveid, **kwargs)
        print("Saving optitrack file to database...")
        try:
            database.save_data(self.filename, "optitrack", saveid, False, False) # Make sure you actually have an "optitrack" system added!
        except Exception as e:
            print(e)
            return False
        print("...done.")
        return super_result

    def _get_manual_position(self):
        ''' Overridden method to get input coordinates based on motion data'''

        # Get data from optitrack datasource
        data = self.motiondata.get() # List of (list of features)
        if len(data) == 0: # Data is not being streamed
            return
        recent = data[-self.smooth_features:] # How many recent coordinates to average
        averaged = np.nanmean(recent, axis=0) # List of averaged features
        if np.isnan(averaged).any(): # No usable coords
            return
        return averaged*100 # convert meters to centimeters
class ScreenTargetCapture(TargetCapture, Window):
    """Concrete implementation of TargetCapture task where targets
    are acquired by "holding" a cursor in an on-screen target"""
    background = (0, 0, 0, 1)
    cursor_color = (.5, 0, .5, 1)

    plant_type = traits.OptionsList(*plantlist,
                                    desc='',
                                    bmi3d_input_options=list(plantlist.keys()))

    starting_pos = (5, 0, 5)

    target_color = (1, 0, 0, .5)

    cursor_visible = False  # Determines when to hide the cursor.
    no_data_count = 0  # Counter for number of missing data frames in a row
    scale_factor = 3.0  #scale factor for converting hand movement to screen movement (1cm hand movement = 3.5cm cursor movement)

    limit2d = 1

    sequence_generators = ['centerout_2D_discrete']

    is_bmi_seed = True
    _target_color = RED

    # Runtime settable traits
    reward_time = traits.Float(.5, desc="Length of juice reward")
    target_radius = traits.Float(2, desc="Radius of targets in cm")

    hold_time = traits.Float(.2, desc="Length of hold required at targets")
    hold_penalty_time = traits.Float(
        1, desc="Length of penalty time for target hold error")
    timeout_time = traits.Float(10, desc="Time allowed to go between targets")
    timeout_penalty_time = traits.Float(
        1, desc="Length of penalty time for timeout error")
    max_attempts = traits.Int(10,
                              desc='The number of attempts at a target before\
        skipping to the next one')

    plant_hide_rate = traits.Float(
        0.0,
        desc=
        'If the plant is visible, specifies a percentage of trials where it will be hidden'
    )
    plant_type_options = list(plantlist.keys())
    plant_type = traits.OptionsList(*plantlist,
                                    bmi3d_input_options=list(plantlist.keys()))
    plant_visible = traits.Bool(
        True,
        desc='Specifies whether entire plant is displayed or just endpoint')
    cursor_radius = traits.Float(.5, desc="Radius of cursor")

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cursor_visible = True

        # Initialize the plant
        if not hasattr(self, 'plant'):
            self.plant = plantlist[self.plant_type]
        self.plant_vis_prev = True

        # Add graphics models for the plant and targets to the window
        if hasattr(self.plant, 'graphics_models'):
            for model in self.plant.graphics_models:
                self.add_model(model)

        # Instantiate the targets
        instantiate_targets = kwargs.pop('instantiate_targets', True)
        if instantiate_targets:
            target1 = VirtualCircularTarget(target_radius=self.target_radius,
                                            target_color=self._target_color)
            target2 = VirtualCircularTarget(target_radius=self.target_radius,
                                            target_color=self._target_color)

            self.targets = [target1, target2]
            for target in self.targets:
                for model in target.graphics_models:
                    self.add_model(model)

        # Initialize target location variable
        self.target_location = np.array([0, 0, 0])

        # Declare any plant attributes which must be saved to the HDF file at the _cycle rate
        for attr in self.plant.hdf_attrs:
            self.add_dtype(*attr)

    def init(self):
        self.add_dtype('target', 'f8', (3, ))
        self.add_dtype('target_index', 'i', (1, ))
        super().init()

    def _cycle(self):
        '''
        Calls any update functions necessary and redraws screen. Runs 60x per second.
        '''
        self.task_data['target'] = self.target_location.copy()
        self.task_data['target_index'] = self.target_index

        ## Run graphics commands to show/hide the plant if the visibility has changed
        if self.plant_type != 'CursorPlant':
            if self.plant_visible != self.plant_vis_prev:
                self.plant_vis_prev = self.plant_visible
                self.plant.set_visibility(self.plant_visible)
                # self.show_object(self.plant, show=self.plant_visible)

        self.move_effector()

        ## Save plant status to HDF file
        plant_data = self.plant.get_data_to_save()
        for key in plant_data:
            self.task_data[key] = plant_data[key]

        super()._cycle()

    def move_effector(self):
        '''Move the end effector, if a robot or similar is being controlled'''
        pass

    def run(self):
        '''
        See experiment.Experiment.run for documentation.
        '''
        # Fire up the plant. For virtual/simulation plants, this does little/nothing.
        self.plant.start()
        try:
            super().run()
        finally:
            self.plant.stop()

    ##### HELPER AND UPDATE FUNCTIONS ####
    def update_cursor_visibility(self):
        ''' Update cursor visible flag to hide cursor if there has been no good data for more than 3 frames in a row'''
        prev = self.cursor_visible
        if self.no_data_count < 3:
            self.cursor_visible = True
            if prev != self.cursor_visible:
                self.show_object(self.cursor, show=True)
        else:
            self.cursor_visible = False
            if prev != self.cursor_visible:
                self.show_object(self.cursor, show=False)

    #### TEST FUNCTIONS ####
    def _test_enter_target(self, ts):
        '''
        return true if the distance between center of cursor and target is smaller than the cursor radius
        '''
        cursor_pos = self.plant.get_endpoint_pos()
        d = np.linalg.norm(cursor_pos - self.target_location)
        return d <= (self.target_radius - self.cursor_radius)

    def _test_leave_early(self, ts):
        '''
        return true if cursor moves outside the exit radius
        '''
        cursor_pos = self.plant.get_endpoint_pos()
        d = np.linalg.norm(cursor_pos - self.target_location)
        rad = self.target_radius - self.cursor_radius
        return d > rad

    #### STATE FUNCTIONS ####
    def _start_wait(self):
        super()._start_wait()
        # hide targets
        for target in self.targets:
            target.hide()

    def _start_target(self):
        super()._start_target()

        # move one of the two targets to the new target location
        target = self.targets[self.target_index % 2]
        target.move_to_position(self.target_location)
        target.cue_trial_start()

    def _start_hold(self):
        #make next target visible unless this is the final target in the trial
        idx = (self.target_index + 1)
        if idx < self.chain_length:
            target = self.targets[idx % 2]
            target.move_to_position(self.targs[idx])

    def _end_hold(self):
        # change current target color to green
        self.targets[self.target_index % 2].cue_trial_end_success()

    def _start_hold_penalty(self):
        super()._start_hold_penalty()
        # hide targets
        for target in self.targets:
            target.hide()

    def _start_timeout_penalty(self):
        super()._start_timeout_penalty()
        # hide targets
        for target in self.targets:
            target.hide()

    def _start_targ_transition(self):
        #hide targets
        for target in self.targets:
            target.hide()

    def _start_reward(self):
        self.targets[self.target_index % 2].show()

    #### Generator functions ####
    @staticmethod
    def centerout_2D_discrete(nblocks=100,
                              ntargets=8,
                              boundaries=(-18, 18, -12, 12),
                              distance=10):
        '''

        Generates a sequence of 2D (x and z) target pairs with the first target
        always at the origin.

        Parameters
        ----------
        length : int
            The number of target pairs in the sequence.
        boundaries: 6 element Tuple
            The limits of the allowed target locations (-x, x, -z, z)
        distance : float
            The distance in cm between the targets in a pair.

        Returns
        -------
        pairs : [nblocks*ntargets x 2 x 3] array of pairs of target locations


        '''

        # Choose a random sequence of points on the edge of a circle of radius
        # "distance"

        theta = []
        for i in range(nblocks):
            temp = np.arange(0, 2 * np.pi, 2 * np.pi / ntargets)
            np.random.shuffle(temp)
            theta = theta + [temp]
        theta = np.hstack(theta)

        x = distance * np.cos(theta)
        y = np.zeros(len(theta))
        z = distance * np.sin(theta)

        pairs = np.zeros([len(theta), 2, 3])
        pairs[:, 1, :] = np.vstack([x, y, z]).T

        return pairs
class FactorBMIBase(BMIResetting):

    #Choose task to use trials from (no assist)
    #TODO make offline function to quickly assess optimal number of factors
    sequence_generators = [
        'centerout_2D_discrete', 'generate_catch_trials', 'all_shar_trials'
    ]
    input_type_list = [
        'shared', 'private', 'shared_scaled', 'private_scaled', 'all',
        'all_scaled_by_shar', 'sc_shared+unsc_priv', 'sc_shared+sc_priv',
        'main_shared', 'main_sc_shared', 'main_sc_private',
        'main_sc_shar+unsc_priv', 'main_sc_shar+sc_priv', 'pca', 'split'
    ]  #, 'priv_shar_concat']
    input_type = traits.OptionsList(*input_type_list,
                                    bmi3d_input_options=input_type_list)

    def init(self):
        nU = self.decoder.n_units

        #Add datatypes for 1) trial type 2) input types:
        self.decoder.filt.FA_input_dict = {}

        self.input_types = [i + '_input'
                            for i in self.input_type_list] + ['task_input']

        for k in self.input_types:
            if k == 'split_input':
                nUn = self.decoder.trained_fa_dict['fa_main_shar_n_dim'] + nU
                #nUn = 2*nU
            elif k == 'task_input' and self.input_type == 'split':
                nUn = self.decoder.trained_fa_dict['fa_main_shar_n_dim'] + nU
                #nUn = 2*nU

            else:
                nUn = nU
            self.decoder.filt.FA_input_dict[k] = np.zeros((nUn, 1))
            self.decoder.filt.FA_input_dict[k][:] = np.nan
            self.add_dtype(k, 'f8', (nUn, 1))

        #Add datatype for 'trial-type':
        self.add_dtype('fa_input', np.str_, 16)
        super(FactorBMIBase, self).init()

    def init_decoder_state(self):
        try:
            fa_dict = self.decoder.trained_fa_dict
            self.decoder.filt.FA_kwargs = fa_dict

        except:
            raise Exception(
                'Must run riglib.train.add_fa_dict_to_decoder and resave with dbq.save'
            )

        #Check if isinstance of FAKalmanFilter or KalmanFilter
        from riglib.bmi.kfdecoder import KalmanFilter, FAKalmanFilter

        if isinstance(self.decoder.filt, KalmanFilter):
            self.decoder.filt.__class__ = FAKalmanFilter

        #Add FA elements to dict:
        print 'adding ', self.input_type, ' as input_type to FA decoder'
        import time
        time.sleep(2.)
        self.decoder.filt.FA_input = self.input_type
        super(FactorBMIBase, self).init_decoder_state()

    def _cycle(self):
        for k in self.input_types:
            try:
                self.task_data[k] = self.decoder.filt.FA_input_dict[k]
            except:
                print k, self.decoder.filt.FA_input_dict[
                    k].shape, self.task_data[k].shape
        self.task_data['fa_input'] = self.decoder.filt.FA_input

        super(FactorBMIBase, self)._cycle()

    def _parse_next_trial(self):
        if type(self.next_trial[1]) is str:
            self.targs = self.next_trial[0]
        else:
            self.targs = self.next_trial
        #print 'trial: ', self.decoder.filt.FA_input, self.targs

    ### FA param saving functions ###:
    def cleanup_hdf(self):
        '''
        Re-open the HDF file and save any extra task data kept in RAM
        '''
        super(FactorBMIBase, self).cleanup_hdf()
        try:
            self.write_FA_data_to_hdf_table(self.h5file.name,
                                            self.decoder.filt.FA_kwargs)
            print 'writing FA params to HDF file'
        except:
            print 'error in writing FA params to hdf file'
            import traceback
            traceback.print_exc()

    @staticmethod
    def write_FA_data_to_hdf_table(hdf_fname, FA_dict, ignore_none=False):

        import tables
        compfilt = tables.Filters(complevel=5, complib="zlib", shuffle=True)

        h5file = tables.openFile(hdf_fname, mode='a')
        fa_grp = h5file.createGroup(h5file.root, "fa_params",
                                    "Parameters for FA model used")

        for key in FA_dict:
            if isinstance(FA_dict[key], np.ndarray):
                h5file.createArray(fa_grp, key, FA_dict[key])
            else:
                try:
                    h5file.createArray(fa_grp, key, np.array([FA_dict[key]]))
                except:
                    print 'cannot save: ', key, 'from FA in hdf file'
        h5file.close()

    @classmethod
    def generate_FA_matrices(self,
                             training_task_entry,
                             plot=False,
                             hdf=None,
                             dec=None,
                             bin_spk=None):

        import utils.fa_decomp as pa
        if bin_spk is None:
            if training_task_entry is not None:
                from db import dbfunctions as dbfn
                te = dbfn.TaskEntry(training_task_entry)
                hdf = te.hdf
                dec = te.decoder

            bin_spk, targ_pos, targ_ix, z, zz = self.extract_trials_all(
                hdf, dec)

        #Zscore is in time x neurons
        zscore_X, mu = self.zscore_spks(bin_spk)

        # #Find optimal number of factors:
        LL, psv = pa.find_k_FA(zscore_X, iters=3, max_k=10, plot=False)

        #Np.nanmean:
        nan_ix = np.isnan(LL)
        samp = np.sum(nan_ix == False, axis=0)
        ll = np.nansum(LL, axis=0)
        LL_new = np.divide(ll, samp)

        num_factors = 1 + (np.argmax(LL_new))
        print 'optimal LL factors: ', num_factors

        FA = skdecomp.FactorAnalysis(n_components=num_factors)

        #Samples x features:
        FA.fit(zscore_X)

        #FA matrices:
        U = np.mat(FA.components_).T
        i = np.diag_indices(U.shape[0])
        Psi = np.mat(np.zeros((U.shape[0], U.shape[0])))
        Psi[i] = FA.noise_variance_
        A = U * U.T
        B = np.linalg.inv(U * U.T + Psi)
        mu_vect = np.array([mu[0, :]]).T  #Size = N x 1
        sharL = A * B

        #Calculate shared / priv scaling:
        bin_spk_tran = bin_spk.T
        mu_mat = np.tile(np.array([mu[0, :]]).T, (1, bin_spk_tran.shape[1]))
        demn = bin_spk_tran - mu_mat
        shared_bin_spk = (sharL * demn)
        priv_bin_spk = bin_spk_tran - mu_mat - shared_bin_spk

        #Scaling:
        eps = 1e-15
        x_var = np.var(np.mat(bin_spk_tran), axis=1) + eps
        pr_var = np.var(priv_bin_spk, axis=1) + eps
        sh_var = np.var(shared_bin_spk, axis=1) + eps

        priv_scalar = np.sqrt(np.divide(x_var, pr_var))
        shared_scalar = np.sqrt(np.divide(x_var, sh_var))

        if plot:
            tmp = np.diag(U.T * U)
            plt.plot(np.arange(1, num_factors + 1),
                     np.cumsum(tmp) / np.sum(tmp), '.-')
            plt.plot([0, num_factors + 1], [.9, .9], '-')

        #Get main shared space:
        u, s, v = np.linalg.svd(A)
        s_red = np.zeros_like(s)
        s_hd = np.zeros_like(s)

        ix = np.nonzero(np.cumsum(s**2) / float(np.sum(s**2)) > .90)[0]
        if len(ix) > 0:
            n_dim_main_shared = ix[0] + 1
        else:
            n_dim_main_shared = len(s)
        if n_dim_main_shared < 2:
            n_dim_main_shared = 2
        print "main shared: n_dim: ", n_dim_main_shared, np.cumsum(s) / float(
            np.sum(s))
        s_red[:n_dim_main_shared] = s[:n_dim_main_shared]
        s_hd[n_dim_main_shared:] = s[n_dim_main_shared:]

        main_shared_A = u * np.diag(s_red) * v
        hd_shared_A = u * np.diag(s_hd) * v
        main_shared_B = np.linalg.inv(main_shared_A + hd_shared_A + Psi)

        uut_psi_inv = main_shared_B.copy()
        u_svd = u[:, :n_dim_main_shared]

        main_sharL = main_shared_A * main_shared_B

        main_shar = main_sharL * demn
        main_shar_var = np.var(main_shar, axis=1) + eps
        main_shar_scal = np.sqrt(np.divide(x_var, main_shar_var))

        main_priv = demn - main_shar
        main_priv_var = np.var(main_priv, axis=1) + eps
        main_priv_scal = np.sqrt(np.divide(x_var, main_priv_var))

        # #Get PCA decomposition:
        #LL, ax = pa.FA_all_targ_ALLms(hdf, iters=2, max_k=20, PCA_instead=True)
        #num_PCs = 1+(np.argmax(np.mean(LL, axis=0)))

        # Main PCA space:
        # Get cov matrix:
        cov_pca = np.cov(zscore_X.T)
        eig_val, eig_vec = np.linalg.eig(cov_pca)

        tot_var = sum(eig_val)
        cum_var_exp = np.cumsum(
            [i / tot_var for i in sorted(eig_val, reverse=True)])
        n_PCs = np.nonzero(cum_var_exp > 0.9)[0][0] + 1

        proj_mat = eig_vec[:, :n_PCs]
        proj_trans = np.mat(proj_mat) * np.mat(proj_mat.T)

        #PC matrices:
        return dict(fa_sharL=sharL,
                    fa_mu=mu_vect,
                    fa_shar_var_sc=shared_scalar,
                    fa_priv_var_sc=priv_scalar,
                    U=U,
                    Psi=Psi,
                    training_task_entry=training_task_entry,
                    FA_iterated_power=FA.iterated_power,
                    FA_score=FA.score(zscore_X),
                    FA_LL=np.array(FA.loglike_),
                    fa_main_shared=main_sharL,
                    fa_main_shared_sc=main_shar_scal,
                    fa_main_private_sc=main_priv_scal,
                    fa_main_shar_n_dim=n_dim_main_shared,
                    sing_vals=s,
                    own_pc_trans=proj_trans,
                    FA_model=FA,
                    uut_psi_inv=uut_psi_inv,
                    u_svd=u_svd)

    @classmethod
    def zscore_spks(self, proc_spks):
        '''Assumes a time x units matrix'''
        mu = np.tile(np.mean(proc_spks, axis=0), (proc_spks.shape[0], 1))
        zscore_X = proc_spks - mu
        return zscore_X, mu

    @classmethod
    def extract_trials_all(self,
                           hdf,
                           dec,
                           neural_bins=100,
                           time_cutoff=None,
                           hdf_ix=False):
        '''
        Summary: method to extract all time points from trials
        Input param: hdf: task file input
        Input param: rew_ix: rows in the hdf file corresponding to reward times
        Input param: neural_bins: ms per bin
        Input param: time_cutoff: time in minutes, only extract trials before this time
        Input param: hdf_ix: bool, whether to return hdf row corresponding to time of decoder 
        update (and hence end of spike bin)

        Output param: bin_spk -- binned spikes in time x units
                      targ_i_all -- target location at each update
                      targ_ix -- target index 
                      trial_ix -- trial number
                      reach_time -- reach time for trial
                      hdf_ix -- end bin in units of hdf rows
        '''
        rew_ix = np.array([
            t[1] for it, t in enumerate(hdf.root.task_msgs[:])
            if t[0] == 'reward'
        ])

        if time_cutoff is not None:
            it_cutoff = time_cutoff * 60 * 60
        else:
            it_cutoff = len(hdf.root.task)
        #Get Go cue and
        go_ix = np.array([
            hdf.root.task_msgs[it - 3][1]
            for it, t in enumerate(hdf.root.task_msgs[:]) if t[0] == 'reward'
        ])
        go_ix = go_ix[go_ix < it_cutoff]
        rew_ix = rew_ix[go_ix < it_cutoff]

        targ_i_all = np.array([[-1, -1]])
        trial_ix_all = np.array([-1])
        reach_tm_all = np.array([-1])
        hdf_ix_all = np.array([-1])

        bin_spk = np.zeros((1, hdf.root.task[0]['spike_counts'].shape[0])) - 1
        drives_neurons = dec.drives_neurons
        drives_neurons_ix0 = np.nonzero(drives_neurons)[0][0]
        update_bmi_ix = np.nonzero(
            np.diff(
                np.squeeze(hdf.root.task[:]['internal_decoder_state']
                           [:, drives_neurons_ix0, 0])))[0] + 1

        for ig, (g, r) in enumerate(zip(go_ix, rew_ix)):
            spk_i = hdf.root.task[g:r]['spike_counts'][:, :, 0]

            #Sum spikes in neural_bins:
            bin_spk_i, nbins, hdf_ix_i = self._bin_spks(
                spk_i, g, r, update_bmi_ix)
            bin_spk = np.vstack((bin_spk, bin_spk_i))
            targ_i_all = np.vstack(
                (targ_i_all,
                 np.tile(hdf.root.task[g + 1]['target'][[0, 2]],
                         (bin_spk_i.shape[0], 1))))
            trial_ix_all = np.hstack(
                (trial_ix_all, np.zeros((bin_spk_i.shape[0])) + ig))
            reach_tm_all = np.hstack(
                (reach_tm_all, np.zeros(
                    (bin_spk_i.shape[0])) + ((r - g) * 1000. / 60.)))
            hdf_ix_all = np.hstack((hdf_ix_all, hdf_ix_i))

        targ_ix = self._get_target_ix(targ_i_all[1:, :])

        if hdf_ix:
            return bin_spk[1:, :], targ_i_all[1:, :], targ_ix, trial_ix_all[
                1:], reach_tm_all[1:], hdf_ix_all[1:]
        else:
            return bin_spk[1:, :], targ_i_all[
                1:, :], targ_ix, trial_ix_all[1:], reach_tm_all[1:]

    @classmethod
    def _get_target_ix(self, targ_pos):
        #Target Index:
        b = np.ascontiguousarray(targ_pos).view(
            np.dtype((np.void, targ_pos.dtype.itemsize * targ_pos.shape[1])))
        _, idx = np.unique(b, return_index=True)
        unique_targ = targ_pos[idx, :]

        #Order by theta:
        theta = np.arctan2(unique_targ[:, 1], unique_targ[:, 0])
        thet_i = np.argsort(theta)
        unique_targ = unique_targ[thet_i, :]

        targ_ix = np.zeros((targ_pos.shape[0]), )
        for ig, (x, y) in enumerate(targ_pos):
            targ_ix[ig] = np.nonzero(
                np.sum(targ_pos[ig, :] == unique_targ, axis=1) == 2)[0]
        return targ_ix

    @classmethod
    def _bin_spks(self, spk_i, g_ix, r_ix, update_bmi_ix):

        #Need to use 'update_bmi_ix' from ReDecoder to get bin edges correctly:
        trial_inds = np.arange(g_ix, r_ix + 1)
        end_bin = np.array([
            (j, i) for j, i in enumerate(trial_inds)
            if np.logical_and(i in update_bmi_ix, i >= (g_ix + 5))
        ])
        nbins = len(end_bin)
        bin_spk_i = np.zeros((nbins, spk_i.shape[1]))

        hdf_ix_i = []
        for ib, (i_ix, hdf_ix) in enumerate(end_bin):
            #Inclusive of EndBin
            bin_spk_i[ib, :] = np.sum(spk_i[i_ix - 5:i_ix + 1, :], axis=0)
            hdf_ix_i.append(hdf_ix)
        return bin_spk_i, nbins, np.array(hdf_ix_i)

    @staticmethod
    def generate_catch_trials(nblocks=5,
                              ntargets=8,
                              distance=10,
                              perc_shar=10,
                              perc_priv=10):
        '''
        Generates a sequence of 2D (x and z) target pairs with the first target
        always at the origin and a second field indicating the extractor type (full, shared, priv)

        1 shared / 1 private for 

        nblocks: multiples of 80 
        perc_shar, perc_priv: multiples of 10, please

        '''
        assert (not np.mod(perc_shar, 10)) and (not np.mod(perc_priv, 10))

        #Make blocks of 80 trials:
        theta = []
        for i in range(10):
            temp = np.arange(0, 2 * np.pi, 2 * np.pi / ntargets)
            np.random.shuffle(temp)
            theta = theta + [temp]
        theta = np.hstack(theta)

        #Each target has correct % of private and correct % of shared targets
        trial_type = np.empty(len(theta), dtype='S10')

        for i in temp:
            targ_ix = np.nonzero(theta == i)[0]
            trial_ix = np.arange(len(targ_ix))
            tmp_trial = np.array(['all'] * len(targ_ix), dtype='S10')

            n_trial_shar = np.floor(perc_shar / 100. * float(len(targ_ix)))
            n_trial_priv = np.floor(perc_priv / 100. * float(len(targ_ix)))

            tmp_trial[:int(n_trial_shar)] = ['shared']
            tmp_trial[int(n_trial_shar):int(n_trial_shar +
                                            n_trial_priv)] = ['private']
            np.random.shuffle(tmp_trial)
            trial_type[targ_ix] = tmp_trial

        #Make Target set:
        x = distance * np.cos(theta)
        y = np.zeros(len(theta))
        z = distance * np.sin(theta)

        pairs = np.zeros([len(theta), 2, 3])
        pairs[:, 1, :] = np.vstack([x, y, z]).T

        Pairs = np.tile(pairs, [nblocks, 1, 1])
        Trial_type = np.tile(trial_type, [nblocks])

        #Will yield a tuple where target location is in next_trial[0], trial_type is in next_trial[1]
        return zip(Pairs, Trial_type)

    @staticmethod
    def all_shar_trials(nblocks=5, ntargets=8, distance=10):
        '''
        Generates a sequence of 2D (x and z) target pairs with the first target
        always at the origin and a second field indicating the extractor type (always shared)
        '''
        #Make blocks of 80 trials:
        theta = []
        for i in range(10):
            temp = np.arange(0, 2 * np.pi, 2 * np.pi / ntargets)
            np.random.shuffle(temp)
            theta = theta + [temp]
        theta = np.hstack(theta)

        #Each target has correct % of private and correct % of shared targets
        trial_type = np.empty(len(theta), dtype='S10')
        trial_type[:] = 'shared'

        #Make Target set:
        x = distance * np.cos(theta)
        y = np.zeros(len(theta))
        z = distance * np.sin(theta)

        pairs = np.zeros([len(theta), 2, 3])
        pairs[:, 1, :] = np.vstack([x, y, z]).T

        Pairs = np.tile(pairs, [nblocks, 1, 1])
        Trial_type = np.tile(trial_type, [nblocks])

        #Will yield a tuple where target location is in next_trial[0], trial_type is in next_trial[1]
        return zip(Pairs, Trial_type)
class RatBMI(BMILoop, LogExperiment):
    status = dict(wait=dict(start_trial='feedback_on', stop=None),
                  feedback_on=dict(baseline_hit='periph_targets', stop=None),
                  periph_targets=dict(target_hit='check_reward',
                                      timeout='noise_burst',
                                      stop=None),
                  check_reward=dict(rewarded_target='reward',
                                    unrewarded_target='feedback_pause'),
                  feedback_pause=dict(end_feedback_pause='wait'),
                  reward=dict(reward_end='wait'),
                  noise_burst=dict(noise_burst_end='noise_burst_timeout'),
                  noise_burst_timeout=dict(noise_burst_timeout_end='wait'))

    #Flag for feedback on or not
    feedback = False
    prev_targ_hit = 't1'
    timeout_time = traits.Float(30.)
    noise_burst_time = traits.Float(3.)
    noise_burst_timeout_time = traits.Float(1.)
    reward_time = traits.Float(1., desc='reward time')
    #Frequency range:
    aud_freq_range = traits.Tuple((1000., 20000.))
    plant_type = traits.OptionsList(*plantlist,
                                    desc='',
                                    bmi3d_input_options=plantlist.keys())

    #Time to average over:
    nsteps = traits.Float(10.)
    feedback_pause = traits.Float(3.)

    def __init__(self, *args, **kwargs):
        super(RatBMI, self).__init__(*args, **kwargs)

        if hasattr(self, 'decoder'):
            print self.decoder
        else:
            self.decoder = kwargs['decoder']
        dec_params = dict(nsteps=self.nsteps, freq_lim=self.aud_freq_range)
        for k, (key, val) in enumerate(dec_params.items()):
            print key, val, self.decoder.filt.dec_params[key]
            assert self.decoder.filt.dec_params[key] == val
        self.decoder.filt.init_from_task(self.decoder.n_units, **dec_params)
        self.plant = plantlist[self.plant_type]

    def init(self, *args, **kwargs):
        self.add_dtype('cursor', 'f8', (2, ))
        self.add_dtype('freq', 'f8', (2, ))
        super(RatBMI, self).init()
        self.decoder.count_max = self.feature_accumulator.count_max

    def _cycle(self):
        self.rat_cursor = self.decoder.filt.get_mean()
        self.task_data['cursor'] = self.rat_cursor
        self.task_data['freq'] = self.decoder.filt.F
        self.decoder.cnt = self.feature_accumulator.count
        self.decoder.feedback = self.feedback
        super(RatBMI, self)._cycle()

    # def move_plant(self):
    #     if self.feature_accumulator.count == self.feature_accumulator.count_max:
    #         print 'self.plant.drive from task.py'
    #         self.plant.drive(self.decoder)
    def _start_wait(self):
        return True

    def _test_start_trial(self, ts):
        return True

    def _test_rewarded_target(self, ts):
        if self.prev_targ_hit == 't1':
            return False
        elif self.prev_targ_hit == 't2':
            return True

    def _test_unrewarded_target(self, ts):
        if self.prev_targ_hit == 't1':
            return True
        elif self.prev_targ_hit == 't2':
            return False

    def _start_feedback_pause(self):
        self.feedback = False

    def _test_end_feedback_pause(self, ts):
        return ts > self.feedback_pause

    def _start_reward(self):
        print 'reward!'

    def _start_feedback_on(self):
        self.feedback = True

    def _test_baseline_hit(self, ts):
        if self.prev_targ_hit == 't1':
            #Must go below baseline:
            return self.rat_cursor <= self.decoder.filt.mid
        elif self.prev_targ_hit == 't2':
            #Must rise above baseline:
            return self.rat_cursor >= self.decoder.filt.mid
        else:
            return False

    def _test_target_hit(self, ts):
        if self.rat_cursor >= self.decoder.filt.t1:
            self.prev_targ_hit = 't1'
            self.feedback = False
            return True
        elif self.rat_cursor <= self.decoder.filt.t2:
            self.prev_targ_hit = 't2'
            self.feedback = False
            return True
        else:
            return False

    def _test_timeout(self, ts):
        return ts > self.timeout_time

    def _test_noise_burst_end(self, ts):
        return ts > self.noise_burst_time

    def _test_noise_burst_timeout_end(self, ts):
        return ts > self.noise_burst_timeout_time

    def _start_noise_burst(self):
        self.feedback = False
        self.plant.play_white_noise()

    def move_plant(self):
        super(RatBMI, self).move_plant()

    def get_current_assist_level(self):
        return 0.
Exemplo n.º 16
0
class ManualControlMixin(traits.HasTraits):
    '''Target capture task where the subject operates a joystick
    to control a cursor. Targets are captured by having the cursor
    dwell in the screen target for the allotted time'''

    # Settable Traits
    wait_time = traits.Float(2., desc="Time between successful trials")
    velocity_control = traits.Bool(False, desc="Position or velocity control")
    random_rewards = traits.Bool(False, desc="Add randomness to reward")
    rotation = traits.OptionsList(*rotations, desc="Control rotation matrix", bmi3d_input_options=list(rotations.keys()))
    scale = traits.Float(1.0, desc="Control scale factor")
    offset = traits.Array(value=[0,0,0], desc="Control offset")
    is_bmi_seed = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.current_pt=np.zeros([3]) #keep track of current pt
        self.last_pt=np.zeros([3]) #keep track of last pt to calc. velocity
        self.no_data_count = 0
        self.reportstats['Input quality'] = "100 %"
        if self.random_rewards:
            self.reward_time_base = self.reward_time

    def init(self):
        self.add_dtype('manual_input', 'f8', (3,))
        super().init()

    def _test_start_trial(self, ts):
        return ts > self.wait_time and not self.pause

    def _test_trial_complete(self, ts):
        if self.target_index==self.chain_length-1 :
            if self.random_rewards:
                if not self.rand_reward_set_flag: #reward time has not been set for this iteration
                    self.reward_time = np.max([2*(np.random.rand()-0.5) + self.reward_time_base, self.reward_time_base/2]) #set randomly with min of base / 2
                    self.rand_reward_set_flag =1
                    #print self.reward_time, self.rand_reward_set_flag
            return self.target_index==self.chain_length-1

    def _test_reward_end(self, ts):
        #When finished reward, reset flag.
        if self.random_rewards:
            if ts > self.reward_time:
                self.rand_reward_set_flag = 0
                #print self.reward_time, self.rand_reward_set_flag, ts
        return ts > self.reward_time

    def _transform_coords(self, coords):
        ''' 
        Returns transformed coordinates based on rotation, offset, and scale traits
        '''
        offset = np.array(
            [[1, 0, 0, 0], 
            [0, 1, 0, 0], 
            [0, 0, 1, 0], 
            [self.offset[0], self.offset[1], self.offset[2], 1]]
        )
        scale = np.array(
            [[self.scale, 0, 0, 0], 
            [0, self.scale, 0, 0], 
            [0, 0, self.scale, 0], 
            [0, 0, 0, 1]]
        )
        old = np.concatenate((np.reshape(coords, -1), [1]))
        new = np.linalg.multi_dot((old, offset, scale, rotations[self.rotation]))
        return new[0:3]

    def _get_manual_position(self):
        '''
        Fetches joystick position
        '''
        if not hasattr(self, 'joystick'):
            return
        pt = self.joystick.get()
        if len(pt) == 0:
            return

        pt = pt[-1] # Use only the latest coordinate

        if len(pt) == 2:
            pt = np.concatenate((np.reshape(pt, -1), [0]))

        return [pt]

    def move_effector(self):
        ''' 
        Sets the 3D coordinates of the cursor. For manual control, uses
        motiontracker / joystick / mouse data. If no data available, returns None
        '''

        # Get raw input and save it as task data
        raw_coords = self._get_manual_position() # array of [3x1] arrays
        if raw_coords is None or len(raw_coords) < 1:
            self.no_data_count += 1
            self.update_report_stats()
            self.task_data['manual_input'] = np.empty((3,))
            return

        self.task_data['manual_input'] = raw_coords.copy()

        # Transform coordinates
        coords = self._transform_coords(raw_coords)
        if self.limit2d:
            coords[1] = 0

        # Set cursor position
        if not self.velocity_control:
            self.current_pt = coords
        else:
            epsilon = 2*(10**-2) # Define epsilon to stabilize cursor movement
            if sum((coords)**2) > epsilon:

                # Add the velocity (units/s) to the position (units)
                self.current_pt = coords / self.fps + self.last_pt
            else:
                self.current_pt = self.last_pt

        self.plant.set_endpoint_pos(self.current_pt)
        self.last_pt = self.current_pt.copy()

    def update_report_stats(self):
        super().update_report_stats()
        quality = 1 - self.no_data_count / max(1, self.cycle_count)
        self.reportstats['Input quality'] = "{} %".format(int(100*quality))

    @classmethod
    def get_desc(cls, params, log_summary):
        duration = round(log_summary['runtime'] / 60, 1)
        return "{}/{} succesful trials in {} min".format(
            log_summary['n_success_trials'], log_summary['n_trials'], duration)
class BMIControlMulti(BMILoop, LinearlyDecreasingAssist,
                      manualcontrolmultitasks.ManualControlMulti):
    '''
    Target capture task with cursor position controlled by BMI output.
    Cursor movement can be assisted toward target by setting assist_level > 0.
    '''

    background = (.5, .5, .5, 1)  # Set the screen background color to grey
    reset = traits.Int(
        0, desc='reset the decoder state to the starting configuration')

    ordered_traits = [
        'session_length', 'assist_level', 'assist_level_time', 'reward_time',
        'timeout_time', 'timeout_penalty_time'
    ]
    exclude_parent_traits = ['marker_count', 'marker_num', 'goal_cache_block']

    static_states = []  # states in which the decoder is not run
    hidden_traits = [
        'arm_hide_rate', 'arm_visible', 'hold_penalty_time', 'rand_start',
        'reset', 'target_radius', 'window_size'
    ]

    is_bmi_seed = False

    cursor_color_adjust = traits.OptionsList(
        *target_colors.keys(), bmi3d_input_options=target_colors.keys())

    def __init__(self, *args, **kwargs):
        super(BMIControlMulti, self).__init__(*args, **kwargs)

    def init(self, *args, **kwargs):
        sph = self.plant.graphics_models[0]
        sph.color = target_colors[self.cursor_color_adjust]
        sph.radius = self.cursor_radius
        self.plant.cursor_radius = self.cursor_radius
        self.plant.cursor.radius = self.cursor_radius
        super(BMIControlMulti, self).init(*args, **kwargs)

    def move_effector(self, *args, **kwargs):
        pass

    def create_assister(self):
        # Create the appropriate type of assister object
        start_level, end_level = self.assist_level
        kwargs = dict(decoder_binlen=self.decoder.binlen,
                      target_radius=self.target_radius)
        if hasattr(self, 'assist_speed'):
            kwargs['assist_speed'] = self.assist_speed

        from db import namelist

        if self.decoder.ssm == namelist.endpt_2D_state_space and isinstance(
                self.decoder, ppfdecoder.PPFDecoder):
            self.assister = OFCEndpointAssister()
        elif self.decoder.ssm == namelist.endpt_2D_state_space:
            self.assister = SimpleEndpointAssister(**kwargs)
        elif (self.decoder.ssm == namelist.tentacle_2D_state_space) or (
                self.decoder.ssm == namelist.joint_2D_state_space):
            # kin_chain = self.plant.kin_chain
            # A, B, W = self.decoder.ssm.get_ssm_matrices(update_rate=self.decoder.binlen)
            # Q = np.mat(np.diag(np.hstack([kin_chain.link_lengths, np.zeros_like(kin_chain.link_lengths), 0])))
            # R = 10000*np.mat(np.eye(B.shape[1]))

            # fb_ctrl = LQRController(A, B, Q, R)
            # self.assister = FeedbackControllerAssist(fb_ctrl, style='additive')
            self.assister = TentacleAssist(ssm=self.decoder.ssm,
                                           kin_chain=self.plant.kin_chain,
                                           update_rate=self.decoder.binlen)
        else:
            raise NotImplementedError(
                "Cannot assist for this type of statespace: %r" %
                self.decoder.ssm)

        print self.assister

    def create_goal_calculator(self):
        from db import namelist
        if self.decoder.ssm == namelist.endpt_2D_state_space:
            self.goal_calculator = goal_calculators.ZeroVelocityGoal(
                self.decoder.ssm)
        elif self.decoder.ssm == namelist.joint_2D_state_space:
            self.goal_calculator = goal_calculators.PlanarMultiLinkJointGoal(
                self.decoder.ssm,
                self.plant.base_loc,
                self.plant.kin_chain,
                multiproc=False,
                init_resp=None)
        elif self.decoder.ssm == namelist.tentacle_2D_state_space:
            shoulder_anchor = self.plant.base_loc
            chain = self.plant.kin_chain
            q_start = self.plant.get_intrinsic_coordinates()
            x_init = np.hstack([q_start, np.zeros_like(q_start), 1])
            x_init = np.mat(x_init).reshape(-1, 1)

            cached = True

            if cached:
                goal_calc_class = goal_calculators.PlanarMultiLinkJointGoalCached
                multiproc = False
            else:
                goal_calc_class = goal_calculators.PlanarMultiLinkJointGoal
                multiproc = True

            self.goal_calculator = goal_calc_class(
                namelist.tentacle_2D_state_space,
                shoulder_anchor,
                chain,
                multiproc=multiproc,
                init_resp=x_init)
        else:
            raise ValueError("Unrecognized decoder state space!")

    def get_target_BMI_state(self, *args):
        '''
        Run the goal calculator to determine the target state of the task
        '''
        if isinstance(self.goal_calculator,
                      goal_calculators.PlanarMultiLinkJointGoalCached):
            task_eps = np.inf
        else:
            task_eps = 0.5
        ik_eps = task_eps / 10
        data, solution_updated = self.goal_calculator(
            self.target_location,
            verbose=False,
            n_particles=500,
            eps=ik_eps,
            n_iter=10,
            q_start=self.plant.get_intrinsic_coordinates())
        target_state, error = data

        if isinstance(self.goal_calculator,
                      goal_calculators.PlanarMultiLinkJointGoal
                      ) and error > task_eps and solution_updated:
            self.goal_calculator.reset()

        return np.array(target_state).reshape(-1, 1)

    def _end_timeout_penalty(self):
        if self.reset:
            self.decoder.filt.state.mean = self.init_decoder_mean
            self.hdf.sendMsg("reset")

    def move_effector(self):
        pass
Exemplo n.º 18
0
class ScreenTargetCapture(TargetCapture, Window):
    """Concrete implementation of TargetCapture task where targets
    are acquired by "holding" a cursor in an on-screen target"""

    limit2d = 1

    sequence_generators = [
        'out_2D',
        'centerout_2D',
        'centeroutback_2D',
        'rand_target_chain_2D',
        'rand_target_chain_3D',
    ]

    hidden_traits = [
        'cursor_color', 'target_color', 'cursor_bounds', 'cursor_radius',
        'plant_hide_rate', 'starting_pos'
    ]

    is_bmi_seed = True

    # Runtime settable traits
    target_radius = traits.Float(2, desc="Radius of targets in cm")
    target_color = traits.OptionsList("yellow",
                                      *target_colors,
                                      desc="Color of the target",
                                      bmi3d_input_options=list(
                                          target_colors.keys()))
    plant_hide_rate = traits.Float(
        0.0,
        desc=
        'If the plant is visible, specifies a percentage of trials where it will be hidden'
    )
    plant_type = traits.OptionsList(*plantlist,
                                    bmi3d_input_options=list(plantlist.keys()))
    plant_visible = traits.Bool(
        True,
        desc='Specifies whether entire plant is displayed or just endpoint')
    cursor_radius = traits.Float(.5, desc='Radius of cursor in cm')
    cursor_color = traits.OptionsList("pink",
                                      *target_colors,
                                      desc='Color of cursor endpoint',
                                      bmi3d_input_options=list(
                                          target_colors.keys()))
    cursor_bounds = traits.Tuple(
        (-10., 10., 0., 0., -10., 10.),
        desc='(x min, x max, y min, y max, z min, z max)')
    starting_pos = traits.Tuple((5., 0., 5.),
                                desc='Where to initialize the cursor')

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Initialize the plant
        if not hasattr(self, 'plant'):
            self.plant = plantlist[self.plant_type]
        self.plant.set_endpoint_pos(np.array(self.starting_pos))
        self.plant.set_bounds(np.array(self.cursor_bounds))
        self.plant.set_color(target_colors[self.cursor_color])
        self.plant.set_cursor_radius(self.cursor_radius)
        self.plant_vis_prev = True
        self.cursor_vis_prev = True

        # Add graphics models for the plant and targets to the window
        if hasattr(self.plant, 'graphics_models'):
            for model in self.plant.graphics_models:
                self.add_model(model)

        # Instantiate the targets
        instantiate_targets = kwargs.pop('instantiate_targets', True)
        if instantiate_targets:

            # Need two targets to have the ability for delayed holds
            target1 = VirtualCircularTarget(
                target_radius=self.target_radius,
                target_color=target_colors[self.target_color])
            target2 = VirtualCircularTarget(
                target_radius=self.target_radius,
                target_color=target_colors[self.target_color])

            self.targets = [target1, target2]

        # Declare any plant attributes which must be saved to the HDF file at the _cycle rate
        for attr in self.plant.hdf_attrs:
            self.add_dtype(*attr)

    def init(self):
        self.add_dtype('trial', 'u4', (1, ))
        self.add_dtype('plant_visible', '?', (1, ))
        super().init()

    def _cycle(self):
        '''
        Calls any update functions necessary and redraws screen
        '''
        self.move_effector()

        ## Run graphics commands to show/hide the plant if the visibility has changed
        self.update_plant_visibility()
        self.task_data['plant_visible'] = self.plant_visible

        ## Save plant status to HDF file
        plant_data = self.plant.get_data_to_save()
        for key in plant_data:
            self.task_data[key] = plant_data[key]

        # Update the trial index
        self.task_data['trial'] = self.calc_trial_num()

        super()._cycle()

    def move_effector(self):
        '''Move the end effector, if a robot or similar is being controlled'''
        pass

    def run(self):
        '''
        See experiment.Experiment.run for documentation.
        '''
        # Fire up the plant. For virtual/simulation plants, this does little/nothing.
        self.plant.start()

        # Include some cleanup in case the parent class has errors
        try:
            super().run()
        finally:
            self.plant.stop()

    ##### HELPER AND UPDATE FUNCTIONS ####
    def update_plant_visibility(self):
        ''' Update plant visibility'''
        if self.plant_visible != self.plant_vis_prev:
            self.plant_vis_prev = self.plant_visible
            self.plant.set_visibility(self.plant_visible)

    #### TEST FUNCTIONS ####
    def _test_enter_target(self, ts):
        '''
        return true if the distance between center of cursor and target is smaller than the cursor radius
        '''
        cursor_pos = self.plant.get_endpoint_pos()
        d = np.linalg.norm(cursor_pos - self.targs[self.target_index])
        return d <= (self.target_radius - self.cursor_radius)

    def _test_leave_target(self, ts):
        '''
        return true if cursor moves outside the exit radius
        '''
        cursor_pos = self.plant.get_endpoint_pos()
        d = np.linalg.norm(cursor_pos - self.targs[self.target_index])
        rad = self.target_radius - self.cursor_radius
        return d > rad

    #### STATE FUNCTIONS ####
    def _start_wait(self):
        super()._start_wait()

        if self.calc_trial_num() == 0:

            # Instantiate the targets here so they don't show up in any states that might come before "wait"
            for target in self.targets:
                for model in target.graphics_models:
                    self.add_model(model)
                    target.hide()

    def _start_target(self):
        super()._start_target()

        # Show target if it is hidden (this is the first target, or previous state was a penalty)
        target = self.targets[self.target_index % 2]
        if self.target_index == 0:
            target.move_to_position(self.targs[self.target_index])
            target.show()
            self.sync_event('TARGET_ON', self.gen_indices[self.target_index])

    def _start_hold(self):
        super()._start_hold()
        self.sync_event('CURSOR_ENTER_TARGET',
                        self.gen_indices[self.target_index])

    def _start_delay(self):
        super()._start_delay()

        # Make next target visible unless this is the final target in the trial
        next_idx = (self.target_index + 1)
        if next_idx < self.chain_length:
            target = self.targets[next_idx % 2]
            target.move_to_position(self.targs[next_idx])
            target.show()
            self.sync_event('TARGET_ON', self.gen_indices[next_idx])
        else:
            # This delay state should only last 1 cycle, don't sync anything
            pass

    def _start_targ_transition(self):
        super()._start_targ_transition()
        if self.target_index == -1:

            # Came from a penalty state
            pass
        elif self.target_index + 1 < self.chain_length:

            # Hide the current target if there are more
            self.targets[self.target_index % 2].hide()
            self.sync_event('TARGET_OFF', self.gen_indices[self.target_index])

    def _start_hold_penalty(self):
        self.sync_event('HOLD_PENALTY')
        super()._start_hold_penalty()
        # Hide targets
        for target in self.targets:
            target.hide()
            target.reset()

    def _end_hold_penalty(self):
        super()._end_hold_penalty()
        self.sync_event('TRIAL_END')

    def _start_delay_penalty(self):
        self.sync_event('DELAY_PENALTY')
        super()._start_delay_penalty()
        # Hide targets
        for target in self.targets:
            target.hide()
            target.reset()

    def _end_delay_penalty(self):
        super()._end_delay_penalty()
        self.sync_event('TRIAL_END')

    def _start_timeout_penalty(self):
        self.sync_event('TIMEOUT_PENALTY')
        super()._start_timeout_penalty()
        # Hide targets
        for target in self.targets:
            target.hide()
            target.reset()

    def _end_timeout_penalty(self):
        super()._end_timeout_penalty()
        self.sync_event('TRIAL_END')

    def _start_reward(self):
        self.targets[self.target_index % 2].cue_trial_end_success()
        self.sync_event('REWARD')

    def _end_reward(self):
        super()._end_reward()
        self.sync_event('TRIAL_END')

        # Hide targets
        for target in self.targets:
            target.hide()
            target.reset()

    #### Generator functions ####
    '''
    Note to self: because of the way these get into the database, the parameters don't
    have human-readable descriptions like the other traits. So it is useful to define
    the descriptions elsewhere, in models.py under Generator.to_json().

    Ideally someone should take the time to reimplement generators as their own classes
    rather than static methods that belong to a task.
    '''

    @staticmethod
    def static(pos=(0, 0, 0), ntrials=0):
        '''Single location, finite (ntrials!=0) or infinite (ntrials==0)'''
        if ntrials == 0:
            while True:
                yield [0], np.array(pos)
        else:
            for _ in range(ntrials):
                yield [0], np.array(pos)

    @staticmethod
    def out_2D(nblocks=100, ntargets=8, distance=10, origin=(0, 0, 0)):
        '''
        Generates a sequence of 2D (x and z) targets at a given distance from the origin

        Parameters
        ----------
        nblocks : int
            The number of ntarget pairs in the sequence.
        ntargets : int
            The number of equally spaced targets
        distance : float
            The distance in cm between the center and peripheral targets.
        origin : 3-tuple
            Location of the central targets around which the peripheral targets span

        Returns
        -------
        [nblocks*ntargets x 1] array of tuples containing trial indices and [1 x 3] target coordinates

        '''
        rng = np.random.default_rng()
        for _ in range(nblocks):
            order = np.arange(ntargets) + 1  # target indices, starting from 1
            rng.shuffle(order)
            for t in range(ntargets):
                idx = order[t]
                theta = 2 * np.pi * idx / ntargets
                pos = np.array(
                    [distance * np.cos(theta), 0, distance * np.sin(theta)]).T
                yield [idx], [pos + origin]

    @staticmethod
    def centerout_2D(nblocks=100, ntargets=8, distance=10, origin=(0, 0, 0)):
        '''
        Pairs of central targets at the origin and peripheral targets centered around the origin

        Returns
        -------
        [nblocks*ntargets x 1] array of tuples containing trial indices and [2 x 3] target coordinates
        '''
        gen = ScreenTargetCapture.out_2D(nblocks, ntargets, distance, origin)
        for _ in range(nblocks * ntargets):
            idx, pos = next(gen)
            targs = np.zeros([2, 3]) + origin
            targs[1, :] = pos[0]
            indices = np.zeros([2, 1])
            indices[1] = idx
            yield indices, targs

    @staticmethod
    def centeroutback_2D(nblocks=100,
                         ntargets=8,
                         distance=10,
                         origin=(0, 0, 0)):
        '''
        Triplets of central targets, peripheral targets, and central targets

        Returns
        -------
        [nblocks*ntargets x 1] array of tuples containing trial indices and [3 x 3] target coordinates
        '''
        gen = ScreenTargetCapture.out_2D(nblocks, ntargets, distance, origin)
        for _ in range(nblocks * ntargets):
            idx, pos = next(gen)
            targs = np.zeros([3, 3]) + origin
            targs[1, :] = pos[0]
            indices = np.zeros([3, 1])
            indices[1] = idx
            yield indices, targs

    @staticmethod
    def rand_target_chain_2D(ntrials=100,
                             chain_length=1,
                             boundaries=(-12, 12, -12, 12)):
        '''
        Generates a sequence of 2D (x and z) target pairs.

        Parameters
        ----------
        ntrials : int
            The number of target chains in the sequence.
        chain_length : int
            The number of targets in each chain
        boundaries: 4 element Tuple
            The limits of the allowed target locations (-x, x, -z, z)

        Returns
        -------
        [ntrials x chain_length x 3] array of target coordinates
        '''
        rng = np.random.default_rng()
        idx = 0
        for t in range(ntrials):

            # Choose a random sequence of points within the boundaries
            pts = rng.uniform(size=(chain_length, 3)) * (
                (boundaries[1] - boundaries[0]), 0,
                (boundaries[3] - boundaries[2]))
            pts = pts + (boundaries[0], 0, boundaries[2])
            yield idx + np.arange(chain_length), pts
            idx += chain_length

    @staticmethod
    def rand_target_chain_3D(ntrials=100,
                             chain_length=1,
                             boundaries=(-12, 12, -10, 10, -12, 12)):
        '''
        Generates a sequence of 3D target pairs.
        Parameters
        ----------
        ntrials : int
            The number of target chains in the sequence.
        chain_length : int
            The number of targets in each chain
        boundaries: 6 element Tuple
            The limits of the allowed target locations (-x, x, -y, y, -z, z)

        Returns
        -------
        [ntrials x chain_length x 3] array of target coordinates
        '''
        rng = np.random.default_rng()
        idx = 0
        for t in range(ntrials):

            # Choose a random sequence of points within the boundaries
            pts = rng.uniform(size=(chain_length, 3)) * (
                (boundaries[1] - boundaries[0]),
                (boundaries[3] - boundaries[2]),
                (boundaries[5] - boundaries[4]))
            pts = pts + (boundaries[0], boundaries[2], boundaries[4])
            yield idx + np.arange(chain_length), pts
            idx += chain_length