예제 #1
0
 def run(self):
     mod = Dragonfly_Module(self.mid, 0)
     mod.ConnectToMMM(self.server)
     for sub in self.subs:
         print "subscribing to %s" % (sub)
         mod.Subscribe(sub)
     mod.SendModuleReady()
     while (self.status()):
         msg = CMessage()
         rcv = mod.ReadMessage(msg, 0)
         if rcv == 1:
             if msg.GetHeader().msg_type in self.subs:
                 self.recv_msg(msg)
         sleep(.001)
예제 #2
0
class SimpleArbitrator(object):
    debug = True
    vel = np.zeros(rc.MAX_CONTROL_DIMS)
    #pos = np.zeros(rc.MAX_CONTROL_DIMS)
    autoVelControlFraction = \
        np.ones_like(rc.MDF_ROBOT_CONTROL_CONFIG().autoVelControlFraction)
    extrinsic_vel = np.zeros_like(rc.MDF_COMPOSITE_MOVEMENT_COMMAND().vel)
    intrinsic_vel = np.zeros_like(rc.MDF_COMPOSITE_MOVEMENT_COMMAND().vel)

    def __init__(self, config_file, server):
        self.load_config(config_file)
        self.setup_dragonfly(server)
        self.run()

    def load_config(self, config_file):
        self.config = SafeConfigParser()
        self.config.read(config_file)
        self.timer_tag = self.config.get('main', 'timer_tag')
        self.extrinsic_tags = self.config.get('main', 'extrinsic_tags').split()
        self.intrinsic_tags = self.config.get('main', 'intrinsic_tags').split()
        default_auto = float(self.config.get('main', 'default_auto'))
        self.autoVelControlFraction[:] = default_auto
        self.gate = 1.  # default value
        self.idle_gateable = 0.  # default value

    def setup_dragonfly(self, server):
        self.mod = Dragonfly_Module(rc.MID_SIMPLE_ARBITRATOR, 0)
        self.mod.ConnectToMMM(server)
        self.mod.Subscribe(MT_EXIT)
        for sub in subscriptions:
            self.mod.Subscribe(eval('rc.MT_%s' % (sub)))
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at", server

    def run(self):
        while True:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0.1)
            if rcv == 1:
                msg_type = msg.GetHeader().msg_type
                dest_mod_id = msg.GetHeader().dest_mod_id
                if msg_type == MT_EXIT:
                    if (dest_mod_id == 0) or (dest_mod_id
                                              == self.mod.GetModuleID()):
                        print 'Received MT_EXIT, disconnecting...'
                        self.mod.SendSignal(rc.MT_EXIT_ACK)
                        self.mod.DisconnectFromMMM()
                        break
                elif msg_type == rc.MT_PING:
                    respond_to_ping(self.mod, msg, 'SimpleArbitrator')
                else:
                    self.process_message(msg)

    def process_message(self, msg):
        '''
        Needs to:
        1) combine non-conflicting controlledDims e.g. from
        OPERATOR_MOVEMENT_COMMANDs, into either extrinsic or
        intrinsic commands
        2) combine intrinsic and extrinsic commands into final command
        '''
        msg_type = msg.GetHeader().msg_type
        if msg_type in [
                rc.MT_OPERATOR_MOVEMENT_COMMAND,
                rc.MT_PLANNER_MOVEMENT_COMMAND, rc.MT_EM_MOVEMENT_COMMAND,
                rc.MT_FIXTURED_MOVEMENT_COMMAND
        ]:

            if msg_type == rc.MT_OPERATOR_MOVEMENT_COMMAND:
                mdf = rc.MDF_OPERATOR_MOVEMENT_COMMAND()

            elif msg_type == rc.MT_PLANNER_MOVEMENT_COMMAND:
                mdf = rc.MDF_PLANNER_MOVEMENT_COMMAND()

            elif msg_type == rc.MT_EM_MOVEMENT_COMMAND:
                mdf = rc.MDF_EM_MOVEMENT_COMMAND()

            elif msg_type == rc.MT_FIXTURED_MOVEMENT_COMMAND:
                mdf = rc.MDF_FIXTURED_MOVEMENT_COMMAND()

            # MOVEMENT_COMMAND
            # ----------------
            # controlledDims
            # pos
            # sample_header
            # sample_interval
            # tag
            # vel
            # ----------------

            copy_from_msg(mdf, msg)
            tag = mdf.tag
            #if not tag in self.accepted_tags:
            #    return
            dim = np.asarray(mdf.controlledDims, dtype=bool)  #.astype(bool)
            if mdf.tag in self.intrinsic_tags:
                # intrinsic is AUTO command
                self.intrinsic_vel[dim] = np.asarray(mdf.vel, dtype=float)[dim]
                #print "intr_vel = " + " ".join(["%5.2f" % (x) for x in self.intrinsic_vel])
            elif mdf.tag in self.extrinsic_tags:
                #print "!"
                # extrinsic is non-AUTO, i.e. EM, command
                self.extrinsic_vel[dim] = np.asarray(mdf.vel, dtype=float)[dim]
                #self.extrinsic_vel[:8] *= self.gate

            if tag == self.timer_tag:
                self.send_output(mdf.sample_header)
        elif msg_type == rc.MT_ROBOT_CONTROL_CONFIG:
            mdf = rc.MDF_ROBOT_CONTROL_CONFIG()
            copy_from_msg(mdf, msg)
            self.autoVelControlFraction[:] = mdf.autoVelControlFraction
        elif msg_type == rc.MT_IDLE:
            mdf = rc.MDF_IDLE()
            copy_from_msg(mdf, msg)
            self.gate = float(np.asarray(mdf.gain, dtype=float).item())
        elif msg_type == rc.MT_IDLE_DETECTION_ENDED:
            self.gate = 1.0
        elif msg_type == rc.MT_TASK_STATE_CONFIG:
            mdf = rc.MDF_TASK_STATE_CONFIG()
            copy_from_msg(mdf, msg)
            self.idle_gateable = mdf.idle_gateable

    def get_combined_command(self):
        C = 1 - self.autoVelControlFraction  # extrinsic fraction
        d = self.intrinsic_vel
        u = self.extrinsic_vel
        combined = C * u + (1 - C) * d
        print "--------------------------------------"
        print "C" + " ".join(["%0.2f" % (x) for x in C])
        print "d" + " ".join(["%0.2f" % (x) for x in d])
        print "u" + " ".join(["%0.2f" % (x) for x in u])
        print "+" + " ".join(["%0.2f" % (x) for x in combined])
        print "gain: ", self.gate
        print "gateable: ", self.idle_gateable
        return combined

    def send_output(self, sample_header):
        mdf = rc.MDF_COMPOSITE_MOVEMENT_COMMAND()
        mdf.tag = 'composite'
        vel = np.zeros_like(mdf.vel)
        vel[:] = self.get_combined_command()
        if self.idle_gateable == 1:
            vel[:8] *= self.gate
        mdf.vel[:] = vel
        mdf.sample_header = sample_header
        msg = CMessage(rc.MT_COMPOSITE_MOVEMENT_COMMAND)
        copy_to_msg(mdf, msg)
        self.mod.SendMessage(msg)
예제 #3
0
class RandomGen(object):
    def __init__(self, config_file, mm_ip):
        daq_config = self.load_config(config_file)
        self.setup_daq(daq_config)
        self.setup_dragonfly(mm_ip)
        self.serial_no = 2
        self.variable = 0  # 0 and 1 cause problems for LogReader
        self.run()

    def load_config(self, config_file):
        cfg = SafeConfigParser()
        cfg.read(config_file)
        daq_config = Config()
        daq_config.minV = cfg.getfloat('main', 'minV')
        daq_config.maxV = cfg.getfloat('main', 'maxV')
        daq_config.nsamp = cfg.getint('main', 'nsamp_per_chan_per_second')
        daq_config.nchan = cfg.getint('main', 'nchan')
        daq_config.nirq = self.freq = cfg.getint('main', 'nirq_per_second')
        return daq_config

    def setup_daq(self, daq_config):
        self.daq_task = DAQInterface(self, daq_config)
        self.daq_task.register_callback(self.on_daq_callback)
        print "DrAQonfly: DAQ configured"

    def setup_dragonfly(self, mm_ip):
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(mm_ip)
        self.mod.Subscribe(MT_EXIT)
        self.mod.Subscribe(rc.MT_PING)
        self.mod.SendModuleReady()
        print "DrAQonfly: connected to dragonfly"

    def on_daq_callback(self, data):
        mdf = rc.MDF_PLOT_POSITION()
        self.serial_no += 1
        mdf.tool_id = 0
        mdf.missing = 0
        self.variable += 1
        mdf.xyz[:] = np.array([self.variable] * 3)
        mdf.ori[:] = np.array(
            [self.variable] * 4
        )  # will work but need!!! reading modules to know the format of buffer
        #mdf.buffer[data.size:] = -1
        msg = CMessage(rc.MT_PLOT_POSITION)
        copy_to_msg(mdf, msg)
        self.mod.SendMessage(msg)
        print self.variable
        sys.stdout.write('|')
        sys.stdout.flush()

        # now check for exit message
        in_msg = CMessage()
        rcv = self.mod.ReadMessage(msg, 0)
        if rcv == 1:
            hdr = msg.GetHeader()
            msg_type = hdr.msg_type
            dest_mod_id = hdr.dest_mod_id
            if msg_type == MT_EXIT:
                if (dest_mod_id == 0) or (dest_mod_id
                                          == self.mod.GetModuleID()):
                    print "Received MT_EXIT, disconnecting..."
                    self.daq_task.StopTask()
                    self.mod.SendSignal(rc.MT_EXIT_ACK)
                    self.mod.DisconnectFromMMM()
                    self.stop()
            elif msg_type == rc.MT_PING:
                respond_to_ping(self.mod, msg, 'RandomGen')

    def run(self):
        self.daq_task.StartTask()
        print "!"
        while True:
            pass

    def stop(self):
        self.daq_task.StopTask()
        self.daq_task.ClearTask()
