Beispiel #1
0
def run(cfg, state=mp.Value('i', 1), queue=None, logger=logger):
    '''
    Main function used to run the offline protocol.

    Parameters
    ----------
    cfg : python.module
        The loaded config module from the corresponding config_offline.py
    queue : mp.Queue
        If not None, redirect sys.stdout to GUI terminal
    logger : logging.logger
        The logger to use
    '''
    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Load the mapping from int to string for triggers events
    cfg.tdef = TriggerDef(cfg.TRIGGER_FILE)

    # Protocol start if equals to 1
    if not state.value:
        sys.exit(-1)

    #-------------------------------------
    # ADD YOUR CODE HERE
    #-------------------------------------

    # TO train a decoder, look at trainer_mi.py protocol

    with state.get_lock():
        state.value = 0
Beispiel #2
0
    def on_new_tdef_file(self, key, trigger_file):
        """
        Update the event QComboBox with the new events from the new tdef file.
        """
        self.tdef = TriggerDef(trigger_file)

        if self.events:
            self.on_update_VBoxLayout()
Beispiel #3
0
def run(cfg, state=mp.Value('i', 1), queue=None, logger=logger):
    '''
    Main function used to run the offline protocol.

    Parameters
    ----------
    cfg : python.module
        The loaded config module from the corresponding config_offline.py
    queue : mp.Queue
        If not None, redirect sys.stdout to GUI terminal
    logger : logging.logger
        The logger to use
    '''
    # Use to redirect sys.stdout to GUI terminal if GUI usage
    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass

    # Protocol start if equals to 1
    if not state.value:
        sys.exit()

    # Load the mapping from int to string for triggers events
    cfg.tdef = TriggerDef(cfg.TRIGGER_FILE)

    # Refresh rate
    refresh_delay = 1.0 / cfg.REFRESH_RATE

    # Trigger
    trigger = Trigger(lpttype=cfg.TRIGGER_DEVICE, state=state)
    if trigger.init(50) == False:
        logger.error(
            '\n** Error connecting to the trigger device. Use a mock trigger instead?'
        )
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = Trigger(lpttype='FAKE')
        trigger.init(50)

    # timers
    timer_refresh = Timer()

    trial = 1
    num_trials = cfg.TRIALS_NB

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        #-------------------------------------
        # ADD YOUR CODE HERE
        #-------------------------------------

    with state.get_lock():
        state.value = 0
Beispiel #4
0
 def on_new_tdef_file(self, key, trigger_file):
     """
     Update the QComboBox with the new events from the new tdef file.
     """
     self.clear_hBoxLayout()
     tdef = TriggerDef(trigger_file)
     nb_directions = 4
     #  Convert 'None' to real None (real None is removed when selected in the GUI)
     tdef_values = [None if i == 'None' else i for i in list(tdef.by_name)]
     self.create_the_comboBoxes(self.chosen_value, tdef_values,
                                nb_directions)
Beispiel #5
0
def run(cfg,
        state=mp.Value('i', 1),
        queue=None,
        interactive=False,
        cv_file=None,
        feat_file=None,
        logger=logger):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # add tdef object
    cfg.tdef = TriggerDef(cfg.TRIGGER_FILE)

    # Extract features
    if not state.value:
        sys.exit(-1)
    featdata = features.compute_features(cfg)

    # Find optimal threshold for TPR balancing
    #balance_tpr(cfg, featdata)

    # Perform cross validation
    if not state.value:
        sys.exit(-1)

    if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']]:
        cross_validate(cfg, featdata, cv_file=cv_file)

    # Train a decoder
    if not state.value:
        sys.exit(-1)

    if cfg.EXPORT_CLS is True:
        train_decoder(cfg, featdata, feat_file=feat_file)

    with state.get_lock():
        state.value = 0
