コード例 #1
0
ファイル: test_mi_cascade.py プロジェクト: syzhang/pycnbi
                    bar.update()
                    qc.print_c('Executing Rex action %s' % rex_dir, 'W')
                    os.system('%s/Rex/RexControlSimple.exe %s %s' % (pycnbi.ROOT, cfg.REX_COMPORT, rex_dir))
                    time.sleep(8)

            if true_label == pred_label:
                msg = 'Correct'
            else:
                msg = 'Wrong'
            print('Trial %d: %s (%s -> %s)' % (trial, msg, true_label, pred_label))
            trial += 1

    # write performance
    fdir, _, _ = qc.parse_path_list(cfg.CLS_MI)
    logfile = time.strftime(fdir + "/online-%Y%m%d-%H%M%S.txt", time.localtime())
    with open(logfile, 'w') as fout:
        for dt, gt in zip(dir_detected, dir_seq):
            fout.write('%s,%s\n' % (gt, dt))
        cfmat, acc = qc.confusion_matrix(dir_seq, dir_detected)
        fout.write('\nAccuracy %.3f\nConfusion matrix\n' % acc)
        fout.write(cfmat)
        print('\nAccuracy %.3f\nConfusion matrix\n' % acc)
        print(cfmat)
    print('Log exported to %s' % logfile)

    bar.finish()
    if decoder_UD:
        decoder_UD.stop()

    print('Finished.')
