예제 #1
0
def run(cfg, state=mp.Value('i', 1), queue=None, logger=logger):
    '''
    Main function used to run the offline protocol.

    Parameters
    ----------
    cfg : python.module
        The loaded config module from the corresponding config_offline.py
    queue : mp.Queue
        If not None, redirect sys.stdout to GUI terminal
    logger : logging.logger
        The logger to use
    '''
    # Use to redirect sys.stdout to GUI terminal if GUI usage
    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass

    # Protocol start if equals to 1
    if not state.value:
        sys.exit()

    # Load the mapping from int to string for triggers events
    cfg.tdef = TriggerDef(cfg.TRIGGER_FILE)

    # Refresh rate
    refresh_delay = 1.0 / cfg.REFRESH_RATE

    # Trigger
    trigger = Trigger(lpttype=cfg.TRIGGER_DEVICE, state=state)
    if trigger.init(50) == False:
        logger.error(
            '\n** Error connecting to the trigger device. Use a mock trigger instead?'
        )
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = Trigger(lpttype='FAKE')
        trigger.init(50)

    # timers
    timer_refresh = Timer()

    trial = 1
    num_trials = cfg.TRIALS_NB

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        #-------------------------------------
        # ADD YOUR CODE HERE
        #-------------------------------------

    with state.get_lock():
        state.value = 0
예제 #2
0
def _log_decoding_helper(state, event_queue, amp_name=None, autostop=False):
    """
    Helper function to run StreamReceiver object in background

    Parameters
    ----------
    state : mp.Value
        The multiprocessing sharing variable
    event_queue : mp.Queue
        The queue used to share new events
    amp_name : str
        The stream name to connect to
    autostop : bool
        If True, automatically finish when no more data is received.
    """
    logger.info('Event acquisition subprocess started.')

    # wait for the start signal
    while state.value == 0:
        time.sleep(0.01)

    # acquire event values and returns event times and event values
    sr = StreamReceiver(bufsize=0, stream_name=amp_name)
    tm = Timer(autoreset=True)
    started = False
    while state.value == 1:
        chunk, ts_list = sr.acquire()
        if autostop:
            if started is True:
                if len(ts_list) == 0:
                    state.value = 0
                    break
            elif len(ts_list) > 0:
                started = True
        tm.sleep_atleast(0.001)
    logger.info('Event acquisition subprocess finishing up ...')

    buffers, times = sr.get_buffer()
    events = buffers[:, 0]  # first channel is the trigger channel
    event_index = np.where(events != 0)[0]
    event_times = times[event_index].reshape(-1).tolist()
    event_values = events[event_index].tolist()
    assert len(event_times) == len(event_values)
    event_queue.put((event_times, event_values))
예제 #3
0
def sample_decoding(decoder):
    """
    Decoding example

    Parameters
    ----------
    decoder : The decoder to use
    """
    def get_index_max(seq):
        if type(seq) == list:
            return max(range(len(seq)), key=seq.__getitem__)
        elif type(seq) == dict:
            return max(seq, key=seq.__getitem__)
        else:
            logger.error('Unsupported input %s' % type(seq))
            return None

    # load trigger definitions for labeling
    labels = decoder.get_label_names()
    tm_watchdog = Timer(autoreset=True)
    tm_cls = Timer()
    while True:
        praw = decoder.get_prob_unread()
        psmooth = decoder.get_prob_smooth()
        if praw is None:
            # watch dog
            if tm_cls.sec() > 5:
                logger.warning(
                    'No classification was done in the last 5 seconds. Are you receiving data streams?'
                )
                tm_cls.reset()
            tm_watchdog.sleep_atleast(0.001)
            continue

        txt = '[%8.1f msec]' % (tm_cls.msec())
        for i, label in enumerate(labels):
            txt += '   %s %.3f (raw %.3f)' % (label, psmooth[i], praw[i])
        maxi = get_index_max(psmooth)
        txt += '   %s' % labels[maxi]
        print(txt)
        tm_cls.reset()