예제 #4
0
class PlotHead(threading.Thread):

    def __init__(self, parent, config_file, server):#, parent):
        #HasTraits.__init__(self)
        threading.Thread.__init__(self)
        self.daemon = True
        self.count = 0
        self.parent = parent
        self.plot_vertex_vec = np.array([3,-2,2])
        self.load_config(config_file)
        self.setup_dragonfly(server)
        self.start()
    
    def load_config(self, config_file):
        cfg = PyFileConfigLoader(config_file)
        cfg.load_config()
        self.config = cfg.config
        self.filename = self.config.head_model
        
    def process_message(self, msg):
        # read a Dragonfly message
        msg_type = msg.GetHeader().msg_type
        dest_mod_id = msg.GetHeader().dest_mod_id
        if  msg_type == MT_EXIT:
            if (dest_mod_id == 0) or (dest_mod_id == self.mod.GetModuleID()):
                print 'Received MT_EXIT, disconnecting...'
                self.mod.SendSignal(rc.MT_EXIT_ACK)
                self.mod.DisconnectFromMMM()
                return
        elif msg_type == rc.MT_PING:
            respond_to_ping(self.mod, msg, 'PlotHead')
        elif msg_type == rc.MT_PLOT_POSITION:
            in_mdf = rc.MDF_PLOT_POSITION()
            copy_from_msg(in_mdf, msg)
            tail = np.array(in_mdf.xyz[:])*0.127 + (self.plot_vertex_vec)#Hotspot position
            head = np.array(in_mdf.ori[:3])/4 #Vector head of coil, used to find ori
                 
            if np.any(np.isnan(tail)) == True:
                pass
            elif np.any(np.isnan(head)) == True:
                 pass
            elif np.any(np.isinf(tail)) == True:
                pass
            elif np.any(np.isinf(head)) == True:
                pass
            else:
                queue.put(np.vstack((head, tail)))
                self.count=+1
                print 'sent message'
        elif msg_type == rc.MT_MNOME_STATE:
            in_mdf = rc.MDF_MNOME_STATE()
            copy_from_msg(in_mdf, msg)
            if in_mdf.state == 0:
                print 'got clear'
                self.parent.reset = True
               
                
        
    def setup_dragonfly(self, server):
        subscriptions = [MT_EXIT, \
                         rc.MT_PING, \
                         rc.MT_PLOT_POSITION, \
                         rc.MT_MNOME_STATE]
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(server)
        for sub in subscriptions:
            self.mod.Subscribe(sub)
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at ", server
    
   # def timer_event(self, parent):
    def run(self):
        while True:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0)
            if rcv == 1:
                self.process_message(msg)
예제 #5
0
class RTFT(object):
    def __init__(self, config_file, server):
        self.load_config(config_file)
        self.init_gui()
        self.setup_dragonfly(server)
        self.solo = True  #false if executed from demigod executive file
        self.RTFT_display = True  #default = True. if message from executive, then use that value
        self.state = -1  #-1= between trials 0 = outside target, 1 = close enough, waiting, 2 = close enough, hold time met
        self.start_hold = time.time()
        self.run()

    def load_config(self, config_file):  #Default config file is RTFT_CONFIG
        self.config = SafeConfigParser()
        self.config.read(config_file)
        self.rate = float(self.config.get('main', 'rate'))
        self.target_vector = [
            float(x)
            for x in self.config.get('main', 'target_vector').split(" ")
        ]
        self.target_color = [
            float(x)
            for x in self.config.get('main', 'target_color').split(" ")
        ]
        self.target_rad = float(self.config.get('main', 'target_radius'))
        self.ball_rad = float(self.config.get('main', 'cursor_radius'))
        self.ball_color = [
            float(x)
            for x in self.config.get('main', 'cursor_color').split(" ")
        ]
        self.max_factor = float(self.config.get('main', 'max_factor'))
        self.force_scale = float(self.config.get('main', 'force_scale'))
        self.threshold = float(self.config.get('main', 'threshold'))
        self.hold_time = float(self.config.get('main', 'hold_time'))

    def setup_dragonfly(self, server):
        subscriptions = [MT_EXIT, \
                         rc.MT_PING, \
                         rc.MT_FT_DATA, \
                         rc.MT_FT_COMPLETE, \
                         rc.MT_RTFT_CONFIG]
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(server)
        for sub in subscriptions:
            self.mod.Subscribe(sub)
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at ", server

    def init_gui(self):
        # Is the orientation matrix missing here????
        self.length = 100
        wallR = box(pos=vector(self.length / 2., 0, 0),
                    size=(0.2, self.length, self.length),
                    color=color.green)
        wallB = box(pos=vector(0, 0, -self.length / 2.),
                    size=(self.length, self.length, 0.2),
                    color=color.white)
        wallDown = box(pos=vector(0, -self.length / 2., 0),
                       size=(self.length, 0.2, self.length),
                       color=color.red)
        wallUp = box(pos=vector(0, self.length / 2., 0),
                     size=(self.length, 0.2, self.length),
                     color=color.white)
        wallL = box(pos=vector(-self.length / 2., 0, 0),
                    size=(0.2, self.length, self.length),
                    color=color.blue)

        self.unit_target = self.target_vector / np.linalg.norm(
            self.target_vector)
        self.target_position = np.array(
            self.unit_target) * self.max_factor * self.force_scale
        self.ball = sphere(pos=[0, 0, 0],
                           radius=self.ball_rad,
                           color=self.ball_color)
        self.target = sphere(pos=self.target_position,
                             radius=self.target_rad,
                             color=self.target_color)
        self.shadow_cursor = ring(pos=[0, -self.length / 2, 0],
                                  axis=(0, 10, 0),
                                  radius=self.ball_rad,
                                  thickness=1,
                                  color=[0.25, 0.25, 0.25])
        self.shadow_target = ring(pos=[
            self.target_position[0], -self.length / 2, self.target_position[2]
        ],
                                  axis=(0, 10, 0),
                                  radius=self.ball_rad,
                                  thickness=1,
                                  color=[0.25, 0.25, 0.25])

    def run(self):
        while True:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, -1)
            if rcv == 1:
                msg_type = msg.GetHeader().msg_type
                dest_mod_id = msg.GetHeader().dest_mod_id
                if msg_type == MT_EXIT:
                    if (dest_mod_id == 0) or (dest_mod_id
                                              == self.mod.GetModuleID()):
                        print 'Received MT_EXIT, disconnecting...'
                        self.mod.SendSignal(rc.MT_EXIT_ACK)
                        self.mod.DisconnectFromMMM()
                        self.pi.ser.close()
                        break
                elif msg_type == rc.MT_PING:
                    respond_to_ping(self.mod, msg, 'RTFT_iso')
                else:
                    self.process_msg(msg)

    def process_msg(self, in_msg):
        header = in_msg.GetHeader()
        if header.msg_type == rc.MT_FT_DATA:
            mdf = rc.MDF_FT_DATA()
            copy_from_msg(mdf, in_msg)
            rate(self.rate)
            self.ball.pos = vector(mdf.F[0:3])
            self.shadow_cursor.pos = vector(
                [mdf.F[0], -self.length / 2, mdf.F[2]])
            self.unit_target = np.array(self.target_vector) / np.linalg.norm(
                self.target_vector)
            self.target_position = np.array(
                self.unit_target) * self.max_factor * self.force_scale
            self.target.pos = self.target_position
            self.shadow_target.pos = [
                self.target_position[0], -self.length / 2,
                self.target_position[2]
            ]
            distance = [a - b for a, b in zip(self.ball.pos, self.target.pos)]
            if (distance[0]**2 + distance[1]**2 + distance[2]**2)**(
                    1 / 2.) >= self.threshold and self.RTFT_display:
                self.ball.color = self.ball_color
                self.state = 0
            elif (distance[0]**2 + distance[1]**2 + distance[2]**2)**(
                    1 / 2.) < self.threshold and self.RTFT_display:
                if self.state == 0:  # if previous sample was outside radius, and now we're inside...
                    self.start_hold = time.time()
                    self.state = 1
                    self.ball.color = color.orange
                else:
                    if time.time() > (self.start_hold + self.hold_time):
                        self.ball.color = color.green
                        self.target.visible = False
                        self.shadow_target.visible = False
                        self.state = 2
                        out_mdf = rc.MDF_FT_COMPLETE()
                        out_mdf.FT_COMPLETE = self.state
                        out_mdf.sample_header = mdf.sample_header
                        msg = CMessage(rc.MT_FT_COMPLETE)
                        copy_to_msg(out_mdf, msg)
                        self.mod.SendMessage(msg)
                    else:
                        self.state = 1
                        self.ball.color = color.orange
            else:
                self.state = -1

            if self.state == 2 and self.solo:  #if no executive file
                self.target.pos = [
                    float(x) for x in [
                        np.random.rand(1, 1) * self.max_factor *
                        self.force_scale,
                        np.random.rand(1, 1) * self.max_factor *
                        self.force_scale,
                        np.random.rand(1, 1) * self.max_factor *
                        self.force_scale
                    ]
                ]
                self.shadow_target.pos = [
                    self.target.pos[0], -self.length / 2, self.target.pos[2]
                ]

            sys.stdout.write(
                "%7.4f, %5d, %16.2f\n" %
                (mdf.F[2], self.state,
                 (self.start_hold + self.hold_time) - time.time()))
            #msg_str = "%7.4f   " * 6 + "\n"
            #sys.stdout.write(msg_str % (mdf.F[0], mdf.F[1], mdf.F[2],
            #                            mdf.T[0], mdf.T[1], mdf.T[2]))
            sys.stdout.flush()

        elif header.msg_type == rc.MT_RTFT_CONFIG:
            mdf = rc.MDF_RTFT_CONFIG()
            copy_from_msg(mdf, in_msg)
            self.max_factor = mdf.max_factor
            self.RTFT_display = mdf.RTFT_display
            self.target_vector = mdf.target_vector[:]
            self.ball.visible = mdf.cursor_visible
            self.target.visible = mdf.target_visible
            self.shadow_target.visible = mdf.shadow_target_visible
            self.shadow_cursor.visible = mdf.shadow_cursor_visible
            self.ball_color = [1, 0, 0]
            self.solo = False
