示例#1
0
def check_speed(decoder, max_count=float('inf')):
    """
    Test decoding speed accross several classifications.

    Parameters
    ----------
    decoder : BCIDecoder or BCIDecoderDaemon class
        The decoder to assess its performance
    max_count : int
        The number of classification for averaging
    """
    tm = Timer()
    count = 0
    mslist = []
    while count < max_count:
        while decoder.get_prob_unread() is None:
            pass
        count += 1
        if tm.sec() > 1:
            t = tm.sec()
            ms = 1000.0 * t / count
            # show time per classification and its reciprocal
            print('%.0f ms/c   %.1f Hz' % (ms, count / t))
            mslist.append(ms)
            count = 0
            tm.reset()
    print('mean = %.1f ms' % np.mean(mslist))
示例#2
0
def fit_predict_thres(cls,
                      X_train,
                      Y_train,
                      X_test,
                      Y_test,
                      cnum,
                      label_list,
                      ignore_thres=None,
                      decision_thres=None):
    """
    Any likelihood lower than a threshold is not counted as classification score
    Confusion matrix, accuracy and F1 score (macro average) are computed.

    Params
    ======
    ignore_thres:
    if not None or larger than 0, likelihood values lower than ignore_thres will be ignored
    while computing confusion matrix.

    """
    timer = Timer()
    cls.fit(X_train, Y_train)
    assert ignore_thres is None or ignore_thres >= 0
    if ignore_thres is None or ignore_thres == 0:
        Y_pred = cls.predict(X_test)
        score = skmetrics.accuracy_score(Y_test, Y_pred)
        cm = skmetrics.confusion_matrix(Y_test, Y_pred, label_list)
        f1 = skmetrics.f1_score(Y_test, Y_pred, average='macro')
    else:
        if decision_thres is not None:
            logger.error(
                'decision threshold and ignore_thres cannot be set at the same time.'
            )
            raise ValueError
        Y_pred = cls.predict_proba(X_test)
        Y_pred_labels = np.argmax(Y_pred, axis=1)
        Y_pred_maxes = np.array([x[i] for i, x in zip(Y_pred_labels, Y_pred)])
        Y_index_overthres = np.where(Y_pred_maxes >= ignore_thres)[0]
        Y_index_underthres = np.where(Y_pred_maxes < ignore_thres)[0]
        Y_pred_overthres = np.array(
            [cls.classes_[x] for x in Y_pred_labels[Y_index_overthres]])
        Y_pred_underthres = np.array(
            [cls.classes_[x] for x in Y_pred_labels[Y_index_underthres]])
        Y_pred_underthres_count = np.array(
            [np.count_nonzero(Y_pred_underthres == c) for c in label_list])
        Y_test_overthres = Y_test[Y_index_overthres]
        score = skmetrics.accuracy_score(Y_test_overthres, Y_pred_overthres)
        cm = skmetrics.confusion_matrix(Y_test_overthres, Y_pred_overthres,
                                        label_list)
        cm = np.concatenate((cm, Y_pred_underthres_count[:, np.newaxis]),
                            axis=1)
        f1 = skmetrics.f1_score(Y_test_overthres,
                                Y_pred_overthres,
                                average='macro')

    logger.info('Cross-validation %d (%.3f) - %.1f sec' %
                (cnum, score, timer.sec()))
    return score, cm, f1
示例#3
0
def get_predict_proba(cls, X_train, Y_train, X_test, Y_test, cnum):
    """
    All likelihoods will be collected from every fold of a cross-validaiton. Based on these likelihoods,
    a threshold will be computed that will balance the true positive rate of each class.
    Available with binary classification scenario only.
    """
    timer = Timer()
    cls.fit(X_train, Y_train)
    Y_pred = cls.predict_proba(X_test)
    logger.info('Cross-validation %d (%d tests) - %.1f sec' %
                (cnum, Y_pred.shape[0], timer.sec()))
    return Y_pred[:, 0]
示例#4
0
def sample_decoding(decoder):
    """
    Decoding example

    Parameters
    ----------
    decoder : The decoder to use
    """
    def get_index_max(seq):
        if type(seq) == list:
            return max(range(len(seq)), key=seq.__getitem__)
        elif type(seq) == dict:
            return max(seq, key=seq.__getitem__)
        else:
            logger.error('Unsupported input %s' % type(seq))
            return None

    # load trigger definitions for labeling
    labels = decoder.get_label_names()
    tm_watchdog = Timer(autoreset=True)
    tm_cls = Timer()
    while True:
        praw = decoder.get_prob_unread()
        psmooth = decoder.get_prob_smooth()
        if praw is None:
            # watch dog
            if tm_cls.sec() > 5:
                logger.warning(
                    'No classification was done in the last 5 seconds. Are you receiving data streams?'
                )
                tm_cls.reset()
            tm_watchdog.sleep_atleast(0.001)
            continue

        txt = '[%8.1f msec]' % (tm_cls.msec())
        for i, label in enumerate(labels):
            txt += '   %s %.3f (raw %.3f)' % (label, psmooth[i], praw[i])
        maxi = get_index_max(psmooth)
        txt += '   %s' % labels[maxi]
        print(txt)
        tm_cls.reset()
