Ejemplo n.º 1
0
def run(cfg, state=mp.Value('i', 1), queue=None):
    """
    Training protocol for Alpha/Theta neurofeedback.
    """
    redirect_stdout_to_queue(logger, queue, 'INFO')

    # add tdef object
    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)

    # Extract features
    if not state.value:
        sys.exit(-1)
        
    freqs = np.arange(cfg.FEATURES['PSD']['fmin'], cfg.FEATURES['PSD']['fmax']+0.5, 1/cfg.FEATURES['PSD']['wlen'])
    featdata = features.compute_features(cfg)
    
    # Average the PSD over the windows
    window_avg_psd =  np.mean(np.squeeze(featdata['X_data']), 0)
    
    # Alpha ref over the alpha band
    alpha_ref = round(np.mean(window_avg_psd[freqs>=8]))
    alpha_thr = round(alpha_ref - (0.5 * np.std(window_avg_psd[freqs>=8]) ))
    
    # Theta ref over Theta band
    theta_ref = round(np.mean(window_avg_psd[freqs<8]))
    theta_thr = round(theta_ref - (0.5 * np.std(window_avg_psd[freqs<8]) ))
        
    logger.info('Theta ref = {}; alpha ref ={}' .format(theta_ref, alpha_ref))
    logger.info('Theta thr = {}; alpha thr ={}' .format(theta_thr, alpha_thr))
Ejemplo n.º 2
0
    def on_new_tdef_file(self, key, trigger_file):
        """
        Update the event QComboBox with the new events from the new tdef file.
        """
        self.tdef = trigger_def(trigger_file)

        if self.events:
            self.on_update_VBoxLayout()
Ejemplo n.º 3
0
 def on_new_tdef_file(self, key, trigger_file):
     """
     Update the QComboBox with the new events from the new tdef file.
     """
     self.clear_hBoxLayout()
     tdef = trigger_def(trigger_file)
     nb_directions = 4
     #  Convert 'None' to real None (real None is removed when selected in the GUI)
     tdef_values = [None if i == 'None' else i for i in list(tdef.by_name)]
     self.create_the_comboBoxes(self.chosen_value, tdef_values,
                                nb_directions)