예제 #6
0
class BackgroundProcess(threading.Thread):
    
    def __init__(self, parent, config_file, server):
        threading.Thread.__init__(self)
        self.load_config(config_file)
        self.parent = parent
        self.collect_count = 0
        self.hotspot_count = 0
        self.TMS_trigger = False
        self.new_hotspot_data = False  
        self.new_collect_data = False
        self.ext_trig = False
        self.daemon = True
        self.setup_dragonfly(server)
        self.ext_trig = dit.DAQOut()
        self.setup_buffers()
        self.start()
     
    def load_config(self, config_file):
        cfg = SafeConfigParser()
        cfg.read(config_file)
        self.config = Config()
        #daq_config.minV  = cfg.getfloat('main', 'minV')
        #daq_config.maxV  = cfg.getfloat('main', 'maxV')
        self.config.nsamp = cfg.getint('main', 'nsamp_per_chan_per_second')
        self.config.nchan = cfg.getint('main', 'nchan')
        self.config.nemg = cfg.getint('main', 'nemg')
        self.config.nirq  = self.freq = cfg.getint('main', 'nirq_per_second')
        self.config.pre_trig = cfg.getfloat('main', 'pre_trigger')
        self.config.trig_chan = cfg.getfloat('main', 'trig_chan')
        self.config.perchan = self.config.nsamp / self.config.nirq
        self.config.npt   = self.config.nsamp * self.config.nchan / self.config.nirq
        self.config.pre_trig_samp = self.config.pre_trig * self.config.nsamp
        assert((self.config.nsamp * self.config.nchan) % self.config.nirq == 0)
        assert(self.config.nsamp % self.config.nirq == 0)
        
    
    def setup_dragonfly(self, server):
        subscriptions = [MT_EXIT, \
                         rc.MT_PING, \
                         rc.MT_DAQ_DATA, \
                         rc.MT_SAMPLE_GENERATED, \
                         rc.MT_TMS_TRIGGER]
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(server)
        for sub in subscriptions:
            self.mod.Subscribe(sub)
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at ", server
        
    def setup_buffers(self):
        self.cols = 2
        self.rows = self.config.nemg / self.cols
        self.npt = 2500 
        self.old_data = np.zeros((self.config.nchan, self.npt+300)) #  bigger to account for variable trig location in buffer
        self.new_data = np.zeros((self.config.nchan, self.npt+300))
        self.collect_data = np.zeros((self.config.nemg, self.npt))
        self.hotspot_data = np.zeros((self.config.nemg, self.npt))
        
        
    def run(self):
        while True:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0)
            if rcv == 1:
                # read a Dragonfly message
                msg_type = msg.GetHeader().msg_type
                dest_mod_id = msg.GetHeader().dest_mod_id
                if  msg_type == MT_EXIT:
                    if (dest_mod_id == 0) or (dest_mod_id == self.mod.GetModuleID()):
                        print 'Received MT_EXIT, disconnecting...'
                        self.mod.SendSignal(rc.MT_EXIT_ACK)
                        self.mod.DisconnectFromMMM()
                        return
                elif msg_type == rc.MT_PING:
                    respond_to_ping(self.mod, msg, 'fast_display_rms')
                else:
                    self.process_message(msg)
                

    def process_message(self, msg):
        msg_type = msg.GetHeader().msg_type
        dest_mod_id = msg.GetHeader().dest_mod_id
        if msg_type == rc.MT_TMS_TRIGGER:
            self.ext_trig.run()
            self.TMS_trigger = True
            
        else:
            # if it is a NiDAQ message from channels 0-7, plot the data
            #self.counter += 1
            if msg_type == rc.MT_DAQ_DATA:
                #sys.stdout.write("*")
                #sys.stdout.flush()
                mdf = rc.MDF_DAQ_DATA()
                copy_from_msg(mdf, msg)
                # add data to data buffers (necessary, or just use graphics buffers?)
                # update plots to new data buffers
                buf = mdf.buffer
 
                self.new_data[:,:-self.config.perchan] = self.old_data[:,self.config.perchan:]
                for i in xrange(self.config.nchan):
                    #if i == 0:
                    #    print mdf.buffer[perchan * i:perchan * (i + 1)].size
                    self.new_data[i, -self.config.perchan:] = buf[i:self.config.nchan * self.config.perchan:self.config.nchan]
                self.old_data[:] = self.new_data[:]
                
        if self.parent.current_tab ==  'Collect':
            if self.TMS_trigger:
                if self.config.pre_trig_samp <= np.argmax(self.old_data[self.config.trig_chan, :] >= 3) <= (self.config.pre_trig_samp)+200:
                    self.trig_index = np.argmax(self.old_data[self.config.trig_chan, :] >= 3)
                    self.collect_data = self.old_data[:self.config.nemg, self.trig_index - self.config.pre_trig_samp:self.trig_index + self.npt - self.config.pre_trig_samp]
                    self.old_data = self.old_data * 0
                    self.new_data = self.new_data * 0
                    self.new_collect_data = True
                    self.TMS_trigger = False
                
        if self.parent.current_tab == 'Hotspot':
            if self.config.pre_trig_samp <= np.argmax(self.old_data[self.config.trig_chan, :] >= 3) <= self.config.pre_trig_samp+200:
                self.trig_index = np.argmax(self.old_data[self.config.trig_chan, :] >= 3)
                self.hotspot_data = self.old_data[:self.config.nemg, self.trig_index - self.config.pre_trig_samp:self.trig_index + self.npt - self.config.pre_trig_samp]
                self.new_hotspot_data = True
                self.old_data = self.old_data * 0
                self.new_data = self.new_data * 0
예제 #7
0
class AppStarter(object):
    def __init__(self, server):
        self.setup_dragonfly(server)
        self.run()

    def setup_dragonfly(self, server):
        self.host_name = platform.uname()[1]
        self.host_os = platform.system()

        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(server)
        self.mod.Subscribe(rc.MT_APP_START)
        self.mod.Subscribe(rc.MT_PING)
        self.mod.Subscribe(MT_EXIT)
        self.mod.Subscribe(MT_KILL)

        self.mod.SendModuleReady()
        print "Connected to Dragonfly at", server

    # this one is slightly different than the one in Common/python, leave this one here
    def respond_to_ping(self, msg, module_name):
        #print "PING received for '{0}'".format(p.module_name)

        dest_mod_id = msg.GetHeader().dest_mod_id
        p = rc.MDF_PING()
        copy_from_msg(p, msg)

        if (p.module_name.lower() == module_name.lower()) or (p.module_name == "*") or \
            (dest_mod_id == self.mod.GetModuleID()):
            mdf = rc.MDF_PING_ACK()
            mdf.module_name = module_name + ":" + self.host_name  # + ":" + self.host_os
            msg_out = CMessage(rc.MT_PING_ACK)
            copy_to_msg(mdf, msg_out)
            self.mod.SendMessage(msg_out)

    def run(self):
        while True:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0.1)
            if rcv == 1:
                msg_type = msg.GetHeader().msg_type

                if msg_type == rc.MT_APP_START:

                    try:
                        mdf = rc.MDF_APP_START()
                        copy_from_msg(mdf, msg)
                        config = mdf.config

                        print "Config: %s" % config

                        # -- to do --
                        # get a list of all modules in appman.conf for this host
                        # see if any of the modules above are already/still running
                        # start non-running modules
                        # -- to do --

                        print "Creating scripts"
                        appman.create_script(config, self.host_name)
                        print "Starting modules on host: %s" % self.host_name
                        appman.run_script(self.host_name)

                        self.mod.SendSignal(rc.MT_APP_START_COMPLETE)

                    except Exception, e:
                        print "ERROR: %s" % (e)

                elif msg_type == rc.MT_PING:
                    print 'got ping'
                    self.respond_to_ping(msg, 'AppStarter')

                # we use this msg to stop modules individually
                elif msg_type == MT_EXIT:
                    print 'got exit'

                elif msg_type == MT_KILL:
                    print 'got kill'
                    appman.kill_modules()
예제 #8
0
class Metronome(object):
    def __init__(self, config_file, mm_ip):
        self.load_config(config_file)
        self.count = 0
        self.pause_state = True
        self.setup_Dragonfly(mm_ip)
        self.calc_rates()
        self.run()

    def load_config(self, config_file):
        self.config = SafeConfigParser()
        self.config.read(config_file)
        self.pretrigger_time = self.config.getfloat('metronome',
                                                    'pretrigger time')
        self.metronome_period = self.config.getfloat('metronome',
                                                     'metronome period')
        self.in_msg_type = 'DAQ_DATA'  # trigger msg
        self.in_msg_num = eval('rc.MT_%s' % (self.in_msg_type.upper()))
        print self.in_msg_num, 'config load complete'

    def calc_rates(self):
        self.in_msg_freq = 1 / self.chk_msg()
        self.metronome_count = self.metronome_period * self.in_msg_freq
        if self.pretrigger_time > 0:  #negative pre-trigger fire after metronome
            self.trigger_out_count = self.metronome_count - self.pretrigger_time * self.in_msg_freq
        else:
            self.trigger_out_count = self.metronome_count + self.pretrigger_time * self.in_msg_freq
        print 'Got frequency! %d' % self.in_msg_freq
        print self.metronome_count, self.trigger_out_count

    def chk_msg(self):
        while True:
            in_msg = CMessage()
            rcv = self.mod.ReadMessage(in_msg, 0.1)
            if rcv == 1:
                msg_type = in_msg.GetHeader().msg_type
                if msg_type == MT_EXIT:
                    if (dest_mod_id == 0) or (dest_mod_id
                                              == self.mod.GetModuleID()):
                        print 'Received MT_EXIT, disconnecting...'
                        self.mod.SendSignal(rc.MT_EXIT_ACK)
                        self.mod.DisconnectFromMMM()
                        break
                elif msg_type == rc.MT_PING:
                    respond_to_ping(self.mod, in_msg, 'Metronome')
                elif msg_type == self.in_msg_num:
                    in_mdf = eval('rc.MDF_%s()' % (self.in_msg_type.upper()))
                    copy_from_msg(in_mdf, in_msg)
                    return in_mdf.sample_header.DeltaTime

    def setup_Dragonfly(self, mm_ip):
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(mm_ip)
        self.mod.Subscribe(MT_EXIT)
        self.mod.Subscribe(rc.MT_PING)
        self.mod.Subscribe(self.in_msg_num)
        self.mod.Subscribe(rc.MT_MNOME_STATE)
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at", mm_ip

    def run(self):
        while True:
            in_msg = CMessage()
            rcv = self.mod.ReadMessage(in_msg, 0.1)
            if rcv == 1:
                msg_type = in_msg.GetHeader().msg_type
                if msg_type == MT_EXIT:
                    if (dest_mod_id == 0) or (dest_mod_id
                                              == self.mod.GetModuleID()):
                        print 'Received MT_EXIT, disconnecting...'
                        self.mod.SendSignal(rc.MT_EXIT_ACK)
                        self.mod.DisconnectFromMMM()
                        break
                elif msg_type == rc.MT_PING:
                    respond_to_ping(self.mod, in_msg, 'Metronome')
                elif msg_type == rc.MT_MNOME_STATE:
                    print 'got message'
                    in_mdf = rc.MDF_MNOME_STATE()
                    copy_from_msg(in_mdf, in_msg)
                    if in_mdf.state == 0:
                        print 'got stop'
                        self.pause_state = True
                        self.count = 0
                    elif in_mdf.state == 1:
                        print 'got start'
                        self.pause_state = False
                        self.count = 0
                    elif in_mdf.state == 2:
                        print 'got pause'
                        self.pause_state = True
                        self.count = 0
                elif msg_type == self.in_msg_num:
                    if self.pause_state:
                        pass
                    else:
                        self.count += 1
                        if self.pretrigger_time > 0:
                            if self.count == self.metronome_count:
                                in_mdf = eval('rc.MDF_%s()' %
                                              (self.in_msg_type.upper()))
                                copy_from_msg(in_mdf, in_msg)
                                out_mdf = rc.MDF_TMS_TRIGGER()
                                out_mdf.sample_header = in_mdf.sample_header
                                out_msg = CMessage(rc.MT_TMS_TRIGGER)
                                copy_to_msg(out_mdf, out_msg)
                                self.mod.SendMessage(out_msg)
                                self.count = 0 - int(
                                    np.random.uniform(0, 1.5, 1)[0] *
                                    self.in_msg_freq)

                            if self.count == self.trigger_out_count:
                                sound_thread = threading.Thread(
                                    target=self.play_sound)
                                sound_thread.start()

                        else:
                            if self.count == self.trigger_out_count:
                                in_mdf = eval('rc.MDF_%s()' %
                                              (self.in_msg_type.upper()))
                                copy_from_msg(in_mdf, in_msg)
                                out_mdf = rc.MDF_TMS_TRIGGER()
                                out_mdf.sample_header = in_mdf.sample_header
                                out_msg = CMessage(rc.MT_TMS_TRIGGER)
                                copy_to_msg(out_mdf, out_msg)
                                self.mod.SendMessage(out_msg)

                            if self.count == self.metronome_count:
                                self.count = 0 - int(
                                    np.random.uniform(0, 1.5, 1)[0] *
                                    self.in_msg_freq)
                                sound_thread = threading.Thread(
                                    target=self.play_sound)
                                sound_thread.start()

    def play_sound(self):
        winsound.Beep(1500, 1000)