예제 #4
0
def run(cfg, state=mp.Value('i', 1), queue=None):
    '''
    Main function used to run the online protocol.

    Parameters
    ----------
    cfg : python.module
        The loaded config module from the corresponding config_offline.py
    queue : mp.Queue
        If not None, redirect sys.stdout to GUI terminal
    logger : logging.logger
        The logger to use
    '''

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass

    #  Protocol runs if state equals to 1
    if not state.value:
        sys.exit(-1)

    # events and triggers
    cfg.tdef = TriggerDef(cfg.TRIGGER_FILE)

    # To send trigger events
    trigger = Trigger(cfg.TRIGGER_DEVICE, state)

    if trigger.init(50) == False:
        logger.error(
            'Cannot connect to trigger device. Use a mock trigger instead?')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = Trigger('FAKE', state)
        trigger.init(50)

    # Instance a stream receiver
    sr = StreamReceiver(bufsize=1,
                        winsize=0.5,
                        stream_name=None,
                        eeg_only=True)

    # Timer for acquisition rate, here 20 Hz
    tm = Timer(autoreset=True)

    # Refresh rate
    refresh_delay = 1.0 / cfg.REFRESH_RATE

    while True:

        # Acquire data from all the connected LSL streams by filling each associated buffers.
        sr.acquire()

        # Extract the latest window from the buffer of the chosen stream.
        window, tslist = sr.get_window(
        )  # window = [samples x channels], tslist = [samples]

        #-------------------------------------
        # ADD YOUR CODE HERE
        #-------------------------------------

        #  To run a trained BCI decoder, look at online_mi.py protocol

        tm.sleep_atleast(refresh_delay)

    with state.get_lock():
        state.value = 0

    logger.info('Finished.')
예제 #5
0
def run(cfg, state=mp.Value('i', 1), queue=None):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2: # 0: stop, 1:start, 2:wait
        pass
    #  Protocol start if equals to 1
    if not state.value:
        sys.exit()

    refresh_delay = 1.0 / cfg.REFRESH_RATE

    cfg.tdef = TriggerDef(cfg.TRIGGER_FILE)

    # visualizer
    keys = {'left':81, 'right':83, 'up':82, 'down':84, 'pgup':85, 'pgdn':86,
        'home':80, 'end':87, 'space':32, 'esc':27, ',':44, '.':46, 's':115, 'c':99,
        '[':91, ']':93, '1':49, '!':33, '2':50, '@':64, '3':51, '#':35}
    color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235),
        K=(0, 0, 0), w=(200, 200, 200))

    dir_sequence = []
    for x in range(cfg.TRIALS_EACH):
        dir_sequence.extend(cfg.DIRECTIONS)
    random.shuffle(dir_sequence)
    num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH

    event = 'start'
    trial = 1

    # Hardware trigger
    if cfg.TRIGGER_DEVICE is None:
        logger.warning('No trigger device set. Press Ctrl+C to stop or Enter to continue.')
        #input()
    trigger = Trigger(lpttype=cfg.TRIGGER_DEVICE, state=state)
    if trigger.init(50) == False:
        logger.error('\n** Error connecting to USB2LPT device. Use a mock trigger instead?')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = Trigger(lpttype='FAKE')
        trigger.init(50)

    # timers
    timer_trigger = Timer()
    timer_dir = Timer()
    timer_refresh = Timer()
    t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
    t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(-cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

    bar = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE)
    bar.fill()
    bar.glass_draw_cue()

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based)
        if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']:
            event = 'gap_s'
            bar.fill()
            timer_trigger.reset()
            trigger.signal(cfg.tdef.INIT)
        elif event == 'gap_s':
            if cfg.TRIAL_PAUSE:
                bar.put_text('Press any key')
                bar.update()
                key = cv2.waitKey()
                if key == keys['esc'] or not state.value:
                    break
                bar.fill()
            bar.put_text('Trial %d / %d' % (trial, num_trials))
            event = 'gap'
            timer_trigger.reset()
        elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']:
            event = 'cue'
            bar.fill()
            bar.draw_cue()
            trigger.signal(cfg.tdef.CUE)
            timer_trigger.reset()
        elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']:
            event = 'dir_r'
            dir = dir_sequence[trial - 1]
            if dir == 'L':  # left
                bar.move('L', 100, overlay=True)
                trigger.signal(cfg.tdef.LEFT_READY)
            elif dir == 'R':  # right
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.RIGHT_READY)
            elif dir == 'U':  # up
                bar.move('U', 100, overlay=True)
                trigger.signal(cfg.tdef.UP_READY)
            elif dir == 'D':  # down
                bar.move('D', 100, overlay=True)
                trigger.signal(cfg.tdef.DOWN_READY)
            elif dir == 'B':  # both hands
                bar.move('L', 100, overlay=True)
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.BOTH_READY)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
            timer_trigger.reset()
        elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready:
            bar.fill()
            bar.draw_cue()
            event = 'dir'
            timer_trigger.reset()
            timer_dir.reset()
            if dir == 'L':  # left
                trigger.signal(cfg.tdef.LEFT_GO)
            elif dir == 'R':  # right
                trigger.signal(cfg.tdef.RIGHT_GO)
            elif dir == 'U':  # up
                trigger.signal(cfg.tdef.UP_GO)
            elif dir == 'D':  # down
                trigger.signal(cfg.tdef.DOWN_GO)
            elif dir == 'B':  # both
                trigger.signal(cfg.tdef.BOTH_GO)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
        elif event == 'dir' and timer_trigger.sec() > t_dir:
            event = 'gap_s'
            bar.fill()
            trial += 1
            logger.info('trial ' + str(trial - 1) + ' done')
            trigger.signal(cfg.tdef.BLANK)
            timer_trigger.reset()
            t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
            t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(-cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

        # protocol
        if event == 'dir':
            dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1)
            if dir == 'L':  # L
                bar.move('L', dx, overlay=True)
            elif dir == 'R':  # R
                bar.move('R', dx, overlay=True)
            elif dir == 'U':  # U
                bar.move('U', dx, overlay=True)
            elif dir == 'D':  # D
                bar.move('D', dx, overlay=True)
            elif dir == 'B':  # Both
                bar.move('L', dx, overlay=True)
                bar.move('R', dx, overlay=True)

        # wait for start
        if event == 'start':
            bar.put_text('Waiting to start')

        bar.update()
        key = 0xFF & cv2.waitKey(1)
        if key == keys['esc'] or not state.value:
            break

    bar.finish()

    with state.get_lock():
        state.value = 0