Beispiel #6
0
def run(cfg, state=mp.Value('i', 1), queue=None):
    def confusion_matrix(Y_true, Y_pred, label_len=6):
        """
        Generate confusion matrix in a string format
        Parameters
        ----------
        Y_true : list
            The true labels
        Y_pred : list
            The test labels
        label_len : int
            The maximum label text length displayed (minimum length: 6)
        Returns
        -------
        cfmat : str
            The confusion matrix in str format (X-axis: prediction, -axis: ground truth)
        acc : float
            The accuracy
        """
        import numpy as np
        from sklearn.metrics import confusion_matrix as sk_confusion_matrix

        # find labels
        if type(Y_true) == np.ndarray:
            Y_labels = np.unique(Y_true)
        else:
            Y_labels = list(set(Y_true))

        # Check the provided label name length
        if label_len < 6:
            label_len = 6
            logger.warning('label_len < 6. Setting to 6.')
        label_tpl = '%' + '-%ds' % label_len
        col_tpl = '%' + '-%d.2f' % label_len

        # sanity check
        if len(Y_pred) > len(Y_true):
            raise RuntimeError('Y_pred has more items than Y_true')
        elif len(Y_pred) < len(Y_true):
            Y_true = Y_true[:len(Y_pred)]

        cm = sk_confusion_matrix(Y_true, Y_pred, Y_labels)

        # compute confusion matrix
        cm_rate = cm.copy().astype('float')
        cm_sum = np.sum(cm, axis=1)

        # Fill confusion string
        for r, s in zip(cm_rate, cm_sum):
            if s > 0:
                r /= s
        cm_txt = label_tpl % 'gt\dt'
        for l in Y_labels:
            cm_txt += label_tpl % str(l)[:label_len]
        cm_txt += '\n'
        for l, r in zip(Y_labels, cm_rate):
            cm_txt += label_tpl % str(l)[:label_len]
            for c in r:
                cm_txt += col_tpl % c
            cm_txt += '\n'

        # compute accuracy
        correct = 0.0
        for c in range(cm.shape[0]):
            correct += cm[c][c]
        cm_sum = cm.sum()
        if cm_sum > 0:
            acc = correct / cm.sum()
        else:
            acc = 0.0

        return cm_txt, acc

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass

    #  Protocol runs if state equals to 1
    if not state.value:
        sys.exit(-1)

    if cfg.FAKE_CLS is None:
        # chooose amp
        if cfg.AMP_NAME is None:
            amp_name = search_lsl(ignore_markers=True, state=state)
        else:
            amp_name = cfg.AMP_NAME
        fake_dirs = None
    else:
        amp_name = None
        fake_dirs = [v for (k, v) in cfg.DIRECTIONS]

    # events and triggers
    tdef = TriggerDef(cfg.TRIGGER_FILE)
    #if cfg.TRIGGER_DEVICE is None:
    #    input('\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.')
    trigger = Trigger(cfg.TRIGGER_DEVICE, state)
    if trigger.init(50) == False:
        logger.error(
            'Cannot connect to USB2LPT device. Use a mock trigger instead?')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = Trigger('FAKE', state)
        trigger.init(50)

    # For adaptive (need to share the actual true label accross process)
    label = mp.Value('i', 0)

    # init classification
    decoder = BCIDecoderDaemon(amp_name, cfg.DECODER_FILE, buffer_size=1.0, fake=(cfg.FAKE_CLS is not None), fake_dirs=fake_dirs, \
                               parallel=cfg.PARALLEL_DECODING[cfg.PARALLEL_DECODING['selected']], alpha_new=cfg.PROB_ALPHA_NEW, label=label)

    # OLD: requires trigger values to be always defined
    #labels = [tdef.by_value[x] for x in decoder.get_labels()]
    # NEW: events can be mapped into integers:
    labels = []
    dirdata = set([d[1] for d in cfg.DIRECTIONS])
    for x in decoder.get_labels():
        if x not in dirdata:
            labels.append(tdef.by_value[x])
        else:
            labels.append(x)

    # map class labels to bar directions
    bar_def = {label: str(dir) for dir, label in cfg.DIRECTIONS}
    bar_dirs = [bar_def[l] for l in labels]
    dir_seq = []
    for x in range(cfg.TRIALS_EACH):
        dir_seq.extend(bar_dirs)

    logger.info('Initializing decoder.')
    while decoder.is_running() == 0:
        time.sleep(0.01)

    # bar visual object
    if cfg.FEEDBACK_TYPE == 'BAR':
        from neurodecode.protocols.viz_bars import BarVisual
        visual = BarVisual(cfg.GLASS_USE,
                           screen_pos=cfg.SCREEN_POS,
                           screen_size=cfg.SCREEN_SIZE)
    elif cfg.FEEDBACK_TYPE == 'BODY':
        assert hasattr(cfg, 'FEEDBACK_IMAGE_PATH'
                       ), 'FEEDBACK_IMAGE_PATH is undefined in your config.'
        from neurodecode.protocols.viz_human import BodyVisual
        visual = BodyVisual(cfg.FEEDBACK_IMAGE_PATH,
                            use_glass=cfg.GLASS_USE,
                            screen_pos=cfg.SCREEN_POS,
                            screen_size=cfg.SCREEN_SIZE)
    visual.put_text('Waiting to start')
    if cfg.LOG_PROBS:
        logdir = io.parse_path(cfg.DECODER_FILE).dir
        probs_logfile = time.strftime(logdir + "probs-%Y%m%d-%H%M%S.txt",
                                      time.localtime())
    else:
        probs_logfile = None
    feedback = Feedback(cfg, state, visual, tdef, trigger, probs_logfile)

    # If adaptive classifier
    if cfg.ADAPTIVE[cfg.ADAPTIVE['selected']]:
        nb_runs = cfg.ADAPTIVE[cfg.ADAPTIVE['selected']][0]
        adaptive = True
    else:
        nb_runs = 1
        adaptive = False

    run = 1
    while run <= nb_runs:

        if cfg.TRIALS_RANDOMIZE:
            random.shuffle(dir_seq)
        else:
            dir_seq = [d[0] for d in cfg.DIRECTIONS] * cfg.TRIALS_EACH
        num_trials = len(dir_seq)

        # For adaptive, retrain classifier
        if run > 1:

            #  Allow to retrain classifier
            with decoder.label.get_lock():
                decoder.label.value = 1

            # Wait that the retraining is done
            while decoder.label.value == 1:
                time.sleep(0.01)

            feedback.viz.put_text('Press any key')
            feedback.viz.update()
            cv2.waitKeyEx()
            feedback.viz.fill()

        # start
        trial = 1
        dir_detected = []
        prob_history = {c: [] for c in bar_dirs}
        while trial <= num_trials:
            if cfg.SHOW_TRIALS:
                title_text = 'Trial %d / %d' % (trial, num_trials)
            else:
                title_text = 'Ready'
            true_label = dir_seq[trial - 1]

            # profiling feedback
            #import cProfile
            #pr = cProfile.Profile()
            #pr.enable()
            result = feedback.classify(decoder,
                                       true_label,
                                       title_text,
                                       bar_dirs,
                                       prob_history=prob_history,
                                       adaptive=adaptive)
            #pr.disable()
            #pr.print_stats(sort='time')

            if result is None:
                decoder.stop()
                return
            else:
                pred_label = result
            dir_detected.append(pred_label)

            if cfg.WITH_REX is True and pred_label == true_label:
                # if cfg.WITH_REX is True:
                if pred_label == 'U':
                    rex_dir = 'N'
                elif pred_label == 'L':
                    rex_dir = 'W'
                elif pred_label == 'R':
                    rex_dir = 'E'
                elif pred_label == 'D':
                    rex_dir = 'S'
                else:
                    logger.warning('Rex cannot execute undefined action %s' %
                                   pred_label)
                    rex_dir = None
                if rex_dir is not None:
                    visual.move(pred_label, 100, overlay=False, barcolor='B')
                    visual.update()
                    logger.info('Executing Rex action %s' % rex_dir)
                    os.system(
                        '%s/Rex/RexControlSimple.exe %s %s' %
                        (os.environ['NEUROD_ROOT'], cfg.REX_COMPORT, rex_dir))
                    time.sleep(8)

            if true_label == pred_label:
                msg = 'Correct'
            else:
                msg = 'Wrong'
            if cfg.TRIALS_RETRY is False or true_label == pred_label:
                logger.info('Trial %d: %s (%s -> %s)' %
                            (trial, msg, true_label, pred_label))
                trial += 1

        if len(dir_detected) > 0:
            # write performance and log results
            fdir = io.parse_path(cfg.DECODER_FILE).dir
            logfile = time.strftime(fdir + "/online-%Y%m%d-%H%M%S.txt",
                                    time.localtime())
            with open(logfile, 'w') as fout:
                fout.write('Ground-truth,Prediction\n')
                for gt, dt in zip(dir_seq, dir_detected):
                    fout.write('%s,%s\n' % (gt, dt))
                cfmat, acc = confusion_matrix(dir_seq, dir_detected)
                fout.write('\nAccuracy %.3f\nConfusion matrix\n' % acc)
                fout.write(cfmat)
                logger.info('Log exported to %s' % logfile)
            print('\nAccuracy %.3f\nConfusion matrix\n' % acc)
            print(cfmat)

        run += 1

    visual.finish()

    with state.get_lock():
        state.value = 0

    if decoder.is_running():
        decoder.stop()
    '''
    # automatic thresholding
    if prob_history and len(bar_dirs) == 2:
        total = sum(len(prob_history[c]) for c in prob_history)
        fout = open(probs_logfile, 'a')
        msg = 'Automatic threshold optimization.\n'
        max_acc = 0
        max_bias = 0
        for bias in np.arange(-0.99, 1.00, 0.01):
            corrects = 0
            for p in prob_history[bar_dirs[0]]:
                p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1
                if p_biased >= 0.5:
                    corrects += 1
            for p in prob_history[bar_dirs[1]]:
                p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1
                if p_biased < 0.5:
                    corrects += 1
            acc = corrects / total
            msg += '%s%.2f: %.3f\n' % (bar_dirs[0], bias, acc)
            if acc > max_acc:
                max_acc = acc
                max_bias = bias
        msg += 'Max acc = %.3f at bias %.2f\n' % (max_acc, max_bias)
        fout.write(msg)
        fout.close()
        print(msg)
    '''

    logger.info('Finished.')