예제 #9
0
class PlotHead(HasTraits):
    scene = Instance(MlabSceneModel, ())

    # The layout of the panel created by Traits
    view = View(Item('scene',
                     editor=SceneEditor(),
                     resizable=True,
                     show_label=False),
                resizable=True)

    def __init__(self, config_file, server, parent):
        HasTraits.__init__(self)
        self.count = 0
        self.parent = parent
        self.pointer_position = np.zeros((1, 3))
        self.head_data = np.zeros((1, 3))
        self.load_config(config_file)
        self.init_plot()
        self.setup_dragonfly(server)

    def load_config(self, config_file):
        cfg = PyFileConfigLoader(config_file)
        cfg.load_config()
        self.config = cfg.config
        self.filename = self.config.head_model
        self.plate = self.config.tools.index('CB609')
        self.marker = self.config.tools.index('CT315')
        self.glasses = self.config.tools.index('ST568')
        self.pointer = self.config.tools.index('P717')
        self.pointer_Ti = np.array(self.config.tool_list[self.pointer].Ti)
        self.pointer_Qi = qa.norm(
            np.array(self.config.tool_list[self.pointer].Qi))
        self.pointer_Ni = np.array(self.config.tool_list[self.pointer].Ni)
        self.pointer_Xi = self.pointer_Ni - self.pointer_Ti
        self.tp = TP(self.pointer_Qi, self.pointer_Ni, self.pointer_Ti)

    def init_plot(self):
        '''
        # create a window with 14 plots (7 rows x 2 columns)
        ## create a window with 8 plots (4 rows x 2 columns)
        reader = tvtk.OBJReader()
        reader.file_name = self.filename
        mapper = tvtk.PolyDataMapper()
        mapper.input = reader.output
        actor = tvtk.Actor()
        mapper.color_mode = 0x000000
        actor.mapper = mapper
        actor.orientation = (-90,180,0)
        self.scene.add_actor(actor)
        '''
        self.plot = mlab.plot3d(0, 0, 0, color=(0, 0, 1))
        self.plot2 = mlab.plot3d(0, 0, 0, color=(1, 0, 0))
        self.pl = self.plot.mlab_source
        self.pl2 = self.plot2.mlab_source

        self.timer = wx.Timer(self.parent)
        self.timer.Start(50)
        self.parent.Bind(wx.EVT_TIMER, self.timer_event)

    def process_message(self, msg):
        # read a Dragonfly message
        msg_type = msg.GetHeader().msg_type
        dest_mod_id = msg.GetHeader().dest_mod_id
        if msg_type == MT_EXIT:
            if (dest_mod_id == 0) or (dest_mod_id == self.mod.GetModuleID()):
                print 'Received MT_EXIT, disconnecting...'
                self.mod.SendSignal(rc.MT_EXIT_ACK)
                self.mod.DisconnectFromMMM()
                return
        elif msg_type == rc.MT_PING:
            respond_to_ping(self.mod, msg, 'PlotHead')
        elif msg_type == rc.MT_POLARIS_POSITION:
            in_mdf = rc.MDF_POLARIS_POSITION()
            copy_from_msg(in_mdf, msg)
            positions = np.asarray(in_mdf.xyz[:])
            orientations = self.shuffle_q(np.asarray(in_mdf.ori[:]))
            if in_mdf.tool_id == (self.pointer + 1):
                Qf = qa.norm(orientations)
                Qr = qa.mult(Qf, qa.inv(self.pointer_Qi)).flatten()
                #find_nans(self.store_head, Qr, 'Qr')
                Tk = positions
                #find_nans(self.store_head, Tk, 'Tk')
                tip_pos = (qa.rotate(Qr, self.pointer_Xi) + Tk).flatten()
                self.pointer_position = np.append(self.pointer_position,
                                                  (tip_pos[np.newaxis, :]),
                                                  axis=0)
                #self.pl.reset(x=self.pointer_position[:,0], y=self.pointer_position[:,1], z=self.pointer_position[:,2])
                print("old=", tip_pos)
                print("new=", self.tp.get_pos(orientations, positions)[0])
            #elif in_mdf.tool_id == (self.glasses + 1):
            #    self.head_data = np.append(self.head_data, (head[np.newaxis,:]), axis=0)
            #    self.pl2.reset(x=self.head_data[:,0], y=self.head_data[:,1], z=self.head_data[:,2])

    def setup_dragonfly(self, server):
        subscriptions = [MT_EXIT, \
                         rc.MT_PING, \
                         rc.MT_POLARIS_POSITION]
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(server)
        for sub in subscriptions:
            self.mod.Subscribe(sub)
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at ", server

    def timer_event(self, parent):
        done = False
        sys.stdout.flush()
        while not done:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0)
            if rcv == 1:
                self.process_message(msg)
            else:
                done = True

    def shuffle_q(self, q):
        return np.roll(q, -1, axis=0)
