コード例 #1
0
class RedGreen(Sequence, Pygame):
    status = dict(
        wait = dict(start_trial="pretrial", premature="penalty", stop=None),
        pretrial = dict(go="trial", premature="penalty"),
        trial = dict(correct="reward", timeout="penalty"),
        reward = dict(post_reward="wait"),
        penalty = dict(post_penalty="wait"),
    )

    colors = traits.Array(shape=(2, 3), value=[[255,0,0],[0,255,0]],
        desc="Tuple of colors (c1, c2) where c* = [r,g,b] between 0 and 1")
    dot_radius = traits.Int(100, desc='dot size')
    delay_range = traits.Tuple((0.5, 5.), 
        desc='delay before switching to second color will be drawn from uniform distribution in this range')

    def _while_pretrial(self):
        import pygame
        self.surf.fill(self.background)
        right = [self.next_trial[0] + 1920, self.next_trial[1]]
        ts = time.time() - self.start_time
        dotsize = (init_dot - self.dot_radius) * (shrinklen - min(ts, shrinklen)) + self.dot_radius
        if (np.mod(np.round(ts*1000),freq) < freq/2):
            pygame.draw.circle(self.surf, self.colors[0], self.next_trial, int(dotsize))
            pygame.draw.circle(self.surf, self.colors[0], right, int(dotsize))
        self.flip_wait()
    
    def _while_trial(self):
        import pygame
        self.surf.fill(self.background)
        right = [self.next_trial[0] + 1920, self.next_trial[1]]
        ts = time.time() - self.start_time
        if (np.mod(np.round(ts*1000),freq) < freq/2):
            pygame.draw.circle(self.surf, self.colors[1], self.next_trial, self.dot_radius)
            pygame.draw.circle(self.surf, self.colors[1], right, self.dot_radius)
        self.flip_wait()
    
    def _start_pretrial(self):
        self._wait_time = np.random.rand()*abs(self.delay_range[1]-self.delay_range[0]) + self.delay_range[0]
    
    def _test_correct(self, ts):
        return self.event is not None
    
    def _test_go(self, ts):
        return ts > self._wait_time + shrinklen

    def _test_premature(self, ts):
        return self.event is not None
コード例 #2
0
class SimulatedEyeData(EyeData):
    '''Simulate an eyetracking system using a series of fixations, with saccades interpolated'''
    fixations = traits.Array(value=[(0, 0), (-0.6, 0.3), (0.6, 0.3)],
                             desc="Location of fixation points")
    fixation_len = traits.Float(0.5, desc="Length of a fixation")

    @property
    def eye_source(self):
        '''
        Docstring

        Parameters
        ----------

        Returns
        -------
        '''
        from riglib import eyetracker
        return eyetracker.Simulate, dict(fixations=fixations,
                                         fixation_len=fixation_len)
コード例 #3
0
class SimulatedEyeData(EyeData):
    '''Simulate an eyetracking system using a series of fixations, with saccades interpolated'''
    fixations = traits.Array(value=[(0, 0), (-0.6, 0.3), (0.6, 0.3)],
                             desc="Location of fixation points")
    fixation_len = traits.Float(0.5, desc="Length of a fixation")

    @property
    def eye_source(self):
        '''
        Docstring

        Parameters
        ----------

        Returns
        -------
        '''
        from riglib import eyetracker
        return eyetracker.Simulate, dict(fixations=self.fixations)

    def _cycle(self):
        '''
        Docstring
        basically, extract the data and do something with it


        Parameters
        ----------

        Returns
        -------
        '''
        #retrieve data
        data_temp = self.eyedata.get()

        #send the data to sinks
        if data_temp is not None:
            self.sinks.send(self.eyedata.name, data_temp)

        super(SimulatedEyeData, self)._cycle()
