Esempio n. 1
0
    def classify(self,
                 decoder,
                 true_label,
                 title_text,
                 bar_dirs,
                 state='start',
                 prob_history=None):
        """
        Run a single trial
        """
        true_label_index = bar_dirs.index(true_label)
        self.tm_trigger.reset()
        if self.bar_bias is not None:
            bias_idx = bar_dirs.index(self.bar_bias[0])

        if self.logf is not None:
            self.logf.write('True label: %s\n' % true_label)

        tm_classify = qc.Timer(autoreset=True)
        self.stimo_timer = qc.Timer()
        while True:
            self.tm_display.sleep_atleast(self.refresh_delay)
            self.tm_display.reset()
            if state == 'start' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['INIT']:
                state = 'gap_s'
                if self.cfg.TRIALS_PAUSE:
                    self.viz.put_text('Press any key')
                    self.viz.update()
                    key = cv2.waitKeyEx()
                    if key == KEYS['esc']:
                        return
                self.viz.fill()
                self.tm_trigger.reset()
                self.trigger.signal(self.tdef.INIT)

            elif state == 'gap_s':
                if self.cfg.TIMINGS['GAP'] > 0:
                    self.viz.put_text(title_text)
                state = 'gap'
                self.tm_trigger.reset()

            elif state == 'gap' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['GAP']:
                state = 'cue'
                self.viz.fill()
                self.viz.draw_cue()
                self.viz.glass_draw_cue()
                self.trigger.signal(self.tdef.CUE)
                self.tm_trigger.reset()

            elif state == 'cue' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['READY']:
                state = 'dir_r'

                if self.cfg.SHOW_CUE is True:
                    if self.cfg.FEEDBACK_TYPE == 'BAR':
                        self.viz.move(true_label,
                                      100,
                                      overlay=False,
                                      barcolor='G')
                    elif self.cfg.FEEDBACK_TYPE == 'BODY':
                        self.viz.put_text(DIRS[true_label], 'R')
                    if true_label == 'L':  # left
                        self.trigger.signal(self.tdef.LEFT_READY)
                    elif true_label == 'R':  # right
                        self.trigger.signal(self.tdef.RIGHT_READY)
                    elif true_label == 'U':  # up
                        self.trigger.signal(self.tdef.UP_READY)
                    elif true_label == 'D':  # down
                        self.trigger.signal(self.tdef.DOWN_READY)
                    elif true_label == 'B':  # both hands
                        self.trigger.signal(self.tdef.BOTH_READY)
                    else:
                        raise RuntimeError('Unknown direction %s' % true_label)
                self.tm_trigger.reset()
                '''
                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    self.viz.set_pc_feedback(False)
                self.viz.move(true_label, 100, overlay=False, barcolor='G')

                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    self.viz.set_pc_feedback(True)
                    if self.cfg.SHOW_CUE is True:
                        self.viz.put_text(dirs[true_label], 'R')

                if true_label == 'L':  # left
                    self.trigger.signal(self.tdef.LEFREADY)
                elif true_label == 'R':  # right
                    self.trigger.signal(self.tdef.RIGHT_READY)
                elif true_label == 'U':  # up
                    self.trigger.signal(self.tdef.UP_READY)
                elif true_label == 'D':  # down
                    self.trigger.signal(self.tdef.DOWN_READY)
                elif true_label == 'B':  # both hands
                    self.trigger.signal(self.tdef.BOTH_READY)
                else:
                    raise RuntimeError('Unknown direction %s' % true_label)
                self.tm_trigger.reset()
                '''
            elif state == 'dir_r' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['DIR_CUE']:
                self.viz.fill()
                self.viz.draw_cue()
                self.viz.glass_draw_cue()
                state = 'dir'

                # initialize bar scores
                bar_label = bar_dirs[0]
                bar_score = 0
                probs = [1.0 / len(bar_dirs)] * len(bar_dirs)
                self.viz.move(bar_label, bar_score, overlay=False)
                probs_acc = np.zeros(len(probs))

                if true_label == 'L':  # left
                    self.trigger.signal(self.tdef.LEFT_GO)
                elif true_label == 'R':  # right
                    self.trigger.signal(self.tdef.RIGHT_GO)
                elif true_label == 'U':  # up
                    self.trigger.signal(self.tdef.UP_GO)
                elif true_label == 'D':  # down
                    self.trigger.signal(self.tdef.DOWN_GO)
                elif true_label == 'B':  # both
                    self.trigger.signal(self.tdef.BOTH_GO)
                else:
                    raise RuntimeError('Unknown truedirection %s' % true_label)

                self.tm_watchdog.reset()
                self.tm_trigger.reset()

            elif state == 'dir':
                if self.tm_trigger.sec() > self.cfg.TIMINGS['CLASSIFY'] or (
                        self.premature_end and bar_score >= 100):
                    if not hasattr(
                            self.cfg,
                            'SHOW_RESULT') or self.cfg.SHOW_RESULT is True:
                        # show classfication result
                        if self.cfg.WITH_STIMO is True:
                            if self.cfg.STIMO_FULLGAIT_CYCLE is not None and bar_label == 'U':
                                res_color = 'G'
                            elif self.cfg.TRIALS_RETRY is False or bar_label == true_label:
                                res_color = 'G'
                            else:
                                res_color = 'Y'
                        else:
                            res_color = 'Y'
                        if self.cfg.FEEDBACK_TYPE == 'BODY':
                            self.viz.move(bar_label,
                                          bar_score,
                                          overlay=False,
                                          barcolor=res_color,
                                          caption=DIRS[bar_label],
                                          caption_color=res_color)
                        else:
                            self.viz.move(bar_label,
                                          100,
                                          overlay=False,
                                          barcolor=res_color)
                    else:
                        if self.cfg.FEEDBACK_TYPE == 'BODY':
                            self.viz.move(bar_label,
                                          bar_score,
                                          overlay=False,
                                          barcolor=res_color,
                                          caption='TRIAL END',
                                          caption_color=res_color)
                        else:
                            self.viz.move(bar_label,
                                          0,
                                          overlay=False,
                                          barcolor=res_color)
                    self.trigger.signal(self.tdef.FEEDBACK)

                    # STIMO
                    if self.cfg.WITH_STIMO is True and self.cfg.STIMO_CONTINUOUS is False:
                        if self.cfg.STIMO_FULLGAIT_CYCLE is not None:
                            if bar_label == 'U':
                                self.ser.write(
                                    self.cfg.STIMO_FULLGAIT_PATTERN[0])
                                logger.info('STIMO: Sent 1')
                                time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE)
                                self.ser.write(
                                    self.cfg.STIMO_FULLGAIT_PATTERN[1])
                                logger.info('STIMO: Sent 2')
                                time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE)
                        elif self.cfg.TRIALS_RETRY is False or bar_label == true_label:
                            if bar_label == 'L':
                                self.ser.write(b'1')
                                logger.info('STIMO: Sent 1')
                            elif bar_label == 'R':
                                self.ser.write(b'2')
                                logger.info('STIMO: Sent 2')

                    # FES event mode mode
                    if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is False:
                        if bar_label == 'L':
                            stim_code = [0, 30, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)
                            logger.info('FES: Sent Left')
                            time.sleep(0.5)
                            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)

                        elif bar_label == 'R':
                            stim_code = [30, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)
                            time.sleep(0.5)
                            logger.info('FES: Sent Right')
                            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)

                    if self.cfg.DEBUG_PROBS:
                        msg = 'DEBUG: Accumulated probabilities = %s' % qc.list2string(
                            probs_acc, '%.3f')
                        logger.info(msg)
                        if self.logf is not None:
                            self.logf.write(msg + '\n')
                    if self.logf is not None:
                        self.logf.write('%s detected as %s (%d)\n\n' %
                                        (true_label, bar_label, bar_score))
                        self.logf.flush()

                    # end of trial
                    state = 'feedback'
                    self.tm_trigger.reset()
                else:
                    # classify
                    probs_new = decoder.get_prob_smooth_unread()
                    if probs_new is None:
                        if self.tm_watchdog.sec() > 3:
                            logger.warning(
                                'No classification being done. Are you receiving data streams?'
                            )
                            self.tm_watchdog.reset()
                    else:
                        self.tm_watchdog.reset()

                        if prob_history is not None:
                            prob_history[true_label].append(
                                probs_new[true_label_index])

                        probs_acc += np.array(probs_new)
                        '''
                        New decoder: already smoothed by the decoder so bias after.
                        '''
                        probs = list(probs_new)
                        if self.bar_bias is not None:
                            probs[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs)
                            probs = [p / newsum for p in probs]
                        '''
                        # Method 2: bias and smoothen
                        if self.bar_bias is not None:
                            # print('BEFORE: %.3f %.3f'% (probs_new[0], probs_new[1]) )
                            probs_new[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs_new)
                            probs_new = [p / newsum for p in probs_new]
                        # print('AFTER: %.3f %.3f'% (probs_new[0], probs_new[1]) )
                        for i in range(len(probs_new)):
                            probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new
                        '''
                        ''' Original method
                        # Method 1: smoothen and bias
                        for i in range( len(probs_new) ):
                            probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new
                        # bias bar
                        if self.bar_bias is not None:
                            probs[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs)
                            probs = [p/newsum for p in probs]
                        '''

                        # determine the direction
                        # TODO: np.argmax(probs)
                        max_pidx = qc.get_index_max(probs)
                        max_label = bar_dirs[max_pidx]

                        if self.cfg.POSITIVE_FEEDBACK is False or \
                                (self.cfg.POSITIVE_FEEDBACK and true_label == max_label):
                            dx = probs[max_pidx]
                            if max_label == 'R':
                                dx *= self.bar_step_right
                            elif max_label == 'L':
                                dx *= self.bar_step_left
                            elif max_label == 'U':
                                dx *= self.bar_step_up
                            elif max_label == 'D':
                                dx *= self.bar_step_down
                            elif max_label == 'B':
                                dx *= self.bar_step_both
                            else:
                                logger.debug('Direction %s using bar step %d' %
                                             (max_label, self.bar_step_left))
                                dx *= self.bar_step_left

                            # slow start
                            selected = self.cfg.BAR_SLOW_START['selected']
                            if self.cfg.BAR_SLOW_START[
                                    selected] and self.tm_trigger.sec(
                                    ) < self.cfg.BAR_SLOW_START[selected]:
                                dx *= self.tm_trigger.sec(
                                ) / self.cfg.BAR_SLOW_START[selected][0]

                            # add likelihoods
                            if max_label == bar_label:
                                bar_score += dx
                            else:
                                bar_score -= dx
                                # change of direction
                                if bar_score < 0:
                                    bar_score = -bar_score
                                    bar_label = max_label
                            bar_score = int(bar_score)
                            if bar_score > 100:
                                bar_score = 100
                            if self.cfg.FEEDBACK_TYPE == 'BODY':
                                if self.cfg.SHOW_CUE:
                                    self.viz.move(bar_label,
                                                  bar_score,
                                                  overlay=False,
                                                  caption=DIRS[true_label],
                                                  caption_color='G')
                                else:
                                    self.viz.move(bar_label,
                                                  bar_score,
                                                  overlay=False)
                            else:
                                self.viz.move(bar_label,
                                              bar_score,
                                              overlay=False)

                            # send the confidence value continuously
                            if self.cfg.WITH_STIMO and self.cfg.STIMO_CONTINUOUS:
                                if self.stimo_timer.sec(
                                ) >= self.cfg.STIMO_COOLOFF:
                                    if bar_label == 'U':
                                        stimo_code = bar_score
                                    else:
                                        stimo_code = 0
                                    self.ser.write(bytes([stimo_code]))
                                    logger.info('Sent STIMO code %d' %
                                                stimo_code)
                                    self.stimo_timer.reset()
                            # with FES
                            if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is True:
                                if self.stimo_timer.sec(
                                ) >= self.cfg.STIMO_COOLOFF:
                                    if bar_label == 'L':
                                        stim_code = [
                                            bar_score, 0, 0, 0, 0, 0, 0, 0
                                        ]
                                    else:
                                        stim_code = [
                                            0, bar_score, 0, 0, 0, 0, 0, 0
                                        ]
                                    self.stim.UpdateChannelSettings(stim_code)
                                    logger.info('Sent FES code %d' % bar_score)
                                    self.stimo_timer.reset()

                        if self.cfg.DEBUG_PROBS:
                            if self.bar_bias is not None:
                                biastxt = '[Bias=%s%.3f]  ' % (
                                    self.bar_bias[0], self.bar_bias[1])
                            else:
                                biastxt = ''
                            msg = '%s%s  prob %s   acc %s   bar %s%d  (%.1f ms)' % \
                                  (biastxt, bar_dirs, qc.list2string(probs_new, '%.2f'), qc.list2string(probs, '%.2f'),
                                   bar_label, bar_score, tm_classify.msec())
                            logger.info(msg)
                            if self.logf is not None:
                                self.logf.write(msg + '\n')

            elif state == 'feedback' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['FEEDBACK']:
                self.trigger.signal(self.tdef.BLANK)
                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    state = 'return'
                    self.tm_trigger.reset()
                else:
                    state = 'gap_s'
                    self.viz.fill()
                    self.viz.update()
                    return bar_label

            elif state == 'return':
                self.viz.set_glass_feedback(False)
                if self.cfg.WITH_STIMO:
                    self.viz.move(bar_label,
                                  bar_score,
                                  overlay=False,
                                  barcolor='B')
                else:
                    self.viz.move(bar_label,
                                  bar_score,
                                  overlay=False,
                                  barcolor='Y')
                self.viz.set_glass_feedback(True)
                bar_score -= 5
                if bar_score <= 0:
                    state = 'gap_s'
                    self.viz.fill()
                    self.viz.update()
                    return bar_label

            self.viz.update()
            key = cv2.waitKeyEx(1)
            if key == KEYS['esc']:
                return
            elif key == KEYS['space']:
                dx = 0
                bar_score = 0
                probs = [1.0 / len(bar_dirs)] * len(bar_dirs)
                self.viz.move(bar_dirs[0], bar_score, overlay=False)
                self.viz.update()
                logger.info('probs and dx reset.')
                self.tm_trigger.reset()
            elif key in ARROW_KEYS and ARROW_KEYS[key] in bar_dirs:
                # change bias on the fly
                if self.bar_bias is None:
                    self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT]
                else:
                    if ARROW_KEYS[key] == self.bar_bias[0]:
                        self.bar_bias[1] += BIAS_INCREMENT
                    elif self.bar_bias[1] >= BIAS_INCREMENT:
                        self.bar_bias[1] -= BIAS_INCREMENT
                    else:
                        self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT]
                if self.bar_bias[1] == 0:
                    self.bar_bias = None
                else:
                    bias_idx = bar_dirs.index(self.bar_bias[0])