예제 #10
0
class MplCanvas(FigureCanvas):
    def __init__(self, parent=None, width=8, height=10, dpi=80):
        self.parent = parent
        self.redraw_yticks = True

        self.figure = Figure(figsize=(width, height), dpi=dpi, facecolor='#bbbbbb')
        FigureCanvas.__init__(self, self.figure)

        self.setParent(parent)

        FigureCanvas.setSizePolicy(self,
                                   QtGui.QSizePolicy.Expanding,
                                   QtGui.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)


    def run(self, config_file, server):
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(server)
        self.msg_types = ['END_TASK_STATE', 'SESSION_CONFIG', 'EM_DECODER_CONFIGURATION']
        self.msg_types.sort()
        self.mod.Subscribe(MT_EXIT)
        self.mod.Subscribe(rc.MT_PING)
        for i in self.msg_types:
            self.mod.Subscribe(eval('rc.MT_%s' % (i)))
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at", server
        print "mod_id = ", self.mod.GetModuleID()

        self.config_file = config_file
        self.load_config()
        self.init_vars()
        self.init_plot()
        self.init_legend()

        timer = QtCore.QTimer(self)
        QtCore.QObject.connect(timer, QtCore.SIGNAL("timeout()"), self.timer_event)
        timer.start(10)


    def init_vars(self):
        self.num_trials = 0
        self.reset_counters()
        self.msg_cnt = 0
        self.console_disp_cnt = 0


    def reset_counters(self):
        self.trial_sync = 0
        self.num_trials_postcalib = 0
        self.num_trial_started_postcalib = 0
        self.num_trial_givenup_postcalib = 0
        self.num_trial_successful_postcalib = 0
        self.shadow_num_trial_started_postcalib = 0
        self.shadow_num_trial_givenup_postcalib = 0
        self.shadow_num_trial_successful_postcalib = 0
        self.started_window = []
        self.givenup_window = []
        self.success_window = []
        self.shadow_started_window = []
        self.shadow_givenup_window = []
        self.shadow_success_window = []
        self.percent_start = 0
        self.percent_success = 0
        self.percent_givenup = 0
        self.hist_narrow_SUR = []
        self.hist_narrow_GUR = []
        self.hist_narrow_STR = []

        self.hist_wide_SUR = []
        self.hist_wide_GUR = []
        self.hist_wide_STR = []

    def update_gui_label_data(self):
        self.parent.GALL.setText("%d" % self.num_trials_postcalib)
        self.parent.GSTR.setText("%d" % self.num_trial_started_postcalib) 
        self.parent.GGUR.setText("%d" % self.num_trial_givenup_postcalib) 
        self.parent.GSUR.setText("%d" % self.num_trial_successful_postcalib) 


    #def reload_config(self):
    #    self.load_config()
    #    for ax in self.figure.axes:
    #        self.figure.delaxes(ax)
    #    self.figure.clear()
    #    self.draw()
    #    self.init_plot(True)
    #    self.init_legend()
    #    self.redraw_yticks = True

    #def load_config(self):
    #    self.config = ConfigObj(self.config_file, unrepr=True)

    def load_config(self):
        self.config = SafeConfigParser()
        self.config.read(self.config_file)
        self.window_narrow = self.config.getint('general', 'window_narrow')
        self.window_wide = self.config.getint('general', 'window_wide')
        self.task_state_codes = {}
        for k, v in self.config.items('task state codes'):
            self.task_state_codes[k] = int(v)


    def init_plot(self, clear=False):
        self.nDims = 3

        self.figure.subplots_adjust(bottom=.05, right=.98, left=.08, top=.98, hspace=0.07)

        self.ax = []
        self.old_size = []
        self.ax_bkg = []
        self.lines = []

        ax = self.figure.add_subplot(1,1,1)

        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.85, box.height])

        self.reset_axis(ax)
        self.draw()

        bbox_width = ax.bbox.width
        bbox_height = ax.bbox.height
        if clear == True:   # force to redraw
            bbox_width = 0
            bbox_height = 0

        self.old_size.append( (bbox_width, bbox_height) )
        self.ax_bkg.append(self.copy_from_bbox(ax.bbox))

        self.colors = ['k', 'r', 'g']
        self.styles = ['-', '-', '--']

        for d in range(self.nDims):
            for m in range(3):
                line, = ax.plot([], [], self.colors[d]+self.styles[m], lw=1.5, aa=True, animated=True)
                line.set_ydata([0, 0])
                line.set_xdata([0, 1])
                self.lines.append(line)
                self.draw()

        self.ax.append(ax)


    def reset_axis(self, ax): #, label):
        ax.grid(True)
        ax.set_ylim(-1, 101)
        ax.set_autoscale_on(False)
        ax.get_xaxis().set_ticks([])
        for tick in ax.get_yticklabels():
            tick.set_fontsize(9)


    def init_legend(self):
        legnd = []

        for d in range(self.nDims):
            for m in range(3):
                line = matplotlib.lines.Line2D([0,0], [0,0], color=self.colors[d], ls=self.styles[m], lw=1.5)
                legnd.append(line)

        legend_text = []
        legend_text.append('STR')
        legend_text.append('STR%d' % self.window_narrow)
        legend_text.append('STR%d' % self.window_wide)
        legend_text.append('GUR')
        legend_text.append('GUR%d' % self.window_narrow)
        legend_text.append('GUR%d' % self.window_wide)
        legend_text.append('SUR')
        legend_text.append('SUR%d' % self.window_narrow)
        legend_text.append('SUR%d' % self.window_wide)

        self.figure.legend(legnd, legend_text, loc = 'right', bbox_to_anchor=(1, 0.5),
                           frameon=False, labelspacing=1.5, prop={'size':'11'})
        self.draw()


    def timer_event(self):
        done = False
        while not done:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0)
            if rcv == 1:
                msg_type = msg.GetHeader().msg_type

                # SESSION_CONFIG => start of session
                if msg_type == rc.MT_SESSION_CONFIG:
                    #self.msg_cnt += 1
                    self.num_trials = 0
                    self.reset_counters()
                    self.update_gui_label_data()


                # EM_DECODER_CONFIGURATION => end of an adaptation round
                elif msg_type == rc.MT_EM_DECODER_CONFIGURATION:
                    #self.msg_cnt += 1
                    self.reset_counters()
                    self.update_gui_label_data()

                # END_TASK_STATE => end of a task
                elif msg_type == rc.MT_END_TASK_STATE:
                    #self.msg_cnt += 1
                    mdf = rc.MDF_END_TASK_STATE()
                    copy_from_msg(mdf, msg)

                    # need to know:
                    #    begin task state code
                    #    final task state code
                    #    intertrial state code

                    if (mdf.id == 1):
                        self.trial_sync = 1
                        self.shadow_started_window.append(0)

                    if (mdf.id == self.task_state_codes['begin']) & (mdf.outcome == 1):
                        if self.trial_sync:
                            #print "*** trial started ***"
                            #self.rewards_given += 1
                            self.shadow_num_trial_started_postcalib += 1
                            self.shadow_success_window.append(0)
                            self.shadow_givenup_window.append(0)
                            self.shadow_started_window[-1] = 1

                    if mdf.reason == "JV_IDLE_TIMEOUT":
                        if self.trial_sync:
                            self.shadow_num_trial_givenup_postcalib += 1
                            self.shadow_givenup_window[-1] = 1

                    if (mdf.id == self.task_state_codes['final']) & (mdf.outcome == 1):
                        if self.trial_sync:
                            #print "*** trial complete and successful"
                            self.shadow_num_trial_successful_postcalib += 1
                            self.shadow_success_window[-1] = 1

                    if (mdf.id == self.task_state_codes['intertrial']):
                        if self.trial_sync:
                            # do end-of-trial stuff here
                            self.num_trials += 1
                            self.num_trials_postcalib += 1
                            self.num_trial_started_postcalib = self.shadow_num_trial_started_postcalib
                            self.num_trial_successful_postcalib = self.shadow_num_trial_successful_postcalib
                            self.num_trial_givenup_postcalib = self.shadow_num_trial_givenup_postcalib

                            if len(self.shadow_success_window) > self.window_wide: #self.window_narrow:
                                self.shadow_success_window.pop(0)

                            if len(self.shadow_givenup_window) > self.window_wide: #self.window_narrow:
                                self.shadow_givenup_window.pop(0)

                            if len(self.shadow_started_window) > self.window_wide: #self.window_narrow:
                                self.shadow_started_window.pop(0)

                            self.success_window = copy.deepcopy(self.shadow_success_window)
                            self.started_window = copy.deepcopy(self.shadow_started_window)
                            self.givenup_window = copy.deepcopy(self.shadow_givenup_window)

                            if self.num_trials_postcalib > 0:
                                self.percent_start = 100 * self.num_trial_started_postcalib / self.num_trials_postcalib
                                self.percent_givenup = 100 * self.num_trial_givenup_postcalib / self.num_trials_postcalib
                                self.percent_success = 100 * self.num_trial_successful_postcalib / self.num_trials_postcalib


                            percent_success_wide_window = np.NAN
                            if len(self.success_window) >= self.window_wide:
                                num_success_window = np.sum(self.success_window)
                                percent_success_wide_window = 100 * num_success_window / len(self.success_window)

                            percent_givenup_wide_window = np.NAN
                            if len(self.givenup_window) >= self.window_wide:
                                num_givenup_window = np.sum(self.givenup_window)
                                percent_givenup_wide_window = 100 * num_givenup_window / len(self.givenup_window)

                            percent_started_wide_window = np.NAN
                            if len(self.started_window) >= self.window_wide:
                                num_started_window = np.sum(self.started_window)
                                percent_started_wide_window = 100 * num_started_window / len(self.started_window)

                            percent_success_narrow_window = np.NAN
                            if len(self.success_window) >= self.window_narrow:
                                success_window_narrow = self.success_window[len(self.success_window)-self.window_narrow:]
                                num_success_window = np.sum(success_window_narrow)
                                percent_success_narrow_window = 100 * num_success_window / len(success_window_narrow)

                            percent_givenup_narrow_window = np.NAN
                            if len(self.givenup_window) >= self.window_narrow:
                                givenup_window_narrow = self.givenup_window[len(self.givenup_window)-self.window_narrow:]
                                num_givenup_window = np.sum(givenup_window_narrow)
                                percent_givenup_narrow_window = 100 * num_givenup_window / len(givenup_window_narrow)

                            if len(self.started_window) >= self.window_narrow:
                                started_window_narrow = self.started_window[len(self.started_window)-self.window_narrow:]
                                num_started_window = np.sum(started_window_narrow)
                                percent_started_narrow_window = 100 * num_started_window / len(started_window_narrow)
                                self.hist_narrow_STR.append(percent_started_narrow_window)
                                self.hist_narrow_SUR.append(percent_success_narrow_window)
                                self.hist_narrow_GUR.append(percent_givenup_narrow_window)

                                self.hist_wide_STR.append(percent_started_wide_window)
                                self.hist_wide_SUR.append(percent_success_wide_window)
                                self.hist_wide_GUR.append(percent_givenup_wide_window)

                            self.update_gui_label_data()


                elif msg_type == rc.MT_PING:
                    respond_to_ping(self.mod, msg, 'TrialStatusDisplay')

                elif msg_type == MT_EXIT:
                    self.exit()
                    done = True

            else:
                done = True

                self.console_disp_cnt += 1
                if self.console_disp_cnt == 50:
                    self.update_plot()
                    self.console_disp_cnt = 0


    def update_plot(self):

        #print "All trials : %d" % (self.num_trials_postcalib)
        #print ""
        #print "GSTR: ", self.percent_start
        #print "GGUR: ", self.percent_givenup
        #print "GSUR: ", self.percent_success
        #print ""
        #print "STR win: ", self.started_window
        #print "GUP win: ", self.givenup_window
        #print "SUC win: ", self.success_window
        #print ""
        #print "nSTR: ", self.hist_narrow_STR
        #print "nGUR: ", self.hist_narrow_GUR
        #print "nSUR :", self.hist_narrow_SUR
        #print ""
        #print "wSTR: ", self.hist_wide_STR
        #print "wGUR: ", self.hist_wide_GUR
        #print "wSUR :", self.hist_wide_SUR
        #print ""
        #print "Msg cnt    : %d" % (self.msg_cnt)
        #print "\n ----------------------- \n"

        i = 0
        ax = self.ax[i]
        current_size = ax.bbox.width, ax.bbox.height
        if self.old_size[i] != current_size:
            self.old_size[i] = current_size
            self.draw()
            self.ax_bkg[i] = self.copy_from_bbox(ax.bbox)
        self.restore_region(self.ax_bkg[i])

        #if len(self.hist_narrow_STR) > 1:
        if not self.hist_narrow_STR:
            self.lines[0].set_ydata([self.percent_start, self.percent_start])
            self.lines[0].set_xdata([0, 1])
            ax.draw_artist(self.lines[0])

            self.lines[3].set_ydata([self.percent_givenup, self.percent_givenup])
            self.lines[3].set_xdata([0, 1])
            ax.draw_artist(self.lines[3])

            self.lines[6].set_ydata([self.percent_success, self.percent_success])
            self.lines[6].set_xdata([0, 1])
            ax.draw_artist(self.lines[6])

        else:
            ax.set_xlim(0, len(self.hist_narrow_STR)-1)

            for k in range(0,9):
                self.lines[k].set_xdata(range(len(self.hist_narrow_STR)))

            self.lines[0].set_ydata([self.percent_start, self.percent_start])
            self.lines[0].set_xdata([0, len(self.hist_narrow_STR)])
            self.lines[1].set_ydata(self.hist_narrow_STR)
            self.lines[2].set_ydata(self.hist_wide_STR)

            ###

            self.lines[3].set_ydata([self.percent_givenup, self.percent_givenup])
            self.lines[3].set_xdata([0, len(self.hist_narrow_STR)])
            self.lines[4].set_ydata(self.hist_narrow_GUR)
            self.lines[5].set_ydata(self.hist_wide_GUR)

            ###

            self.lines[6].set_ydata([self.percent_success, self.percent_success])
            self.lines[6].set_xdata([0, len(self.hist_narrow_STR)])
            self.lines[7].set_ydata(self.hist_narrow_SUR)
            self.lines[8].set_ydata(self.hist_wide_SUR)

            for k in range(0,9):
                ax.draw_artist(self.lines[k])

        self.blit(ax.bbox)

        # need to redraw once to update y-ticks
        if self.redraw_yticks == True:
            self.draw()
            self.redraw_yticks = False


    def exit(self):
        print "exiting"
        self.parent.exit_app()

    def stop(self):
        print 'disconnecting'
        self.mod.SendSignal(rc.MT_EXIT_ACK)
        self.mod.DisconnectFromMMM()
