def check_speed(decoder, max_count=float('inf')): """ Test decoding speed accross several classifications. Parameters ---------- decoder : BCIDecoder or BCIDecoderDaemon class The decoder to assess its performance max_count : int The number of classification for averaging """ tm = Timer() count = 0 mslist = [] while count < max_count: while decoder.get_prob_unread() is None: pass count += 1 if tm.sec() > 1: t = tm.sec() ms = 1000.0 * t / count # show time per classification and its reciprocal print('%.0f ms/c %.1f Hz' % (ms, count / t)) mslist.append(ms) count = 0 tm.reset() print('mean = %.1f ms' % np.mean(mslist))
def fit_predict_thres(cls, X_train, Y_train, X_test, Y_test, cnum, label_list, ignore_thres=None, decision_thres=None): """ Any likelihood lower than a threshold is not counted as classification score Confusion matrix, accuracy and F1 score (macro average) are computed. Params ====== ignore_thres: if not None or larger than 0, likelihood values lower than ignore_thres will be ignored while computing confusion matrix. """ timer = Timer() cls.fit(X_train, Y_train) assert ignore_thres is None or ignore_thres >= 0 if ignore_thres is None or ignore_thres == 0: Y_pred = cls.predict(X_test) score = skmetrics.accuracy_score(Y_test, Y_pred) cm = skmetrics.confusion_matrix(Y_test, Y_pred, label_list) f1 = skmetrics.f1_score(Y_test, Y_pred, average='macro') else: if decision_thres is not None: logger.error( 'decision threshold and ignore_thres cannot be set at the same time.' ) raise ValueError Y_pred = cls.predict_proba(X_test) Y_pred_labels = np.argmax(Y_pred, axis=1) Y_pred_maxes = np.array([x[i] for i, x in zip(Y_pred_labels, Y_pred)]) Y_index_overthres = np.where(Y_pred_maxes >= ignore_thres)[0] Y_index_underthres = np.where(Y_pred_maxes < ignore_thres)[0] Y_pred_overthres = np.array( [cls.classes_[x] for x in Y_pred_labels[Y_index_overthres]]) Y_pred_underthres = np.array( [cls.classes_[x] for x in Y_pred_labels[Y_index_underthres]]) Y_pred_underthres_count = np.array( [np.count_nonzero(Y_pred_underthres == c) for c in label_list]) Y_test_overthres = Y_test[Y_index_overthres] score = skmetrics.accuracy_score(Y_test_overthres, Y_pred_overthres) cm = skmetrics.confusion_matrix(Y_test_overthres, Y_pred_overthres, label_list) cm = np.concatenate((cm, Y_pred_underthres_count[:, np.newaxis]), axis=1) f1 = skmetrics.f1_score(Y_test_overthres, Y_pred_overthres, average='macro') logger.info('Cross-validation %d (%.3f) - %.1f sec' % (cnum, score, timer.sec())) return score, cm, f1
def get_predict_proba(cls, X_train, Y_train, X_test, Y_test, cnum): """ All likelihoods will be collected from every fold of a cross-validaiton. Based on these likelihoods, a threshold will be computed that will balance the true positive rate of each class. Available with binary classification scenario only. """ timer = Timer() cls.fit(X_train, Y_train) Y_pred = cls.predict_proba(X_test) logger.info('Cross-validation %d (%d tests) - %.1f sec' % (cnum, Y_pred.shape[0], timer.sec())) return Y_pred[:, 0]
def sample_decoding(decoder): """ Decoding example Parameters ---------- decoder : The decoder to use """ def get_index_max(seq): if type(seq) == list: return max(range(len(seq)), key=seq.__getitem__) elif type(seq) == dict: return max(seq, key=seq.__getitem__) else: logger.error('Unsupported input %s' % type(seq)) return None # load trigger definitions for labeling labels = decoder.get_label_names() tm_watchdog = Timer(autoreset=True) tm_cls = Timer() while True: praw = decoder.get_prob_unread() psmooth = decoder.get_prob_smooth() if praw is None: # watch dog if tm_cls.sec() > 5: logger.warning( 'No classification was done in the last 5 seconds. Are you receiving data streams?' ) tm_cls.reset() tm_watchdog.sleep_atleast(0.001) continue txt = '[%8.1f msec]' % (tm_cls.msec()) for i, label in enumerate(labels): txt += ' %s %.3f (raw %.3f)' % (label, psmooth[i], praw[i]) maxi = get_index_max(psmooth) txt += ' %s' % labels[maxi] print(txt) tm_cls.reset()
def run(cfg, state=mp.Value('i', 1), queue=None): redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol start if equals to 1 if not state.value: sys.exit() refresh_delay = 1.0 / cfg.REFRESH_RATE cfg.tdef = TriggerDef(cfg.TRIGGER_FILE) # visualizer keys = {'left':81, 'right':83, 'up':82, 'down':84, 'pgup':85, 'pgdn':86, 'home':80, 'end':87, 'space':32, 'esc':27, ',':44, '.':46, 's':115, 'c':99, '[':91, ']':93, '1':49, '!':33, '2':50, '@':64, '3':51, '#':35} color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235), K=(0, 0, 0), w=(200, 200, 200)) dir_sequence = [] for x in range(cfg.TRIALS_EACH): dir_sequence.extend(cfg.DIRECTIONS) random.shuffle(dir_sequence) num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH event = 'start' trial = 1 # Hardware trigger if cfg.TRIGGER_DEVICE is None: logger.warning('No trigger device set. Press Ctrl+C to stop or Enter to continue.') #input() trigger = Trigger(lpttype=cfg.TRIGGER_DEVICE, state=state) if trigger.init(50) == False: logger.error('\n** Error connecting to USB2LPT device. Use a mock trigger instead?') input('Press Ctrl+C to stop or Enter to continue.') trigger = Trigger(lpttype='FAKE') trigger.init(50) # timers timer_trigger = Timer() timer_dir = Timer() timer_refresh = Timer() t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE']) t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(-cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE']) bar = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) bar.fill() bar.glass_draw_cue() # start while trial <= num_trials: timer_refresh.sleep_atleast(refresh_delay) timer_refresh.reset() # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based) if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']: event = 'gap_s' bar.fill() timer_trigger.reset() trigger.signal(cfg.tdef.INIT) elif event == 'gap_s': if cfg.TRIAL_PAUSE: bar.put_text('Press any key') bar.update() key = cv2.waitKey() if key == keys['esc'] or not state.value: break bar.fill() bar.put_text('Trial %d / %d' % (trial, num_trials)) event = 'gap' timer_trigger.reset() elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']: event = 'cue' bar.fill() bar.draw_cue() trigger.signal(cfg.tdef.CUE) timer_trigger.reset() elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']: event = 'dir_r' dir = dir_sequence[trial - 1] if dir == 'L': # left bar.move('L', 100, overlay=True) trigger.signal(cfg.tdef.LEFT_READY) elif dir == 'R': # right bar.move('R', 100, overlay=True) trigger.signal(cfg.tdef.RIGHT_READY) elif dir == 'U': # up bar.move('U', 100, overlay=True) trigger.signal(cfg.tdef.UP_READY) elif dir == 'D': # down bar.move('D', 100, overlay=True) trigger.signal(cfg.tdef.DOWN_READY) elif dir == 'B': # both hands bar.move('L', 100, overlay=True) bar.move('R', 100, overlay=True) trigger.signal(cfg.tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %d' % dir) timer_trigger.reset() elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready: bar.fill() bar.draw_cue() event = 'dir' timer_trigger.reset() timer_dir.reset() if dir == 'L': # left trigger.signal(cfg.tdef.LEFT_GO) elif dir == 'R': # right trigger.signal(cfg.tdef.RIGHT_GO) elif dir == 'U': # up trigger.signal(cfg.tdef.UP_GO) elif dir == 'D': # down trigger.signal(cfg.tdef.DOWN_GO) elif dir == 'B': # both trigger.signal(cfg.tdef.BOTH_GO) else: raise RuntimeError('Unknown direction %d' % dir) elif event == 'dir' and timer_trigger.sec() > t_dir: event = 'gap_s' bar.fill() trial += 1 logger.info('trial ' + str(trial - 1) + ' done') trigger.signal(cfg.tdef.BLANK) timer_trigger.reset() t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE']) t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(-cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE']) # protocol if event == 'dir': dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1) if dir == 'L': # L bar.move('L', dx, overlay=True) elif dir == 'R': # R bar.move('R', dx, overlay=True) elif dir == 'U': # U bar.move('U', dx, overlay=True) elif dir == 'D': # D bar.move('D', dx, overlay=True) elif dir == 'B': # Both bar.move('L', dx, overlay=True) bar.move('R', dx, overlay=True) # wait for start if event == 'start': bar.put_text('Waiting to start') bar.update() key = 0xFF & cv2.waitKey(1) if key == keys['esc'] or not state.value: break bar.finish() with state.get_lock(): state.value = 0
def __init__(self, image_path, use_glass=False, glass_feedback=True, pc_feedback=True, screen_pos=None, screen_size=None): """ Input: use_glass: if False, mock Glass will be used glass_feedback: show feedback to the user? pc_feedback: show feedback on the pc screen? screen_pos: screen position in (x,y) screen_size: screen size in (x,y) """ # screen size and message setting if screen_size is None: if sys.platform.startswith('win'): from win32api import GetSystemMetrics screen_width = GetSystemMetrics(0) screen_height = GetSystemMetrics(1) else: screen_width = 1024 screen_height = 768 screen_size = (screen_width, screen_height) else: screen_width, screen_height = screen_size if screen_pos is None: screen_x, screen_y = (0, 0) else: screen_x, screen_y = screen_pos self.text_size = 2 self.img = np.zeros((screen_height, screen_width, 3), np.uint8) self.glass = bgi_client.GlassControl(mock=not use_glass) self.glass.connect('127.0.0.1', 59900) self.set_glass_feedback(glass_feedback) self.set_pc_feedback(pc_feedback) self.set_cue_color(boxcol='B', crosscol='W') self.width = self.img.shape[1] self.height = self.img.shape[0] hw = int(self.barwidth / 2) self.cx = int(self.width / 2) self.cy = int(self.height / 2) self.xl1 = self.cx - hw self.xl2 = self.xl1 - self.barwidth self.xr1 = self.cx + hw self.xr2 = self.xr1 + self.barwidth self.yl1 = self.cy - hw self.yl2 = self.yl1 - self.barwidth self.yr1 = self.cy + hw self.yr2 = self.yr1 + self.barwidth if os.path.isdir(image_path): # load images left_image_path = '%s/left' % image_path right_image_path = '%s/right' % image_path tm = Timer() logger.info('Reading images from %s' % left_image_path) self.left_images = read_images(left_image_path, screen_size) logger.info('Reading images from %s' % right_image_path) self.right_images = read_images(right_image_path, screen_size) logger.info('Took %.1f s' % tm.sec()) else: # load pickled images # note: this is painfully slow in Pytohn 2 even with cPickle (3s vs 27s) assert image_path[-4:] == '.pkl', 'The file must be of .pkl format' logger.info('Loading image binary file %s ...' % image_path) tm = Timer() with gzip.open(image_path, 'rb') as fp: image_data = pickle.load(fp) self.left_images = image_data['left_images'] self.right_images = image_data['right_images'] feedback_w = self.left_images[0].shape[1] / 2 feedback_h = self.left_images[0].shape[0] / 2 loc_x = [int(self.cx - feedback_w), int(self.cx + feedback_w)] loc_y = [int(self.cy - feedback_h), int(self.cy + feedback_h)] img_fit = np.zeros((screen_height, screen_width, 3), np.uint8) # adjust to the current screen size logger.info('Fitting images into the current screen size') for i, img in enumerate(self.left_images): img_fit = np.zeros((screen_height, screen_width, 3), np.uint8) img_fit[loc_y[0]:loc_y[1], loc_x[0]:loc_x[1]] = img self.left_images[i] = img_fit for i, img in enumerate(self.right_images): img_fit = np.zeros((screen_height, screen_width, 3), np.uint8) img_fit[loc_y[0]:loc_y[1], loc_x[0]:loc_x[1]] = img self.right_images[i] = img_fit logger.info('Took %.1f s' % tm.sec()) logger.info('Done.') cv2.namedWindow("Protocol", cv2.WND_PROP_FULLSCREEN) cv2.moveWindow("Protocol", screen_x, screen_y) cv2.setWindowProperty("Protocol", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
def log_decoding(decoder, logfile, amp_name=None, pklfile=True, matfile=False, autostop=False, prob_smooth=False): """ Decode online and write results with event timestamps Parameters ---------- decoder : BCIDecoder or BCIDecoderDaemon class The decoder to use logfile : str The file path to contain the result in Python pickle format amp_name : str The stream name to connect to pklfile : bool If True, export the results to Python pickle format matfile : bool If True, export the results to .mat file autostop : bool If True, automatically finish when no more data is received. prob_smooth : bool If True, use smoothed probability values according to decoder's smoothing parameter. """ import cv2 import scipy # run event acquisition process in the background state = mp.Value('i', 1) event_queue = mp.Queue() proc = mp.Process(target=_log_decoding_helper, args=[state, event_queue, amp_name, autostop]) proc.start() logger.info('Spawned event acquisition process.') # init variables and choose decoding function labels = decoder.get_label_names() probs = [] prob_times = [] if prob_smooth: decode_fn = decoder.get_prob_smooth_unread else: decode_fn = decoder.get_prob_unread # simple controller UI cv2.namedWindow("Decoding", cv2.WINDOW_AUTOSIZE) cv2.moveWindow("Decoding", 1400, 50) img = np.zeros([100, 400, 3], np.uint8) cv2.putText(img, 'Press any key to start', (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) cv2.imshow("Decoding", img) cv2.waitKeyEx() img *= 0 cv2.putText(img, 'Press ESC to stop', (40, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) cv2.imshow("Decoding", img) key = 0 started = False tm_watchdog = Timer(autoreset=True) tm_cls = Timer() while key != 27: prob, prob_time = decode_fn(True) t_lsl = pylsl.local_clock() key = cv2.waitKeyEx(1) if prob is None: # watch dog if tm_cls.sec() > 5: if autostop and started: logger.info('No more streaming data. Finishing.') break tm_cls.reset() tm_watchdog.sleep_atleast(0.001) continue probs.append(prob) prob_times.append(prob_time) txt = '[%.3f] ' % prob_time txt += ', '.join(['%s: %.2f' % (l, p) for l, p in zip(labels, prob)]) txt += ' (%d ms, LSL Diff = %.3f)' % (tm_cls.msec(), (t_lsl - prob_time)) logger.info(txt) if not started: started = True tm_cls.reset() # finish up processes cv2.destroyAllWindows() logger.info('Cleaning up event acquisition process.') state.value = 0 decoder.stop() event_times, event_values = event_queue.get() proc.join() # save values if len(prob_times) == 0: logger.error('No decoding result. Please debug.') import pdb pdb.set_trace() t_start = prob_times[0] probs = np.vstack(probs) event_times = np.array(event_times) event_times = event_times[np.where(event_times >= t_start)[0]] - t_start prob_times = np.array(prob_times) - t_start event_values = np.array(event_values) data = dict(probs=probs, prob_times=prob_times, event_times=event_times, event_values=event_values, labels=labels) if pklfile: with open(logfile, 'wb') as f: pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) logger.info('Saved to %s' % logfile) if matfile: pp = io.parse_path(logfile) pp = Path(logfile) matout = '%s/%s.mat' % (pp.parent, pp.stem) scipy.io.savemat(matout, data) logger.info('Saved to %s' % matout)
class Feedback: """ Perform a classification with visual feedback """ def __init__(self, cfg, viz, tdef, trigger, logfile=None): self.cfg = cfg self.tdef = tdef self.trigger = trigger self.viz = viz self.viz.fill() self.refresh_delay = 1.0 / self.cfg.REFRESH_RATE self.bar_step_left = self.cfg.BAR_STEP['left'] self.bar_step_right = self.cfg.BAR_STEP['right'] self.bar_step_up = self.cfg.BAR_STEP['up'] self.bar_step_down = self.cfg.BAR_STEP['down'] self.bar_step_both = self.cfg.BAR_STEP['both'] if type(self.cfg.BAR_BIAS) is tuple: self.bar_bias = list(self.cfg.BAR_BIAS) else: self.bar_bias = self.cfg.BAR_BIAS # New decoder: already smoothed by the decoder so bias after. #self.alpha_old = self.cfg.PROB_ACC_ALPHA #self.alpha_new = 1.0 - self.cfg.PROB_ACC_ALPHA if hasattr(self.cfg, 'BAR_REACH_FINISH') and self.cfg.BAR_REACH_FINISH == True: self.premature_end = True else: self.premature_end = False self.tm_trigger = Timer() self.tm_display = Timer() self.tm_watchdog = Timer() if logfile is not None: self.logf = open(logfile, 'w') else: self.logf = None # STIMO only if self.cfg.WITH_STIMO is True: if self.cfg.STIMO_COMPORT is None: atens = [x for x in serial.tools.list_ports.grep('ATEN')] if len(atens) == 0: raise RuntimeError('No ATEN device found. Stop.') try: self.stimo_port = atens[0].device except AttributeError: # depends on Python distribution self.stimo_port = atens[0][0] else: self.stimo_port = self.cfg.STIMO_COMPORT self.ser = serial.Serial(self.stimo_port, self.cfg.STIMO_BAUDRATE) logger.info('STIMO serial port %s is_open = %s' % (self.stimo_port, self.ser.is_open)) # FES only if self.cfg.WITH_FES is True: self.stim = fes.Motionstim8() self.stim.OpenSerialPort(self.cfg.FES_COMPORT) self.stim.InitializeChannelListMode() logger.info('Opened FES serial port') def __del__(self): # STIMO only if self.cfg.WITH_STIMO is True: self.ser.close() logger.info('Closed STIMO serial port %s' % self.stimo_port) # FES only if self.cfg.WITH_FES is True: stim_code = [0, 0, 0, 0, 0, 0, 0, 0] self.stim.UpdateChannelSettings(stim_code) self.stim.CloseSerialPort() logger.info('Closed FES serial port') def classify(self, decoder, true_label, title_text, bar_dirs, state='start', prob_history=None): """ Run a single trial """ def list2string(vec, fmt, sep=' '): return sep.join(fmt % x for x in vec) true_label_index = bar_dirs.index(true_label) self.tm_trigger.reset() if self.bar_bias is not None: bias_idx = bar_dirs.index(self.bar_bias[0]) if self.logf is not None: self.logf.write('True label: %s\n' % true_label) tm_classify = Timer(autoreset=True) self.stimo_timer = Timer() while True: self.tm_display.sleep_atleast(self.refresh_delay) self.tm_display.reset() if state == 'start' and self.tm_trigger.sec( ) > self.cfg.TIMINGS['INIT']: state = 'gap_s' if self.cfg.TRIALS_PAUSE: self.viz.put_text('Press any key') self.viz.update() key = cv2.waitKeyEx() if key == KEYS['esc']: return self.viz.fill() self.tm_trigger.reset() self.trigger.signal(self.tdef.INIT) elif state == 'gap_s': if self.cfg.TIMINGS['GAP'] > 0: self.viz.put_text(title_text) state = 'gap' self.tm_trigger.reset() elif state == 'gap' and self.tm_trigger.sec( ) > self.cfg.TIMINGS['GAP']: state = 'cue' self.viz.fill() self.viz.draw_cue() self.viz.glass_draw_cue() self.trigger.signal(self.tdef.CUE) self.tm_trigger.reset() elif state == 'cue' and self.tm_trigger.sec( ) > self.cfg.TIMINGS['READY']: state = 'dir_r' if self.cfg.SHOW_CUE is True: if self.cfg.FEEDBACK_TYPE == 'BAR': self.viz.move(true_label, 100, overlay=False, barcolor='G') elif self.cfg.FEEDBACK_TYPE == 'BODY': self.viz.put_text(DIRS[true_label], 'R') if true_label == 'L': # left self.trigger.signal(self.tdef.LEFT_READY) elif true_label == 'R': # right self.trigger.signal(self.tdef.RIGHT_READY) elif true_label == 'U': # up self.trigger.signal(self.tdef.UP_READY) elif true_label == 'D': # down self.trigger.signal(self.tdef.DOWN_READY) elif true_label == 'B': # both hands self.trigger.signal(self.tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %s' % true_label) self.tm_trigger.reset() ''' if self.cfg.FEEDBACK_TYPE == 'BODY': self.viz.set_pc_feedback(False) self.viz.move(true_label, 100, overlay=False, barcolor='G') if self.cfg.FEEDBACK_TYPE == 'BODY': self.viz.set_pc_feedback(True) if self.cfg.SHOW_CUE is True: self.viz.put_text(dirs[true_label], 'R') if true_label == 'L': # left self.trigger.signal(self.tdef.LEFREADY) elif true_label == 'R': # right self.trigger.signal(self.tdef.RIGHT_READY) elif true_label == 'U': # up self.trigger.signal(self.tdef.UP_READY) elif true_label == 'D': # down self.trigger.signal(self.tdef.DOWN_READY) elif true_label == 'B': # both hands self.trigger.signal(self.tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %s' % true_label) self.tm_trigger.reset() ''' elif state == 'dir_r' and self.tm_trigger.sec( ) > self.cfg.TIMINGS['DIR_CUE']: self.viz.fill() self.viz.draw_cue() self.viz.glass_draw_cue() state = 'dir' # initialize bar scores bar_label = bar_dirs[0] bar_score = 0 probs = [1.0 / len(bar_dirs)] * len(bar_dirs) self.viz.move(bar_label, bar_score, overlay=False) probs_acc = np.zeros(len(probs)) if true_label == 'L': # left self.trigger.signal(self.tdef.LEFT_GO) elif true_label == 'R': # right self.trigger.signal(self.tdef.RIGHT_GO) elif true_label == 'U': # up self.trigger.signal(self.tdef.UP_GO) elif true_label == 'D': # down self.trigger.signal(self.tdef.DOWN_GO) elif true_label == 'B': # both self.trigger.signal(self.tdef.BOTH_GO) else: raise RuntimeError('Unknown truedirection %s' % true_label) self.tm_watchdog.reset() self.tm_trigger.reset() elif state == 'dir': if self.tm_trigger.sec() > self.cfg.TIMINGS['CLASSIFY'] or ( self.premature_end and bar_score >= 100): if not hasattr( self.cfg, 'SHOW_RESULT') or self.cfg.SHOW_RESULT is True: # show classfication result if self.cfg.WITH_STIMO is True: if self.cfg.STIMO_FULLGAIT_CYCLE is not None and bar_label == 'U': res_color = 'G' elif self.cfg.TRIALS_RETRY is False or bar_label == true_label: res_color = 'G' else: res_color = 'Y' else: res_color = 'Y' if self.cfg.FEEDBACK_TYPE == 'BODY': self.viz.move(bar_label, bar_score, overlay=False, barcolor=res_color, caption=DIRS[bar_label], caption_color=res_color) else: self.viz.move(bar_label, 100, overlay=False, barcolor=res_color) else: if self.cfg.FEEDBACK_TYPE == 'BODY': self.viz.move(bar_label, bar_score, overlay=False, barcolor=res_color, caption='TRIAL END', caption_color=res_color) else: self.viz.move(bar_label, 0, overlay=False, barcolor=res_color) self.trigger.signal(self.tdef.FEEDBACK) # STIMO if self.cfg.WITH_STIMO is True and self.cfg.STIMO_CONTINUOUS is False: if self.cfg.STIMO_FULLGAIT_CYCLE is not None: if bar_label == 'U': self.ser.write( self.cfg.STIMO_FULLGAIT_PATTERN[0]) logger.info('STIMO: Sent 1') time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE) self.ser.write( self.cfg.STIMO_FULLGAIT_PATTERN[1]) logger.info('STIMO: Sent 2') time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE) elif self.cfg.TRIALS_RETRY is False or bar_label == true_label: if bar_label == 'L': self.ser.write(b'1') logger.info('STIMO: Sent 1') elif bar_label == 'R': self.ser.write(b'2') logger.info('STIMO: Sent 2') # FES event mode mode if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is False: if bar_label == 'L': stim_code = [0, 30, 0, 0, 0, 0, 0, 0] self.stim.UpdateChannelSettings(stim_code) logger.info('FES: Sent Left') time.sleep(0.5) stim_code = [0, 0, 0, 0, 0, 0, 0, 0] self.stim.UpdateChannelSettings(stim_code) elif bar_label == 'R': stim_code = [30, 0, 0, 0, 0, 0, 0, 0] self.stim.UpdateChannelSettings(stim_code) time.sleep(0.5) logger.info('FES: Sent Right') stim_code = [0, 0, 0, 0, 0, 0, 0, 0] self.stim.UpdateChannelSettings(stim_code) if self.cfg.DEBUG_PROBS: msg = 'DEBUG: Accumulated probabilities = %s' % list2string( probs_acc, '%.3f') logger.info(msg) if self.logf is not None: self.logf.write(msg + '\n') if self.logf is not None: self.logf.write('%s detected as %s (%d)\n\n' % (true_label, bar_label, bar_score)) self.logf.flush() # end of trial state = 'feedback' self.tm_trigger.reset() else: # classify probs_new = decoder.get_prob_smooth_unread() if probs_new is None: if self.tm_watchdog.sec() > 3: logger.warning( 'No classification being done. Are you receiving data streams?' ) self.tm_watchdog.reset() else: self.tm_watchdog.reset() if prob_history is not None: prob_history[true_label].append( probs_new[true_label_index]) probs_acc += np.array(probs_new) ''' New decoder: already smoothed by the decoder so bias after. ''' probs = list(probs_new) if self.bar_bias is not None: probs[bias_idx] += self.bar_bias[1] newsum = sum(probs) probs = [p / newsum for p in probs] ''' # Method 2: bias and smoothen if self.bar_bias is not None: # print('BEFORE: %.3f %.3f'% (probs_new[0], probs_new[1]) ) probs_new[bias_idx] += self.bar_bias[1] newsum = sum(probs_new) probs_new = [p / newsum for p in probs_new] # print('AFTER: %.3f %.3f'% (probs_new[0], probs_new[1]) ) for i in range(len(probs_new)): probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new ''' ''' Original method # Method 1: smoothen and bias for i in range( len(probs_new) ): probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new # bias bar if self.bar_bias is not None: probs[bias_idx] += self.bar_bias[1] newsum = sum(probs) probs = [p/newsum for p in probs] ''' # determine the direction # TODO: np.argmax(probs) max_pidx = np.argmax(probs) max_label = bar_dirs[max_pidx] if self.cfg.POSITIVE_FEEDBACK is False or \ (self.cfg.POSITIVE_FEEDBACK and true_label == max_label): dx = probs[max_pidx] if max_label == 'R': dx *= self.bar_step_right elif max_label == 'L': dx *= self.bar_step_left elif max_label == 'U': dx *= self.bar_step_up elif max_label == 'D': dx *= self.bar_step_down elif max_label == 'B': dx *= self.bar_step_both else: logger.debug('Direction %s using bar step %d' % (max_label, self.bar_step_left)) dx *= self.bar_step_left # slow start selected = self.cfg.BAR_SLOW_START['selected'] if self.cfg.BAR_SLOW_START[ selected] and self.tm_trigger.sec( ) < self.cfg.BAR_SLOW_START[selected]: dx *= self.tm_trigger.sec( ) / self.cfg.BAR_SLOW_START[selected][0] # add likelihoods if max_label == bar_label: bar_score += dx else: bar_score -= dx # change of direction if bar_score < 0: bar_score = -bar_score bar_label = max_label bar_score = int(bar_score) if bar_score > 100: bar_score = 100 if self.cfg.FEEDBACK_TYPE == 'BODY': if self.cfg.SHOW_CUE: self.viz.move(bar_label, bar_score, overlay=False, caption=DIRS[true_label], caption_color='G') else: self.viz.move(bar_label, bar_score, overlay=False) else: self.viz.move(bar_label, bar_score, overlay=False) # send the confidence value continuously if self.cfg.WITH_STIMO and self.cfg.STIMO_CONTINUOUS: if self.stimo_timer.sec( ) >= self.cfg.STIMO_COOLOFF: if bar_label == 'U': stimo_code = bar_score else: stimo_code = 0 self.ser.write(bytes([stimo_code])) logger.info('Sent STIMO code %d' % stimo_code) self.stimo_timer.reset() # with FES if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is True: if self.stimo_timer.sec( ) >= self.cfg.STIMO_COOLOFF: if bar_label == 'L': stim_code = [ bar_score, 0, 0, 0, 0, 0, 0, 0 ] else: stim_code = [ 0, bar_score, 0, 0, 0, 0, 0, 0 ] self.stim.UpdateChannelSettings(stim_code) logger.info('Sent FES code %d' % bar_score) self.stimo_timer.reset() if self.cfg.DEBUG_PROBS: if self.bar_bias is not None: biastxt = '[Bias=%s%.3f] ' % ( self.bar_bias[0], self.bar_bias[1]) else: biastxt = '' msg = '%s%s prob %s acc %s bar %s%d (%.1f ms)' % \ (biastxt, bar_dirs, list2string(probs_new, '%.2f'), list2string(probs, '%.2f'), bar_label, bar_score, tm_classify.msec()) logger.info(msg) if self.logf is not None: self.logf.write(msg + '\n') elif state == 'feedback' and self.tm_trigger.sec( ) > self.cfg.TIMINGS['FEEDBACK']: self.trigger.signal(self.tdef.BLANK) if self.cfg.FEEDBACK_TYPE == 'BODY': state = 'return' self.tm_trigger.reset() else: state = 'gap_s' self.viz.fill() self.viz.update() return bar_label elif state == 'return': self.viz.set_glass_feedback(False) if self.cfg.WITH_STIMO: self.viz.move(bar_label, bar_score, overlay=False, barcolor='B') else: self.viz.move(bar_label, bar_score, overlay=False, barcolor='Y') self.viz.set_glass_feedback(True) bar_score -= 5 if bar_score <= 0: state = 'gap_s' self.viz.fill() self.viz.update() return bar_label self.viz.update() key = cv2.waitKeyEx(1) if key == KEYS['esc']: return elif key == KEYS['space']: dx = 0 bar_score = 0 probs = [1.0 / len(bar_dirs)] * len(bar_dirs) self.viz.move(bar_dirs[0], bar_score, overlay=False) self.viz.update() logger.info('probs and dx reset.') self.tm_trigger.reset() elif key in ARROW_KEYS and ARROW_KEYS[key] in bar_dirs: # change bias on the fly if self.bar_bias is None: self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT] else: if ARROW_KEYS[key] == self.bar_bias[0]: self.bar_bias[1] += BIAS_INCREMENT elif self.bar_bias[1] >= BIAS_INCREMENT: self.bar_bias[1] -= BIAS_INCREMENT else: self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT] if self.bar_bias[1] == 0: self.bar_bias = None else: bias_idx = bar_dirs.index(self.bar_bias[0])
def train_decoder(cfg, featdata, feat_file=None): """ Train the final decoder using all data """ def sort_by_value(s, reverse=False): assert type(s) == dict or type( s) == list, 'Input must be a dictionary or list.' if type(s) == list: s = dict(enumerate(s)) s_rev = dict((v, k) for k, v in s.items()) if not len(s_rev) == len(s): logger.warning('sort_by_value(): %d identical values' % (len(s.values()) - len(set(s.values())) + 1)) values = sorted(s_rev, reverse=reverse) keys = [s_rev[x] for x in values] return keys, values # Init a classifier selected_classifier = cfg.CLASSIFIER['selected'] cls_params = cfg.CLASSIFIER[selected_classifier] if selected_classifier == 'GB': cls = GradientBoostingClassifier( loss='deviance', learning_rate=cls_params['learning_rate'], n_estimators=cls_params['trees'], subsample=1.0, max_depth=cls_params['depth'], random_state=cls_params['seed'], max_features='sqrt', verbose=0, warm_start=False, presort='auto') elif selected_classifier == 'XGB': cls = XGBClassifier(loss='deviance', learning_rate=cls_params['learning_rate'], n_estimators=cls_params['trees'], subsample=1.0, max_depth=cls_params['depth'], random_state=cfg.GB['seed'], max_features='sqrt', verbose=0, warm_start=False, presort='auto') elif selected_classifier == 'RF': cls = RandomForestClassifier(n_estimators=cls_params['trees'], max_features='auto', max_depth=cls_params['depth'], n_jobs=cfg.N_JOBS, random_state=cls_params['seed'], oob_score=False, class_weight='balanced_subsample') elif selected_classifier == 'LDA': cls = LDA(solver=cls_params['solver'], shrinkage=cls_params['shrinkage']) elif selected_classifier == 'rLDA': cls = rLDA(cls_params["r_coeff"]) else: logger.error('Unknown classifier %s' % selected_classifier) raise ValueError # Setup features X_data = featdata['X_data'] Y_data = featdata['Y_data'] wlen = featdata['wlen'] if cfg.FEATURES['PSD']['wlen'] is None: cfg.FEATURES['PSD']['wlen'] = wlen w_frames = featdata['w_frames'] ch_names = featdata['ch_names'] X_data_merged = np.concatenate(X_data) Y_data_merged = np.concatenate(Y_data) if cfg.CV['BALANCE_SAMPLES']: X_data_merged, Y_data_merged = balance_samples( X_data_merged, Y_data_merged, cfg.CV['BALANCE_SAMPLES'], verbose=True) # Start training the decoder logger.info('Training the decoder') timer = Timer() cls.n_jobs = cfg.N_JOBS cls.fit(X_data_merged, Y_data_merged) logger.info('Trained %d samples x %d dimension in %.1f sec' %\ (X_data_merged.shape[0], X_data_merged.shape[1], timer.sec())) cls.n_jobs = 1 # always set n_jobs=1 for testing # Export the decoder classes = {c: cfg.tdef.by_value[c] for c in np.unique(Y_data)} if cfg.FEATURES['selected'] == 'PSD': data = dict(cls=cls, ch_names=ch_names, psde=featdata['psde'], sfreq=featdata['sfreq'], picks=featdata['picks'], classes=classes, epochs=cfg.EPOCH, w_frames=w_frames, w_seconds=cfg.FEATURES['PSD']['wlen'], wstep=cfg.FEATURES['PSD']['wstep'], spatial=cfg.SP_FILTER, spatial_ch=cfg.SP_CHANNELS, spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']], spectral_ch=cfg.TP_CHANNELS, notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']], notch_ch=cfg.NOTCH_CHANNELS, multiplier=cfg.MULTIPLIER, ref_ch=cfg.REREFERENCE[cfg.REREFERENCE['selected']], decim=cfg.FEATURES['PSD']['decim']) if cfg.SAVE_FEATURES: data["SAVED_FEAT"] = dict(X=X_data_merged, Y=Y_data_merged) clsfile = '%s/classifier/classifier-%s.pkl' % (cfg.DATA_PATH, platform.architecture()[0]) make_dirs('%s/classifier' % cfg.DATA_PATH) with open(clsfile, 'wb') as f: pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) logger.info('Decoder saved to %s' % clsfile) # Reverse-lookup frequency from FFT fq = 0 if type(cfg.FEATURES['PSD']['wlen']) == list: fq_res = 1.0 / cfg.FEATURES['PSD']['wlen'][0] else: fq_res = 1.0 / cfg.FEATURES['PSD']['wlen'] fqlist = [] while fq <= cfg.FEATURES['PSD']['fmax']: if fq >= cfg.FEATURES['PSD']['fmin']: fqlist.append(fq) fq += fq_res # Show top distinctive features if cfg.FEATURES['selected'] == 'PSD': logger.info('Good features ordered by importance') if selected_classifier in ['RF', 'GB', 'XGB']: keys, values = sort_by_value(list(cls.feature_importances_), reverse=True) elif selected_classifier in ['LDA', 'rLDA']: keys, values = sort_by_value(cls.coef_.reshape(-1).tolist(), reverse=True) keys = np.array(keys) values = np.array(values) if cfg.EXPORT_GOOD_FEATURES: if feat_file is None: gfout = open('%s/classifier/good_features.txt' % cfg.DATA_PATH, 'w') else: gfout = open(feat_file, 'w') if type(wlen) is not list: ch_names = [ch_names[c] for c in featdata['picks']] else: ch_names = [] for w in range(len(wlen)): for c in featdata['picks']: ch_names.append('w%d-%s' % (w, ch_names[c])) chlist, hzlist = features.feature2chz(keys, fqlist, ch_names=ch_names) valnorm = values[:cfg.FEAT_TOPN].copy() valsum = np.sum(valnorm) if valsum == 0: valsum = 1 valnorm = valnorm / valsum * 100.0 # show top-N features for i, (ch, hz) in enumerate(zip(chlist, hzlist)): if i >= cfg.FEAT_TOPN: break txt = '%-3s %5.1f Hz normalized importance %-6s raw importance %-6s feature %-5d' %\ (ch, hz, '%.2f%%' % valnorm[i], '%.2f%%' % (values[i] * 100.0), keys[i]) logger.info(txt) if cfg.EXPORT_GOOD_FEATURES: gfout.write('Importance(%) Channel Frequency Index\n') for i, (ch, hz) in enumerate(zip(chlist, hzlist)): gfout.write('%.3f\t%s\t%s\t%d\n' % (values[i] * 100.0, ch, hz, keys[i])) gfout.close()
def cross_validate(cfg, featdata, cv_file=None): """ Perform cross validation """ # Init a classifier selected_classifier = cfg.CLASSIFIER['selected'] cls_params = cfg.CLASSIFIER[selected_classifier] if selected_classifier == 'GB': cls = GradientBoostingClassifier( loss='deviance', learning_rate=cls_params['learning_rate'], presort='auto', n_estimators=cls_params['trees'], subsample=1.0, max_depth=cls_params['depth'], random_state=cls_params['seed'], max_features='sqrt', verbose=0, warm_start=False) elif selected_classifier == 'XGB': cls = XGBClassifier(loss='deviance', learning_rate=cls_params['learning_rate'], presort='auto', n_estimators=cls_params['trees'], subsample=1.0, max_depth=cls_params['depth'], random_state=cls_params, max_features='sqrt', verbose=0, warm_start=False) elif selected_classifier == 'RF': cls = RandomForestClassifier(n_estimators=cls_params['trees'], max_features='auto', max_depth=cls_params['depth'], n_jobs=cfg.N_JOBS, random_state=cls_params['seed'], oob_score=False, class_weight='balanced_subsample') elif selected_classifier == 'LDA': cls = LDA(solver=cls_params['solver'], shrinkage=cls_params['shrinkage']) elif selected_classifier == 'rLDA': cls = rLDA(cls_params['r_coeff']) else: logger.error('Unknown classifier type %s' % selected_classifier) raise ValueError # Setup features X_data = featdata['X_data'] Y_data = featdata['Y_data'] wlen = featdata['wlen'] # Choose CV type ntrials, nsamples, fsize = X_data.shape selected_cv = cfg.CV_PERFORM['selected'] if selected_cv == 'LeaveOneOut': logger.info('%d-fold leave-one-out cross-validation' % ntrials) if SKLEARN_OLD: cv = LeaveOneOut(len(Y_data)) else: cv = LeaveOneOut() elif selected_cv == 'StratifiedShuffleSplit': logger.info( '%d-fold stratified cross-validation with test set ratio %.2f' % (cfg.CV_PERFORM[selected_cv]['folds'], cfg.CV_PERFORM[selected_cv]['test_ratio'])) if SKLEARN_OLD: cv = StratifiedShuffleSplit( Y_data[:, 0], cfg.CV_PERFORM[selected_cv]['folds'], test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'], random_state=cfg.CV_PERFORM[selected_cv]['seed']) else: cv = StratifiedShuffleSplit( n_splits=cfg.CV_PERFORM[selected_cv]['folds'], test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'], random_state=cfg.CV_PERFORM[selected_cv]['seed']) else: logger.error('%s is not supported yet. Sorry.' % cfg.CV_PERFORM[cfg.CV_PERFORM['selected']]) raise NotImplementedError logger.info('%d trials, %d samples per trial, %d feature dimension' % (ntrials, nsamples, fsize)) # Do it! timer_cv = Timer() scores, cm_txt = crossval_epochs(cv, X_data, Y_data, cls, cfg.tdef.by_value, cfg.CV['BALANCE_SAMPLES'], n_jobs=cfg.N_JOBS, ignore_thres=cfg.CV['IGNORE_THRES'], decision_thres=cfg.CV['DECISION_THRES']) t_cv = timer_cv.sec() # Export results txt = 'Cross validation took %d seconds.\n' % t_cv txt += '\n- Class information\n' txt += '%d epochs, %d samples per epoch, %d feature dimension (total %d samples)\n' %\ (ntrials, nsamples, fsize, ntrials * nsamples) for ev in np.unique(Y_data): txt += '%s: %d trials\n' % (cfg.tdef.by_value[ev], len(np.where(Y_data[:, 0] == ev)[0])) if cfg.CV['BALANCE_SAMPLES']: txt += 'The number of samples was balanced using %ssampling.\n' % cfg.BALANCE_SAMPLES.lower( ) txt += '\n- Experiment condition\n' txt += 'Sampling frequency: %.3f Hz\n' % featdata['sfreq'] txt += 'Spatial filter: %s (channels: %s)\n' % (cfg.SP_FILTER, cfg.SP_CHANNELS) txt += 'Spectral filter: %s (channels: %s)\n' % ( cfg.TP_FILTER[cfg.TP_FILTER['selected']], cfg.TP_CHANNELS) txt += 'Notch filter: %s (channels: %s)\n' % ( cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']], cfg.NOTCH_CHANNELS) txt += 'PSD Channels: ' + ','.join( [str(featdata['ch_names'][p]) for p in featdata['picks']]) + '\n' txt += 'PSD range: %.1f - %.1f Hz\n' % (cfg.FEATURES['PSD']['fmin'], cfg.FEATURES['PSD']['fmax']) txt += 'Window step: %.2f msec\n' % ( 1000.0 * cfg.FEATURES['PSD']['wstep'] / featdata['sfreq']) if type(wlen) is list: for i, w in enumerate(wlen): txt += 'Window size: %.1f msec\n' % (w * 1000.0) txt += 'Epoch range: %s sec\n' % (cfg.EPOCH[i]) else: txt += 'Window size: %.1f msec\n' % (cfg.FEATURES['PSD']['wlen'] * 1000.0) txt += 'Epoch range: %s sec\n' % (cfg.EPOCH) txt += 'Decimation factor: %d\n' % cfg.FEATURES['PSD']['decim'] # Compute stats cv_mean, cv_std = np.mean(scores), np.std(scores) txt += '\n- Average CV accuracy over %d epochs (random seed=%s)\n' % ( ntrials, cfg.CV_PERFORM[cfg.CV_PERFORM['selected']]['seed']) if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] in [ 'LeaveOneOut', 'StratifiedShuffleSplit' ]: txt += "mean %.3f, std: %.3f\n" % (cv_mean, cv_std) txt += 'Classifier: %s, ' % selected_classifier if selected_classifier == 'RF': txt += '%d trees, %s max depth, random state %s\n' % ( cfg.CLASSIFIER['RF']['trees'], cfg.CLASSIFIER['RF']['depth'], cfg.CLASSIFIER['RF']['seed']) elif selected_classifier == 'GB' or selected_classifier == 'XGB': txt += '%d trees, %s max depth, %s learing_rate, random state %s\n' % ( cfg.CLASSIFIER['GB']['trees'], cfg.CLASSIFIER['GB']['depth'], cfg.CLASSIFIER['GB']['learning_rate'], cfg.CLASSIFIER['GB']['seed']) elif selected_classifier == 'rLDA': txt += 'regularization coefficient %.2f\n' % cfg.CLASSIFIER['rLDA'][ 'r_coeff'] if cfg.CV['IGNORE_THRES'] is not None: txt += 'Decision threshold: %.2f\n' % cfg.CV['IGNORE_THRES'] txt += '\n- Confusion Matrix\n' + cm_txt logger.info(txt) # Export to a file if 'export_result' in cfg.CV_PERFORM[selected_cv] and cfg.CV_PERFORM[ selected_cv]['export_result'] is True: if cv_file is None: if cfg.EXPORT_CLS is True: make_dirs('%s/classifier' % cfg.DATA_PATH) fout = open('%s/classifier/cv_result.txt' % cfg.DATA_PATH, 'w') else: fout = open('%s/cv_result.txt' % cfg.DATA_PATH, 'w') else: fout = open(cv_file, 'w') fout.write(txt) fout.close()