Esempio n. 2
0
def stream_player(server_name,
                  fif_file,
                  chunk_size,
                  auto_restart=True,
                  wait_start=True,
                  repeat=np.float('inf'),
                  high_resolution=False,
                  trigger_file=None):
    """
    Input
    =====
    server_name: LSL server name.
    fif_file: fif file to replay.
    chunk_size: number of samples to send at once (usually 16-32 is good enough).
    auto_restart: play from beginning again after reaching the end.
    wait_start: wait for user to start in the beginning.
    repeat: number of loops to play.
    high_resolution: use perf_counter() instead of sleep() for higher time resolution
                     but uses much more cpu due to polling.
    trigger_file: used to convert event numbers into event strings for readability.
    
    Note: Run neurodecode.set_log_level('DEBUG') to print out the relative time stamps since started.
    
    """
    raw, events = pu.load_raw(fif_file)
    sfreq = raw.info['sfreq']  # sampling frequency
    n_channels = len(raw.ch_names)  # number of channels
    if trigger_file is not None:
        tdef = trigger_def(trigger_file)
    try:
        event_ch = raw.ch_names.index('TRIGGER')
    except ValueError:
        event_ch = None
    if raw is not None:
        logger.info_green('Successfully loaded %s' % fif_file)
        logger.info('Server name: %s' % server_name)
        logger.info('Sampling frequency %.3f Hz' % sfreq)
        logger.info('Number of channels : %d' % n_channels)
        logger.info('Chunk size : %d' % chunk_size)
        for i, ch in enumerate(raw.ch_names):
            logger.info('%d %s' % (i, ch))
        logger.info('Trigger channel : %s' % event_ch)
    else:
        raise RuntimeError('Error while loading %s' % fif_file)

    # set server information
    sinfo = pylsl.StreamInfo(server_name, channel_count=n_channels, channel_format='float32',\
        nominal_srate=sfreq, type='EEG', source_id=server_name)
    desc = sinfo.desc()
    channel_desc = desc.append_child("channels")
    for ch in raw.ch_names:
        channel_desc.append_child('channel').append_child_value('label', str(ch))\
            .append_child_value('type','EEG').append_child_value('unit','microvolts')
    desc.append_child('amplifier').append_child('settings').append_child_value(
        'is_slave', 'false')
    desc.append_child('acquisition').append_child_value(
        'manufacturer',
        'NeuroDecode').append_child_value('serial_number', 'N/A')
    outlet = pylsl.StreamOutlet(sinfo, chunk_size=chunk_size)

    if wait_start:
        input('Press Enter to start streaming.')
    logger.info('Streaming started')

    idx_chunk = 0
    t_chunk = chunk_size / sfreq
    finished = False
    if high_resolution:
        t_start = time.perf_counter()
    else:
        t_start = time.time()

    # start streaming
    played = 1
    while played < repeat:
        idx_current = idx_chunk * chunk_size
        chunk = raw._data[:, idx_current:idx_current + chunk_size]
        data = chunk.transpose().tolist()
        if idx_current >= raw._data.shape[1] - chunk_size:
            finished = True
        if high_resolution:
            # if a resolution over 2 KHz is needed
            t_sleep_until = t_start + idx_chunk * t_chunk
            while time.perf_counter() < t_sleep_until:
                pass
        else:
            # time.sleep() can have 500 us resolution using the tweak tool provided.
            t_wait = t_start + idx_chunk * t_chunk - time.time()
            if t_wait > 0.001:
                time.sleep(t_wait)
        outlet.push_chunk(data)
        logger.debug('[%8.3fs] sent %d samples (LSL %8.3f)' %
                     (time.perf_counter(), len(data), pylsl.local_clock()))
        if event_ch is not None:
            event_values = set(chunk[event_ch]) - set([0])
            if len(event_values) > 0:
                if trigger_file is None:
                    logger.info('Events: %s' % event_values)
                else:
                    for event in event_values:
                        if event in tdef.by_value:
                            logger.info('Events: %s (%s)' %
                                        (event, tdef.by_value[event]))
                        else:
                            logger.info('Events: %s (Undefined event)' % event)
        idx_chunk += 1

        if finished:
            if auto_restart is False:
                input(
                    'Reached the end of data. Press Enter to restart or Ctrl+C to stop.'
                )
            else:
                logger.info('Reached the end of data. Restarting.')
            idx_chunk = 0
            finished = False
            if high_resolution:
                t_start = time.perf_counter()
            else:
                t_start = time.time()
            played += 1