Beispiel #7
0
def run(cfg, state=mp.Value('i', 1), queue=None):
    '''
    Main function used to run the online protocol.

    Parameters
    ----------
    cfg : python.module
        The loaded config module from the corresponding config_offline.py
    queue : mp.Queue
        If not None, redirect sys.stdout to GUI terminal
    logger : logging.logger
        The logger to use
    '''

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass

    #  Protocol runs if state equals to 1
    if not state.value:
        sys.exit(-1)

    # events and triggers
    cfg.tdef = TriggerDef(cfg.TRIGGER_FILE)

    # To send trigger events
    trigger = Trigger(cfg.TRIGGER_DEVICE, state)

    if trigger.init(50) == False:
        logger.error(
            'Cannot connect to trigger device. Use a mock trigger instead?')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = Trigger('FAKE', state)
        trigger.init(50)

    # Instance a stream receiver
    sr = StreamReceiver(bufsize=1,
                        winsize=0.5,
                        stream_name=None,
                        eeg_only=True)

    # Timer for acquisition rate, here 20 Hz
    tm = Timer(autoreset=True)

    # Refresh rate
    refresh_delay = 1.0 / cfg.REFRESH_RATE

    while True:

        # Acquire data from all the connected LSL streams by filling each associated buffers.
        sr.acquire()

        # Extract the latest window from the buffer of the chosen stream.
        window, tslist = sr.get_window(
        )  # window = [samples x channels], tslist = [samples]

        #-------------------------------------
        # ADD YOUR CODE HERE
        #-------------------------------------

        #  To run a trained BCI decoder, look at online_mi.py protocol

        tm.sleep_atleast(refresh_delay)

    with state.get_lock():
        state.value = 0

    logger.info('Finished.')