コード例 #2
0
    if len(dir_detected) > 0:
        # write performance and log results
        fdir, _, _ = qc.parse_path_list(cfg.DECODER_FILE)
        logfile = time.strftime(fdir + "/online-%Y%m%d-%H%M%S.txt", time.localtime())
        with open(logfile, 'w') as fout:
            fout.write('Ground-truth,Prediction\n')
            for gt, dt in zip(dir_seq, dir_detected):
                fout.write('%s,%s\n' % (gt, dt))
            cfmat, acc = qc.confusion_matrix(dir_seq, dir_detected)
            fout.write('\nAccuracy %.3f\nConfusion matrix\n' % acc)
            fout.write(cfmat)
            logger.info('Log exported to %s' % logfile)
        print('\nAccuracy %.3f\nConfusion matrix\n' % acc)
        print(cfmat)

    visual.finish()

    with state.get_lock():
        state.value = 0

    if decoder:
        decoder.stop()

    '''
    # automatic thresholding
    if prob_history and len(bar_dirs) == 2:
        total = sum(len(prob_history[c]) for c in prob_history)
        fout = open(probs_logfile, 'a')
        msg = 'Automatic threshold optimization.\n'
        max_acc = 0
        max_bias = 0
コード例 #3
0
ファイル: test_mi.py プロジェクト: syzhang/pycnbi
def config_run(cfg_module):
    if not (os.path.exists(cfg_module) and os.path.isfile(cfg_module)):
        raise IOError('%s cannot be loaded.' % os.path.realpath(cfg_module))
    cfg = load_cfg(cfg_module)
    if cfg.FAKE_CLS is None:
        # chooose amp
        if cfg.AMP_NAME is None and cfg.AMP_SERIAL is None:
            amp_name, amp_serial = pu.search_lsl(ignore_markers=True)
        else:
            amp_name = cfg.AMP_NAME
            amp_serial = cfg.AMP_SERIAL
        fake_dirs = None
    else:
        amp_name = None
        amp_serial = None
        fake_dirs = [v for (k, v) in cfg.DIRECTIONS]

    # events and triggers
    tdef = trigger_def(cfg.TRIGGER_DEF)
    if cfg.TRIGGER_DEVICE is None:
        input(
            '\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.'
        )
    trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        qc.print_c(
            '\n** Error connecting to USB2LPT device. Use a mock trigger instead?',
            'R')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # init classification
    decoder = BCIDecoderDaemon(cfg.CLS_MI,
                               buffer_size=1.0,
                               fake=(cfg.FAKE_CLS is not None),
                               amp_name=amp_name,
                               amp_serial=amp_serial,
                               fake_dirs=fake_dirs,
                               parallel=cfg.PARALLEL_DECODING,
                               alpha_new=cfg.PROB_ALPHA_NEW)

    # OLD: requires trigger values to be always defined
    #labels = [tdef.by_value[x] for x in decoder.get_labels()]
    # NEW: events can be mapped into integers:
    labels = []
    dirdata = set([d[1] for d in cfg.DIRECTIONS])
    for x in decoder.get_labels():
        if x not in dirdata:
            labels.append(tdef.by_value[x])
        else:
            labels.append(x)

    # map class labels to bar directions
    bar_def = {label: str(dir) for dir, label in cfg.DIRECTIONS}
    bar_dirs = [bar_def[l] for l in labels]
    dir_seq = []
    for x in range(cfg.TRIALS_EACH):
        dir_seq.extend(bar_dirs)
    if cfg.TRIALS_RANDOMIZE:
        random.shuffle(dir_seq)
    else:
        dir_seq = [d[0] for d in cfg.DIRECTIONS] * cfg.TRIALS_EACH
    num_trials = len(dir_seq)

    qc.print_c('Initializing decoder.', 'W')
    while decoder.is_running() is 0:
        time.sleep(0.01)

    # bar visual object
    if cfg.FEEDBACK_TYPE == 'BAR':
        from pycnbi.protocols.viz_bars import BarVisual
        visual = BarVisual(cfg.GLASS_USE,
                           screen_pos=cfg.SCREEN_POS,
                           screen_size=cfg.SCREEN_SIZE)
    elif cfg.FEEDBACK_TYPE == 'BODY':
        assert hasattr(cfg,
                       'IMAGE_PATH'), 'IMAGE_PATH is undefined in your config.'
        from pycnbi.protocols.viz_human import BodyVisual
        visual = BodyVisual(cfg.IMAGE_PATH,
                            use_glass=cfg.GLASS_USE,
                            screen_pos=cfg.SCREEN_POS,
                            screen_size=cfg.SCREEN_SIZE)
    visual.put_text('Waiting to start')
    if cfg.LOG_PROBS:
        logdir = qc.parse_path_list(cfg.CLS_MI)[0]
        probs_logfile = time.strftime(logdir + "probs-%Y%m%d-%H%M%S.txt",
                                      time.localtime())
    else:
        probs_logfile = None
    feedback = Feedback(cfg, visual, tdef, trigger, probs_logfile)

    # start
    trial = 1
    dir_detected = []
    prob_history = {c: [] for c in bar_dirs}
    while trial <= num_trials:
        if cfg.SHOW_TRIALS:
            title_text = 'Trial %d / %d' % (trial, num_trials)
        else:
            title_text = 'Ready'
        true_label = dir_seq[trial - 1]

        # profiling feedback
        #import cProfile
        #pr = cProfile.Profile()
        #pr.enable()
        result = feedback.classify(decoder,
                                   true_label,
                                   title_text,
                                   bar_dirs,
                                   prob_history=prob_history)
        #pr.disable()
        #pr.print_stats(sort='time')

        if result is None:
            break
        else:
            pred_label = result
        dir_detected.append(pred_label)

        if cfg.WITH_REX is True and pred_label == true_label:
            # if cfg.WITH_REX is True:
            if pred_label == 'U':
                rex_dir = 'N'
            elif pred_label == 'L':
                rex_dir = 'W'
            elif pred_label == 'R':
                rex_dir = 'E'
            elif pred_label == 'D':
                rex_dir = 'S'
            else:
                qc.print_c(
                    'Warning: Rex cannot execute undefined action %s' %
                    pred_label, 'W')
                rex_dir = None
            if rex_dir is not None:
                visual.move(pred_label, 100, overlay=False, barcolor='B')
                visual.update()
                qc.print_c('Executing Rex action %s' % rex_dir, 'W')
                os.system('%s/Rex/RexControlSimple.exe %s %s' %
                          (pycnbi.ROOT, cfg.REX_COMPORT, rex_dir))
                time.sleep(8)

        if true_label == pred_label:
            msg = 'Correct'
        else:
            msg = 'Wrong'
        if cfg.TRIALS_RETRY is False or true_label == pred_label:
            print('Trial %d: %s (%s -> %s)' %
                  (trial, msg, true_label, pred_label))
            trial += 1

    if len(dir_detected) > 0:
        # write performance and log results
        fdir, _, _ = qc.parse_path_list(cfg.CLS_MI)
        logfile = time.strftime(fdir + "/online-%Y%m%d-%H%M%S.txt",
                                time.localtime())
        with open(logfile, 'w') as fout:
            fout.write('Ground-truth,Prediction\n')
            for gt, dt in zip(dir_seq, dir_detected):
                fout.write('%s,%s\n' % (gt, dt))
            cfmat, acc = qc.confusion_matrix(dir_seq, dir_detected)
            fout.write('\nAccuracy %.3f\nConfusion matrix\n' % acc)
            fout.write(cfmat)
            print('Log exported to %s' % logfile)
        print('\nAccuracy %.3f\nConfusion matrix\n' % acc)
        print(cfmat)

    visual.finish()
    if decoder:
        decoder.stop()
    '''
    # automatic thresholding
    if prob_history and len(bar_dirs) == 2:
        total = sum(len(prob_history[c]) for c in prob_history)
        fout = open(probs_logfile, 'a')
        msg = 'Automatic threshold optimization.\n'
        max_acc = 0
        max_bias = 0
        for bias in np.arange(-0.99, 1.00, 0.01):
            corrects = 0
            for p in prob_history[bar_dirs[0]]:
                p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1
                if p_biased >= 0.5:
                    corrects += 1
            for p in prob_history[bar_dirs[1]]:
                p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1
                if p_biased < 0.5:
                    corrects += 1
            acc = corrects / total
            msg += '%s%.2f: %.3f\n' % (bar_dirs[0], bias, acc)
            if acc > max_acc:
                max_acc = acc
                max_bias = bias
        msg += 'Max acc = %.3f at bias %.2f\n' % (max_acc, max_bias)
        fout.write(msg)
        fout.close()
        print(msg)
    '''

    print('Finished.')
コード例 #4
0
ファイル: train_mi.py プロジェクト: iPsych/neurodecode
def run(cfg, state=mp.Value('i', 1), queue=None):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass
    #  Protocol start if equals to 1
    if not state.value:
        sys.exit()

    refresh_delay = 1.0 / cfg.REFRESH_RATE

    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)

    # visualizer
    keys = {
        'left': 81,
        'right': 83,
        'up': 82,
        'down': 84,
        'pgup': 85,
        'pgdn': 86,
        'home': 80,
        'end': 87,
        'space': 32,
        'esc': 27,
        ',': 44,
        '.': 46,
        's': 115,
        'c': 99,
        '[': 91,
        ']': 93,
        '1': 49,
        '!': 33,
        '2': 50,
        '@': 64,
        '3': 51,
        '#': 35
    }
    color = dict(G=(20, 140, 0),
                 B=(210, 0, 0),
                 R=(0, 50, 200),
                 Y=(0, 215, 235),
                 K=(0, 0, 0),
                 w=(200, 200, 200))

    dir_sequence = []
    for x in range(cfg.TRIALS_EACH):
        dir_sequence.extend(cfg.DIRECTIONS)
    random.shuffle(dir_sequence)
    num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH

    event = 'start'
    trial = 1

    # Hardware trigger
    if cfg.TRIGGER_DEVICE is None:
        logger.warning(
            'No trigger device set. Press Ctrl+C to stop or Enter to continue.'
        )
        #input()
    trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        logger.error(
            '\n** Error connecting to USB2LPT device. Use a mock trigger instead?'
        )
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # timers
    timer_trigger = qc.Timer()
    timer_dir = qc.Timer()
    timer_refresh = qc.Timer()
    t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'],
                                                cfg.TIMINGS['DIR_RANDOMIZE'])
    t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(
        -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

    bar = BarVisual(cfg.GLASS_USE,
                    screen_pos=cfg.SCREEN_POS,
                    screen_size=cfg.SCREEN_SIZE)
    bar.fill()
    bar.glass_draw_cue()

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based)
        if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']:
            event = 'gap_s'
            bar.fill()
            timer_trigger.reset()
            trigger.signal(cfg.tdef.INIT)
        elif event == 'gap_s':
            if cfg.TRIAL_PAUSE:
                bar.put_text('Press any key')
                bar.update()
                key = cv2.waitKey()
                if key == keys['esc'] or not state.value:
                    break
                bar.fill()
            bar.put_text('Trial %d / %d' % (trial, num_trials))
            event = 'gap'
            timer_trigger.reset()
        elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']:
            event = 'cue'
            bar.fill()
            bar.draw_cue()
            trigger.signal(cfg.tdef.CUE)
            timer_trigger.reset()
        elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']:
            event = 'dir_r'
            dir = dir_sequence[trial - 1]
            if dir == 'L':  # left
                bar.move('L', 100, overlay=True)
                trigger.signal(cfg.tdef.LEFT_READY)
            elif dir == 'R':  # right
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.RIGHT_READY)
            elif dir == 'U':  # up
                bar.move('U', 100, overlay=True)
                trigger.signal(cfg.tdef.UP_READY)
            elif dir == 'D':  # down
                bar.move('D', 100, overlay=True)
                trigger.signal(cfg.tdef.DOWN_READY)
            elif dir == 'B':  # both hands
                bar.move('L', 100, overlay=True)
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.BOTH_READY)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
            timer_trigger.reset()
        elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready:
            bar.fill()
            bar.draw_cue()
            event = 'dir'
            timer_trigger.reset()
            timer_dir.reset()
            if dir == 'L':  # left
                trigger.signal(cfg.tdef.LEFT_GO)
            elif dir == 'R':  # right
                trigger.signal(cfg.tdef.RIGHT_GO)
            elif dir == 'U':  # up
                trigger.signal(cfg.tdef.UP_GO)
            elif dir == 'D':  # down
                trigger.signal(cfg.tdef.DOWN_GO)
            elif dir == 'B':  # both
                trigger.signal(cfg.tdef.BOTH_GO)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
        elif event == 'dir' and timer_trigger.sec() > t_dir:
            event = 'gap_s'
            bar.fill()
            trial += 1
            logger.info('trial ' + str(trial - 1) + ' done')
            trigger.signal(cfg.tdef.BLANK)
            timer_trigger.reset()
            t_dir = cfg.TIMINGS['DIR'] + random.uniform(
                -cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
            t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(
                -cfg.TIMINGS['READY_RANDOMIZE'],
                cfg.TIMINGS['READY_RANDOMIZE'])

        # protocol
        if event == 'dir':
            dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1)
            if dir == 'L':  # L
                bar.move('L', dx, overlay=True)
            elif dir == 'R':  # R
                bar.move('R', dx, overlay=True)
            elif dir == 'U':  # U
                bar.move('U', dx, overlay=True)
            elif dir == 'D':  # D
                bar.move('D', dx, overlay=True)
            elif dir == 'B':  # Both
                bar.move('L', dx, overlay=True)
                bar.move('R', dx, overlay=True)

        # wait for start
        if event == 'start':
            bar.put_text('Waiting to start')

        bar.update()
        key = 0xFF & cv2.waitKey(1)
        if key == keys['esc'] or not state.value:
            break

    bar.finish()

    with state.get_lock():
        state.value = 0