예제 #11
0
class TrialStatus(object):
    def __init__(self, config_file, server):
        self.load_config(config_file)
        self.msg_nums = [eval('rc.MT_%s' % (x)) for x in self.msg_types]
        self.trial_sync = 0
        self.num_trials = 0
        self.num_trials_postcalib = 0
        self.num_trial_started_postcalib = 0
        self.num_trial_successful_postcalib = 0
        self.num_trial_givenup_postcalib = 0
        self.success_window = []
        self.started_window = []
        self.givenup_window = []
        self.shadow_num_trial_started_postcalib = 0
        self.shadow_num_trial_successful_postcalib = 0
        self.shadow_num_trial_givenup_postcalib = 0
        self.shadow_success_window = []
        self.shadow_started_window = []
        self.shadow_givenup_window = []

        self.last_time = time()
        self.setup_dragonfly(server)

        #self.rewards_given = 0
        #self.prev_rewards_given = 0
        self.run()

    def load_config(self, config_file):
        self.config = SafeConfigParser()
        self.config.read(config_file)
        self.msg_types = [
            'END_TASK_STATE', 'SESSION_CONFIG', 'EM_DECODER_CONFIGURATION'
        ]  #'GIVE_REWARD'
        self.msg_types.sort()
        self.window_len = self.config.getint('general', 'window_len')
        self.task_state_codes = {}
        for k, v in self.config.items('task state codes'):
            self.task_state_codes[k] = int(v)

    def setup_dragonfly(self, server):
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(server)
        self.mod.Subscribe(MT_EXIT)
        self.mod.Subscribe(rc.MT_PING)
        for i in self.msg_types:
            self.mod.Subscribe(eval('rc.MT_%s' % (i)))
        self.mod.SendModuleReady()
        print "Connected to RTMA at", server

    def run(self):
        while True:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0.1)
            if rcv == 1:
                msg_type = msg.GetHeader().msg_type
                dest_mod_id = msg.GetHeader().dest_mod_id
                if msg_type == MT_EXIT:
                    if (dest_mod_id == 0) or (dest_mod_id
                                              == self.mod.GetModuleID()):
                        print 'Received MT_EXIT, disconnecting...'
                        self.mod.SendSignal(rc.MT_EXIT_ACK)
                        self.mod.DisconnectFromMMM()
                        break
                elif msg_type == rc.MT_PING:
                    respond_to_ping(self.mod, msg, 'TrialStatus')
                else:
                    self.process_message(msg)

            this_time = time()
            self.diff_time = this_time - self.last_time
            if self.diff_time > 1.:
                self.last_time = this_time
                self.write()

    def reset_counters(self):
        self.trial_sync = 0
        self.num_trials_postcalib = 0
        self.num_trial_started_postcalib = 0
        self.num_trial_givenup_postcalib = 0
        self.num_trial_successful_postcalib = 0
        self.shadow_num_trial_started_postcalib = 0
        self.shadow_num_trial_givenup_postcalib = 0
        self.shadow_num_trial_successful_postcalib = 0
        self.started_window = []
        self.givenup_window = []
        self.success_window = []
        self.shadow_started_window = []
        self.shadow_givenup_window = []
        self.shadow_success_window = []

    def process_message(self, in_msg):
        msg_type = in_msg.GetHeader().msg_type
        if not msg_type in self.msg_nums:
            return

        # SESSION_CONFIG => start of session
        if msg_type == rc.MT_SESSION_CONFIG:
            self.num_trials = 0
            self.reset_counters()

        # EM_DECODER_CONFIGURATION => end of an adaptation round
        elif msg_type == rc.MT_EM_DECODER_CONFIGURATION:
            self.reset_counters()

        # END_TASK_STATE => end of a task
        elif msg_type == rc.MT_END_TASK_STATE:
            mdf = rc.MDF_END_TASK_STATE()
            copy_from_msg(mdf, in_msg)

            # need to know:
            #    begin task state code
            #    final task state code
            #    intertrial state code

            if (mdf.id == 1):
                self.trial_sync = 1
                self.shadow_started_window.append(0)

            if (mdf.id == self.task_state_codes['begin']) & (mdf.outcome == 1):
                if self.trial_sync:
                    #print "*** trial started ***"
                    #self.rewards_given += 1
                    self.shadow_num_trial_started_postcalib += 1
                    self.shadow_success_window.append(0)
                    self.shadow_givenup_window.append(0)
                    self.shadow_started_window[-1] = 1

            if mdf.reason == "JV_IDLE_TIMEOUT":
                if self.trial_sync:
                    self.shadow_num_trial_givenup_postcalib += 1
                    self.shadow_givenup_window[-1] = 1

            if (mdf.id == self.task_state_codes['final']) & (mdf.outcome == 1):
                if self.trial_sync:
                    #print "*** trial complete and successful"
                    self.shadow_num_trial_successful_postcalib += 1
                    self.shadow_success_window[-1] = 1

            if (mdf.id == self.task_state_codes['intertrial']):
                if self.trial_sync:
                    # do end-of-trial stuff here
                    self.num_trials += 1
                    self.num_trials_postcalib += 1
                    self.num_trial_started_postcalib = self.shadow_num_trial_started_postcalib
                    self.num_trial_successful_postcalib = self.shadow_num_trial_successful_postcalib
                    self.num_trial_givenup_postcalib = self.shadow_num_trial_givenup_postcalib

                    if len(self.shadow_success_window) > self.window_len:
                        self.shadow_success_window.pop(0)

                    if len(self.shadow_givenup_window) > self.window_len:
                        self.shadow_givenup_window.pop(0)

                    if len(self.shadow_started_window) > self.window_len:
                        self.shadow_started_window.pop(0)

                    self.success_window = copy.deepcopy(
                        self.shadow_success_window)
                    self.started_window = copy.deepcopy(
                        self.shadow_started_window)
                    self.givenup_window = copy.deepcopy(
                        self.shadow_givenup_window)

    def write(self):
        percent_start = percent_success = percent_givenup = 0
        percent_success_window = num_success_window = 0
        percent_started_window = num_started_window = 0
        percent_givenup_window = num_givenup_window = 0

        if self.num_trials_postcalib > 0:
            percent_start = 100 * self.num_trial_started_postcalib / self.num_trials_postcalib
            percent_givenup = 100 * self.num_trial_givenup_postcalib / self.num_trials_postcalib
            percent_success = 100 * self.num_trial_successful_postcalib / self.num_trials_postcalib

        if len(self.success_window) > 0:
            num_success_window = np.sum(self.success_window)
            percent_success_window = 100 * num_success_window / len(
                self.success_window)

        if len(self.started_window) > 0:
            num_started_window = np.sum(self.started_window)
            percent_started_window = 100 * num_started_window / len(
                self.started_window)

        if len(self.givenup_window) > 0:
            num_givenup_window = np.sum(self.givenup_window)
            percent_givenup_window = 100 * num_givenup_window / len(
                self.givenup_window)

        print "All trials : %d\n" % (self.num_trials_postcalib)
        print "Started trials  : %d (%0.0f%%)" % (
            self.num_trial_started_postcalib, percent_start)
        print "Given-up trials : %d (%0.0f%%)" % (
            self.num_trial_givenup_postcalib, percent_givenup)
        print "Success trials  : %d (%0.0f%%)\n" % (
            self.num_trial_successful_postcalib, percent_success)
        print "Started trials out of last %d\t\t: %d (%0.0f%%)" % (len(
            self.started_window), num_started_window, percent_started_window)
        print "Given-up trials out of last started %d\t: %d (%0.0f%%)" % (len(
            self.givenup_window), num_givenup_window, percent_givenup_window)
        print "Success trials out of last started %d\t: %d (%0.0f%%)" % (len(
            self.success_window), num_success_window, percent_success_window)
        print ""