示例#5
0
def run(cfg, state=mp.Value('i', 1), queue=None):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2: # 0: stop, 1:start, 2:wait
        pass
    #  Protocol start if equals to 1
    if not state.value:
        sys.exit()

    refresh_delay = 1.0 / cfg.REFRESH_RATE

    cfg.tdef = TriggerDef(cfg.TRIGGER_FILE)

    # visualizer
    keys = {'left':81, 'right':83, 'up':82, 'down':84, 'pgup':85, 'pgdn':86,
        'home':80, 'end':87, 'space':32, 'esc':27, ',':44, '.':46, 's':115, 'c':99,
        '[':91, ']':93, '1':49, '!':33, '2':50, '@':64, '3':51, '#':35}
    color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235),
        K=(0, 0, 0), w=(200, 200, 200))

    dir_sequence = []
    for x in range(cfg.TRIALS_EACH):
        dir_sequence.extend(cfg.DIRECTIONS)
    random.shuffle(dir_sequence)
    num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH

    event = 'start'
    trial = 1

    # Hardware trigger
    if cfg.TRIGGER_DEVICE is None:
        logger.warning('No trigger device set. Press Ctrl+C to stop or Enter to continue.')
        #input()
    trigger = Trigger(lpttype=cfg.TRIGGER_DEVICE, state=state)
    if trigger.init(50) == False:
        logger.error('\n** Error connecting to USB2LPT device. Use a mock trigger instead?')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = Trigger(lpttype='FAKE')
        trigger.init(50)

    # timers
    timer_trigger = Timer()
    timer_dir = Timer()
    timer_refresh = Timer()
    t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
    t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(-cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

    bar = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE)
    bar.fill()
    bar.glass_draw_cue()

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based)
        if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']:
            event = 'gap_s'
            bar.fill()
            timer_trigger.reset()
            trigger.signal(cfg.tdef.INIT)
        elif event == 'gap_s':
            if cfg.TRIAL_PAUSE:
                bar.put_text('Press any key')
                bar.update()
                key = cv2.waitKey()
                if key == keys['esc'] or not state.value:
                    break
                bar.fill()
            bar.put_text('Trial %d / %d' % (trial, num_trials))
            event = 'gap'
            timer_trigger.reset()
        elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']:
            event = 'cue'
            bar.fill()
            bar.draw_cue()
            trigger.signal(cfg.tdef.CUE)
            timer_trigger.reset()
        elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']:
            event = 'dir_r'
            dir = dir_sequence[trial - 1]
            if dir == 'L':  # left
                bar.move('L', 100, overlay=True)
                trigger.signal(cfg.tdef.LEFT_READY)
            elif dir == 'R':  # right
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.RIGHT_READY)
            elif dir == 'U':  # up
                bar.move('U', 100, overlay=True)
                trigger.signal(cfg.tdef.UP_READY)
            elif dir == 'D':  # down
                bar.move('D', 100, overlay=True)
                trigger.signal(cfg.tdef.DOWN_READY)
            elif dir == 'B':  # both hands
                bar.move('L', 100, overlay=True)
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.BOTH_READY)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
            timer_trigger.reset()
        elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready:
            bar.fill()
            bar.draw_cue()
            event = 'dir'
            timer_trigger.reset()
            timer_dir.reset()
            if dir == 'L':  # left
                trigger.signal(cfg.tdef.LEFT_GO)
            elif dir == 'R':  # right
                trigger.signal(cfg.tdef.RIGHT_GO)
            elif dir == 'U':  # up
                trigger.signal(cfg.tdef.UP_GO)
            elif dir == 'D':  # down
                trigger.signal(cfg.tdef.DOWN_GO)
            elif dir == 'B':  # both
                trigger.signal(cfg.tdef.BOTH_GO)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
        elif event == 'dir' and timer_trigger.sec() > t_dir:
            event = 'gap_s'
            bar.fill()
            trial += 1
            logger.info('trial ' + str(trial - 1) + ' done')
            trigger.signal(cfg.tdef.BLANK)
            timer_trigger.reset()
            t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
            t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(-cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

        # protocol
        if event == 'dir':
            dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1)
            if dir == 'L':  # L
                bar.move('L', dx, overlay=True)
            elif dir == 'R':  # R
                bar.move('R', dx, overlay=True)
            elif dir == 'U':  # U
                bar.move('U', dx, overlay=True)
            elif dir == 'D':  # D
                bar.move('D', dx, overlay=True)
            elif dir == 'B':  # Both
                bar.move('L', dx, overlay=True)
                bar.move('R', dx, overlay=True)

        # wait for start
        if event == 'start':
            bar.put_text('Waiting to start')

        bar.update()
        key = 0xFF & cv2.waitKey(1)
        if key == keys['esc'] or not state.value:
            break

    bar.finish()

    with state.get_lock():
        state.value = 0
示例#6
0
    def __init__(self,
                 image_path,
                 use_glass=False,
                 glass_feedback=True,
                 pc_feedback=True,
                 screen_pos=None,
                 screen_size=None):
        """
        Input:
            use_glass: if False, mock Glass will be used
            glass_feedback: show feedback to the user?
            pc_feedback: show feedback on the pc screen?
            screen_pos: screen position in (x,y)
            screen_size: screen size in (x,y)
        """
        # screen size and message setting
        if screen_size is None:
            if sys.platform.startswith('win'):
                from win32api import GetSystemMetrics
                screen_width = GetSystemMetrics(0)
                screen_height = GetSystemMetrics(1)
            else:
                screen_width = 1024
                screen_height = 768
            screen_size = (screen_width, screen_height)
        else:
            screen_width, screen_height = screen_size
        if screen_pos is None:
            screen_x, screen_y = (0, 0)
        else:
            screen_x, screen_y = screen_pos

        self.text_size = 2
        self.img = np.zeros((screen_height, screen_width, 3), np.uint8)
        self.glass = bgi_client.GlassControl(mock=not use_glass)
        self.glass.connect('127.0.0.1', 59900)
        self.set_glass_feedback(glass_feedback)
        self.set_pc_feedback(pc_feedback)
        self.set_cue_color(boxcol='B', crosscol='W')
        self.width = self.img.shape[1]
        self.height = self.img.shape[0]
        hw = int(self.barwidth / 2)
        self.cx = int(self.width / 2)
        self.cy = int(self.height / 2)
        self.xl1 = self.cx - hw
        self.xl2 = self.xl1 - self.barwidth
        self.xr1 = self.cx + hw
        self.xr2 = self.xr1 + self.barwidth
        self.yl1 = self.cy - hw
        self.yl2 = self.yl1 - self.barwidth
        self.yr1 = self.cy + hw
        self.yr2 = self.yr1 + self.barwidth

        if os.path.isdir(image_path):
            # load images
            left_image_path = '%s/left' % image_path
            right_image_path = '%s/right' % image_path
            tm = Timer()
            logger.info('Reading images from %s' % left_image_path)
            self.left_images = read_images(left_image_path, screen_size)
            logger.info('Reading images from %s' % right_image_path)
            self.right_images = read_images(right_image_path, screen_size)
            logger.info('Took %.1f s' % tm.sec())
        else:
            # load pickled images
            # note: this is painfully slow in Pytohn 2 even with cPickle (3s vs 27s)
            assert image_path[-4:] == '.pkl', 'The file must be of .pkl format'
            logger.info('Loading image binary file %s ...' % image_path)
            tm = Timer()
            with gzip.open(image_path, 'rb') as fp:
                image_data = pickle.load(fp)
            self.left_images = image_data['left_images']
            self.right_images = image_data['right_images']
            feedback_w = self.left_images[0].shape[1] / 2
            feedback_h = self.left_images[0].shape[0] / 2
            loc_x = [int(self.cx - feedback_w), int(self.cx + feedback_w)]
            loc_y = [int(self.cy - feedback_h), int(self.cy + feedback_h)]
            img_fit = np.zeros((screen_height, screen_width, 3), np.uint8)

            # adjust to the current screen size
            logger.info('Fitting images into the current screen size')
            for i, img in enumerate(self.left_images):
                img_fit = np.zeros((screen_height, screen_width, 3), np.uint8)
                img_fit[loc_y[0]:loc_y[1], loc_x[0]:loc_x[1]] = img
                self.left_images[i] = img_fit
            for i, img in enumerate(self.right_images):
                img_fit = np.zeros((screen_height, screen_width, 3), np.uint8)
                img_fit[loc_y[0]:loc_y[1], loc_x[0]:loc_x[1]] = img
                self.right_images[i] = img_fit

            logger.info('Took %.1f s' % tm.sec())
            logger.info('Done.')
        cv2.namedWindow("Protocol", cv2.WND_PROP_FULLSCREEN)
        cv2.moveWindow("Protocol", screen_x, screen_y)
        cv2.setWindowProperty("Protocol", cv2.WND_PROP_FULLSCREEN,
                              cv2.WINDOW_FULLSCREEN)