Esempio n. 3
0
    def _daemon(self,
                classifier,
                probs,
                probs_smooth,
                pread,
                t_problast,
                running,
                return_psd,
                psd_ctypes,
                lock,
                interleave=None,
                label=None):
        """
        Runs Decoder class as a daemon.
        """

        pid = os.getpid()
        ps = psutil.Process(pid)

        if os.name == 'posix':
            # Unix
            ps.nice(
                0
            )  # A negative value increases priority but requires root privilages
        else:
            # Windows
            ps.nice(psutil.HIGH_PRIORITY_CLASS)

        logger.debug('[DecodeWorker-%-6d] Decoder worker process started' %
                     (pid))
        decoder = BCIDecoder(self.amp_name,
                             classifier,
                             buffer_size=self.buffer_sec,
                             fake=self.fake,
                             label=label)
        if self.fake == False:
            psd = ctypeslib.as_array(psd_ctypes)
        else:
            psd = None

        if interleave is None:
            # single-core decoding
            with running.get_lock():
                running.value = 1

            while running.value == 1:
                # compute features and likelihoods
                probs[:], t_prob = decoder.get_prob(True)
                probs_smooth_sum = 0
                for i in range(len(probs_smooth)):
                    probs_smooth[i] = probs_smooth[i] * self.alpha_old + probs[
                        i] * self.alpha_new
                    probs_smooth_sum += probs_smooth[i]
                for i in range(len(probs_smooth)):
                    probs_smooth[i] /= probs_smooth_sum
                pread.value = 0
                t_problast.value = t_prob

                # copy back PSD values only when requested
                if self.fake == False and return_psd.value == 1:
                    lock.acquire()
                    psd[:] = decoder.psd_buffer[-1].reshape((1, -1))
                    lock.release()
                    return_psd.value = 0
        else:
            # interleaved parallel decoding
            t_start = interleave['t_start']
            period = interleave['period']
            running.value = 1
            t_next = t_start + math.ceil(
                ((time.time() - t_start) / period)) * period

            while running.value == 1:
                # end of the current time slot
                t_next += period

                # compute likelihoods
                t_prob_wall = time.time()
                probs_local, t_prob_lsl = decoder.get_prob(True)

                # update the probs only if the current value is the latest
                ##################################################################
                # TODO: use timestamp to compare instead of time.time()
                ##################################################################
                if t_prob_wall > t_problast.value:
                    lock.acquire()
                    probs[:] = probs_local
                    for i in range(len(probs_smooth)):
                        probs_smooth[i] = probs_smooth[
                            i] * self.alpha_old + probs[i] * self.alpha_new
                    pread.value = 0
                    t_problast.value = t_prob_wall
                    lock.release()

                # copy back PSD values only when requested
                if self.fake == False and return_psd.value == 1:
                    lock.acquire()
                    psd[:] = decoder.psd_buffer[-1].reshape((1, -1))
                    lock.release()
                    return_psd.value = 0

                # get the next time slot if didn't finish in the current slot
                if time.time() > t_next:
                    t_next_new = t_start + math.ceil(
                        ((time.time() - t_start) / period)) * period
                    logger.warning('[DecodeWorker-%-6d] High decoding delay (%.1f ms): t_next = %.3f -> %.3f' %\
                          (pid, (time.time() - t_next + period) * 1000, t_next, t_next_new))
                    t_next = t_next_new

                # sleep until the next time slot
                t_sleep = t_next - time.time()
                if t_sleep > 0.001:
                    time.sleep(t_sleep)
                logger.debug('[DecodeWorker-%-6d] Woke up at %.3f' %
                             (pid, time.time()))