class MplCanvas(FigureCanvas):
    subscriptions = [
        rc.MT_PING, MT_EXIT, rc.MT_TASK_STATE_CONFIG, rc.MT_FORCE_SENSOR_DATA,
        rc.MT_FORCE_FEEDBACK, rc.MT_END_TASK_STATE
    ]

    def __init__(self, parent=None, width=8, height=10, dpi=80):
        self.parent = parent
        self.paused = False
        self.LiveData = None
        self.fdbk_actual_pos = None
        self.tsc_mdf = None
        self.ets_mdf = None
        self.redraw_yticks = True

        self.figure = Figure(figsize=(width, height),
                             dpi=dpi,
                             facecolor='#bbbbbb')
        FigureCanvas.__init__(self, self.figure)

        self.setParent(parent)

        FigureCanvas.setSizePolicy(self, QtGui.QSizePolicy.Expanding,
                                   QtGui.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

    def init_judge_display(self):
        N = self.config['config']['number_of_data_points']

        self.nAccumPoints = 0
        self.nScaleResetPoints = N * 2

        self.LiveData = {
            'ActualPos': {},
            'ThreshUpper': {},
            'ThreshLower': {},
            'JudgingMethod': {},
            'JudgingPolarity': {},
            'max_scale': {},
            'min_scale': {}
        }

        allDims = 6  # max possible number of dims

        for d in range(allDims):
            self.LiveData['ActualPos'][d] = np.zeros(N)
            self.LiveData['ThreshUpper'][d] = nan_array(N)
            self.LiveData['ThreshLower'][d] = nan_array(N)
            self.LiveData['JudgingMethod'][d] = nan_array(N)
            self.LiveData['JudgingPolarity'][d] = nan_array(N)
            self.LiveData['max_scale'][d] = np.finfo(float).eps
            self.LiveData['min_scale'][d] = -np.finfo(float).eps  #

        self.LiveData['TaskStateNo'] = np.zeros(N)
        self.LiveData['TaskStateVerdict'] = np.ones(N)

    def run(self, config_file, server):
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(server)
        for sub in self.subscriptions:
            self.mod.Subscribe(sub)
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at", server

        self.config_file = config_file
        self.load_config()

        self.init_judge_display()
        self.init_plot()
        self.init_legend()

        self.timer = QtCore.QTimer(self)
        QtCore.QObject.connect(self.timer, QtCore.SIGNAL("timeout()"),
                               self.timer_event)
        self.timer.start(10)

    def update_scales(self, d):
        min_data = np.nanmin(self.LiveData['ActualPos'][d])
        if min_data < self.LiveData['min_scale'][d]:
            self.LiveData['min_scale'][d] = min_data * 1.05
            # add 5%
            if self.auto_scale[d]:
                self.redraw_yticks = True
                #print 'update y lim - (', d, ')'

        max_data = np.nanmax(self.LiveData['ActualPos'][d])
        if max_data > self.LiveData['max_scale'][d]:
            self.LiveData['max_scale'][d] = max_data * 1.05
            # add 5%
            if self.auto_scale[d]:
                self.redraw_yticks = True
                #print 'update y lim + (', d, ')'

    def update_judging_data(self):

        #self.nAccumPoints += 1
        #if self.nAccumPoints == self.nScaleResetPoints:
        #    print 'global y lim update'
        #    self.nAccumPoints = 0
        #    for d in range(self.nDims):
        #        if self.auto_scale[d]:
        #            self.LiveData['max_scale'][d] = np.finfo(float).eps
        #            self.LiveData['min_scale'][d] = -np.finfo(float).eps

        # this loop is so we can update plot data even if we haven't received any TASK_STATE_CONFIG messages
        for i in range(self.nDims):
            d = self.dims[i] - 1
            actual_pos = self.fdbk_actual_pos[d]
            self.LiveData['ActualPos'][d] = self.add_to_windowed_array(
                self.LiveData['ActualPos'][d], actual_pos)
            self.update_scales(d)

        if self.tsc_mdf is None:
            return

        sep_threshold = np.array(self.tsc_mdf.sep_threshold, dtype=float)
        dofs_to_judge = np.where(~np.isnan(sep_threshold) == True)[0]

        #        for i in range(self.nDims):
        #            d = self.dims[i] - 1
        #
        #            threshU = np.NAN
        #            threshL = np.NAN
        #            method = np.NAN
        #            polarity = np.NAN
        #
        #            if np.where(dofs_to_judge == d)[0].size > 0:
        #                thresh = self.tsc_mdf.sep_threshold[d]
        #                method = 2 #self.tsc_mdf.sep_threshold_judging_method[d]
        #                polarity = self.tsc_mdf.sep_threshold_judging_polarity[d]
        #
        #                # invert polarities (because we plot "keep out" zones)
        #                if (polarity > 0) and (self.tsc_mdf.timed_out_conseq == 0):
        #                    polarity = ~polarity & 3;
        #
        #                ## judging_method: 1=distance (default), 2=absolute
        #                #if method == 1:  # dist
        #                #    target = self.tsc_mdf.target[d]
        #                #    threshU = target + thresh
        #                #    threshL = target - thresh
        #                #else:            # abs
        #                threshU = thresh
        #
        #            # insert new data to plotting arrays
        #            self.LiveData['ThreshUpper'][d] = self.add_to_windowed_array(self.LiveData['ThreshUpper'][d], threshU)
        #            self.LiveData['ThreshLower'][d] = self.add_to_windowed_array(self.LiveData['ThreshLower'][d], threshL)
        #            self.LiveData['JudgingMethod'][d] = self.add_to_windowed_array(self.LiveData['JudgingMethod'][d], method)
        #            self.LiveData['JudgingPolarity'][d] = self.add_to_windowed_array(self.LiveData['JudgingPolarity'][d], polarity)

        self.LiveData['TaskStateNo'] = self.add_to_windowed_array(
            self.LiveData['TaskStateNo'], self.tsc_mdf.id)

        if self.ets_mdf is not None:
            self.LiveData['TaskStateVerdict'][-2] = self.ets_mdf.outcome
            self.LiveData['TaskStateVerdict'][-1] = self.ets_mdf.outcome
            self.LiveData['TaskStateVerdict'] = self.add_to_windowed_array(
                self.LiveData['TaskStateVerdict'], self.ets_mdf.outcome)
            self.ets_mdf = None
        else:
            self.LiveData['TaskStateVerdict'] = self.add_to_windowed_array(
                self.LiveData['TaskStateVerdict'], 1)

    def add_to_windowed_array(self, arr, data):
        arr = np.append(arr, data)
        arr = np.delete(arr, 0)
        return arr

    def load_config(self):
        self.config = ConfigObj(self.config_file, unrepr=True)

    def reload_config(self):
        self.load_config()
        for ax in self.figure.axes:
            self.figure.delaxes(ax)
        self.figure.clear()
        self.draw()
        self.init_plot(True)
        self.init_legend()
        self.redraw_yticks = True

    def timer_event(self):
        done = False
        while not done:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0)
            if rcv == 1:
                msg_type = msg.GetHeader().msg_type

                if msg_type == rc.MT_TASK_STATE_CONFIG:
                    self.tsc_mdf = rc.MDF_TASK_STATE_CONFIG()
                    copy_from_msg(self.tsc_mdf, msg)

                elif msg_type == rc.MT_FORCE_FEEDBACK:
                    mdf = rc.MDF_FORCE_FEEDBACK()
                    copy_from_msg(mdf, msg)

                    #self.fdbk_actual_pos = []
                    self.fdbk_actual_pos = [mdf.x, mdf.y, mdf.z, 0.0, 0.0, 0.0]

                    self.update_judging_data()

                elif msg_type == rc.MT_FORCE_SENSOR_DATA:
                    mdf = rc.MDF_FORCE_SENSOR_DATA()
                    copy_from_msg(mdf, msg)

                    self.fdbk_actual_pos = []
                    self.fdbk_actual_pos.extend(mdf.data)

                    self.update_judging_data()

                elif msg_type == rc.MT_END_TASK_STATE:
                    self.ets_mdf = rc.MDF_END_TASK_STATE()
                    copy_from_msg(self.ets_mdf, msg)

                elif msg_type == rc.MT_PING:
                    respond_to_ping(self.mod, msg, 'SimpleDisplay')

                elif msg_type == MT_EXIT:
                    self.exit()
                    done = True

            else:
                done = True

        self.update_plot()

    def init_plot(self, clear=False):
        self.figure.subplots_adjust(bottom=.05,
                                    right=.98,
                                    left=.08,
                                    top=.98,
                                    hspace=0.07)

        active_dims = 0
        if 'active_dims' in self.config['config']:
            active_dims = self.config['config']['active_dims']

        axis_labels = []
        self.dims = []
        if active_dims:
            active_labels = ['x', 'y', 'z', 'rx', 'ry', 'rz']
            for f in active_dims:
                if (f > 0) and (f <= 6):
                    self.dims.append(f)
                    axis_labels.extend(
                        ['#%d (%s)' % (f, active_labels[f - 1])])
                else:
                    print "Warning: invalid dim specified: %d, skipping.." % f

        self.nDims = len(active_dims)
        self.xN = self.config['config']['number_of_data_points']
        self.bg = self.config['marked_task_states'].keys()

        self.max_scale = self.config['config']['max_scale']
        self.min_scale = self.config['config']['min_scale']

        if 'auto_scale' in self.config['config']:
            self.auto_scale = self.config['config']['auto_scale']
        else:
            self.auto_scale = [0] * self.nDims

        self.ax = []
        self.old_size = []
        self.ax_bkg = []
        self.pos = []
        self.zones = {}
        self.zone_idx = []

        for d in range(self.nDims):
            ax = self.figure.add_subplot(self.nDims, 1, d + 1)
            self.reset_axis(ax, axis_labels[d])
            self.draw()

            bbox_width = ax.bbox.width
            bbox_height = ax.bbox.height
            if clear == True:
                # force to redraw
                bbox_width = 0
                bbox_height = 0

            self.old_size.append((bbox_width, bbox_height))
            self.ax_bkg.append(self.copy_from_bbox(ax.bbox))

            line, = ax.plot([], [], 'k-', lw=1.0, aa=None, animated=True)
            line.set_xdata(range(self.xN))
            line.set_ydata([0.0] * self.xN)
            self.pos.append(line)
            self.draw()

            self.zones[d] = []
            self.zone_idx.append(0)
            for z in range(60):
                patch = ax.add_patch(
                    Polygon(
                        [[0, 1e-12], [1e-12, 0], [1e-12, 1e-12], [0, 1e-12]],
                        fc='none',
                        ec='none',
                        fill=True,
                        closed=True,
                        aa=None,
                        animated=True))
                self.zones[d].append(patch)
                self.draw()

            self.ax.append(ax)

    def reset_axis(self, ax, label):
        ax.grid(True)
        ax.set_xlim(0, self.xN - 1)
        ax.set_autoscale_on(False)
        ax.set_ylabel(label, fontsize='small')
        ax.get_xaxis().set_ticks([])
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.02f'))
        for tick in ax.get_yticklabels():
            tick.set_fontsize(9)

    def init_legend(self):
        legnd = []

        line = matplotlib.lines.Line2D([0, 0], [0, 0], color='k')
        legnd.append(line)

        for d in range(len(self.bg)):
            b_color = self.config['marked_task_states'][self.bg[d]]['color']
            patch = Polygon([[0, 0], [0, 0], [0, 0], [0, 0]],
                            fc=b_color,
                            ec='none',
                            fill=True,
                            closed=True,
                            alpha=0.65)
            legnd.append(patch)

        self.figure.legend(legnd, ['Position'] + self.bg,
                           loc='lower center',
                           frameon=False,
                           ncol=20,
                           prop={'size': '11'},
                           columnspacing=.5)
        self.draw()

    def plot_bg_mask(self, ax, idx, x, mask, ylim, fc, ec, hatch, alpha):
        # Find starts and ends of contiguous regions of true values in mask because
        # we want just one patch object per contiguous region
        _mask = np.asarray(np.insert(mask, 0, 0), dtype=int)
        begin_indices = np.where(np.diff(_mask) == 1)[0]

        _mask = np.asarray(np.append(mask, 0), dtype=int)
        end_indices = np.where(np.diff(_mask) == -1)[0]

        # Get DeltaX
        dx = np.mean(np.diff(x))

        # Get YLim if it was not given
        if len(ylim) == 0:
            ylim = ax.get_ylim()

        z = self.zones[idx]
        a = self.zone_idx[idx]

        for i in range(len(begin_indices)):
            b = begin_indices[i]
            e = end_indices[i]
            xb = x[b] - dx / 2
            xe = x[e] + dx / 2

            patch = z[a]
            patch.set_xy([[xb, ylim[0]], [xe, ylim[0]], [xe, ylim[1]],
                          [xb, ylim[1]]])
            patch.set_edgecolor(ec)
            patch.set_facecolor(fc)
            patch.set_hatch(hatch)
            patch.set_alpha(alpha)

            ax.draw_artist(patch)
            a = a + 1

        self.zone_idx[idx] = a

    def update_plot(self):
        if self.paused == False:
            LiveData = self.LiveData
        else:
            LiveData = self.PausedData

        for i in range(self.nDims):
            ax = self.ax[i]
            d = self.dims[i] - 1

            current_size = ax.bbox.width, ax.bbox.height
            if self.old_size[i] != current_size:
                self.old_size[i] = current_size
                self.draw()
                self.ax_bkg[i] = self.copy_from_bbox(ax.bbox)

            self.restore_region(self.ax_bkg[i])

            self.zone_idx[i] = 0

            min_scale = self.min_scale[i]
            max_scale = self.max_scale[i]

            if self.auto_scale[i]:
                min_scale = LiveData['min_scale'][d]
                max_scale = LiveData['max_scale'][d]

            ax.set_ylim(min_scale, max_scale)
            yLimG = ax.get_ylim()

            for b in range(len(self.bg)):
                b_id = self.config['marked_task_states'][self.bg[b]]['id']
                b_color = self.config['marked_task_states'][
                    self.bg[b]]['color']

                mask = np.where(LiveData['TaskStateNo'] == b_id, 1, 0)
                if np.sum(mask) > 0:
                    self.plot_bg_mask(ax, i, range(self.xN), mask, [], b_color,
                                      'none', None, 0.65)
                else:
                    # always draw patch for all colors so that they will always show up in the legend
                    z = self.zones[i]
                    patch = z[self.zone_idx[i]]
                    patch.set_xy([[0, 0], [0, 0], [0, 0], [0, 0]])
                    ax.draw_artist(patch)
                    self.zone_idx[i] = self.zone_idx[i] + 1

            # threshold_judging_method: 1=distance, 2=absolute
            # threshold_judging_polarity: 1 = <, 2 = >
            methods = ~np.isnan(LiveData['JudgingMethod'][d])
            if np.sum(methods) > 0:
                methods = np.unique(LiveData['JudgingMethod'][d][methods])
                for m in range(len(methods)):
                    method = methods[m]
                    met_mask = np.where(LiveData['JudgingMethod'][d] == method,
                                        True, False)
                    polaritys = np.unique(
                        LiveData['JudgingPolarity'][d][met_mask])

                    for p in range(len(polaritys)):
                        polarity = polaritys[p]
                        pol_mask = np.where(
                            LiveData['JudgingPolarity'][d] == polarity, True,
                            False)
                        mask = met_mask & pol_mask

                        yLimUs = np.unique(LiveData['ThreshUpper'][d][mask])

                        for b in range(len(yLimUs)):
                            yLimU = yLimUs[b]
                            submask = np.where(
                                LiveData['ThreshUpper'][d] == yLimU, True,
                                False) & mask

                            if method == 1:  # dist
                                yLimLs = np.unique(
                                    LiveData['ThreshLower'][d][submask])

                                for k in range(len(yLimLs)):
                                    yLimL = yLimLs[k]
                                    submask2 = np.where(
                                        LiveData['ThreshLower'][d] == yLimL,
                                        True, False) & submask

                                    if polarity == 1:  # <
                                        self.plot_bg_mask(
                                            ax, i, range(self.xN), submask2,
                                            [yLimL, yLimU], 'none', 'black',
                                            '//', 1)
                                    else:
                                        self.plot_bg_mask(
                                            ax, i, range(self.xN), submask2,
                                            [yLimG[0], yLimL], 'none', 'black',
                                            '//', 1)
                                        self.plot_bg_mask(
                                            ax, i, range(self.xN), submask2,
                                            [yLimU, yLimG[1]], 'none', 'black',
                                            '//', 1)

                            else:  # abs
                                if polarity == 1:  # <
                                    self.plot_bg_mask(ax, i, range(self.xN),
                                                      submask,
                                                      [yLimG[0], yLimU],
                                                      'none', 'black', '//', 1)
                                else:
                                    self.plot_bg_mask(ax, i, range(self.xN),
                                                      submask,
                                                      [yLimU, yLimG[1]],
                                                      'none', 'black', '//', 1)

            fail_mask = np.where(LiveData['TaskStateVerdict'] == 0, True,
                                 False)
            self.plot_bg_mask(ax, i, range(self.xN), fail_mask, [], 'red',
                              'none', None, 0.65)

            self.pos[i].set_ydata(LiveData['ActualPos'][d])
            ax.draw_artist(self.pos[i])

            self.blit(ax.bbox)

            if self.zone_idx[i] > 60:
                print "ERROR: too many zones! Increase number of preallocated patches"

        # need to redraw once to update y-ticks
        if self.redraw_yticks == True:
            self.draw()
            self.redraw_yticks = False

    def pause(self, pause_state):
        self.paused = pause_state
        self.PausedData = copy.deepcopy(self.LiveData)

    def exit(self):
        print "exiting"
        self.parent.exit_app()

    def stop(self):
        print 'disconnecting'
        self.timer.stop()
        self.mod.SendSignal(rc.MT_EXIT_ACK)
        self.mod.DisconnectFromMMM()
