def run(cfg, state=mp.Value('i', 1), queue=None, logger=logger): ''' Main function used to run the offline protocol. Parameters ---------- cfg : python.module The loaded config module from the corresponding config_offline.py queue : mp.Queue If not None, redirect sys.stdout to GUI terminal logger : logging.logger The logger to use ''' # Use to redirect sys.stdout to GUI terminal if GUI usage redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol start if equals to 1 if not state.value: sys.exit() # Load the mapping from int to string for triggers events cfg.tdef = TriggerDef(cfg.TRIGGER_FILE) # Refresh rate refresh_delay = 1.0 / cfg.REFRESH_RATE # Trigger trigger = Trigger(lpttype=cfg.TRIGGER_DEVICE, state=state) if trigger.init(50) == False: logger.error( '\n** Error connecting to the trigger device. Use a mock trigger instead?' ) input('Press Ctrl+C to stop or Enter to continue.') trigger = Trigger(lpttype='FAKE') trigger.init(50) # timers timer_refresh = Timer() trial = 1 num_trials = cfg.TRIALS_NB # start while trial <= num_trials: timer_refresh.sleep_atleast(refresh_delay) timer_refresh.reset() #------------------------------------- # ADD YOUR CODE HERE #------------------------------------- with state.get_lock(): state.value = 0
def _log_decoding_helper(state, event_queue, amp_name=None, autostop=False): """ Helper function to run StreamReceiver object in background Parameters ---------- state : mp.Value The multiprocessing sharing variable event_queue : mp.Queue The queue used to share new events amp_name : str The stream name to connect to autostop : bool If True, automatically finish when no more data is received. """ logger.info('Event acquisition subprocess started.') # wait for the start signal while state.value == 0: time.sleep(0.01) # acquire event values and returns event times and event values sr = StreamReceiver(bufsize=0, stream_name=amp_name) tm = Timer(autoreset=True) started = False while state.value == 1: chunk, ts_list = sr.acquire() if autostop: if started is True: if len(ts_list) == 0: state.value = 0 break elif len(ts_list) > 0: started = True tm.sleep_atleast(0.001) logger.info('Event acquisition subprocess finishing up ...') buffers, times = sr.get_buffer() events = buffers[:, 0] # first channel is the trigger channel event_index = np.where(events != 0)[0] event_times = times[event_index].reshape(-1).tolist() event_values = events[event_index].tolist() assert len(event_times) == len(event_values) event_queue.put((event_times, event_values))
def sample_decoding(decoder): """ Decoding example Parameters ---------- decoder : The decoder to use """ def get_index_max(seq): if type(seq) == list: return max(range(len(seq)), key=seq.__getitem__) elif type(seq) == dict: return max(seq, key=seq.__getitem__) else: logger.error('Unsupported input %s' % type(seq)) return None # load trigger definitions for labeling labels = decoder.get_label_names() tm_watchdog = Timer(autoreset=True) tm_cls = Timer() while True: praw = decoder.get_prob_unread() psmooth = decoder.get_prob_smooth() if praw is None: # watch dog if tm_cls.sec() > 5: logger.warning( 'No classification was done in the last 5 seconds. Are you receiving data streams?' ) tm_cls.reset() tm_watchdog.sleep_atleast(0.001) continue txt = '[%8.1f msec]' % (tm_cls.msec()) for i, label in enumerate(labels): txt += ' %s %.3f (raw %.3f)' % (label, psmooth[i], praw[i]) maxi = get_index_max(psmooth) txt += ' %s' % labels[maxi] print(txt) tm_cls.reset()
def run(cfg, state=mp.Value('i', 1), queue=None): ''' Main function used to run the online protocol. Parameters ---------- cfg : python.module The loaded config module from the corresponding config_offline.py queue : mp.Queue If not None, redirect sys.stdout to GUI terminal logger : logging.logger The logger to use ''' redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol runs if state equals to 1 if not state.value: sys.exit(-1) # events and triggers cfg.tdef = TriggerDef(cfg.TRIGGER_FILE) # To send trigger events trigger = Trigger(cfg.TRIGGER_DEVICE, state) if trigger.init(50) == False: logger.error( 'Cannot connect to trigger device. Use a mock trigger instead?') input('Press Ctrl+C to stop or Enter to continue.') trigger = Trigger('FAKE', state) trigger.init(50) # Instance a stream receiver sr = StreamReceiver(bufsize=1, winsize=0.5, stream_name=None, eeg_only=True) # Timer for acquisition rate, here 20 Hz tm = Timer(autoreset=True) # Refresh rate refresh_delay = 1.0 / cfg.REFRESH_RATE while True: # Acquire data from all the connected LSL streams by filling each associated buffers. sr.acquire() # Extract the latest window from the buffer of the chosen stream. window, tslist = sr.get_window( ) # window = [samples x channels], tslist = [samples] #------------------------------------- # ADD YOUR CODE HERE #------------------------------------- # To run a trained BCI decoder, look at online_mi.py protocol tm.sleep_atleast(refresh_delay) with state.get_lock(): state.value = 0 logger.info('Finished.')
def run(cfg, state=mp.Value('i', 1), queue=None): redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol start if equals to 1 if not state.value: sys.exit() refresh_delay = 1.0 / cfg.REFRESH_RATE cfg.tdef = TriggerDef(cfg.TRIGGER_FILE) # visualizer keys = {'left':81, 'right':83, 'up':82, 'down':84, 'pgup':85, 'pgdn':86, 'home':80, 'end':87, 'space':32, 'esc':27, ',':44, '.':46, 's':115, 'c':99, '[':91, ']':93, '1':49, '!':33, '2':50, '@':64, '3':51, '#':35} color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235), K=(0, 0, 0), w=(200, 200, 200)) dir_sequence = [] for x in range(cfg.TRIALS_EACH): dir_sequence.extend(cfg.DIRECTIONS) random.shuffle(dir_sequence) num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH event = 'start' trial = 1 # Hardware trigger if cfg.TRIGGER_DEVICE is None: logger.warning('No trigger device set. Press Ctrl+C to stop or Enter to continue.') #input() trigger = Trigger(lpttype=cfg.TRIGGER_DEVICE, state=state) if trigger.init(50) == False: logger.error('\n** Error connecting to USB2LPT device. Use a mock trigger instead?') input('Press Ctrl+C to stop or Enter to continue.') trigger = Trigger(lpttype='FAKE') trigger.init(50) # timers timer_trigger = Timer() timer_dir = Timer() timer_refresh = Timer() t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE']) t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(-cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE']) bar = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) bar.fill() bar.glass_draw_cue() # start while trial <= num_trials: timer_refresh.sleep_atleast(refresh_delay) timer_refresh.reset() # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based) if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']: event = 'gap_s' bar.fill() timer_trigger.reset() trigger.signal(cfg.tdef.INIT) elif event == 'gap_s': if cfg.TRIAL_PAUSE: bar.put_text('Press any key') bar.update() key = cv2.waitKey() if key == keys['esc'] or not state.value: break bar.fill() bar.put_text('Trial %d / %d' % (trial, num_trials)) event = 'gap' timer_trigger.reset() elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']: event = 'cue' bar.fill() bar.draw_cue() trigger.signal(cfg.tdef.CUE) timer_trigger.reset() elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']: event = 'dir_r' dir = dir_sequence[trial - 1] if dir == 'L': # left bar.move('L', 100, overlay=True) trigger.signal(cfg.tdef.LEFT_READY) elif dir == 'R': # right bar.move('R', 100, overlay=True) trigger.signal(cfg.tdef.RIGHT_READY) elif dir == 'U': # up bar.move('U', 100, overlay=True) trigger.signal(cfg.tdef.UP_READY) elif dir == 'D': # down bar.move('D', 100, overlay=True) trigger.signal(cfg.tdef.DOWN_READY) elif dir == 'B': # both hands bar.move('L', 100, overlay=True) bar.move('R', 100, overlay=True) trigger.signal(cfg.tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %d' % dir) timer_trigger.reset() elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready: bar.fill() bar.draw_cue() event = 'dir' timer_trigger.reset() timer_dir.reset() if dir == 'L': # left trigger.signal(cfg.tdef.LEFT_GO) elif dir == 'R': # right trigger.signal(cfg.tdef.RIGHT_GO) elif dir == 'U': # up trigger.signal(cfg.tdef.UP_GO) elif dir == 'D': # down trigger.signal(cfg.tdef.DOWN_GO) elif dir == 'B': # both trigger.signal(cfg.tdef.BOTH_GO) else: raise RuntimeError('Unknown direction %d' % dir) elif event == 'dir' and timer_trigger.sec() > t_dir: event = 'gap_s' bar.fill() trial += 1 logger.info('trial ' + str(trial - 1) + ' done') trigger.signal(cfg.tdef.BLANK) timer_trigger.reset() t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE']) t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(-cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE']) # protocol if event == 'dir': dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1) if dir == 'L': # L bar.move('L', dx, overlay=True) elif dir == 'R': # R bar.move('R', dx, overlay=True) elif dir == 'U': # U bar.move('U', dx, overlay=True) elif dir == 'D': # D bar.move('D', dx, overlay=True) elif dir == 'B': # Both bar.move('L', dx, overlay=True) bar.move('R', dx, overlay=True) # wait for start if event == 'start': bar.put_text('Waiting to start') bar.update() key = 0xFF & cv2.waitKey(1) if key == keys['esc'] or not state.value: break bar.finish() with state.get_lock(): state.value = 0
def log_decoding(decoder, logfile, amp_name=None, pklfile=True, matfile=False, autostop=False, prob_smooth=False): """ Decode online and write results with event timestamps Parameters ---------- decoder : BCIDecoder or BCIDecoderDaemon class The decoder to use logfile : str The file path to contain the result in Python pickle format amp_name : str The stream name to connect to pklfile : bool If True, export the results to Python pickle format matfile : bool If True, export the results to .mat file autostop : bool If True, automatically finish when no more data is received. prob_smooth : bool If True, use smoothed probability values according to decoder's smoothing parameter. """ import cv2 import scipy # run event acquisition process in the background state = mp.Value('i', 1) event_queue = mp.Queue() proc = mp.Process(target=_log_decoding_helper, args=[state, event_queue, amp_name, autostop]) proc.start() logger.info('Spawned event acquisition process.') # init variables and choose decoding function labels = decoder.get_label_names() probs = [] prob_times = [] if prob_smooth: decode_fn = decoder.get_prob_smooth_unread else: decode_fn = decoder.get_prob_unread # simple controller UI cv2.namedWindow("Decoding", cv2.WINDOW_AUTOSIZE) cv2.moveWindow("Decoding", 1400, 50) img = np.zeros([100, 400, 3], np.uint8) cv2.putText(img, 'Press any key to start', (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) cv2.imshow("Decoding", img) cv2.waitKeyEx() img *= 0 cv2.putText(img, 'Press ESC to stop', (40, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) cv2.imshow("Decoding", img) key = 0 started = False tm_watchdog = Timer(autoreset=True) tm_cls = Timer() while key != 27: prob, prob_time = decode_fn(True) t_lsl = pylsl.local_clock() key = cv2.waitKeyEx(1) if prob is None: # watch dog if tm_cls.sec() > 5: if autostop and started: logger.info('No more streaming data. Finishing.') break tm_cls.reset() tm_watchdog.sleep_atleast(0.001) continue probs.append(prob) prob_times.append(prob_time) txt = '[%.3f] ' % prob_time txt += ', '.join(['%s: %.2f' % (l, p) for l, p in zip(labels, prob)]) txt += ' (%d ms, LSL Diff = %.3f)' % (tm_cls.msec(), (t_lsl - prob_time)) logger.info(txt) if not started: started = True tm_cls.reset() # finish up processes cv2.destroyAllWindows() logger.info('Cleaning up event acquisition process.') state.value = 0 decoder.stop() event_times, event_values = event_queue.get() proc.join() # save values if len(prob_times) == 0: logger.error('No decoding result. Please debug.') import pdb pdb.set_trace() t_start = prob_times[0] probs = np.vstack(probs) event_times = np.array(event_times) event_times = event_times[np.where(event_times >= t_start)[0]] - t_start prob_times = np.array(prob_times) - t_start event_values = np.array(event_values) data = dict(probs=probs, prob_times=prob_times, event_times=event_times, event_values=event_values, labels=labels) if pklfile: with open(logfile, 'wb') as f: pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) logger.info('Saved to %s' % logfile) if matfile: pp = io.parse_path(logfile) pp = Path(logfile) matout = '%s/%s.mat' % (pp.parent, pp.stem) scipy.io.savemat(matout, data) logger.info('Saved to %s' % matout)
class Feedback: """ Perform a classification with visual feedback """ def __init__(self, cfg, viz, tdef, trigger, logfile=None): self.cfg = cfg self.tdef = tdef self.trigger = trigger self.viz = viz self.viz.fill() self.refresh_delay = 1.0 / self.cfg.REFRESH_RATE self.bar_step_left = self.cfg.BAR_STEP['left'] self.bar_step_right = self.cfg.BAR_STEP['right'] self.bar_step_up = self.cfg.BAR_STEP['up'] self.bar_step_down = self.cfg.BAR_STEP['down'] self.bar_step_both = self.cfg.BAR_STEP['both'] if type(self.cfg.BAR_BIAS) is tuple: self.bar_bias = list(self.cfg.BAR_BIAS) else: self.bar_bias = self.cfg.BAR_BIAS # New decoder: already smoothed by the decoder so bias after. #self.alpha_old = self.cfg.PROB_ACC_ALPHA #self.alpha_new = 1.0 - self.cfg.PROB_ACC_ALPHA if hasattr(self.cfg, 'BAR_REACH_FINISH') and self.cfg.BAR_REACH_FINISH == True: self.premature_end = True else: self.premature_end = False self.tm_trigger = Timer() self.tm_display = Timer() self.tm_watchdog = Timer() if logfile is not None: self.logf = open(logfile, 'w') else: self.logf = None # STIMO only if self.cfg.WITH_STIMO is True: if self.cfg.STIMO_COMPORT is None: atens = [x for x in serial.tools.list_ports.grep('ATEN')] if len(atens) == 0: raise RuntimeError('No ATEN device found. Stop.') try: self.stimo_port = atens[0].device except AttributeError: # depends on Python distribution self.stimo_port = atens[0][0] else: self.stimo_port = self.cfg.STIMO_COMPORT self.ser = serial.Serial(self.stimo_port, self.cfg.STIMO_BAUDRATE) logger.info('STIMO serial port %s is_open = %s' % (self.stimo_port, self.ser.is_open)) # FES only if self.cfg.WITH_FES is True: self.stim = fes.Motionstim8() self.stim.OpenSerialPort(self.cfg.FES_COMPORT) self.stim.InitializeChannelListMode() logger.info('Opened FES serial port') def __del__(self): # STIMO only if self.cfg.WITH_STIMO is True: self.ser.close() logger.info('Closed STIMO serial port %s' % self.stimo_port) # FES only if self.cfg.WITH_FES is True: stim_code = [0, 0, 0, 0, 0, 0, 0, 0] self.stim.UpdateChannelSettings(stim_code) self.stim.CloseSerialPort() logger.info('Closed FES serial port') def classify(self, decoder, true_label, title_text, bar_dirs, state='start', prob_history=None): """ Run a single trial """ def list2string(vec, fmt, sep=' '): return sep.join(fmt % x for x in vec) true_label_index = bar_dirs.index(true_label) self.tm_trigger.reset() if self.bar_bias is not None: bias_idx = bar_dirs.index(self.bar_bias[0]) if self.logf is not None: self.logf.write('True label: %s\n' % true_label) tm_classify = Timer(autoreset=True) self.stimo_timer = Timer() while True: self.tm_display.sleep_atleast(self.refresh_delay) self.tm_display.reset() if state == 'start' and self.tm_trigger.sec( ) > self.cfg.TIMINGS['INIT']: state = 'gap_s' if self.cfg.TRIALS_PAUSE: self.viz.put_text('Press any key') self.viz.update() key = cv2.waitKeyEx() if key == KEYS['esc']: return self.viz.fill() self.tm_trigger.reset() self.trigger.signal(self.tdef.INIT) elif state == 'gap_s': if self.cfg.TIMINGS['GAP'] > 0: self.viz.put_text(title_text) state = 'gap' self.tm_trigger.reset() elif state == 'gap' and self.tm_trigger.sec( ) > self.cfg.TIMINGS['GAP']: state = 'cue' self.viz.fill() self.viz.draw_cue() self.viz.glass_draw_cue() self.trigger.signal(self.tdef.CUE) self.tm_trigger.reset() elif state == 'cue' and self.tm_trigger.sec( ) > self.cfg.TIMINGS['READY']: state = 'dir_r' if self.cfg.SHOW_CUE is True: if self.cfg.FEEDBACK_TYPE == 'BAR': self.viz.move(true_label, 100, overlay=False, barcolor='G') elif self.cfg.FEEDBACK_TYPE == 'BODY': self.viz.put_text(DIRS[true_label], 'R') if true_label == 'L': # left self.trigger.signal(self.tdef.LEFT_READY) elif true_label == 'R': # right self.trigger.signal(self.tdef.RIGHT_READY) elif true_label == 'U': # up self.trigger.signal(self.tdef.UP_READY) elif true_label == 'D': # down self.trigger.signal(self.tdef.DOWN_READY) elif true_label == 'B': # both hands self.trigger.signal(self.tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %s' % true_label) self.tm_trigger.reset() ''' if self.cfg.FEEDBACK_TYPE == 'BODY': self.viz.set_pc_feedback(False) self.viz.move(true_label, 100, overlay=False, barcolor='G') if self.cfg.FEEDBACK_TYPE == 'BODY': self.viz.set_pc_feedback(True) if self.cfg.SHOW_CUE is True: self.viz.put_text(dirs[true_label], 'R') if true_label == 'L': # left self.trigger.signal(self.tdef.LEFREADY) elif true_label == 'R': # right self.trigger.signal(self.tdef.RIGHT_READY) elif true_label == 'U': # up self.trigger.signal(self.tdef.UP_READY) elif true_label == 'D': # down self.trigger.signal(self.tdef.DOWN_READY) elif true_label == 'B': # both hands self.trigger.signal(self.tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %s' % true_label) self.tm_trigger.reset() ''' elif state == 'dir_r' and self.tm_trigger.sec( ) > self.cfg.TIMINGS['DIR_CUE']: self.viz.fill() self.viz.draw_cue() self.viz.glass_draw_cue() state = 'dir' # initialize bar scores bar_label = bar_dirs[0] bar_score = 0 probs = [1.0 / len(bar_dirs)] * len(bar_dirs) self.viz.move(bar_label, bar_score, overlay=False) probs_acc = np.zeros(len(probs)) if true_label == 'L': # left self.trigger.signal(self.tdef.LEFT_GO) elif true_label == 'R': # right self.trigger.signal(self.tdef.RIGHT_GO) elif true_label == 'U': # up self.trigger.signal(self.tdef.UP_GO) elif true_label == 'D': # down self.trigger.signal(self.tdef.DOWN_GO) elif true_label == 'B': # both self.trigger.signal(self.tdef.BOTH_GO) else: raise RuntimeError('Unknown truedirection %s' % true_label) self.tm_watchdog.reset() self.tm_trigger.reset() elif state == 'dir': if self.tm_trigger.sec() > self.cfg.TIMINGS['CLASSIFY'] or ( self.premature_end and bar_score >= 100): if not hasattr( self.cfg, 'SHOW_RESULT') or self.cfg.SHOW_RESULT is True: # show classfication result if self.cfg.WITH_STIMO is True: if self.cfg.STIMO_FULLGAIT_CYCLE is not None and bar_label == 'U': res_color = 'G' elif self.cfg.TRIALS_RETRY is False or bar_label == true_label: res_color = 'G' else: res_color = 'Y' else: res_color = 'Y' if self.cfg.FEEDBACK_TYPE == 'BODY': self.viz.move(bar_label, bar_score, overlay=False, barcolor=res_color, caption=DIRS[bar_label], caption_color=res_color) else: self.viz.move(bar_label, 100, overlay=False, barcolor=res_color) else: if self.cfg.FEEDBACK_TYPE == 'BODY': self.viz.move(bar_label, bar_score, overlay=False, barcolor=res_color, caption='TRIAL END', caption_color=res_color) else: self.viz.move(bar_label, 0, overlay=False, barcolor=res_color) self.trigger.signal(self.tdef.FEEDBACK) # STIMO if self.cfg.WITH_STIMO is True and self.cfg.STIMO_CONTINUOUS is False: if self.cfg.STIMO_FULLGAIT_CYCLE is not None: if bar_label == 'U': self.ser.write( self.cfg.STIMO_FULLGAIT_PATTERN[0]) logger.info('STIMO: Sent 1') time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE) self.ser.write( self.cfg.STIMO_FULLGAIT_PATTERN[1]) logger.info('STIMO: Sent 2') time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE) elif self.cfg.TRIALS_RETRY is False or bar_label == true_label: if bar_label == 'L': self.ser.write(b'1') logger.info('STIMO: Sent 1') elif bar_label == 'R': self.ser.write(b'2') logger.info('STIMO: Sent 2') # FES event mode mode if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is False: if bar_label == 'L': stim_code = [0, 30, 0, 0, 0, 0, 0, 0] self.stim.UpdateChannelSettings(stim_code) logger.info('FES: Sent Left') time.sleep(0.5) stim_code = [0, 0, 0, 0, 0, 0, 0, 0] self.stim.UpdateChannelSettings(stim_code) elif bar_label == 'R': stim_code = [30, 0, 0, 0, 0, 0, 0, 0] self.stim.UpdateChannelSettings(stim_code) time.sleep(0.5) logger.info('FES: Sent Right') stim_code = [0, 0, 0, 0, 0, 0, 0, 0] self.stim.UpdateChannelSettings(stim_code) if self.cfg.DEBUG_PROBS: msg = 'DEBUG: Accumulated probabilities = %s' % list2string( probs_acc, '%.3f') logger.info(msg) if self.logf is not None: self.logf.write(msg + '\n') if self.logf is not None: self.logf.write('%s detected as %s (%d)\n\n' % (true_label, bar_label, bar_score)) self.logf.flush() # end of trial state = 'feedback' self.tm_trigger.reset() else: # classify probs_new = decoder.get_prob_smooth_unread() if probs_new is None: if self.tm_watchdog.sec() > 3: logger.warning( 'No classification being done. Are you receiving data streams?' ) self.tm_watchdog.reset() else: self.tm_watchdog.reset() if prob_history is not None: prob_history[true_label].append( probs_new[true_label_index]) probs_acc += np.array(probs_new) ''' New decoder: already smoothed by the decoder so bias after. ''' probs = list(probs_new) if self.bar_bias is not None: probs[bias_idx] += self.bar_bias[1] newsum = sum(probs) probs = [p / newsum for p in probs] ''' # Method 2: bias and smoothen if self.bar_bias is not None: # print('BEFORE: %.3f %.3f'% (probs_new[0], probs_new[1]) ) probs_new[bias_idx] += self.bar_bias[1] newsum = sum(probs_new) probs_new = [p / newsum for p in probs_new] # print('AFTER: %.3f %.3f'% (probs_new[0], probs_new[1]) ) for i in range(len(probs_new)): probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new ''' ''' Original method # Method 1: smoothen and bias for i in range( len(probs_new) ): probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new # bias bar if self.bar_bias is not None: probs[bias_idx] += self.bar_bias[1] newsum = sum(probs) probs = [p/newsum for p in probs] ''' # determine the direction # TODO: np.argmax(probs) max_pidx = np.argmax(probs) max_label = bar_dirs[max_pidx] if self.cfg.POSITIVE_FEEDBACK is False or \ (self.cfg.POSITIVE_FEEDBACK and true_label == max_label): dx = probs[max_pidx] if max_label == 'R': dx *= self.bar_step_right elif max_label == 'L': dx *= self.bar_step_left elif max_label == 'U': dx *= self.bar_step_up elif max_label == 'D': dx *= self.bar_step_down elif max_label == 'B': dx *= self.bar_step_both else: logger.debug('Direction %s using bar step %d' % (max_label, self.bar_step_left)) dx *= self.bar_step_left # slow start selected = self.cfg.BAR_SLOW_START['selected'] if self.cfg.BAR_SLOW_START[ selected] and self.tm_trigger.sec( ) < self.cfg.BAR_SLOW_START[selected]: dx *= self.tm_trigger.sec( ) / self.cfg.BAR_SLOW_START[selected][0] # add likelihoods if max_label == bar_label: bar_score += dx else: bar_score -= dx # change of direction if bar_score < 0: bar_score = -bar_score bar_label = max_label bar_score = int(bar_score) if bar_score > 100: bar_score = 100 if self.cfg.FEEDBACK_TYPE == 'BODY': if self.cfg.SHOW_CUE: self.viz.move(bar_label, bar_score, overlay=False, caption=DIRS[true_label], caption_color='G') else: self.viz.move(bar_label, bar_score, overlay=False) else: self.viz.move(bar_label, bar_score, overlay=False) # send the confidence value continuously if self.cfg.WITH_STIMO and self.cfg.STIMO_CONTINUOUS: if self.stimo_timer.sec( ) >= self.cfg.STIMO_COOLOFF: if bar_label == 'U': stimo_code = bar_score else: stimo_code = 0 self.ser.write(bytes([stimo_code])) logger.info('Sent STIMO code %d' % stimo_code) self.stimo_timer.reset() # with FES if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is True: if self.stimo_timer.sec( ) >= self.cfg.STIMO_COOLOFF: if bar_label == 'L': stim_code = [ bar_score, 0, 0, 0, 0, 0, 0, 0 ] else: stim_code = [ 0, bar_score, 0, 0, 0, 0, 0, 0 ] self.stim.UpdateChannelSettings(stim_code) logger.info('Sent FES code %d' % bar_score) self.stimo_timer.reset() if self.cfg.DEBUG_PROBS: if self.bar_bias is not None: biastxt = '[Bias=%s%.3f] ' % ( self.bar_bias[0], self.bar_bias[1]) else: biastxt = '' msg = '%s%s prob %s acc %s bar %s%d (%.1f ms)' % \ (biastxt, bar_dirs, list2string(probs_new, '%.2f'), list2string(probs, '%.2f'), bar_label, bar_score, tm_classify.msec()) logger.info(msg) if self.logf is not None: self.logf.write(msg + '\n') elif state == 'feedback' and self.tm_trigger.sec( ) > self.cfg.TIMINGS['FEEDBACK']: self.trigger.signal(self.tdef.BLANK) if self.cfg.FEEDBACK_TYPE == 'BODY': state = 'return' self.tm_trigger.reset() else: state = 'gap_s' self.viz.fill() self.viz.update() return bar_label elif state == 'return': self.viz.set_glass_feedback(False) if self.cfg.WITH_STIMO: self.viz.move(bar_label, bar_score, overlay=False, barcolor='B') else: self.viz.move(bar_label, bar_score, overlay=False, barcolor='Y') self.viz.set_glass_feedback(True) bar_score -= 5 if bar_score <= 0: state = 'gap_s' self.viz.fill() self.viz.update() return bar_label self.viz.update() key = cv2.waitKeyEx(1) if key == KEYS['esc']: return elif key == KEYS['space']: dx = 0 bar_score = 0 probs = [1.0 / len(bar_dirs)] * len(bar_dirs) self.viz.move(bar_dirs[0], bar_score, overlay=False) self.viz.update() logger.info('probs and dx reset.') self.tm_trigger.reset() elif key in ARROW_KEYS and ARROW_KEYS[key] in bar_dirs: # change bias on the fly if self.bar_bias is None: self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT] else: if ARROW_KEYS[key] == self.bar_bias[0]: self.bar_bias[1] += BIAS_INCREMENT elif self.bar_bias[1] >= BIAS_INCREMENT: self.bar_bias[1] -= BIAS_INCREMENT else: self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT] if self.bar_bias[1] == 0: self.bar_bias = None else: bias_idx = bar_dirs.index(self.bar_bias[0])
class GlassControl(object): """ Controls Glass UI Constructor: mock: set to False if you don't have a Glass. """ def __init__(self, mock=False): self.BUFFER_SIZE = 1024 self.last_dir = 'L' self.timer = Timer(autoreset=True) self.mock = mock if self.mock: self.print('Using a fake, mock Glass control object.') def print(self, *args): if len(args) > 0: print('[GlassControl] ', end='') print(*args) def connect(self, ip, port): if self.mock: return self.ip = ip self.port = port # Networking via USB if IP=127.0.0.1 if ip == '127.0.0.1': exe = 'adb forward tcp:%d tcp:%d' % (port, port) self.print(exe) os.system(exe) time.sleep(0.2) self.print('Connecting to %s:%d' % (ip, port)) try: self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.connect((self.ip, self.port)) except: self.print('* ERROR connecting to Glass. The error was:') self.print(sys.exc_info()[0], sys.exc_info()[1]) sys.exit(-1) def disconnect(self): if self.mock: return self.print('Disconnecting from Glass') self.socket.close() def send_byte(self, msg): if sys.version_info.major >= 3: self.socket.sendall(bytes(msg + '\n', "UTF-8")) else: self.socket.sendall(bytes(unicode(msg + '\n'))) def send_msg(self, msg, wait=True): """ Send a message to the Glass Glass requires some delay after when the last command was sent. This function will be blocked until minimum this delay is satisfied. Set wait=False to force sending message, but the msg is likely to be ignored. """ if wait: # Wait only if the time hasn't passed enough self.timer.sleep_atleast(0.033) # 30 Hz if self.mock: return try: self.send_byte(msg) except Exception as e: self.print('* ERROR: Glass communication failed! Attempting to reconnect again.') self.disconnect() time.sleep(2) # Let's try again self.connect(self.ip, self.port) try: self.send_byte(msg) except Exception as e: self.print('Sorry, cannot fix the problem. I give up.') raise Exception(e) # Show empty bars def clear(self): if self.mock: return self.send_msg('C') # Show empty bars def draw_cross(self): if self.mock: return self.clear() # Only one direction at a time def move_bar(self, new_dir, amount, overlay=False): if self.mock: return if overlay is False and self.last_dir != new_dir: self.send_msg('%s0' % self.last_dir) self.send_msg('%s%d' % (new_dir, amount)) self.last_dir = new_dir # Fill screen with a solid color (None, 'R','G','B') def fill(self, color=None): if self.mock: return if color is None: self.send_msg('F0') elif color == 'R': self.send_msg('F1') elif color == 'G': self.send_msg('F2') elif color == 'B': self.send_msg('F3') elif color == 'K': self.send_msg('F4') def fullbar_color(self, color): if color not in ['R', 'G', 'B', 'Y']: print('**** UNSUPPORTED GLASS BAR COLOR ****') else: msg = 'B' + color[0] # print('*** GLASS SENDING', msg) self.send_msg(msg)