示例#7
0
def log_decoding(decoder,
                 logfile,
                 amp_name=None,
                 pklfile=True,
                 matfile=False,
                 autostop=False,
                 prob_smooth=False):
    """
    Decode online and write results with event timestamps

    Parameters
    ----------
    decoder : BCIDecoder or BCIDecoderDaemon class
        The decoder to use
    logfile : str
        The file path to contain the result in Python pickle format
    amp_name : str
        The stream name to connect to
    pklfile : bool
    If True, export the results to Python pickle format
    matfile : bool
    If True, export the results to .mat file
    autostop : bool
        If True, automatically finish when no more data is received.
    prob_smooth : bool
        If True, use smoothed probability values according to decoder's smoothing parameter.
    """

    import cv2
    import scipy

    # run event acquisition process in the background
    state = mp.Value('i', 1)
    event_queue = mp.Queue()
    proc = mp.Process(target=_log_decoding_helper,
                      args=[state, event_queue, amp_name, autostop])
    proc.start()
    logger.info('Spawned event acquisition process.')

    # init variables and choose decoding function
    labels = decoder.get_label_names()
    probs = []
    prob_times = []
    if prob_smooth:
        decode_fn = decoder.get_prob_smooth_unread
    else:
        decode_fn = decoder.get_prob_unread

    # simple controller UI
    cv2.namedWindow("Decoding", cv2.WINDOW_AUTOSIZE)
    cv2.moveWindow("Decoding", 1400, 50)
    img = np.zeros([100, 400, 3], np.uint8)
    cv2.putText(img, 'Press any key to start', (20, 60),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.imshow("Decoding", img)
    cv2.waitKeyEx()
    img *= 0
    cv2.putText(img, 'Press ESC to stop', (40, 60), cv2.FONT_HERSHEY_SIMPLEX,
                1, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.imshow("Decoding", img)

    key = 0
    started = False
    tm_watchdog = Timer(autoreset=True)
    tm_cls = Timer()
    while key != 27:
        prob, prob_time = decode_fn(True)
        t_lsl = pylsl.local_clock()
        key = cv2.waitKeyEx(1)
        if prob is None:
            # watch dog
            if tm_cls.sec() > 5:
                if autostop and started:
                    logger.info('No more streaming data. Finishing.')
                    break
                tm_cls.reset()
            tm_watchdog.sleep_atleast(0.001)
            continue
        probs.append(prob)
        prob_times.append(prob_time)
        txt = '[%.3f] ' % prob_time
        txt += ', '.join(['%s: %.2f' % (l, p) for l, p in zip(labels, prob)])
        txt += ' (%d ms, LSL Diff = %.3f)' % (tm_cls.msec(),
                                              (t_lsl - prob_time))
        logger.info(txt)
        if not started:
            started = True
        tm_cls.reset()

    # finish up processes
    cv2.destroyAllWindows()
    logger.info('Cleaning up event acquisition process.')
    state.value = 0
    decoder.stop()
    event_times, event_values = event_queue.get()
    proc.join()

    # save values
    if len(prob_times) == 0:
        logger.error('No decoding result. Please debug.')
        import pdb
        pdb.set_trace()
    t_start = prob_times[0]
    probs = np.vstack(probs)
    event_times = np.array(event_times)
    event_times = event_times[np.where(event_times >= t_start)[0]] - t_start
    prob_times = np.array(prob_times) - t_start
    event_values = np.array(event_values)
    data = dict(probs=probs,
                prob_times=prob_times,
                event_times=event_times,
                event_values=event_values,
                labels=labels)
    if pklfile:
        with open(logfile, 'wb') as f:
            pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
        logger.info('Saved to %s' % logfile)
    if matfile:
        pp = io.parse_path(logfile)
        pp = Path(logfile)
        matout = '%s/%s.mat' % (pp.parent, pp.stem)
        scipy.io.savemat(matout, data)
        logger.info('Saved to %s' % matout)
示例#8
0
class Feedback:
    """
    Perform a classification with visual feedback
    """
    def __init__(self, cfg, viz, tdef, trigger, logfile=None):
        self.cfg = cfg
        self.tdef = tdef
        self.trigger = trigger
        self.viz = viz
        self.viz.fill()
        self.refresh_delay = 1.0 / self.cfg.REFRESH_RATE
        self.bar_step_left = self.cfg.BAR_STEP['left']
        self.bar_step_right = self.cfg.BAR_STEP['right']
        self.bar_step_up = self.cfg.BAR_STEP['up']
        self.bar_step_down = self.cfg.BAR_STEP['down']
        self.bar_step_both = self.cfg.BAR_STEP['both']

        if type(self.cfg.BAR_BIAS) is tuple:
            self.bar_bias = list(self.cfg.BAR_BIAS)
        else:
            self.bar_bias = self.cfg.BAR_BIAS

        # New decoder: already smoothed by the decoder so bias after.
        #self.alpha_old = self.cfg.PROB_ACC_ALPHA
        #self.alpha_new = 1.0 - self.cfg.PROB_ACC_ALPHA

        if hasattr(self.cfg,
                   'BAR_REACH_FINISH') and self.cfg.BAR_REACH_FINISH == True:
            self.premature_end = True
        else:
            self.premature_end = False

        self.tm_trigger = Timer()
        self.tm_display = Timer()
        self.tm_watchdog = Timer()
        if logfile is not None:
            self.logf = open(logfile, 'w')
        else:
            self.logf = None

        # STIMO only
        if self.cfg.WITH_STIMO is True:
            if self.cfg.STIMO_COMPORT is None:
                atens = [x for x in serial.tools.list_ports.grep('ATEN')]
                if len(atens) == 0:
                    raise RuntimeError('No ATEN device found. Stop.')
                try:
                    self.stimo_port = atens[0].device
                except AttributeError:  # depends on Python distribution
                    self.stimo_port = atens[0][0]
            else:
                self.stimo_port = self.cfg.STIMO_COMPORT
            self.ser = serial.Serial(self.stimo_port, self.cfg.STIMO_BAUDRATE)
            logger.info('STIMO serial port %s is_open = %s' %
                        (self.stimo_port, self.ser.is_open))

        # FES only
        if self.cfg.WITH_FES is True:
            self.stim = fes.Motionstim8()
            self.stim.OpenSerialPort(self.cfg.FES_COMPORT)
            self.stim.InitializeChannelListMode()
            logger.info('Opened FES serial port')

    def __del__(self):
        # STIMO only
        if self.cfg.WITH_STIMO is True:
            self.ser.close()
            logger.info('Closed STIMO serial port %s' % self.stimo_port)
        # FES only
        if self.cfg.WITH_FES is True:
            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
            self.stim.UpdateChannelSettings(stim_code)
            self.stim.CloseSerialPort()
            logger.info('Closed FES serial port')

    def classify(self,
                 decoder,
                 true_label,
                 title_text,
                 bar_dirs,
                 state='start',
                 prob_history=None):
        """
        Run a single trial
        """
        def list2string(vec, fmt, sep=' '):
            return sep.join(fmt % x for x in vec)

        true_label_index = bar_dirs.index(true_label)
        self.tm_trigger.reset()
        if self.bar_bias is not None:
            bias_idx = bar_dirs.index(self.bar_bias[0])

        if self.logf is not None:
            self.logf.write('True label: %s\n' % true_label)

        tm_classify = Timer(autoreset=True)
        self.stimo_timer = Timer()
        while True:
            self.tm_display.sleep_atleast(self.refresh_delay)
            self.tm_display.reset()
            if state == 'start' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['INIT']:
                state = 'gap_s'
                if self.cfg.TRIALS_PAUSE:
                    self.viz.put_text('Press any key')
                    self.viz.update()
                    key = cv2.waitKeyEx()
                    if key == KEYS['esc']:
                        return
                self.viz.fill()
                self.tm_trigger.reset()
                self.trigger.signal(self.tdef.INIT)

            elif state == 'gap_s':
                if self.cfg.TIMINGS['GAP'] > 0:
                    self.viz.put_text(title_text)
                state = 'gap'
                self.tm_trigger.reset()

            elif state == 'gap' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['GAP']:
                state = 'cue'
                self.viz.fill()
                self.viz.draw_cue()
                self.viz.glass_draw_cue()
                self.trigger.signal(self.tdef.CUE)
                self.tm_trigger.reset()

            elif state == 'cue' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['READY']:
                state = 'dir_r'

                if self.cfg.SHOW_CUE is True:
                    if self.cfg.FEEDBACK_TYPE == 'BAR':
                        self.viz.move(true_label,
                                      100,
                                      overlay=False,
                                      barcolor='G')
                    elif self.cfg.FEEDBACK_TYPE == 'BODY':
                        self.viz.put_text(DIRS[true_label], 'R')
                    if true_label == 'L':  # left
                        self.trigger.signal(self.tdef.LEFT_READY)
                    elif true_label == 'R':  # right
                        self.trigger.signal(self.tdef.RIGHT_READY)
                    elif true_label == 'U':  # up
                        self.trigger.signal(self.tdef.UP_READY)
                    elif true_label == 'D':  # down
                        self.trigger.signal(self.tdef.DOWN_READY)
                    elif true_label == 'B':  # both hands
                        self.trigger.signal(self.tdef.BOTH_READY)
                    else:
                        raise RuntimeError('Unknown direction %s' % true_label)
                self.tm_trigger.reset()
                '''
                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    self.viz.set_pc_feedback(False)
                self.viz.move(true_label, 100, overlay=False, barcolor='G')

                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    self.viz.set_pc_feedback(True)
                    if self.cfg.SHOW_CUE is True:
                        self.viz.put_text(dirs[true_label], 'R')

                if true_label == 'L':  # left
                    self.trigger.signal(self.tdef.LEFREADY)
                elif true_label == 'R':  # right
                    self.trigger.signal(self.tdef.RIGHT_READY)
                elif true_label == 'U':  # up
                    self.trigger.signal(self.tdef.UP_READY)
                elif true_label == 'D':  # down
                    self.trigger.signal(self.tdef.DOWN_READY)
                elif true_label == 'B':  # both hands
                    self.trigger.signal(self.tdef.BOTH_READY)
                else:
                    raise RuntimeError('Unknown direction %s' % true_label)
                self.tm_trigger.reset()
                '''
            elif state == 'dir_r' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['DIR_CUE']:
                self.viz.fill()
                self.viz.draw_cue()
                self.viz.glass_draw_cue()
                state = 'dir'

                # initialize bar scores
                bar_label = bar_dirs[0]
                bar_score = 0
                probs = [1.0 / len(bar_dirs)] * len(bar_dirs)
                self.viz.move(bar_label, bar_score, overlay=False)
                probs_acc = np.zeros(len(probs))

                if true_label == 'L':  # left
                    self.trigger.signal(self.tdef.LEFT_GO)
                elif true_label == 'R':  # right
                    self.trigger.signal(self.tdef.RIGHT_GO)
                elif true_label == 'U':  # up
                    self.trigger.signal(self.tdef.UP_GO)
                elif true_label == 'D':  # down
                    self.trigger.signal(self.tdef.DOWN_GO)
                elif true_label == 'B':  # both
                    self.trigger.signal(self.tdef.BOTH_GO)
                else:
                    raise RuntimeError('Unknown truedirection %s' % true_label)

                self.tm_watchdog.reset()
                self.tm_trigger.reset()

            elif state == 'dir':
                if self.tm_trigger.sec() > self.cfg.TIMINGS['CLASSIFY'] or (
                        self.premature_end and bar_score >= 100):
                    if not hasattr(
                            self.cfg,
                            'SHOW_RESULT') or self.cfg.SHOW_RESULT is True:
                        # show classfication result
                        if self.cfg.WITH_STIMO is True:
                            if self.cfg.STIMO_FULLGAIT_CYCLE is not None and bar_label == 'U':
                                res_color = 'G'
                            elif self.cfg.TRIALS_RETRY is False or bar_label == true_label:
                                res_color = 'G'
                            else:
                                res_color = 'Y'
                        else:
                            res_color = 'Y'
                        if self.cfg.FEEDBACK_TYPE == 'BODY':
                            self.viz.move(bar_label,
                                          bar_score,
                                          overlay=False,
                                          barcolor=res_color,
                                          caption=DIRS[bar_label],
                                          caption_color=res_color)
                        else:
                            self.viz.move(bar_label,
                                          100,
                                          overlay=False,
                                          barcolor=res_color)
                    else:
                        if self.cfg.FEEDBACK_TYPE == 'BODY':
                            self.viz.move(bar_label,
                                          bar_score,
                                          overlay=False,
                                          barcolor=res_color,
                                          caption='TRIAL END',
                                          caption_color=res_color)
                        else:
                            self.viz.move(bar_label,
                                          0,
                                          overlay=False,
                                          barcolor=res_color)
                    self.trigger.signal(self.tdef.FEEDBACK)

                    # STIMO
                    if self.cfg.WITH_STIMO is True and self.cfg.STIMO_CONTINUOUS is False:
                        if self.cfg.STIMO_FULLGAIT_CYCLE is not None:
                            if bar_label == 'U':
                                self.ser.write(
                                    self.cfg.STIMO_FULLGAIT_PATTERN[0])
                                logger.info('STIMO: Sent 1')
                                time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE)
                                self.ser.write(
                                    self.cfg.STIMO_FULLGAIT_PATTERN[1])
                                logger.info('STIMO: Sent 2')
                                time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE)
                        elif self.cfg.TRIALS_RETRY is False or bar_label == true_label:
                            if bar_label == 'L':
                                self.ser.write(b'1')
                                logger.info('STIMO: Sent 1')
                            elif bar_label == 'R':
                                self.ser.write(b'2')
                                logger.info('STIMO: Sent 2')

                    # FES event mode mode
                    if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is False:
                        if bar_label == 'L':
                            stim_code = [0, 30, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)
                            logger.info('FES: Sent Left')
                            time.sleep(0.5)
                            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)

                        elif bar_label == 'R':
                            stim_code = [30, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)
                            time.sleep(0.5)
                            logger.info('FES: Sent Right')
                            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)

                    if self.cfg.DEBUG_PROBS:
                        msg = 'DEBUG: Accumulated probabilities = %s' % list2string(
                            probs_acc, '%.3f')
                        logger.info(msg)
                        if self.logf is not None:
                            self.logf.write(msg + '\n')
                    if self.logf is not None:
                        self.logf.write('%s detected as %s (%d)\n\n' %
                                        (true_label, bar_label, bar_score))
                        self.logf.flush()

                    # end of trial
                    state = 'feedback'
                    self.tm_trigger.reset()
                else:
                    # classify
                    probs_new = decoder.get_prob_smooth_unread()
                    if probs_new is None:
                        if self.tm_watchdog.sec() > 3:
                            logger.warning(
                                'No classification being done. Are you receiving data streams?'
                            )
                            self.tm_watchdog.reset()
                    else:
                        self.tm_watchdog.reset()

                        if prob_history is not None:
                            prob_history[true_label].append(
                                probs_new[true_label_index])

                        probs_acc += np.array(probs_new)
                        '''
                        New decoder: already smoothed by the decoder so bias after.
                        '''
                        probs = list(probs_new)
                        if self.bar_bias is not None:
                            probs[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs)
                            probs = [p / newsum for p in probs]
                        '''
                        # Method 2: bias and smoothen
                        if self.bar_bias is not None:
                            # print('BEFORE: %.3f %.3f'% (probs_new[0], probs_new[1]) )
                            probs_new[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs_new)
                            probs_new = [p / newsum for p in probs_new]
                        # print('AFTER: %.3f %.3f'% (probs_new[0], probs_new[1]) )
                        for i in range(len(probs_new)):
                            probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new
                        '''
                        ''' Original method
                        # Method 1: smoothen and bias
                        for i in range( len(probs_new) ):
                            probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new
                        # bias bar
                        if self.bar_bias is not None:
                            probs[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs)
                            probs = [p/newsum for p in probs]
                        '''

                        # determine the direction
                        # TODO: np.argmax(probs)
                        max_pidx = np.argmax(probs)
                        max_label = bar_dirs[max_pidx]

                        if self.cfg.POSITIVE_FEEDBACK is False or \
                                (self.cfg.POSITIVE_FEEDBACK and true_label == max_label):
                            dx = probs[max_pidx]
                            if max_label == 'R':
                                dx *= self.bar_step_right
                            elif max_label == 'L':
                                dx *= self.bar_step_left
                            elif max_label == 'U':
                                dx *= self.bar_step_up
                            elif max_label == 'D':
                                dx *= self.bar_step_down
                            elif max_label == 'B':
                                dx *= self.bar_step_both
                            else:
                                logger.debug('Direction %s using bar step %d' %
                                             (max_label, self.bar_step_left))
                                dx *= self.bar_step_left

                            # slow start
                            selected = self.cfg.BAR_SLOW_START['selected']
                            if self.cfg.BAR_SLOW_START[
                                    selected] and self.tm_trigger.sec(
                                    ) < self.cfg.BAR_SLOW_START[selected]:
                                dx *= self.tm_trigger.sec(
                                ) / self.cfg.BAR_SLOW_START[selected][0]

                            # add likelihoods
                            if max_label == bar_label:
                                bar_score += dx
                            else:
                                bar_score -= dx
                                # change of direction
                                if bar_score < 0:
                                    bar_score = -bar_score
                                    bar_label = max_label
                            bar_score = int(bar_score)
                            if bar_score > 100:
                                bar_score = 100
                            if self.cfg.FEEDBACK_TYPE == 'BODY':
                                if self.cfg.SHOW_CUE:
                                    self.viz.move(bar_label,
                                                  bar_score,
                                                  overlay=False,
                                                  caption=DIRS[true_label],
                                                  caption_color='G')
                                else:
                                    self.viz.move(bar_label,
                                                  bar_score,
                                                  overlay=False)
                            else:
                                self.viz.move(bar_label,
                                              bar_score,
                                              overlay=False)

                            # send the confidence value continuously
                            if self.cfg.WITH_STIMO and self.cfg.STIMO_CONTINUOUS:
                                if self.stimo_timer.sec(
                                ) >= self.cfg.STIMO_COOLOFF:
                                    if bar_label == 'U':
                                        stimo_code = bar_score
                                    else:
                                        stimo_code = 0
                                    self.ser.write(bytes([stimo_code]))
                                    logger.info('Sent STIMO code %d' %
                                                stimo_code)
                                    self.stimo_timer.reset()
                            # with FES
                            if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is True:
                                if self.stimo_timer.sec(
                                ) >= self.cfg.STIMO_COOLOFF:
                                    if bar_label == 'L':
                                        stim_code = [
                                            bar_score, 0, 0, 0, 0, 0, 0, 0
                                        ]
                                    else:
                                        stim_code = [
                                            0, bar_score, 0, 0, 0, 0, 0, 0
                                        ]
                                    self.stim.UpdateChannelSettings(stim_code)
                                    logger.info('Sent FES code %d' % bar_score)
                                    self.stimo_timer.reset()

                        if self.cfg.DEBUG_PROBS:
                            if self.bar_bias is not None:
                                biastxt = '[Bias=%s%.3f]  ' % (
                                    self.bar_bias[0], self.bar_bias[1])
                            else:
                                biastxt = ''
                            msg = '%s%s  prob %s   acc %s   bar %s%d  (%.1f ms)' % \
                                  (biastxt, bar_dirs, list2string(probs_new, '%.2f'), list2string(probs, '%.2f'),
                                   bar_label, bar_score, tm_classify.msec())
                            logger.info(msg)
                            if self.logf is not None:
                                self.logf.write(msg + '\n')

            elif state == 'feedback' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['FEEDBACK']:
                self.trigger.signal(self.tdef.BLANK)
                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    state = 'return'
                    self.tm_trigger.reset()
                else:
                    state = 'gap_s'
                    self.viz.fill()
                    self.viz.update()
                    return bar_label

            elif state == 'return':
                self.viz.set_glass_feedback(False)
                if self.cfg.WITH_STIMO:
                    self.viz.move(bar_label,
                                  bar_score,
                                  overlay=False,
                                  barcolor='B')
                else:
                    self.viz.move(bar_label,
                                  bar_score,
                                  overlay=False,
                                  barcolor='Y')
                self.viz.set_glass_feedback(True)
                bar_score -= 5
                if bar_score <= 0:
                    state = 'gap_s'
                    self.viz.fill()
                    self.viz.update()
                    return bar_label

            self.viz.update()
            key = cv2.waitKeyEx(1)
            if key == KEYS['esc']:
                return
            elif key == KEYS['space']:
                dx = 0
                bar_score = 0
                probs = [1.0 / len(bar_dirs)] * len(bar_dirs)
                self.viz.move(bar_dirs[0], bar_score, overlay=False)
                self.viz.update()
                logger.info('probs and dx reset.')
                self.tm_trigger.reset()
            elif key in ARROW_KEYS and ARROW_KEYS[key] in bar_dirs:
                # change bias on the fly
                if self.bar_bias is None:
                    self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT]
                else:
                    if ARROW_KEYS[key] == self.bar_bias[0]:
                        self.bar_bias[1] += BIAS_INCREMENT
                    elif self.bar_bias[1] >= BIAS_INCREMENT:
                        self.bar_bias[1] -= BIAS_INCREMENT
                    else:
                        self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT]
                if self.bar_bias[1] == 0:
                    self.bar_bias = None
                else:
                    bias_idx = bar_dirs.index(self.bar_bias[0])