예제 #13
0
class ButtonDetector(object):
    debug = True
    bounce_start = np.zeros([ncontrollers, rc.MAX_INPUT_DOFS]) - 1
    was_pressed = np.zeros([ncontrollers, rc.MAX_INPUT_DOFS], dtype=bool)

    def __init__(self, config_file, server):
        self.load_config(config_file)
        self.get_inputs()
        self.setup_dragonfly(server)
        self.run()

    def load_config(self, config_file):
        self.config = SafeConfigParser()
        self.config.read(config_file)

    def setup_dragonfly(self, server):
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(server)
        self.mod.Subscribe(MT_EXIT)
        self.mod.Subscribe(rc.MT_PING)
        self.mod.Subscribe(rc.MT_INPUT_DOF_DATA)
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at", server

    def run(self):
        while True:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0.1)
            if rcv == 1:
                msg_type = msg.GetHeader().msg_type
                dest_mod_id = msg.GetHeader().dest_mod_id
                if  msg_type == MT_EXIT:
                    if (dest_mod_id == 0) or (dest_mod_id == self.mod.GetModuleID()):
                        print 'Received MT_EXIT, disconnecting...'
                        self.mod.SendSignal(rc.MT_EXIT_ACK)
                        self.mod.DisconnectFromMMM()
                        break;
                elif msg_type == rc.MT_PING:
                    respond_to_ping(self.mod, msg, 'ButtonDetector')
                else:
                    self.process_message(msg)

    def get_inputs(self):
        sections = self.config.sections()
        self.inputs = {}
        for sec in sections:
            if 'input ' in sec:
                tag = self.config.get(sec, 'tag')
                self.inputs[tag] = {}

    def process_message(self, msg):
        msg_type = msg.GetHeader().msg_type
        if msg_type == rc.MT_INPUT_DOF_DATA:
            mdf = rc.MDF_INPUT_DOF_DATA()
            copy_from_msg(mdf, msg)
            tag = mdf.tag
            if tag in self.inputs.keys():
                dof_vals = np.asarray(mdf.dof_vals[:], dtype=float)
                cid = int(mdf.tag[-1])
                pressed = ~self.was_pressed[cid] & (dof_vals > btn_threshold)
                started = self.bounce_start[cid] > 0

                # start timers on previously unstarted counters
                self.bounce_start[cid, pressed & ~started] = time.time()

                dt = time.time() - self.bounce_start[cid]
                held = dt > bounce_threshold
                valid_held = pressed & held

                for vh in np.flatnonzero(valid_held):
                    if vh in name_lookup.keys():
                        self.was_pressed[cid, vh] = True
                        self.send_btn_press(name_lookup[vh], cid)

                released = self.was_pressed[cid] & (dof_vals < btn_threshold)
                valid_released = released & held & ~valid_held

                for vr in np.flatnonzero(valid_released):
                    if vr in name_lookup.keys():
                        self.was_pressed[cid, vr] = False
                        self.send_btn_release(name_lookup[vr], cid)
                        self.bounce_start[cid, vr] = -1

    def send_btn_press(self, btn, controller_id):
        print "controller_id %d sending button press %s" % (controller_id, btn)
        btn_map = {'l1' : rc.PS3_B_L1,
                   'l2' : rc.PS3_B_L2,
                   'r1' : rc.PS3_B_R1,
                   'x'  : rc.PS3_B_X,
                   'sq' : rc.PS3_B_SQUARE,
                   'crc': rc.PS3_B_CIRCLE,
                   'trg': rc.PS3_B_TRIANGLE}
        mdf_out = rc.MDF_PS3_BUTTON_PRESS()
        mdf_out.whichButton = btn_map[btn]
        mdf_out.controllerId = controller_id
        # make outgoing message data
        msg_out = CMessage(rc.MT_PS3_BUTTON_PRESS)
        copy_to_msg(mdf_out, msg_out)
        self.mod.SendMessage(msg_out)

    def send_btn_release(self, btn, controller_id):
        print "controller_id %d sending button release %s" % (controller_id, btn)
        btn_map = {'l1' : rc.PS3_B_L1,
                   'l2' : rc.PS3_B_L2,
                   'r1' : rc.PS3_B_R1,
                   'x'  : rc.PS3_B_X,
                   'sq' : rc.PS3_B_SQUARE,
                   'crc': rc.PS3_B_CIRCLE,
                   'trg': rc.PS3_B_TRIANGLE}
        mdf_out = rc.MDF_PS3_BUTTON_RELEASE()
        mdf_out.whichButton = btn_map[btn]
        mdf_out.controllerId = controller_id
        msg_out = CMessage(rc.MT_PS3_BUTTON_RELEASE)
        copy_to_msg(mdf_out, msg_out)
        self.mod.SendMessage(msg_out)
예제 #14
0
class SampleGenerator(object):
    def __init__(self, config_file, server):
        self.serial_no = 2
        self.freq = 50  # Hz
        self.load_config(config_file)
        self.setup_dragonfly(server)
        self.run()

    def load_config(self, config_file):
        self.config = SafeConfigParser()
        self.config.read(config_file)
        triggers = self.config.get('main', 'triggers').split()
        self.triggers = [eval('rc.MT_%s' % (x)) for x in triggers]
        if not triggers:
            freq = self.config.get('main', 'frequency')
            if freq != '':
                self.freq = self.config.getfloat('main', 'frequency')
            print "Freq: %.2f" % (self.freq)

    def setup_dragonfly(self, server):
        self.mod = Dragonfly_Module(rc.MID_SAMPLE_GENERATOR, 0)
        self.mod.ConnectToMMM(server)
        self.mod.Subscribe(MT_EXIT)
        self.mod.Subscribe(rc.MT_PING)
        self.mod.Subscribe(rc.MT_SPM_SPIKECOUNT)
        for trigger in self.triggers:
            self.mod.Subscribe(trigger)
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at", server

        if platform.system() == "Windows":
            # On Windows, the best timer is time.clock()
            self.default_timer = time.clock
        else:
            # On most other platforms the best timer is time.time()
            self.default_timer = time.time

    def run(self):
        self.delta_time_calc = self.default_timer()  #time.time()
        while True:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0.001)
            if rcv == 1:
                hdr = msg.GetHeader()
                msg_type = hdr.msg_type
                dest_mod_id = hdr.dest_mod_id
                if msg_type == MT_EXIT:
                    if (dest_mod_id == 0) or (dest_mod_id
                                              == self.mod.GetModuleID()):
                        print 'Received MT_EXIT, disconnecting...'
                        self.mod.SendSignal(rc.MT_EXIT_ACK)
                        self.mod.DisconnectFromMMM()
                        break
                elif msg_type == rc.MT_PING:
                    respond_to_ping(self.mod, msg, 'SampleGenerator')
                elif (msg_type == rc.MT_SPM_SPIKECOUNT):
                    msg_src_mod_id = hdr.src_mod_id
                    if msg_src_mod_id == rc.MID_SPM_MOD:
                        print "\n\n ** Detected SPM_SPIKECOUNT messages coming from SPM_MOD! Quitting..\n\n"
                        sys.exit(0)
                else:
                    if len(self.triggers) > 0:
                        self.process_msg(msg)
            else:
                # if no triggers...
                if len(self.triggers) == 0:
                    period = (1. / self.freq)
                    time_now = self.default_timer()
                    delta_time = period - (time_now - self.delta_time_calc)
                    #print "%f %f %f\n\n" % (time_now, self.delta_time_calc, delta_time)
                    if delta_time > 0:
                        time.sleep(delta_time)
                    self.delta_time_calc = self.delta_time_calc + period
                    self.send_sample_generated()

    def process_msg(self, msg):
        msg_type = msg.GetHeader().msg_type
        if msg_type in self.triggers:
            time_now = self.default_timer()  #time.time()
            delta_time = time_now - self.delta_time_calc
            self.delta_time_calc = time_now
            self.send_sample_generated()

    def send_sample_generated(self):
        sg = rc.MDF_SAMPLE_GENERATED()
        self.serial_no += 1
        sg.sample_header.SerialNo = self.serial_no
        sg.sample_header.Flags = 0
        sg.sample_header.DeltaTime = (1. / self.freq)
        sg.source_timestamp = self.default_timer()  #time.time()
        sg_msg = CMessage(rc.MT_SAMPLE_GENERATED)
        copy_to_msg(sg, sg_msg)
        self.mod.SendMessage(sg_msg)
        sys.stdout.write('|')
        sys.stdout.flush()
예제 #15
0
class MessageWatcher(object):
    # msg_types = ['GROBOT_RAW_FEEDBACK',
    #              'GROBOT_FEEDBACK',
    #              'SAMPLE_GENERATED',
    #              'SPM_SPIKECOUNT',
    #              'EM_MOVEMENT_COMMAND',
    #              'COMPOSITE_MOVEMENT_COMMAND'
    #              ]

    def __init__(self, config_file):
        self.load_config(config_file)
        self.msg_nums = [eval('rc.MT_%s' % (x)) for x in self.msg_types]
        self.count = np.zeros((len(self.msg_nums)), dtype=int)
        self.last_time = time()
        self.setup_Dragonfly()
        self.run()

    def load_config(self, config_file):
        self.config = SafeConfigParser()
        self.config.read(config_file)
        self.msg_types = [x.upper() for x in self.config.options('messages')]
        self.msg_types.sort()

    def setup_Dragonfly(self):
        server = self.config.get('Dragonfly', 'server')
        self.mod = Dragonfly_Module(0, 0)
        self.mod.ConnectToMMM(server)
        for i in self.msg_types:
            self.mod.Subscribe(eval('rc.MT_%s' % (i)))
        self.mod.SendModuleReady()
        print "Connected to Dragonfly at", server

    def run(self):
        while True:
            msg = CMessage()
            rcv = self.mod.ReadMessage(msg, 0.1)
            if rcv == 1:
                self.process_message(msg)

            this_time = time()
            self.diff_time = this_time - self.last_time
            if self.diff_time > 1.:
                self.last_time = this_time
                self.write()
                self.count[:] = 0

    def process_message(self, in_msg):
        msg_type = in_msg.GetHeader().msg_type
        if not msg_type in self.msg_nums:
            return
        msg_idx = self.msg_nums.index(msg_type)
        self.count[msg_idx] += 1

    def write(self):
        for msg_type, c in zip(self.msg_types, self.count):
            rate = c / self.diff_time
            print "%40s %5.2f Hz" % (msg_type, rate)
            if (('GROBOT_RAW_FEEDBACK' in msg_type) and (rate < 48.0)):
                print "Raw feedback rate is too low!"
                print "Raw feedback rate is too low!"
                print "Raw feedback rate is too low!"
                print "Raw feedback rate is too low!"
        print "window was %0.3f seconds\n" % (self.diff_time)
        print ""