예제 #6
0
def log_decoding(decoder,
                 logfile,
                 amp_name=None,
                 pklfile=True,
                 matfile=False,
                 autostop=False,
                 prob_smooth=False):
    """
    Decode online and write results with event timestamps

    Parameters
    ----------
    decoder : BCIDecoder or BCIDecoderDaemon class
        The decoder to use
    logfile : str
        The file path to contain the result in Python pickle format
    amp_name : str
        The stream name to connect to
    pklfile : bool
    If True, export the results to Python pickle format
    matfile : bool
    If True, export the results to .mat file
    autostop : bool
        If True, automatically finish when no more data is received.
    prob_smooth : bool
        If True, use smoothed probability values according to decoder's smoothing parameter.
    """

    import cv2
    import scipy

    # run event acquisition process in the background
    state = mp.Value('i', 1)
    event_queue = mp.Queue()
    proc = mp.Process(target=_log_decoding_helper,
                      args=[state, event_queue, amp_name, autostop])
    proc.start()
    logger.info('Spawned event acquisition process.')

    # init variables and choose decoding function
    labels = decoder.get_label_names()
    probs = []
    prob_times = []
    if prob_smooth:
        decode_fn = decoder.get_prob_smooth_unread
    else:
        decode_fn = decoder.get_prob_unread

    # simple controller UI
    cv2.namedWindow("Decoding", cv2.WINDOW_AUTOSIZE)
    cv2.moveWindow("Decoding", 1400, 50)
    img = np.zeros([100, 400, 3], np.uint8)
    cv2.putText(img, 'Press any key to start', (20, 60),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.imshow("Decoding", img)
    cv2.waitKeyEx()
    img *= 0
    cv2.putText(img, 'Press ESC to stop', (40, 60), cv2.FONT_HERSHEY_SIMPLEX,
                1, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.imshow("Decoding", img)

    key = 0
    started = False
    tm_watchdog = Timer(autoreset=True)
    tm_cls = Timer()
    while key != 27:
        prob, prob_time = decode_fn(True)
        t_lsl = pylsl.local_clock()
        key = cv2.waitKeyEx(1)
        if prob is None:
            # watch dog
            if tm_cls.sec() > 5:
                if autostop and started:
                    logger.info('No more streaming data. Finishing.')
                    break
                tm_cls.reset()
            tm_watchdog.sleep_atleast(0.001)
            continue
        probs.append(prob)
        prob_times.append(prob_time)
        txt = '[%.3f] ' % prob_time
        txt += ', '.join(['%s: %.2f' % (l, p) for l, p in zip(labels, prob)])
        txt += ' (%d ms, LSL Diff = %.3f)' % (tm_cls.msec(),
                                              (t_lsl - prob_time))
        logger.info(txt)
        if not started:
            started = True
        tm_cls.reset()

    # finish up processes
    cv2.destroyAllWindows()
    logger.info('Cleaning up event acquisition process.')
    state.value = 0
    decoder.stop()
    event_times, event_values = event_queue.get()
    proc.join()

    # save values
    if len(prob_times) == 0:
        logger.error('No decoding result. Please debug.')
        import pdb
        pdb.set_trace()
    t_start = prob_times[0]
    probs = np.vstack(probs)
    event_times = np.array(event_times)
    event_times = event_times[np.where(event_times >= t_start)[0]] - t_start
    prob_times = np.array(prob_times) - t_start
    event_values = np.array(event_values)
    data = dict(probs=probs,
                prob_times=prob_times,
                event_times=event_times,
                event_values=event_values,
                labels=labels)
    if pklfile:
        with open(logfile, 'wb') as f:
            pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
        logger.info('Saved to %s' % logfile)
    if matfile:
        pp = io.parse_path(logfile)
        pp = Path(logfile)
        matout = '%s/%s.mat' % (pp.parent, pp.stem)
        scipy.io.savemat(matout, data)
        logger.info('Saved to %s' % matout)
예제 #7
0
class Feedback:
    """
    Perform a classification with visual feedback
    """
    def __init__(self, cfg, viz, tdef, trigger, logfile=None):
        self.cfg = cfg
        self.tdef = tdef
        self.trigger = trigger
        self.viz = viz
        self.viz.fill()
        self.refresh_delay = 1.0 / self.cfg.REFRESH_RATE
        self.bar_step_left = self.cfg.BAR_STEP['left']
        self.bar_step_right = self.cfg.BAR_STEP['right']
        self.bar_step_up = self.cfg.BAR_STEP['up']
        self.bar_step_down = self.cfg.BAR_STEP['down']
        self.bar_step_both = self.cfg.BAR_STEP['both']

        if type(self.cfg.BAR_BIAS) is tuple:
            self.bar_bias = list(self.cfg.BAR_BIAS)
        else:
            self.bar_bias = self.cfg.BAR_BIAS

        # New decoder: already smoothed by the decoder so bias after.
        #self.alpha_old = self.cfg.PROB_ACC_ALPHA
        #self.alpha_new = 1.0 - self.cfg.PROB_ACC_ALPHA

        if hasattr(self.cfg,
                   'BAR_REACH_FINISH') and self.cfg.BAR_REACH_FINISH == True:
            self.premature_end = True
        else:
            self.premature_end = False

        self.tm_trigger = Timer()
        self.tm_display = Timer()
        self.tm_watchdog = Timer()
        if logfile is not None:
            self.logf = open(logfile, 'w')
        else:
            self.logf = None

        # STIMO only
        if self.cfg.WITH_STIMO is True:
            if self.cfg.STIMO_COMPORT is None:
                atens = [x for x in serial.tools.list_ports.grep('ATEN')]
                if len(atens) == 0:
                    raise RuntimeError('No ATEN device found. Stop.')
                try:
                    self.stimo_port = atens[0].device
                except AttributeError:  # depends on Python distribution
                    self.stimo_port = atens[0][0]
            else:
                self.stimo_port = self.cfg.STIMO_COMPORT
            self.ser = serial.Serial(self.stimo_port, self.cfg.STIMO_BAUDRATE)
            logger.info('STIMO serial port %s is_open = %s' %
                        (self.stimo_port, self.ser.is_open))

        # FES only
        if self.cfg.WITH_FES is True:
            self.stim = fes.Motionstim8()
            self.stim.OpenSerialPort(self.cfg.FES_COMPORT)
            self.stim.InitializeChannelListMode()
            logger.info('Opened FES serial port')

    def __del__(self):
        # STIMO only
        if self.cfg.WITH_STIMO is True:
            self.ser.close()
            logger.info('Closed STIMO serial port %s' % self.stimo_port)
        # FES only
        if self.cfg.WITH_FES is True:
            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
            self.stim.UpdateChannelSettings(stim_code)
            self.stim.CloseSerialPort()
            logger.info('Closed FES serial port')

    def classify(self,
                 decoder,
                 true_label,
                 title_text,
                 bar_dirs,
                 state='start',
                 prob_history=None):
        """
        Run a single trial
        """
        def list2string(vec, fmt, sep=' '):
            return sep.join(fmt % x for x in vec)

        true_label_index = bar_dirs.index(true_label)
        self.tm_trigger.reset()
        if self.bar_bias is not None:
            bias_idx = bar_dirs.index(self.bar_bias[0])

        if self.logf is not None:
            self.logf.write('True label: %s\n' % true_label)

        tm_classify = Timer(autoreset=True)
        self.stimo_timer = Timer()
        while True:
            self.tm_display.sleep_atleast(self.refresh_delay)
            self.tm_display.reset()
            if state == 'start' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['INIT']:
                state = 'gap_s'
                if self.cfg.TRIALS_PAUSE:
                    self.viz.put_text('Press any key')
                    self.viz.update()
                    key = cv2.waitKeyEx()
                    if key == KEYS['esc']:
                        return
                self.viz.fill()
                self.tm_trigger.reset()
                self.trigger.signal(self.tdef.INIT)

            elif state == 'gap_s':
                if self.cfg.TIMINGS['GAP'] > 0:
                    self.viz.put_text(title_text)
                state = 'gap'
                self.tm_trigger.reset()

            elif state == 'gap' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['GAP']:
                state = 'cue'
                self.viz.fill()
                self.viz.draw_cue()
                self.viz.glass_draw_cue()
                self.trigger.signal(self.tdef.CUE)
                self.tm_trigger.reset()

            elif state == 'cue' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['READY']:
                state = 'dir_r'

                if self.cfg.SHOW_CUE is True:
                    if self.cfg.FEEDBACK_TYPE == 'BAR':
                        self.viz.move(true_label,
                                      100,
                                      overlay=False,
                                      barcolor='G')
                    elif self.cfg.FEEDBACK_TYPE == 'BODY':
                        self.viz.put_text(DIRS[true_label], 'R')
                    if true_label == 'L':  # left
                        self.trigger.signal(self.tdef.LEFT_READY)
                    elif true_label == 'R':  # right
                        self.trigger.signal(self.tdef.RIGHT_READY)
                    elif true_label == 'U':  # up
                        self.trigger.signal(self.tdef.UP_READY)
                    elif true_label == 'D':  # down
                        self.trigger.signal(self.tdef.DOWN_READY)
                    elif true_label == 'B':  # both hands
                        self.trigger.signal(self.tdef.BOTH_READY)
                    else:
                        raise RuntimeError('Unknown direction %s' % true_label)
                self.tm_trigger.reset()
                '''
                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    self.viz.set_pc_feedback(False)
                self.viz.move(true_label, 100, overlay=False, barcolor='G')

                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    self.viz.set_pc_feedback(True)
                    if self.cfg.SHOW_CUE is True:
                        self.viz.put_text(dirs[true_label], 'R')

                if true_label == 'L':  # left
                    self.trigger.signal(self.tdef.LEFREADY)
                elif true_label == 'R':  # right
                    self.trigger.signal(self.tdef.RIGHT_READY)
                elif true_label == 'U':  # up
                    self.trigger.signal(self.tdef.UP_READY)
                elif true_label == 'D':  # down
                    self.trigger.signal(self.tdef.DOWN_READY)
                elif true_label == 'B':  # both hands
                    self.trigger.signal(self.tdef.BOTH_READY)
                else:
                    raise RuntimeError('Unknown direction %s' % true_label)
                self.tm_trigger.reset()
                '''
            elif state == 'dir_r' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['DIR_CUE']:
                self.viz.fill()
                self.viz.draw_cue()
                self.viz.glass_draw_cue()
                state = 'dir'

                # initialize bar scores
                bar_label = bar_dirs[0]
                bar_score = 0
                probs = [1.0 / len(bar_dirs)] * len(bar_dirs)
                self.viz.move(bar_label, bar_score, overlay=False)
                probs_acc = np.zeros(len(probs))

                if true_label == 'L':  # left
                    self.trigger.signal(self.tdef.LEFT_GO)
                elif true_label == 'R':  # right
                    self.trigger.signal(self.tdef.RIGHT_GO)
                elif true_label == 'U':  # up
                    self.trigger.signal(self.tdef.UP_GO)
                elif true_label == 'D':  # down
                    self.trigger.signal(self.tdef.DOWN_GO)
                elif true_label == 'B':  # both
                    self.trigger.signal(self.tdef.BOTH_GO)
                else:
                    raise RuntimeError('Unknown truedirection %s' % true_label)

                self.tm_watchdog.reset()
                self.tm_trigger.reset()

            elif state == 'dir':
                if self.tm_trigger.sec() > self.cfg.TIMINGS['CLASSIFY'] or (
                        self.premature_end and bar_score >= 100):
                    if not hasattr(
                            self.cfg,
                            'SHOW_RESULT') or self.cfg.SHOW_RESULT is True:
                        # show classfication result
                        if self.cfg.WITH_STIMO is True:
                            if self.cfg.STIMO_FULLGAIT_CYCLE is not None and bar_label == 'U':
                                res_color = 'G'
                            elif self.cfg.TRIALS_RETRY is False or bar_label == true_label:
                                res_color = 'G'
                            else:
                                res_color = 'Y'
                        else:
                            res_color = 'Y'
                        if self.cfg.FEEDBACK_TYPE == 'BODY':
                            self.viz.move(bar_label,
                                          bar_score,
                                          overlay=False,
                                          barcolor=res_color,
                                          caption=DIRS[bar_label],
                                          caption_color=res_color)
                        else:
                            self.viz.move(bar_label,
                                          100,
                                          overlay=False,
                                          barcolor=res_color)
                    else:
                        if self.cfg.FEEDBACK_TYPE == 'BODY':
                            self.viz.move(bar_label,
                                          bar_score,
                                          overlay=False,
                                          barcolor=res_color,
                                          caption='TRIAL END',
                                          caption_color=res_color)
                        else:
                            self.viz.move(bar_label,
                                          0,
                                          overlay=False,
                                          barcolor=res_color)
                    self.trigger.signal(self.tdef.FEEDBACK)

                    # STIMO
                    if self.cfg.WITH_STIMO is True and self.cfg.STIMO_CONTINUOUS is False:
                        if self.cfg.STIMO_FULLGAIT_CYCLE is not None:
                            if bar_label == 'U':
                                self.ser.write(
                                    self.cfg.STIMO_FULLGAIT_PATTERN[0])
                                logger.info('STIMO: Sent 1')
                                time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE)
                                self.ser.write(
                                    self.cfg.STIMO_FULLGAIT_PATTERN[1])
                                logger.info('STIMO: Sent 2')
                                time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE)
                        elif self.cfg.TRIALS_RETRY is False or bar_label == true_label:
                            if bar_label == 'L':
                                self.ser.write(b'1')
                                logger.info('STIMO: Sent 1')
                            elif bar_label == 'R':
                                self.ser.write(b'2')
                                logger.info('STIMO: Sent 2')

                    # FES event mode mode
                    if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is False:
                        if bar_label == 'L':
                            stim_code = [0, 30, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)
                            logger.info('FES: Sent Left')
                            time.sleep(0.5)
                            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)

                        elif bar_label == 'R':
                            stim_code = [30, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)
                            time.sleep(0.5)
                            logger.info('FES: Sent Right')
                            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)

                    if self.cfg.DEBUG_PROBS:
                        msg = 'DEBUG: Accumulated probabilities = %s' % list2string(
                            probs_acc, '%.3f')
                        logger.info(msg)
                        if self.logf is not None:
                            self.logf.write(msg + '\n')
                    if self.logf is not None:
                        self.logf.write('%s detected as %s (%d)\n\n' %
                                        (true_label, bar_label, bar_score))
                        self.logf.flush()

                    # end of trial
                    state = 'feedback'
                    self.tm_trigger.reset()
                else:
                    # classify
                    probs_new = decoder.get_prob_smooth_unread()
                    if probs_new is None:
                        if self.tm_watchdog.sec() > 3:
                            logger.warning(
                                'No classification being done. Are you receiving data streams?'
                            )
                            self.tm_watchdog.reset()
                    else:
                        self.tm_watchdog.reset()

                        if prob_history is not None:
                            prob_history[true_label].append(
                                probs_new[true_label_index])

                        probs_acc += np.array(probs_new)
                        '''
                        New decoder: already smoothed by the decoder so bias after.
                        '''
                        probs = list(probs_new)
                        if self.bar_bias is not None:
                            probs[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs)
                            probs = [p / newsum for p in probs]
                        '''
                        # Method 2: bias and smoothen
                        if self.bar_bias is not None:
                            # print('BEFORE: %.3f %.3f'% (probs_new[0], probs_new[1]) )
                            probs_new[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs_new)
                            probs_new = [p / newsum for p in probs_new]
                        # print('AFTER: %.3f %.3f'% (probs_new[0], probs_new[1]) )
                        for i in range(len(probs_new)):
                            probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new
                        '''
                        ''' Original method
                        # Method 1: smoothen and bias
                        for i in range( len(probs_new) ):
                            probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new
                        # bias bar
                        if self.bar_bias is not None:
                            probs[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs)
                            probs = [p/newsum for p in probs]
                        '''

                        # determine the direction
                        # TODO: np.argmax(probs)
                        max_pidx = np.argmax(probs)
                        max_label = bar_dirs[max_pidx]

                        if self.cfg.POSITIVE_FEEDBACK is False or \
                                (self.cfg.POSITIVE_FEEDBACK and true_label == max_label):
                            dx = probs[max_pidx]
                            if max_label == 'R':
                                dx *= self.bar_step_right
                            elif max_label == 'L':
                                dx *= self.bar_step_left
                            elif max_label == 'U':
                                dx *= self.bar_step_up
                            elif max_label == 'D':
                                dx *= self.bar_step_down
                            elif max_label == 'B':
                                dx *= self.bar_step_both
                            else:
                                logger.debug('Direction %s using bar step %d' %
                                             (max_label, self.bar_step_left))
                                dx *= self.bar_step_left

                            # slow start
                            selected = self.cfg.BAR_SLOW_START['selected']
                            if self.cfg.BAR_SLOW_START[
                                    selected] and self.tm_trigger.sec(
                                    ) < self.cfg.BAR_SLOW_START[selected]:
                                dx *= self.tm_trigger.sec(
                                ) / self.cfg.BAR_SLOW_START[selected][0]

                            # add likelihoods
                            if max_label == bar_label:
                                bar_score += dx
                            else:
                                bar_score -= dx
                                # change of direction
                                if bar_score < 0:
                                    bar_score = -bar_score
                                    bar_label = max_label
                            bar_score = int(bar_score)
                            if bar_score > 100:
                                bar_score = 100
                            if self.cfg.FEEDBACK_TYPE == 'BODY':
                                if self.cfg.SHOW_CUE:
                                    self.viz.move(bar_label,
                                                  bar_score,
                                                  overlay=False,
                                                  caption=DIRS[true_label],
                                                  caption_color='G')
                                else:
                                    self.viz.move(bar_label,
                                                  bar_score,
                                                  overlay=False)
                            else:
                                self.viz.move(bar_label,
                                              bar_score,
                                              overlay=False)

                            # send the confidence value continuously
                            if self.cfg.WITH_STIMO and self.cfg.STIMO_CONTINUOUS:
                                if self.stimo_timer.sec(
                                ) >= self.cfg.STIMO_COOLOFF:
                                    if bar_label == 'U':
                                        stimo_code = bar_score
                                    else:
                                        stimo_code = 0
                                    self.ser.write(bytes([stimo_code]))
                                    logger.info('Sent STIMO code %d' %
                                                stimo_code)
                                    self.stimo_timer.reset()
                            # with FES
                            if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is True:
                                if self.stimo_timer.sec(
                                ) >= self.cfg.STIMO_COOLOFF:
                                    if bar_label == 'L':
                                        stim_code = [
                                            bar_score, 0, 0, 0, 0, 0, 0, 0
                                        ]
                                    else:
                                        stim_code = [
                                            0, bar_score, 0, 0, 0, 0, 0, 0
                                        ]
                                    self.stim.UpdateChannelSettings(stim_code)
                                    logger.info('Sent FES code %d' % bar_score)
                                    self.stimo_timer.reset()

                        if self.cfg.DEBUG_PROBS:
                            if self.bar_bias is not None:
                                biastxt = '[Bias=%s%.3f]  ' % (
                                    self.bar_bias[0], self.bar_bias[1])
                            else:
                                biastxt = ''
                            msg = '%s%s  prob %s   acc %s   bar %s%d  (%.1f ms)' % \
                                  (biastxt, bar_dirs, list2string(probs_new, '%.2f'), list2string(probs, '%.2f'),
                                   bar_label, bar_score, tm_classify.msec())
                            logger.info(msg)
                            if self.logf is not None:
                                self.logf.write(msg + '\n')

            elif state == 'feedback' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['FEEDBACK']:
                self.trigger.signal(self.tdef.BLANK)
                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    state = 'return'
                    self.tm_trigger.reset()
                else:
                    state = 'gap_s'
                    self.viz.fill()
                    self.viz.update()
                    return bar_label

            elif state == 'return':
                self.viz.set_glass_feedback(False)
                if self.cfg.WITH_STIMO:
                    self.viz.move(bar_label,
                                  bar_score,
                                  overlay=False,
                                  barcolor='B')
                else:
                    self.viz.move(bar_label,
                                  bar_score,
                                  overlay=False,
                                  barcolor='Y')
                self.viz.set_glass_feedback(True)
                bar_score -= 5
                if bar_score <= 0:
                    state = 'gap_s'
                    self.viz.fill()
                    self.viz.update()
                    return bar_label

            self.viz.update()
            key = cv2.waitKeyEx(1)
            if key == KEYS['esc']:
                return
            elif key == KEYS['space']:
                dx = 0
                bar_score = 0
                probs = [1.0 / len(bar_dirs)] * len(bar_dirs)
                self.viz.move(bar_dirs[0], bar_score, overlay=False)
                self.viz.update()
                logger.info('probs and dx reset.')
                self.tm_trigger.reset()
            elif key in ARROW_KEYS and ARROW_KEYS[key] in bar_dirs:
                # change bias on the fly
                if self.bar_bias is None:
                    self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT]
                else:
                    if ARROW_KEYS[key] == self.bar_bias[0]:
                        self.bar_bias[1] += BIAS_INCREMENT
                    elif self.bar_bias[1] >= BIAS_INCREMENT:
                        self.bar_bias[1] -= BIAS_INCREMENT
                    else:
                        self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT]
                if self.bar_bias[1] == 0:
                    self.bar_bias = None
                else:
                    bias_idx = bar_dirs.index(self.bar_bias[0])