コード例 #4
0
class ManualControlMixin(traits.HasTraits):
    '''Target capture task where the subject operates a joystick
    to control a cursor. Targets are captured by having the cursor
    dwell in the screen target for the allotted time'''

    # Settable Traits
    wait_time = traits.Float(2., desc="Time between successful trials")
    velocity_control = traits.Bool(False, desc="Position or velocity control")
    random_rewards = traits.Bool(False, desc="Add randomness to reward")
    rotation = traits.OptionsList(*rotations, desc="Control rotation matrix", bmi3d_input_options=list(rotations.keys()))
    scale = traits.Float(1.0, desc="Control scale factor")
    offset = traits.Array(value=[0,0,0], desc="Control offset")
    is_bmi_seed = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.current_pt=np.zeros([3]) #keep track of current pt
        self.last_pt=np.zeros([3]) #keep track of last pt to calc. velocity
        self.no_data_count = 0
        self.reportstats['Input quality'] = "100 %"
        if self.random_rewards:
            self.reward_time_base = self.reward_time

    def init(self):
        self.add_dtype('manual_input', 'f8', (3,))
        super().init()

    def _test_start_trial(self, ts):
        return ts > self.wait_time and not self.pause

    def _test_trial_complete(self, ts):
        if self.target_index==self.chain_length-1 :
            if self.random_rewards:
                if not self.rand_reward_set_flag: #reward time has not been set for this iteration
                    self.reward_time = np.max([2*(np.random.rand()-0.5) + self.reward_time_base, self.reward_time_base/2]) #set randomly with min of base / 2
                    self.rand_reward_set_flag =1
                    #print self.reward_time, self.rand_reward_set_flag
            return self.target_index==self.chain_length-1

    def _test_reward_end(self, ts):
        #When finished reward, reset flag.
        if self.random_rewards:
            if ts > self.reward_time:
                self.rand_reward_set_flag = 0
                #print self.reward_time, self.rand_reward_set_flag, ts
        return ts > self.reward_time

    def _transform_coords(self, coords):
        ''' 
        Returns transformed coordinates based on rotation, offset, and scale traits
        '''
        offset = np.array(
            [[1, 0, 0, 0], 
            [0, 1, 0, 0], 
            [0, 0, 1, 0], 
            [self.offset[0], self.offset[1], self.offset[2], 1]]
        )
        scale = np.array(
            [[self.scale, 0, 0, 0], 
            [0, self.scale, 0, 0], 
            [0, 0, self.scale, 0], 
            [0, 0, 0, 1]]
        )
        old = np.concatenate((np.reshape(coords, -1), [1]))
        new = np.linalg.multi_dot((old, offset, scale, rotations[self.rotation]))
        return new[0:3]

    def _get_manual_position(self):
        '''
        Fetches joystick position
        '''
        if not hasattr(self, 'joystick'):
            return
        pt = self.joystick.get()
        if len(pt) == 0:
            return

        pt = pt[-1] # Use only the latest coordinate

        if len(pt) == 2:
            pt = np.concatenate((np.reshape(pt, -1), [0]))

        return [pt]

    def move_effector(self):
        ''' 
        Sets the 3D coordinates of the cursor. For manual control, uses
        motiontracker / joystick / mouse data. If no data available, returns None
        '''

        # Get raw input and save it as task data
        raw_coords = self._get_manual_position() # array of [3x1] arrays
        if raw_coords is None or len(raw_coords) < 1:
            self.no_data_count += 1
            self.update_report_stats()
            self.task_data['manual_input'] = np.empty((3,))
            return

        self.task_data['manual_input'] = raw_coords.copy()

        # Transform coordinates
        coords = self._transform_coords(raw_coords)
        if self.limit2d:
            coords[1] = 0

        # Set cursor position
        if not self.velocity_control:
            self.current_pt = coords
        else:
            epsilon = 2*(10**-2) # Define epsilon to stabilize cursor movement
            if sum((coords)**2) > epsilon:

                # Add the velocity (units/s) to the position (units)
                self.current_pt = coords / self.fps + self.last_pt
            else:
                self.current_pt = self.last_pt

        self.plant.set_endpoint_pos(self.current_pt)
        self.last_pt = self.current_pt.copy()

    def update_report_stats(self):
        super().update_report_stats()
        quality = 1 - self.no_data_count / max(1, self.cycle_count)
        self.reportstats['Input quality'] = "{} %".format(int(100*quality))

    @classmethod
    def get_desc(cls, params, log_summary):
        duration = round(log_summary['runtime'] / 60, 1)
        return "{}/{} succesful trials in {} min".format(
            log_summary['n_success_trials'], log_summary['n_trials'], duration)