示例#9
0
def train_decoder(cfg, featdata, feat_file=None):
    """
    Train the final decoder using all data
    """
    def sort_by_value(s, reverse=False):
        assert type(s) == dict or type(
            s) == list, 'Input must be a dictionary or list.'

        if type(s) == list:
            s = dict(enumerate(s))

        s_rev = dict((v, k) for k, v in s.items())

        if not len(s_rev) == len(s):
            logger.warning('sort_by_value(): %d identical values' %
                           (len(s.values()) - len(set(s.values())) + 1))

        values = sorted(s_rev, reverse=reverse)
        keys = [s_rev[x] for x in values]

        return keys, values

    # Init a classifier
    selected_classifier = cfg.CLASSIFIER['selected']
    cls_params = cfg.CLASSIFIER[selected_classifier]

    if selected_classifier == 'GB':
        cls = GradientBoostingClassifier(
            loss='deviance',
            learning_rate=cls_params['learning_rate'],
            n_estimators=cls_params['trees'],
            subsample=1.0,
            max_depth=cls_params['depth'],
            random_state=cls_params['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'XGB':
        cls = XGBClassifier(loss='deviance',
                            learning_rate=cls_params['learning_rate'],
                            n_estimators=cls_params['trees'],
                            subsample=1.0,
                            max_depth=cls_params['depth'],
                            random_state=cfg.GB['seed'],
                            max_features='sqrt',
                            verbose=0,
                            warm_start=False,
                            presort='auto')
    elif selected_classifier == 'RF':
        cls = RandomForestClassifier(n_estimators=cls_params['trees'],
                                     max_features='auto',
                                     max_depth=cls_params['depth'],
                                     n_jobs=cfg.N_JOBS,
                                     random_state=cls_params['seed'],
                                     oob_score=False,
                                     class_weight='balanced_subsample')
    elif selected_classifier == 'LDA':
        cls = LDA(solver=cls_params['solver'],
                  shrinkage=cls_params['shrinkage'])
    elif selected_classifier == 'rLDA':
        cls = rLDA(cls_params["r_coeff"])
    else:
        logger.error('Unknown classifier %s' % selected_classifier)
        raise ValueError

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']
    if cfg.FEATURES['PSD']['wlen'] is None:
        cfg.FEATURES['PSD']['wlen'] = wlen
    w_frames = featdata['w_frames']
    ch_names = featdata['ch_names']
    X_data_merged = np.concatenate(X_data)
    Y_data_merged = np.concatenate(Y_data)
    if cfg.CV['BALANCE_SAMPLES']:
        X_data_merged, Y_data_merged = balance_samples(
            X_data_merged,
            Y_data_merged,
            cfg.CV['BALANCE_SAMPLES'],
            verbose=True)

    # Start training the decoder
    logger.info('Training the decoder')
    timer = Timer()
    cls.n_jobs = cfg.N_JOBS
    cls.fit(X_data_merged, Y_data_merged)
    logger.info('Trained %d samples x %d dimension in %.1f sec' %\
          (X_data_merged.shape[0], X_data_merged.shape[1], timer.sec()))
    cls.n_jobs = 1  # always set n_jobs=1 for testing

    # Export the decoder
    classes = {c: cfg.tdef.by_value[c] for c in np.unique(Y_data)}
    if cfg.FEATURES['selected'] == 'PSD':
        data = dict(cls=cls,
                    ch_names=ch_names,
                    psde=featdata['psde'],
                    sfreq=featdata['sfreq'],
                    picks=featdata['picks'],
                    classes=classes,
                    epochs=cfg.EPOCH,
                    w_frames=w_frames,
                    w_seconds=cfg.FEATURES['PSD']['wlen'],
                    wstep=cfg.FEATURES['PSD']['wstep'],
                    spatial=cfg.SP_FILTER,
                    spatial_ch=cfg.SP_CHANNELS,
                    spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']],
                    spectral_ch=cfg.TP_CHANNELS,
                    notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']],
                    notch_ch=cfg.NOTCH_CHANNELS,
                    multiplier=cfg.MULTIPLIER,
                    ref_ch=cfg.REREFERENCE[cfg.REREFERENCE['selected']],
                    decim=cfg.FEATURES['PSD']['decim'])

    if cfg.SAVE_FEATURES:
        data["SAVED_FEAT"] = dict(X=X_data_merged, Y=Y_data_merged)

    clsfile = '%s/classifier/classifier-%s.pkl' % (cfg.DATA_PATH,
                                                   platform.architecture()[0])
    make_dirs('%s/classifier' % cfg.DATA_PATH)
    with open(clsfile, 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
    logger.info('Decoder saved to %s' % clsfile)

    # Reverse-lookup frequency from FFT
    fq = 0
    if type(cfg.FEATURES['PSD']['wlen']) == list:
        fq_res = 1.0 / cfg.FEATURES['PSD']['wlen'][0]
    else:
        fq_res = 1.0 / cfg.FEATURES['PSD']['wlen']
    fqlist = []
    while fq <= cfg.FEATURES['PSD']['fmax']:
        if fq >= cfg.FEATURES['PSD']['fmin']:
            fqlist.append(fq)
        fq += fq_res

    # Show top distinctive features
    if cfg.FEATURES['selected'] == 'PSD':
        logger.info('Good features ordered by importance')
        if selected_classifier in ['RF', 'GB', 'XGB']:
            keys, values = sort_by_value(list(cls.feature_importances_),
                                         reverse=True)
        elif selected_classifier in ['LDA', 'rLDA']:
            keys, values = sort_by_value(cls.coef_.reshape(-1).tolist(),
                                         reverse=True)
        keys = np.array(keys)
        values = np.array(values)

        if cfg.EXPORT_GOOD_FEATURES:
            if feat_file is None:
                gfout = open('%s/classifier/good_features.txt' % cfg.DATA_PATH,
                             'w')
            else:
                gfout = open(feat_file, 'w')

        if type(wlen) is not list:
            ch_names = [ch_names[c] for c in featdata['picks']]
        else:
            ch_names = []
            for w in range(len(wlen)):
                for c in featdata['picks']:
                    ch_names.append('w%d-%s' % (w, ch_names[c]))

        chlist, hzlist = features.feature2chz(keys, fqlist, ch_names=ch_names)
        valnorm = values[:cfg.FEAT_TOPN].copy()
        valsum = np.sum(valnorm)
        if valsum == 0:
            valsum = 1
        valnorm = valnorm / valsum * 100.0

        # show top-N features
        for i, (ch, hz) in enumerate(zip(chlist, hzlist)):
            if i >= cfg.FEAT_TOPN:
                break
            txt = '%-3s %5.1f Hz  normalized importance %-6s  raw importance %-6s  feature %-5d' %\
                  (ch, hz, '%.2f%%' % valnorm[i], '%.2f%%' % (values[i] * 100.0), keys[i])
            logger.info(txt)

        if cfg.EXPORT_GOOD_FEATURES:
            gfout.write('Importance(%) Channel Frequency Index\n')
            for i, (ch, hz) in enumerate(zip(chlist, hzlist)):
                gfout.write('%.3f\t%s\t%s\t%d\n' %
                            (values[i] * 100.0, ch, hz, keys[i]))
            gfout.close()
示例#10
0
def cross_validate(cfg, featdata, cv_file=None):
    """
    Perform cross validation
    """
    # Init a classifier
    selected_classifier = cfg.CLASSIFIER['selected']
    cls_params = cfg.CLASSIFIER[selected_classifier]

    if selected_classifier == 'GB':
        cls = GradientBoostingClassifier(
            loss='deviance',
            learning_rate=cls_params['learning_rate'],
            presort='auto',
            n_estimators=cls_params['trees'],
            subsample=1.0,
            max_depth=cls_params['depth'],
            random_state=cls_params['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False)
    elif selected_classifier == 'XGB':
        cls = XGBClassifier(loss='deviance',
                            learning_rate=cls_params['learning_rate'],
                            presort='auto',
                            n_estimators=cls_params['trees'],
                            subsample=1.0,
                            max_depth=cls_params['depth'],
                            random_state=cls_params,
                            max_features='sqrt',
                            verbose=0,
                            warm_start=False)
    elif selected_classifier == 'RF':
        cls = RandomForestClassifier(n_estimators=cls_params['trees'],
                                     max_features='auto',
                                     max_depth=cls_params['depth'],
                                     n_jobs=cfg.N_JOBS,
                                     random_state=cls_params['seed'],
                                     oob_score=False,
                                     class_weight='balanced_subsample')
    elif selected_classifier == 'LDA':
        cls = LDA(solver=cls_params['solver'],
                  shrinkage=cls_params['shrinkage'])
    elif selected_classifier == 'rLDA':
        cls = rLDA(cls_params['r_coeff'])
    else:
        logger.error('Unknown classifier type %s' % selected_classifier)
        raise ValueError

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']

    # Choose CV type
    ntrials, nsamples, fsize = X_data.shape
    selected_cv = cfg.CV_PERFORM['selected']
    if selected_cv == 'LeaveOneOut':
        logger.info('%d-fold leave-one-out cross-validation' % ntrials)
        if SKLEARN_OLD:
            cv = LeaveOneOut(len(Y_data))
        else:
            cv = LeaveOneOut()
    elif selected_cv == 'StratifiedShuffleSplit':
        logger.info(
            '%d-fold stratified cross-validation with test set ratio %.2f' %
            (cfg.CV_PERFORM[selected_cv]['folds'],
             cfg.CV_PERFORM[selected_cv]['test_ratio']))
        if SKLEARN_OLD:
            cv = StratifiedShuffleSplit(
                Y_data[:, 0],
                cfg.CV_PERFORM[selected_cv]['folds'],
                test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'],
                random_state=cfg.CV_PERFORM[selected_cv]['seed'])
        else:
            cv = StratifiedShuffleSplit(
                n_splits=cfg.CV_PERFORM[selected_cv]['folds'],
                test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'],
                random_state=cfg.CV_PERFORM[selected_cv]['seed'])
    else:
        logger.error('%s is not supported yet. Sorry.' %
                     cfg.CV_PERFORM[cfg.CV_PERFORM['selected']])
        raise NotImplementedError
    logger.info('%d trials, %d samples per trial, %d feature dimension' %
                (ntrials, nsamples, fsize))

    # Do it!
    timer_cv = Timer()
    scores, cm_txt = crossval_epochs(cv,
                                     X_data,
                                     Y_data,
                                     cls,
                                     cfg.tdef.by_value,
                                     cfg.CV['BALANCE_SAMPLES'],
                                     n_jobs=cfg.N_JOBS,
                                     ignore_thres=cfg.CV['IGNORE_THRES'],
                                     decision_thres=cfg.CV['DECISION_THRES'])
    t_cv = timer_cv.sec()

    # Export results
    txt = 'Cross validation took %d seconds.\n' % t_cv
    txt += '\n- Class information\n'
    txt += '%d epochs, %d samples per epoch, %d feature dimension (total %d samples)\n' %\
        (ntrials, nsamples, fsize, ntrials * nsamples)
    for ev in np.unique(Y_data):
        txt += '%s: %d trials\n' % (cfg.tdef.by_value[ev],
                                    len(np.where(Y_data[:, 0] == ev)[0]))
    if cfg.CV['BALANCE_SAMPLES']:
        txt += 'The number of samples was balanced using %ssampling.\n' % cfg.BALANCE_SAMPLES.lower(
        )
    txt += '\n- Experiment condition\n'
    txt += 'Sampling frequency: %.3f Hz\n' % featdata['sfreq']
    txt += 'Spatial filter: %s (channels: %s)\n' % (cfg.SP_FILTER,
                                                    cfg.SP_CHANNELS)
    txt += 'Spectral filter: %s (channels: %s)\n' % (
        cfg.TP_FILTER[cfg.TP_FILTER['selected']], cfg.TP_CHANNELS)
    txt += 'Notch filter: %s (channels: %s)\n' % (
        cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']], cfg.NOTCH_CHANNELS)
    txt += 'PSD Channels: ' + ','.join(
        [str(featdata['ch_names'][p]) for p in featdata['picks']]) + '\n'
    txt += 'PSD range: %.1f - %.1f Hz\n' % (cfg.FEATURES['PSD']['fmin'],
                                            cfg.FEATURES['PSD']['fmax'])
    txt += 'Window step: %.2f msec\n' % (
        1000.0 * cfg.FEATURES['PSD']['wstep'] / featdata['sfreq'])
    if type(wlen) is list:
        for i, w in enumerate(wlen):
            txt += 'Window size: %.1f msec\n' % (w * 1000.0)
            txt += 'Epoch range: %s sec\n' % (cfg.EPOCH[i])
    else:
        txt += 'Window size: %.1f msec\n' % (cfg.FEATURES['PSD']['wlen'] *
                                             1000.0)
        txt += 'Epoch range: %s sec\n' % (cfg.EPOCH)
    txt += 'Decimation factor: %d\n' % cfg.FEATURES['PSD']['decim']

    # Compute stats
    cv_mean, cv_std = np.mean(scores), np.std(scores)
    txt += '\n- Average CV accuracy over %d epochs (random seed=%s)\n' % (
        ntrials, cfg.CV_PERFORM[cfg.CV_PERFORM['selected']]['seed'])
    if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] in [
            'LeaveOneOut', 'StratifiedShuffleSplit'
    ]:
        txt += "mean %.3f, std: %.3f\n" % (cv_mean, cv_std)
    txt += 'Classifier: %s, ' % selected_classifier
    if selected_classifier == 'RF':
        txt += '%d trees, %s max depth, random state %s\n' % (
            cfg.CLASSIFIER['RF']['trees'], cfg.CLASSIFIER['RF']['depth'],
            cfg.CLASSIFIER['RF']['seed'])
    elif selected_classifier == 'GB' or selected_classifier == 'XGB':
        txt += '%d trees, %s max depth, %s learing_rate, random state %s\n' % (
            cfg.CLASSIFIER['GB']['trees'], cfg.CLASSIFIER['GB']['depth'],
            cfg.CLASSIFIER['GB']['learning_rate'],
            cfg.CLASSIFIER['GB']['seed'])
    elif selected_classifier == 'rLDA':
        txt += 'regularization coefficient %.2f\n' % cfg.CLASSIFIER['rLDA'][
            'r_coeff']
    if cfg.CV['IGNORE_THRES'] is not None:
        txt += 'Decision threshold: %.2f\n' % cfg.CV['IGNORE_THRES']
    txt += '\n- Confusion Matrix\n' + cm_txt
    logger.info(txt)

    # Export to a file
    if 'export_result' in cfg.CV_PERFORM[selected_cv] and cfg.CV_PERFORM[
            selected_cv]['export_result'] is True:
        if cv_file is None:
            if cfg.EXPORT_CLS is True:
                make_dirs('%s/classifier' % cfg.DATA_PATH)
                fout = open('%s/classifier/cv_result.txt' % cfg.DATA_PATH, 'w')
            else:
                fout = open('%s/cv_result.txt' % cfg.DATA_PATH, 'w')
        else:
            fout = open(cv_file, 'w')
        fout.write(txt)
        fout.close()