예제 #8
0
class GlassControl(object):
    """
    Controls Glass UI

    Constructor:
        mock: set to False if you don't have a Glass.

    """

    def __init__(self, mock=False):
        self.BUFFER_SIZE = 1024
        self.last_dir = 'L'
        self.timer = Timer(autoreset=True)
        self.mock = mock
        if self.mock:
            self.print('Using a fake, mock Glass control object.')

    def print(self, *args):
        if len(args) > 0: print('[GlassControl] ', end='')
        print(*args)

    def connect(self, ip, port):
        if self.mock: return
        self.ip = ip
        self.port = port

        # Networking via USB if IP=127.0.0.1
        if ip == '127.0.0.1':
            exe = 'adb forward tcp:%d tcp:%d' % (port, port)
            self.print(exe)
            os.system(exe)
            time.sleep(0.2)
        self.print('Connecting to %s:%d' % (ip, port))
        try:
            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.socket.connect((self.ip, self.port))
        except:
            self.print('* ERROR connecting to Glass. The error was:')
            self.print(sys.exc_info()[0], sys.exc_info()[1])
            sys.exit(-1)

    def disconnect(self):
        if self.mock: return
        self.print('Disconnecting from Glass')
        self.socket.close()

    def send_byte(self, msg):
        if sys.version_info.major >= 3:
            self.socket.sendall(bytes(msg + '\n', "UTF-8"))
        else:
            self.socket.sendall(bytes(unicode(msg + '\n')))

    def send_msg(self, msg, wait=True):
        """
        Send a message to the Glass

        Glass requires some delay after when the last command was sent.
        This function will be blocked until minimum this delay is satisfied.
        Set wait=False to force sending message, but the msg is likely to be ignored.

        """
        if wait:
            # Wait only if the time hasn't passed enough
            self.timer.sleep_atleast(0.033)  # 30 Hz
        if self.mock:
            return
        try:
            self.send_byte(msg)
        except Exception as e:
            self.print('* ERROR: Glass communication failed! Attempting to reconnect again.')
            self.disconnect()
            time.sleep(2)
            # Let's try again
            self.connect(self.ip, self.port)
            try:
                self.send_byte(msg)
            except Exception as e:
                self.print('Sorry, cannot fix the problem. I give up.')
                raise Exception(e)

    # Show empty bars
    def clear(self):
        if self.mock: return
        self.send_msg('C')

    # Show empty bars
    def draw_cross(self):
        if self.mock: return
        self.clear()

    # Only one direction at a time
    def move_bar(self, new_dir, amount, overlay=False):
        if self.mock: return
        if overlay is False and self.last_dir != new_dir:
            self.send_msg('%s0' % self.last_dir)
        self.send_msg('%s%d' % (new_dir, amount))
        self.last_dir = new_dir

    # Fill screen with a solid color (None, 'R','G','B')
    def fill(self, color=None):
        if self.mock: return
        if color is None:
            self.send_msg('F0')
        elif color == 'R':
            self.send_msg('F1')
        elif color == 'G':
            self.send_msg('F2')
        elif color == 'B':
            self.send_msg('F3')
        elif color == 'K':
            self.send_msg('F4')

    def fullbar_color(self, color):
        if color not in ['R', 'G', 'B', 'Y']:
            print('**** UNSUPPORTED GLASS BAR COLOR ****')
        else:
            msg = 'B' + color[0]
            # print('*** GLASS SENDING', msg)
            self.send_msg(msg)