Beispiel #8
0
def run(cfg, state=mp.Value('i', 1), queue=None):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2: # 0: stop, 1:start, 2:wait
        pass
    #  Protocol start if equals to 1
    if not state.value:
        sys.exit()

    refresh_delay = 1.0 / cfg.REFRESH_RATE

    cfg.tdef = TriggerDef(cfg.TRIGGER_FILE)

    # visualizer
    keys = {'left':81, 'right':83, 'up':82, 'down':84, 'pgup':85, 'pgdn':86,
        'home':80, 'end':87, 'space':32, 'esc':27, ',':44, '.':46, 's':115, 'c':99,
        '[':91, ']':93, '1':49, '!':33, '2':50, '@':64, '3':51, '#':35}
    color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235),
        K=(0, 0, 0), w=(200, 200, 200))

    dir_sequence = []
    for x in range(cfg.TRIALS_EACH):
        dir_sequence.extend(cfg.DIRECTIONS)
    random.shuffle(dir_sequence)
    num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH

    event = 'start'
    trial = 1

    # Hardware trigger
    if cfg.TRIGGER_DEVICE is None:
        logger.warning('No trigger device set. Press Ctrl+C to stop or Enter to continue.')
        #input()
    trigger = Trigger(lpttype=cfg.TRIGGER_DEVICE, state=state)
    if trigger.init(50) == False:
        logger.error('\n** Error connecting to USB2LPT device. Use a mock trigger instead?')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = Trigger(lpttype='FAKE')
        trigger.init(50)

    # timers
    timer_trigger = Timer()
    timer_dir = Timer()
    timer_refresh = Timer()
    t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
    t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(-cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

    bar = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE)
    bar.fill()
    bar.glass_draw_cue()

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based)
        if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']:
            event = 'gap_s'
            bar.fill()
            timer_trigger.reset()
            trigger.signal(cfg.tdef.INIT)
        elif event == 'gap_s':
            if cfg.TRIAL_PAUSE:
                bar.put_text('Press any key')
                bar.update()
                key = cv2.waitKey()
                if key == keys['esc'] or not state.value:
                    break
                bar.fill()
            bar.put_text('Trial %d / %d' % (trial, num_trials))
            event = 'gap'
            timer_trigger.reset()
        elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']:
            event = 'cue'
            bar.fill()
            bar.draw_cue()
            trigger.signal(cfg.tdef.CUE)
            timer_trigger.reset()
        elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']:
            event = 'dir_r'
            dir = dir_sequence[trial - 1]
            if dir == 'L':  # left
                bar.move('L', 100, overlay=True)
                trigger.signal(cfg.tdef.LEFT_READY)
            elif dir == 'R':  # right
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.RIGHT_READY)
            elif dir == 'U':  # up
                bar.move('U', 100, overlay=True)
                trigger.signal(cfg.tdef.UP_READY)
            elif dir == 'D':  # down
                bar.move('D', 100, overlay=True)
                trigger.signal(cfg.tdef.DOWN_READY)
            elif dir == 'B':  # both hands
                bar.move('L', 100, overlay=True)
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.BOTH_READY)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
            timer_trigger.reset()
        elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready:
            bar.fill()
            bar.draw_cue()
            event = 'dir'
            timer_trigger.reset()
            timer_dir.reset()
            if dir == 'L':  # left
                trigger.signal(cfg.tdef.LEFT_GO)
            elif dir == 'R':  # right
                trigger.signal(cfg.tdef.RIGHT_GO)
            elif dir == 'U':  # up
                trigger.signal(cfg.tdef.UP_GO)
            elif dir == 'D':  # down
                trigger.signal(cfg.tdef.DOWN_GO)
            elif dir == 'B':  # both
                trigger.signal(cfg.tdef.BOTH_GO)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
        elif event == 'dir' and timer_trigger.sec() > t_dir:
            event = 'gap_s'
            bar.fill()
            trial += 1
            logger.info('trial ' + str(trial - 1) + ' done')
            trigger.signal(cfg.tdef.BLANK)
            timer_trigger.reset()
            t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
            t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(-cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

        # protocol
        if event == 'dir':
            dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1)
            if dir == 'L':  # L
                bar.move('L', dx, overlay=True)
            elif dir == 'R':  # R
                bar.move('R', dx, overlay=True)
            elif dir == 'U':  # U
                bar.move('U', dx, overlay=True)
            elif dir == 'D':  # D
                bar.move('D', dx, overlay=True)
            elif dir == 'B':  # Both
                bar.move('L', dx, overlay=True)
                bar.move('R', dx, overlay=True)

        # wait for start
        if event == 'start':
            bar.put_text('Waiting to start')

        bar.update()
        key = 0xFF & cv2.waitKey(1)
        if key == keys['esc'] or not state.value:
            break

    bar.finish()

    with state.get_lock():
        state.value = 0
import neurodecode
import numpy as np
from neurodecode.triggers import TriggerDef

tdef = TriggerDef('triggerdef_16.ini')

DATA_DIRS = [r'D:\data\MI\rx1\train']
CHANNEL_PICKS = [5, 6, 7, 11]
'''"""""""""""""""""""""""""""
 Epochs and events of interest
"""""""""""""""""""""""""""'''
TRIGGERS = {tdef.LEFT_GO, tdef.RIGHT_GO}
EPOCH = [-2.0, 4.0]
EVENT_FILE = None
'''"""""""""""""""""""""""""""
 Baseline relative to onset while plotting
 None in index 0: beginning of data
 None in index 1: end of data
"""""""""""""""""""""""""""'''
BS_TIMES = (None, 0)
'''"""""""""""""""""""""""""""
 PSD
"""""""""""""""""""""""""""'''
FREQ_RANGE = np.arange(1, 40, 1)
'''"""""""""""""""""""""""""""
 Unit conversion
"""""""""""""""""""""""""""'''
MULTIPLIER = 10**6  # (V->uV)
'''"""""""""""""""""""""""""""
 Filters
"""""""""""""""""""""""""""'''