コード例 #5
0
class Optitrack(traits.HasTraits):
    '''
    Enable reading of raw motiontracker data from Optitrack system
    Requires the natnet library from https://github.com/leoscholl/python_natnet
    To be used as a feature with the ManualControl task for the time being. However,
    ideally this would be implemented as a decoder :)
    '''

    optitrack_feature = traits.OptionsList(("rigid body", "skeleton", "marker"))
    smooth_features = traits.Int(1, desc="How many features to average")
    scale = traits.Float(DEFAULT_SCALE, desc="Control scale factor")
    offset = traits.Array(value=DEFAULT_OFFSET, desc="Control offset")

    hidden_traits = ['optitrack_feature', 'smooth_features']

    def init(self):
        '''
        Secondary init function. See riglib.experiment.Experiment.init()
        Prior to starting the task, this 'init' sets up the DataSource for interacting with the 
        motion tracker system and registers the source with the SinkRegister so that the data gets saved to file as it is collected.
        '''

        # Start the natnet client and recording
        import natnet
        now = datetime.now()
        local_path = "C:/Users/Orsborn Lab/Documents"
        session = "OptiTrack/Session " + now.strftime("%Y-%m-%d")
        take = now.strftime("Take %Y-%m-%d %H:%M:%S")
        logger = Logger(take)
        try:
            client = natnet.Client.connect(logger=logger)
            if self.saveid is not None:
                take += " (%d)" % self.saveid
                client.set_session(os.path.join(local_path, session))
                client.set_take(take)
                self.filename = os.path.join(session, take + '.tak')
                client._send_command_and_wait("LiveMode")
                time.sleep(0.1)
                if client.start_recording():
                    self.optitrack_status = 'recording'
            else:
                self.optitrack_status = 'streaming'
        except natnet.DiscoveryError:
            self.optitrack_status = 'Optitrack couldn\'t be started, make sure Motive is open!'
            client = optitrack.SimulatedClient()
        self.client = client

        # Create a source to buffer the motion tracking data
        from riglib import source
        self.motiondata = source.DataSource(optitrack.make(optitrack.System, self.client, self.optitrack_feature, 1))

        # Save to the sink
        from riglib import sink
        sink_manager = sink.SinkManager.get_instance()
        sink_manager.register(self.motiondata)
        super().init()

    def run(self):
        '''
        Code to execute immediately prior to the beginning of the task FSM executing, or after the FSM has finished running. 
        See riglib.experiment.Experiment.run(). This 'run' method starts the motiondata source and stops it after the FSM has finished running
        '''
        if not self.optitrack_status in ['recording', 'streaming']:
            import io
            self.terminated_in_error = True
            self.termination_err = io.StringIO()
            self.termination_err.write(self.optitrack_status)
            self.termination_err.seek(0)
            self.state = None
            super().run()
        else:
            self.motiondata.start()
            try:
                super().run()
            finally:
                print("Stopping optitrack")
                self.client.stop_recording()
                self.motiondata.stop()

    def _start_None(self):
        '''
        Code to run before the 'None' state starts (i.e., the task stops)
        '''
        #self.client.stop_recording()
        self.motiondata.stop()
        super()._start_None()

    def join(self):
        '''
        See riglib.experiment.Experiment.join(). Re-join the motiondata source process before cleaning up the experiment thread
        '''
        print("Joining optitrack datasource")
        self.motiondata.join()
        super().join()

    def cleanup(self, database, saveid, **kwargs):
        '''
        Save the optitrack recorded file into the database
        '''
        super_result = super().cleanup(database, saveid, **kwargs)
        print("Saving optitrack file to database...")
        try:
            database.save_data(self.filename, "optitrack", saveid, False, False) # Make sure you actually have an "optitrack" system added!
        except Exception as e:
            print(e)
            return False
        print("...done.")
        return super_result

    def _get_manual_position(self):
        ''' Overridden method to get input coordinates based on motion data'''

        # Get data from optitrack datasource
        data = self.motiondata.get() # List of (list of features)
        if len(data) == 0: # Data is not being streamed
            return
        recent = data[-self.smooth_features:] # How many recent coordinates to average
        averaged = np.nanmean(recent, axis=0) # List of averaged features
        if np.isnan(averaged).any(): # No usable coords
            return
        return averaged*100 # convert meters to centimeters