Ejemplo n.º 4
0
def merge_events(trigger_file, events, rawfile_in, rawfile_out):
    tdef = trigger_def(trigger_file)
    raw, eve = pu.load_raw(rawfile_in)

    logger.info('=== Before merging ===')
    notfounds = []
    for key in np.unique(eve[:, 2]):
        if key in tdef.by_value:
            logger.info(
                '%s: %d events' %
                (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
        else:
            logger.info('%d: %d events' %
                        (key, len(np.where(eve[:, 2] == key)[0])))
            notfounds.append(key)
    if notfounds:
        for key in notfounds:
            logger.warning('Key %d was not found in the definition file.' %
                           key)

    for key in events:
        ev_src = events[key]
        ev_out = tdef.by_name[key]
        x = []
        for e in ev_src:
            x.append(np.where(eve[:, 2] == tdef.by_name[e])[0])
        eve[np.concatenate(x), 2] = ev_out

    # sanity check
    dups = np.where(0 == np.diff(eve[:, 0]))[0]
    assert len(dups) == 0

    # reset trigger channel
    raw._data[0] *= 0
    raw.add_events(eve, 'TRIGGER')
    raw.save(rawfile_out, overwrite=True)

    logger.info('=== After merging ===')
    for key in np.unique(eve[:, 2]):
        if key in tdef.by_value:
            logger.info(
                '%s: %d events' %
                (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
        else:
            logger.info('%s: %d events' %
                        (key, len(np.where(eve[:, 2] == key)[0])))
Ejemplo n.º 5
0
def run(cfg,
        state=mp.Value('i', 1),
        queue=None,
        interactive=False,
        cv_file=None,
        feat_file=None,
        logger=logger):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # add tdef object
    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)

    # Extract features
    if not state.value:
        sys.exit(-1)
    featdata = features.compute_features(cfg)

    # Find optimal threshold for TPR balancing
    #balance_tpr(cfg, featdata)

    # Perform cross validation
    if not state.value:
        sys.exit(-1)

    if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] is not None:
        cross_validate(cfg, featdata, cv_file=cv_file)

    # Train a decoder
    if not state.value:
        sys.exit(-1)

    if cfg.EXPORT_CLS is True:
        train_decoder(cfg, featdata, feat_file=feat_file)

    with state.get_lock():
        state.value = 0
Ejemplo n.º 6
0
def stream_player(server_name,
                  fif_file,
                  chunk_size,
                  auto_restart=True,
                  wait_start=True,
                  repeat=np.float('inf'),
                  high_resolution=False,
                  trigger_file=None):
    """
    Input
    =====
    server_name: LSL server name.
    fif_file: fif file to replay.
    chunk_size: number of samples to send at once (usually 16-32 is good enough).
    auto_restart: play from beginning again after reaching the end.
    wait_start: wait for user to start in the beginning.
    repeat: number of loops to play.
    high_resolution: use perf_counter() instead of sleep() for higher time resolution
                     but uses much more cpu due to polling.
    trigger_file: used to convert event numbers into event strings for readability.
    
    Note: Run neurodecode.set_log_level('DEBUG') to print out the relative time stamps since started.
    
    """
    raw, events = pu.load_raw(fif_file)
    sfreq = raw.info['sfreq']  # sampling frequency
    n_channels = len(raw.ch_names)  # number of channels
    if trigger_file is not None:
        tdef = trigger_def(trigger_file)
    try:
        event_ch = raw.ch_names.index('TRIGGER')
    except ValueError:
        event_ch = None
    if raw is not None:
        logger.info_green('Successfully loaded %s' % fif_file)
        logger.info('Server name: %s' % server_name)
        logger.info('Sampling frequency %.3f Hz' % sfreq)
        logger.info('Number of channels : %d' % n_channels)
        logger.info('Chunk size : %d' % chunk_size)
        for i, ch in enumerate(raw.ch_names):
            logger.info('%d %s' % (i, ch))
        logger.info('Trigger channel : %s' % event_ch)
    else:
        raise RuntimeError('Error while loading %s' % fif_file)

    # set server information
    sinfo = pylsl.StreamInfo(server_name, channel_count=n_channels, channel_format='float32',\
        nominal_srate=sfreq, type='EEG', source_id=server_name)
    desc = sinfo.desc()
    channel_desc = desc.append_child("channels")
    for ch in raw.ch_names:
        channel_desc.append_child('channel').append_child_value('label', str(ch))\
            .append_child_value('type','EEG').append_child_value('unit','microvolts')
    desc.append_child('amplifier').append_child('settings').append_child_value(
        'is_slave', 'false')
    desc.append_child('acquisition').append_child_value(
        'manufacturer',
        'NeuroDecode').append_child_value('serial_number', 'N/A')
    outlet = pylsl.StreamOutlet(sinfo, chunk_size=chunk_size)

    if wait_start:
        input('Press Enter to start streaming.')
    logger.info('Streaming started')

    idx_chunk = 0
    t_chunk = chunk_size / sfreq
    finished = False
    if high_resolution:
        t_start = time.perf_counter()
    else:
        t_start = time.time()

    # start streaming
    played = 1
    while played < repeat:
        idx_current = idx_chunk * chunk_size
        chunk = raw._data[:, idx_current:idx_current + chunk_size]
        data = chunk.transpose().tolist()
        if idx_current >= raw._data.shape[1] - chunk_size:
            finished = True
        if high_resolution:
            # if a resolution over 2 KHz is needed
            t_sleep_until = t_start + idx_chunk * t_chunk
            while time.perf_counter() < t_sleep_until:
                pass
        else:
            # time.sleep() can have 500 us resolution using the tweak tool provided.
            t_wait = t_start + idx_chunk * t_chunk - time.time()
            if t_wait > 0.001:
                time.sleep(t_wait)
        outlet.push_chunk(data)
        logger.debug('[%8.3fs] sent %d samples (LSL %8.3f)' %
                     (time.perf_counter(), len(data), pylsl.local_clock()))
        if event_ch is not None:
            event_values = set(chunk[event_ch]) - set([0])
            if len(event_values) > 0:
                if trigger_file is None:
                    logger.info('Events: %s' % event_values)
                else:
                    for event in event_values:
                        if event in tdef.by_value:
                            logger.info('Events: %s (%s)' %
                                        (event, tdef.by_value[event]))
                        else:
                            logger.info('Events: %s (Undefined event)' % event)
        idx_chunk += 1

        if finished:
            if auto_restart is False:
                input(
                    'Reached the end of data. Press Enter to restart or Ctrl+C to stop.'
                )
            else:
                logger.info('Reached the end of data. Restarting.')
            idx_chunk = 0
            finished = False
            if high_resolution:
                t_start = time.perf_counter()
            else:
                t_start = time.time()
            played += 1
Ejemplo n.º 7
0
def run(cfg, state=mp.Value('i', 1), queue=None):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass
    #  Protocol start if equals to 1
    if not state.value:
        sys.exit()

    refresh_delay = 1.0 / cfg.REFRESH_RATE

    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)

    # visualizer
    keys = {
        'left': 81,
        'right': 83,
        'up': 82,
        'down': 84,
        'pgup': 85,
        'pgdn': 86,
        'home': 80,
        'end': 87,
        'space': 32,
        'esc': 27,
        ',': 44,
        '.': 46,
        's': 115,
        'c': 99,
        '[': 91,
        ']': 93,
        '1': 49,
        '!': 33,
        '2': 50,
        '@': 64,
        '3': 51,
        '#': 35
    }
    color = dict(G=(20, 140, 0),
                 B=(210, 0, 0),
                 R=(0, 50, 200),
                 Y=(0, 215, 235),
                 K=(0, 0, 0),
                 w=(200, 200, 200))

    dir_sequence = []
    for x in range(cfg.TRIALS_EACH):
        dir_sequence.extend(cfg.DIRECTIONS)
    random.shuffle(dir_sequence)
    num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH

    event = 'start'
    trial = 1

    # Hardware trigger
    if cfg.TRIGGER_DEVICE is None:
        logger.warning(
            'No trigger device set. Press Ctrl+C to stop or Enter to continue.'
        )
        #input()
    trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        logger.error(
            '\n** Error connecting to USB2LPT device. Use a mock trigger instead?'
        )
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # timers
    timer_trigger = qc.Timer()
    timer_dir = qc.Timer()
    timer_refresh = qc.Timer()
    t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'],
                                                cfg.TIMINGS['DIR_RANDOMIZE'])
    t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(
        -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

    bar = BarVisual(cfg.GLASS_USE,
                    screen_pos=cfg.SCREEN_POS,
                    screen_size=cfg.SCREEN_SIZE)
    bar.fill()
    bar.glass_draw_cue()

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based)
        if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']:
            event = 'gap_s'
            bar.fill()
            timer_trigger.reset()
            trigger.signal(cfg.tdef.INIT)
        elif event == 'gap_s':
            if cfg.TRIAL_PAUSE:
                bar.put_text('Press any key')
                bar.update()
                key = cv2.waitKey()
                if key == keys['esc'] or not state.value:
                    break
                bar.fill()
            bar.put_text('Trial %d / %d' % (trial, num_trials))
            event = 'gap'
            timer_trigger.reset()
        elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']:
            event = 'cue'
            bar.fill()
            bar.draw_cue()
            trigger.signal(cfg.tdef.CUE)
            timer_trigger.reset()
        elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']:
            event = 'dir_r'
            dir = dir_sequence[trial - 1]
            if dir == 'L':  # left
                bar.move('L', 100, overlay=True)
                trigger.signal(cfg.tdef.LEFT_READY)
            elif dir == 'R':  # right
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.RIGHT_READY)
            elif dir == 'U':  # up
                bar.move('U', 100, overlay=True)
                trigger.signal(cfg.tdef.UP_READY)
            elif dir == 'D':  # down
                bar.move('D', 100, overlay=True)
                trigger.signal(cfg.tdef.DOWN_READY)
            elif dir == 'B':  # both hands
                bar.move('L', 100, overlay=True)
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.BOTH_READY)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
            timer_trigger.reset()
        elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready:
            bar.fill()
            bar.draw_cue()
            event = 'dir'
            timer_trigger.reset()
            timer_dir.reset()
            if dir == 'L':  # left
                trigger.signal(cfg.tdef.LEFT_GO)
            elif dir == 'R':  # right
                trigger.signal(cfg.tdef.RIGHT_GO)
            elif dir == 'U':  # up
                trigger.signal(cfg.tdef.UP_GO)
            elif dir == 'D':  # down
                trigger.signal(cfg.tdef.DOWN_GO)
            elif dir == 'B':  # both
                trigger.signal(cfg.tdef.BOTH_GO)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
        elif event == 'dir' and timer_trigger.sec() > t_dir:
            event = 'gap_s'
            bar.fill()
            trial += 1
            logger.info('trial ' + str(trial - 1) + ' done')
            trigger.signal(cfg.tdef.BLANK)
            timer_trigger.reset()
            t_dir = cfg.TIMINGS['DIR'] + random.uniform(
                -cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
            t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(
                -cfg.TIMINGS['READY_RANDOMIZE'],
                cfg.TIMINGS['READY_RANDOMIZE'])

        # protocol
        if event == 'dir':
            dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1)
            if dir == 'L':  # L
                bar.move('L', dx, overlay=True)
            elif dir == 'R':  # R
                bar.move('R', dx, overlay=True)
            elif dir == 'U':  # U
                bar.move('U', dx, overlay=True)
            elif dir == 'D':  # D
                bar.move('D', dx, overlay=True)
            elif dir == 'B':  # Both
                bar.move('L', dx, overlay=True)
                bar.move('R', dx, overlay=True)

        # wait for start
        if event == 'start':
            bar.put_text('Waiting to start')

        bar.update()
        key = 0xFF & cv2.waitKey(1)
        if key == keys['esc'] or not state.value:
            break

    bar.finish()

    with state.get_lock():
        state.value = 0
Ejemplo n.º 8
0
import neurodecode
import numpy as np
import neurodecode.utils.q_common as qc
from neurodecode.triggers.trigger_def import trigger_def

tdef = trigger_def('triggerdef_16.ini')

DATA_DIRS = [r'D:\data\MI\rx1\train']
CHANNEL_PICKS = [5, 6, 7, 11]

'''"""""""""""""""""""""""""""
 Epochs and events of interest
"""""""""""""""""""""""""""'''
TRIGGERS = {tdef.LEFT_GO, tdef.RIGHT_GO}
EPOCH = [-2.0, 4.0]
EVENT_FILE = None

'''"""""""""""""""""""""""""""
 Baseline relative to onset while plotting
 None in index 0: beginning of data
 None in index 1: end of data
"""""""""""""""""""""""""""'''
BS_TIMES = (None, 0)

'''"""""""""""""""""""""""""""
 PSD
"""""""""""""""""""""""""""'''
FREQ_RANGE = np.arange(1, 40, 1)

'''"""""""""""""""""""""""""""
 Unit conversion
Ejemplo n.º 9
0
def run(cfg, state=mp.Value('i', 1), queue=None):
    """
    Offline protocol
    """

    # visualizer
    keys = {
        'left': 81,
        'right': 83,
        'up': 82,
        'down': 84,
        'pgup': 85,
        'pgdn': 86,
        'home': 80,
        'end': 87,
        'space': 32,
        'esc': 27,
        ',': 44,
        '.': 46,
        's': 115,
        'c': 99,
        '[': 91,
        ']': 93,
        '1': 49,
        '!': 33,
        '2': 50,
        '@': 64,
        '3': 51,
        '#': 35
    }

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass

    # Protocol runs if state equals to 1
    if not state.value:
        sys.exit(-1)

    global_timer = qc.Timer(autoreset=False)

    # Init trigger communication
    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)
    trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        logger.error('\n** Error connecting to trigger device.')
        raise RuntimeError

    # Preload the starting voice
    pgmixer.init()
    pgmixer.music.load(cfg.START_VOICE)

    # Init feedback
    viz = BarVisual(cfg.GLASS_USE,
                    screen_pos=cfg.SCREEN_POS,
                    screen_size=cfg.SCREEN_SIZE)
    viz.fill()
    viz.put_text('Close your eyes and relax')
    viz.update()

    # PLay the start voice
    pgmixer.music.play()

    # Wait a key press
    key = 0xFF & cv2.waitKey(0)
    if key == keys['esc'] or not state.value:
        sys.exit(-1)

    viz.fill()
    viz.put_text('Recording in progress')
    viz.update()

    #----------------------------------------------------------------------
    # Main
    #----------------------------------------------------------------------
    trigger.signal(cfg.tdef.INIT)

    while state.value == 1 and global_timer.sec() < cfg.GLOBAL_TIME:

        key = cv2.waitKey(1)
        if key == keys['esc']:
            with state.get_lock():
                state.value = 0

    trigger.signal(cfg.tdef.END)

    # Remove the text
    viz.fill()
    viz.put_text('Recording is finished')
    viz.update()

    # Ending voice
    pgmixer.music.load(cfg.END_VOICE)
    pgmixer.music.play()
    time.sleep(5)

    # Close cv2 window
    viz.finish()
Ejemplo n.º 10
0
def run(cfg, state=mp.Value('i', 1), queue=None):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass

    #  Protocol runs if state equals to 1
    if not state.value:
        sys.exit(-1)

    if cfg.FAKE_CLS is None:
        # chooose amp
        if cfg.AMP_NAME is None and cfg.AMP_SERIAL is None:
            amp_name, amp_serial = pu.search_lsl(state, ignore_markers=True)
        else:
            amp_name = cfg.AMP_NAME
            amp_serial = cfg.AMP_SERIAL
        fake_dirs = None
    else:
        amp_name = None
        amp_serial = None
        fake_dirs = [v for (k, v) in cfg.DIRECTIONS]

    # events and triggers
    tdef = trigger_def(cfg.TRIGGER_FILE)
    #if cfg.TRIGGER_DEVICE is None:
    #    input('\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.')
    trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        logger.error(
            'Cannot connect to USB2LPT device. Use a mock trigger instead?')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # init classification
    decoder = BCIDecoderDaemon(
        cfg.DECODER_FILE,
        buffer_size=1.0,
        fake=(cfg.FAKE_CLS is not None),
        amp_name=amp_name,
        amp_serial=amp_serial,
        fake_dirs=fake_dirs,
        parallel=cfg.PARALLEL_DECODING[cfg.PARALLEL_DECODING['selected']],
        alpha_new=cfg.PROB_ALPHA_NEW)

    # OLD: requires trigger values to be always defined
    #labels = [tdef.by_value[x] for x in decoder.get_labels()]
    # NEW: events can be mapped into integers:
    labels = []
    dirdata = set([d[1] for d in cfg.DIRECTIONS])
    for x in decoder.get_labels():
        if x not in dirdata:
            labels.append(tdef.by_value[x])
        else:
            labels.append(x)

    # map class labels to bar directions
    bar_def = {label: str(dir) for dir, label in cfg.DIRECTIONS}
    bar_dirs = [bar_def[l] for l in labels]
    dir_seq = []
    for x in range(cfg.TRIALS_EACH):
        dir_seq.extend(bar_dirs)
    if cfg.TRIALS_RANDOMIZE:
        random.shuffle(dir_seq)
    else:
        dir_seq = [d[0] for d in cfg.DIRECTIONS] * cfg.TRIALS_EACH
    num_trials = len(dir_seq)

    logger.info('Initializing decoder.')
    while decoder.is_running() is 0:
        time.sleep(0.01)

    # bar visual object
    if cfg.FEEDBACK_TYPE == 'BAR':
        from neurodecode.protocols.viz_bars import BarVisual
        visual = BarVisual(cfg.GLASS_USE,
                           screen_pos=cfg.SCREEN_POS,
                           screen_size=cfg.SCREEN_SIZE)
    if cfg.FEEDBACK_TYPE == 'COLORS':
        from neurodecode.protocols.viz_colors import ColorVisual
        visual = ColorVisual(cfg.GLASS_USE,
                             screen_pos=cfg.SCREEN_POS,
                             screen_size=cfg.SCREEN_SIZE)
    elif cfg.FEEDBACK_TYPE == 'BODY':
        assert hasattr(cfg, 'FEEDBACK_IMAGE_PATH'
                       ), 'FEEDBACK_IMAGE_PATH is undefined in your config.'
        from neurodecode.protocols.viz_human import BodyVisual
        visual = BodyVisual(cfg.FEEDBACK_IMAGE_PATH,
                            use_glass=cfg.GLASS_USE,
                            screen_pos=cfg.SCREEN_POS,
                            screen_size=cfg.SCREEN_SIZE)
    visual.put_text('Waiting to start')
    if cfg.LOG_PROBS:
        logdir = qc.parse_path_list(cfg.DECODER_FILE)[0]
        probs_logfile = time.strftime(logdir + "probs-%Y%m%d-%H%M%S.txt",
                                      time.localtime())
    else:
        probs_logfile = None
    feedback = Feedback(cfg, state, visual, tdef, trigger, probs_logfile)

    # start
    trial = 1
    dir_detected = []
    prob_history = {c: [] for c in bar_dirs}
    while trial <= num_trials:
        if cfg.SHOW_TRIALS:
            title_text = 'Trial %d / %d' % (trial, num_trials)
        else:
            title_text = 'Ready'
        true_label = dir_seq[trial - 1]

        # profiling feedback
        #import cProfile
        #pr = cProfile.Profile()
        #pr.enable()
        result = feedback.classify(decoder,
                                   true_label,
                                   title_text,
                                   bar_dirs,
                                   prob_history=prob_history)
        #pr.disable()
        #pr.print_stats(sort='time')

        if result is None:
            break
        else:
            pred_label = result
        dir_detected.append(pred_label)

        if cfg.WITH_REX is True and pred_label == true_label:
            # if cfg.WITH_REX is True:
            if pred_label == 'U':
                rex_dir = 'N'
            elif pred_label == 'L':
                rex_dir = 'W'
            elif pred_label == 'R':
                rex_dir = 'E'
            elif pred_label == 'D':
                rex_dir = 'S'
            else:
                logger.warning('Rex cannot execute undefined action %s' %
                               pred_label)
                rex_dir = None
            if rex_dir is not None:
                visual.move(pred_label, 100, overlay=False, barcolor='B')
                visual.update()
                logger.info('Executing Rex action %s' % rex_dir)
                os.system('%s/Rex/RexControlSimple.exe %s %s' %
                          (pycnbi.ROOT, cfg.REX_COMPORT, rex_dir))
                time.sleep(8)

        if true_label == pred_label:
            msg = 'Correct'
        else:
            msg = 'Wrong'
        if cfg.TRIALS_RETRY is False or true_label == pred_label:
            logger.info('Trial %d: %s (%s -> %s)' %
                        (trial, msg, true_label, pred_label))
            trial += 1

    if len(dir_detected) > 0:
        # write performance and log results
        fdir, _, _ = qc.parse_path_list(cfg.DECODER_FILE)
        logfile = time.strftime(fdir + "/online-%Y%m%d-%H%M%S.txt",
                                time.localtime())
        with open(logfile, 'w') as fout:
            fout.write('Ground-truth,Prediction\n')
            for gt, dt in zip(dir_seq, dir_detected):
                fout.write('%s,%s\n' % (gt, dt))
            cfmat, acc = qc.confusion_matrix(dir_seq, dir_detected)
            fout.write('\nAccuracy %.3f\nConfusion matrix\n' % acc)
            fout.write(cfmat)
            logger.info('Log exported to %s' % logfile)
        print('\nAccuracy %.3f\nConfusion matrix\n' % acc)
        print(cfmat)

    visual.finish()

    with state.get_lock():
        state.value = 0

    if decoder:
        decoder.stop()
    '''
    # automatic thresholding
    if prob_history and len(bar_dirs) == 2:
        total = sum(len(prob_history[c]) for c in prob_history)
        fout = open(probs_logfile, 'a')
        msg = 'Automatic threshold optimization.\n'
        max_acc = 0
        max_bias = 0
        for bias in np.arange(-0.99, 1.00, 0.01):
            corrects = 0
            for p in prob_history[bar_dirs[0]]:
                p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1
                if p_biased >= 0.5:
                    corrects += 1
            for p in prob_history[bar_dirs[1]]:
                p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1
                if p_biased < 0.5:
                    corrects += 1
            acc = corrects / total
            msg += '%s%.2f: %.3f\n' % (bar_dirs[0], bias, acc)
            if acc > max_acc:
                max_acc = acc
                max_bias = bias
        msg += 'Max acc = %.3f at bias %.2f\n' % (max_acc, max_bias)
        fout.write(msg)
        fout.close()
        print(msg)
    '''

    logger.info('Finished.')
Ejemplo n.º 11
0
def run(cfg, amp_name, amp_serial, state=mp.Value('i', 1), experiment_mode=True, baseline=False):
    """
    Online protocol for Alpha/Theta neurofeedback.
    """
    #----------------------------------------------------------------------
    # LSL stream connection
    #----------------------------------------------------------------------

    sr = protocol_utils.connect_lsl_stream(amp_name=amp_name,
                                           amp_serial=amp_serial,
                                           window_size=cfg['window_size'],
                                           buffer_size=cfg['buffer_size'])

    sfreq = sr.get_sample_rate()
    trg_ch = sr.get_trigger_channel()

    #----------------------------------------------------------------------
    # PSD estimators initialization
    #----------------------------------------------------------------------
    psde_alpha = protocol_utils.init_psde(*list(cfg['alpha_band_freq'].values()),
                                          sampling_frequency=cfg['sampling_frequency'],
                                          n_jobs=cfg['n_jobs'])
    psde_theta = protocol_utils.init_psde(*list(cfg['theta_band_freq'].values()),
                                          sampling_frequency=cfg['sampling_frequency'],
                                          n_jobs=cfg['n_jobs'])
    #----------------------------------------------------------------------
    # Initialize the feedback sounds
    #----------------------------------------------------------------------
    sound_1, sound_2 = protocol_utils.init_feedback_sounds(cfg['music_state_1_path'],
                                                           cfg['music_state_2_path'])


    #----------------------------------------------------------------------
    # Main
    #----------------------------------------------------------------------
    global_timer   = qc.Timer(autoreset=False)
    internal_timer = qc.Timer(autoreset=True)

    pgmixer.init()
    if experiment_mode:
        # Init trigger communication
        trigger_signals = trigger_def(cfg['trigger_file'])
        trigger = pyLptControl.Trigger(state, cfg['trigger_device'])
        if trigger.init(50) == False:
            logger.error('\n** Error connecting to trigger device.')
            raise RuntimeError


        # Preload the starting voice
        print(cfg['start_voice_file'])
        pgmixer.music.load(cfg['start_voice_file'])

        # Init feedback
        viz = BarVisual(False,
                        screen_pos=cfg['screen_pos'],
                        screen_size=cfg['screen_size'])

        viz.fill()
        viz.put_text('Close your eyes and relax')
        viz.update()

        pgmixer.music.play()

        # Wait a key press
        key = 0xFF & cv2.waitKey(0)
        if key == KEYS['esc'] or not state.value:
            sys.exit(-1)

        print('recording started')

        trigger.signal(trigger_signals.INIT)

    state = 'RATIO_FEEDBACK'

    if not baseline:
        sound_1.play(loops=-1)
        sound_2.play(loops=-1)

    current_max = 0
    last_ts = None

    last_ratio = None
    measured_psd_ratios = np.full(cfg['window_size_psd_max'], np.nan)

    while global_timer.sec() < cfg['global_time']:

        #----------------------------------------------------------------------
        # Data acquisition
        #----------------------------------------------------------------------
        #  Pz = 8
        sr.acquire()
        window, tslist = sr.get_window()    # window = [samples x channels]
        window = window.T                   # window = [channels x samples]

        # Check if proper real-time acquisition
        if last_ts is not None:
            tsnew = np.where(np.array(tslist) > last_ts)[0]
            if len(tsnew) == 0:
                logger.warning('There seems to be delay in receiving data.')
                time.sleep(1)
                continue


        # Spatial filtering
        window = pu.preprocess(window,
                               sfreq=sfreq,
                               spatial=cfg.get('spatial_filter'),
                               spatial_ch=cfg.get('spatial_channels'))

        #----------------------------------------------------------------------
        # Computing the Power Spectrum Densities using multitapers
        #----------------------------------------------------------------------
        # PSD
        if not baseline:
            if cfg['feature_type'] == FeatureType.THETA:
                psd_theta = protocol_utils.compute_psd(window, psde_theta)

                feature = psd_theta
            elif cfg['feature_type'] == FeatureType.ALPHA_THETA:
                psd_alpha = protocol_utils.compute_psd(window, psde_alpha)
                psd_theta = protocol_utils.compute_psd(window, psde_theta)

                feature = psd_alpha / psd_theta

            measured_psd_ratios = add_to_queue(measured_psd_ratios, feature)
            current_music_ratio  = feature / np.max(measured_psd_ratios[~np.isnan(measured_psd_ratios)])

            #current_music_ratio  = feature / np.max(measured_psd_ratios)

            if last_ratio is not None:
                applied_music_ratio = last_ratio + (current_music_ratio - last_ratio) * 0.25
            else:
                applied_music_ratio = current_music_ratio

            mix_sounds(style=cfg['music_mix_style'],
                       sounds=(sound_1, sound_2),
                       feature_value=applied_music_ratio)

            print((f"{cfg['feature_type']}: {feature:0.3f}"
                   f"\t, current_music_ratio: {current_music_ratio:0.3f}"
                   f"\t, applied music ratio: {applied_music_ratio:0.3f}"
                   ))

            last_ratio = applied_music_ratio


        last_ts = tslist[-1]
        internal_timer.sleep_atleast(cfg['timer_sleep'])


    if not baseline:
        sound_1.fadeout(3)
        sound_2.fadeout(3)


    if experiment_mode:
        trigger.signal(trigger_signals.END)

        # Remove the text
        viz.fill()
        viz.put_text('Recording is finished')
        viz.update()

        # Ending voice
        pgmixer.music.load(cfg['end_voice_file'])
        pgmixer.music.play()
        time.sleep(5)

        # Close cv2 window
        viz.finish()

    print('done')