Beispiel #1
0
def run(cfg, state=mp.Value('i', 1), queue=None, interactive=False, cv_file=None, feat_file=None, logger=logger):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # add tdef object
    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)

    # Extract features
    if not state.value:
        sys.exit(-1)
    featdata = features.compute_features(cfg)

    # Find optimal threshold for TPR balancing
    #balance_tpr(cfg, featdata)

    # Perform cross validation
    if not state.value:
        sys.exit(-1)

    if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] is not None:
        cross_validate(cfg, featdata, cv_file=cv_file)

    # Train a decoder
    if not state.value:
        sys.exit(-1)

    if cfg.EXPORT_CLS is True:
        train_decoder(cfg, featdata, feat_file=feat_file)

    with state.get_lock():
        state.value = 0
Beispiel #2
0
def merge_events(trigger_file, events, eeg_in, eeg_out):
    tdef = trigger_def(trigger_file)
    raw, eve = pu.load_raw(eeg_in)

    print('\nEvents before merging')
    for key in np.unique(eve[:, 2]):
        print('%s: %d' %
              (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))

    for key in events:
        ev_src = events[key]
        ev_out = tdef.by_key[key]
        x = []
        for e in ev_src:
            x.append(np.where(eve[:, 2] == tdef.by_key[e])[0])
        eve[np.concatenate(x), 2] = ev_out

    # sanity check
    dups = np.where(0 == np.diff(eve[:, 0]))[0]
    assert len(dups) == 0
    assert max(eve[:, 2]) <= max(tdef.by_value.keys())

    # reset trigger channel
    raw._data[0] *= 0
    raw.add_events(eve, 'TRIGGER')
    raw.save(eeg_out, overwrite=True)

    print('\nResulting events')
    for key in np.unique(eve[:, 2]):
        print('%s: %d' %
              (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
def run(cfg, state=mp.Value('i', 1), queue=None):
    """
    Training protocol for Alpha/Theta neurofeedback.
    """
    redirect_stdout_to_queue(logger, queue, 'INFO')

    # add tdef object
    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)

    # Extract features
    if not state.value:
        sys.exit(-1)

    raw = mne.io.read_raw_brainvision(cfg.DATA_PATH, preload=True)
    outfile = cfg.OUT_MICROSTATES_FILE

    ch_names = [
        'P3', 'C3', 'F3', 'Fz', 'F4', 'C4', 'P4', 'Cz', 'Pz', 'Fp1', 'Fp2',
        'T3', 'T5', 'O1', 'O2', 'F7', 'F8', 'T6', 'T4'
    ]
    raw.pick_channels(ch_names)

    raw.set_montage('standard_1005')
    raw.set_eeg_reference('average')
    raw.filter(1, 30)
    maps, segmentation = microstates.segment(raw.get_data(),
                                             n_states=4,
                                             max_n_peaks=10000000,
                                             max_iter=5000,
                                             normalize=True)
    np.savetxt(outfile, maps, delimiter=" ")
Beispiel #4
0
def preprocess(loadedraw, events,\
               APPLY_CAR,\
               l_freq,\
               h_freq,\
               filter_method,\
               tmin,\
               tmax,\
               tlow,\
               thigh,\
               n_jobs,\
               picks_feat,\
               baselineRange,\
               verbose=False):
    # Load raw, apply bandpass (if applicable), epoch
    raw = loadedraw.copy()
    # Di
    # %% Spatial filter - Common Average Reference (CAR)
    if APPLY_CAR:
        raw._data[1:] = raw._data[1:] - np.mean(raw._data[1:], axis=0)
    # print('Preprocess: CAR done')

    # %% Properties initialization
    tdef = trigger_def('triggerdef_errp.ini')
    sfreq = raw.info['sfreq']
    event_id = dict(correct=tdef.by_key['FEEDBACK_CORRECT'], wrong=tdef.by_key['FEEDBACK_WRONG'])
    # %% Bandpass temporal filtering
    b, a, zi = pu.butter_bandpass(h_freq, l_freq, sfreq,
                                  raw._data.shape[0] - 1)  # raw._data.shape[0]- 1 because  channel 0 is trigger
    if filter_method is 'NC' and cv_container is None:
        raw.filter(l_freq=2, filter_length='10s', h_freq=h_freq, n_jobs=n_jobs, picks=picks_feat, method='fft',
                   iir_params=None)  # method='iir'and irr_params=None -> filter with a 4th order Butterworth
    # print('Preprocess: NC_bandpass filtering done')
    if filter_method is 'LFILT':
        # print('Preprocess: LFILT filtering done')
        for x in range(1, raw._data.shape[0]):  # range starting from 1 because channel 0 is trigger
            # raw._data[x,:] = lfilter(b, a, raw._data[x,:])
            raw._data[x, :], zi[:, x - 1] = lfilter(b, a, raw._data[x, :], -1, zi[:, x - 1])
            # self.eeg[:,x], self.zi[:,x] = lfilter(b, a, self.eeg[:,x], -1,zi[:,x])

            # %% Epoching and baselining
            #	 = tmin-paddingLength
            #	t_upper = tmax+paddingLength
    t_lower = 0
    t_upper = thigh

    #	t_lower = 0
    #	t_upper = tmax+paddingLength

    epochs = mne.Epochs(raw, events=events, event_id=event_id, tmin=t_lower, tmax=t_upper, baseline=baselineRange,
                        picks=picks_feat, preload=True, proj=False, verbose=verbose)
    total_wframes = epochs.get_data().shape[2]
    print('Preprocess: Epoching done')
    #	if tmin != tmin_bkp:
    #		# if the baseline range was before the initial tmin, epochs was tricked to
    #		# to select this range (it expects that baseline is witin [tmin,tmax])
    #		# this part restore the initial tmin and prune the data
    #		epochs.tmin = tmin_bkp
    #		epochs._data = epochs._data[:,:,int((tmin_bkp-tmin)*sfreq):]
    return tdef, sfreq, event_id, b, a, zi, t_lower, t_upper, epochs, total_wframes
Beispiel #5
0
    def __init__(self, classifier=None, buffer_size=1.0, fake=False, amp_serial=None,\
                 amp_name=None, fake_dirs=None, parallel=None, alpha_new=None):
        """
        Params
        ------
        classifier: file name of the classifier
        buffer_size: buffer window size in seconds
        fake:
            False: Connect to an amplifier LSL server and decode
            True: Create a mock decoder (fake probabilities biased to 1.0)
        buffer_size: Buffer size in seconds.
        parallel: dict(period, stride, num_strides)
            period: Decoding period length for a single decoder in seconds.
            stride: Time step between decoders in seconds.
            num_strides: Number of decoders to run in parallel.
        alpha_new: exponential smoothing factor, real value in [0, 1].
            p_new = p_new * alpha_new + p_old * (1 - alpha_new)

        Example: If the decoder runs 32ms per cycle, we can set
                 period=0.04, stride=0.01, num_strides=4
                 to achieve 100 Hz decoding.
        """

        self.classifier = classifier
        self.buffer_sec = buffer_size
        self.startmsg = 'Decoder daemon started.'
        self.stopmsg = 'Decoder daemon stopped.'
        self.fake = fake
        self.amp_serial = amp_serial
        self.amp_name = amp_name
        self.parallel = parallel
        if alpha_new is None:
            alpha_new = 1
        if not 0 <= alpha_new <= 1:
            raise ValueError('alpha_new must be a real number between 0 and 1.')
        self.alpha_new = alpha_new
        self.alpha_old = 1 - alpha_new

        if fake == False or fake is None:
            self.model = qc.load_obj(self.classifier)
            if self.model == None:
                raise IOError('Error loading %s' % self.model)
            else:
                self.labels = self.model['cls'].classes_
                self.label_names = [self.model['classes'][k] for k in self.labels]
        else:
            # create a fake decoder with LEFT/RIGHT classes
            self.model = None
            tdef = trigger_def('triggerdef_16.ini')
            if type(fake_dirs) is not list:
                raise RuntimeError('Decoder(): wrong argument type of fake_dirs: %s.' % type(fake_dirs))
            self.labels = [tdef.by_key[t] for t in fake_dirs]
            self.label_names = [tdef.by_value[v] for v in self.labels]
            self.startmsg = '** WARNING: FAKE ' + self.startmsg
            self.stopmsg = 'FAKE ' + self.stopmsg

        self.psdlock = mp.Lock()
        self.reset()
        self.start()
Beispiel #6
0
 def on_new_tdef_file(self, key, trigger_file):
     """
     Update the event QComboBox with the new events from the new tdef file.
     """
     self.tdef = trigger_def(trigger_file)
     
     if self.events:
         self.on_update_VBoxLayout()            
Beispiel #7
0
class Basic:
    """
    Contains the basic parameters for the training modality of Motor Imagery protocol
    """
    '''"""""""""""""""""""""""""""
        DATA
    """""""""""""""""""""""""""'''
    # read all data files from this directory for training
    DATADIR = r'C:\LSL\pycnbi_local\z2\records\fif'

    # which trigger set?
    tdef = trigger_def('triggerdef_16.ini')
    TRIGGER_DEF = {tdef.LEFT_GO, tdef.RIGHT_GO}

    # epoch ranges in seconds relative to onset
    EPOCH = [0.5, 4.5]
    '''"""""""""""""""""""""""""""
        CHANNEL SPECIFICATION

     CHANNEL_PICKS
     Pick a subset of channels for PSD. Note that Python uses zero-based indexing.
     However, for fif files saved using PyCNBI library, index 0 is the trigger channel
     and data channels start from index 1. (to be consistent with MATLAB)
     None: Use all channels. Ignored if LOAD_PSD == True

     REF_CH_NEW: Re-reference to this set of channels, averaged if more than 1.
     REF_CH_OLD: Recover this channel which was used as reference channel.
    """""""""""""""""""""""""""'''

    #CHANNEL_PICKS = None
    CHANNEL_PICKS = [
        'Fp1', 'Fp2', 'Fz', 'F3', 'F4', 'F7', 'F8', 'Cz', 'C3', 'C4', 'P3',
        'Pz', 'P4'
    ]
    #CHANNEL_PICKS = CAP['ANTNEURO_64_NO_PERIPHERAL']
    EXCLUDES = ['M1', 'M2', 'EOG']
    REF_CH = None
    #REF_CH = ['CPz', ['M1', 'M2']]
    '''"""""""""""""""""""""""""""
        FILTERS
    """""""""""""""""""""""""""'''
    # apply spatial filter immediately after loading data
    SP_FILTER = 'car'  # None | 'car' | 'laplacian'
    # only consider the following channels while computing
    SP_CHANNELS = CHANNEL_PICKS  # CHANNEL_PICKS # None | list

    # apply spectrial filter immediately after applying SP_FILTER
    # Can be either overlap-add FIR or forward-backward IIR via filtfilt
    # Value: None or [lfreq, hfreq]
    # if lfreq < hfreq: bandpass
    # if lfreq > hfreq: bandstop
    # if lfreq == None: highpass
    # if hfreq == None: lowpass
    #TP_FILTER = [0.6, 4.0]
    TP_FILTER = None

    NOTCH_FILTER = None  # None or list of values
Beispiel #8
0
 def on_new_tdef_file(self, key, trigger_file):
     """
     Update the QComboBox with the new events from the new tdef file.
     """
     self.clear_hBoxLayout()
     tdef = trigger_def(trigger_file)
     nb_directions = 4
     #  Convert 'None' to real None (real None is removed when selected in the GUI)
     tdef_values = [ None if i == 'None' else i for i in list(tdef.by_name) ]
     self.create_the_comboBoxes(self.chosen_value, tdef_values, nb_directions)
Beispiel #9
0
def merge_events(trigger_file, events, rawfile_in, rawfile_out):
    tdef = trigger_def(trigger_file)
    raw, eve = pu.load_raw(rawfile_in)

    logger.info('=== Before merging ===')
    notfounds = []
    for key in np.unique(eve[:, 2]):
        if key in tdef.by_value:
            logger.info(
                '%s: %d events' %
                (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
        else:
            logger.info('%d: %d events' %
                        (key, len(np.where(eve[:, 2] == key)[0])))
            notfounds.append(key)
    if notfounds:
        for key in notfounds:
            logger.warning('Key %d was not found in the definition file.' %
                           key)

    for key in events:
        ev_src = events[key]
        ev_out = tdef.by_name[key]
        x = []
        for e in ev_src:
            x.append(np.where(eve[:, 2] == tdef.by_name[e])[0])
        eve[np.concatenate(x), 2] = ev_out

    # sanity check
    dups = np.where(0 == np.diff(eve[:, 0]))[0]
    assert len(dups) == 0

    # reset trigger channel
    raw._data[0] *= 0
    raw.add_events(eve, 'TRIGGER')
    raw.save(rawfile_out, overwrite=True)

    logger.info('=== After merging ===')
    for key in np.unique(eve[:, 2]):
        if key in tdef.by_value:
            logger.info(
                '%s: %d events' %
                (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
        else:
            logger.info('%s: %d events' %
                        (key, len(np.where(eve[:, 2] == key)[0])))
Beispiel #10
0
def run(cfg, queue=None, interactive=False, cv_file=None, feat_file=None):

    redirect_stdout_to_queue(queue)

    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)

    # Extract features
    featdata = features.compute_features(cfg)

    # Find optimal threshold for TPR balancing
    #balance_tpr(cfg, featdata)

    # Perform cross validation
    if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] is not None:
        cross_validate(cfg, featdata, cv_file=cv_file)

    # Train a decoder
    if cfg.EXPORT_CLS is True:
        train_decoder(cfg, featdata, feat_file=feat_file)
Beispiel #11
0
def config_run(cfg_module):
    if not (os.path.exists(cfg_module) and os.path.isfile(cfg_module)):
        raise IOError('%s cannot be loaded.' % os.path.realpath(cfg_module))
    cfg = load_cfg(cfg_module)
    if cfg.FAKE_CLS is None:
        # chooose amp
        if cfg.AMP_NAME is None and cfg.AMP_SERIAL is None:
            amp_name, amp_serial = pu.search_lsl(ignore_markers=True)
        else:
            amp_name = cfg.AMP_NAME
            amp_serial = cfg.AMP_SERIAL
        fake_dirs = None
    else:
        amp_name = None
        amp_serial = None
        fake_dirs = [v for (k, v) in cfg.DIRECTIONS]

    # events and triggers
    tdef = trigger_def(cfg.TRIGGER_DEF)
    if cfg.TRIGGER_DEVICE is None:
        input(
            '\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.'
        )
    trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        qc.print_c(
            '\n** Error connecting to USB2LPT device. Use a mock trigger instead?',
            'R')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # init classification
    decoder = BCIDecoderDaemon(cfg.CLS_MI,
                               buffer_size=1.0,
                               fake=(cfg.FAKE_CLS is not None),
                               amp_name=amp_name,
                               amp_serial=amp_serial,
                               fake_dirs=fake_dirs,
                               parallel=cfg.PARALLEL_DECODING,
                               alpha_new=cfg.PROB_ALPHA_NEW)

    # OLD: requires trigger values to be always defined
    #labels = [tdef.by_value[x] for x in decoder.get_labels()]
    # NEW: events can be mapped into integers:
    labels = []
    dirdata = set([d[1] for d in cfg.DIRECTIONS])
    for x in decoder.get_labels():
        if x not in dirdata:
            labels.append(tdef.by_value[x])
        else:
            labels.append(x)

    # map class labels to bar directions
    bar_def = {label: str(dir) for dir, label in cfg.DIRECTIONS}
    bar_dirs = [bar_def[l] for l in labels]
    dir_seq = []
    for x in range(cfg.TRIALS_EACH):
        dir_seq.extend(bar_dirs)
    if cfg.TRIALS_RANDOMIZE:
        random.shuffle(dir_seq)
    else:
        dir_seq = [d[0] for d in cfg.DIRECTIONS] * cfg.TRIALS_EACH
    num_trials = len(dir_seq)

    qc.print_c('Initializing decoder.', 'W')
    while decoder.is_running() is 0:
        time.sleep(0.01)

    # bar visual object
    if cfg.FEEDBACK_TYPE == 'BAR':
        from pycnbi.protocols.viz_bars import BarVisual
        visual = BarVisual(cfg.GLASS_USE,
                           screen_pos=cfg.SCREEN_POS,
                           screen_size=cfg.SCREEN_SIZE)
    elif cfg.FEEDBACK_TYPE == 'BODY':
        assert hasattr(cfg,
                       'IMAGE_PATH'), 'IMAGE_PATH is undefined in your config.'
        from pycnbi.protocols.viz_human import BodyVisual
        visual = BodyVisual(cfg.IMAGE_PATH,
                            use_glass=cfg.GLASS_USE,
                            screen_pos=cfg.SCREEN_POS,
                            screen_size=cfg.SCREEN_SIZE)
    visual.put_text('Waiting to start')
    if cfg.LOG_PROBS:
        logdir = qc.parse_path_list(cfg.CLS_MI)[0]
        probs_logfile = time.strftime(logdir + "probs-%Y%m%d-%H%M%S.txt",
                                      time.localtime())
    else:
        probs_logfile = None
    feedback = Feedback(cfg, visual, tdef, trigger, probs_logfile)

    # start
    trial = 1
    dir_detected = []
    prob_history = {c: [] for c in bar_dirs}
    while trial <= num_trials:
        if cfg.SHOW_TRIALS:
            title_text = 'Trial %d / %d' % (trial, num_trials)
        else:
            title_text = 'Ready'
        true_label = dir_seq[trial - 1]

        # profiling feedback
        #import cProfile
        #pr = cProfile.Profile()
        #pr.enable()
        result = feedback.classify(decoder,
                                   true_label,
                                   title_text,
                                   bar_dirs,
                                   prob_history=prob_history)
        #pr.disable()
        #pr.print_stats(sort='time')

        if result is None:
            break
        else:
            pred_label = result
        dir_detected.append(pred_label)

        if cfg.WITH_REX is True and pred_label == true_label:
            # if cfg.WITH_REX is True:
            if pred_label == 'U':
                rex_dir = 'N'
            elif pred_label == 'L':
                rex_dir = 'W'
            elif pred_label == 'R':
                rex_dir = 'E'
            elif pred_label == 'D':
                rex_dir = 'S'
            else:
                qc.print_c(
                    'Warning: Rex cannot execute undefined action %s' %
                    pred_label, 'W')
                rex_dir = None
            if rex_dir is not None:
                visual.move(pred_label, 100, overlay=False, barcolor='B')
                visual.update()
                qc.print_c('Executing Rex action %s' % rex_dir, 'W')
                os.system('%s/Rex/RexControlSimple.exe %s %s' %
                          (pycnbi.ROOT, cfg.REX_COMPORT, rex_dir))
                time.sleep(8)

        if true_label == pred_label:
            msg = 'Correct'
        else:
            msg = 'Wrong'
        if cfg.TRIALS_RETRY is False or true_label == pred_label:
            print('Trial %d: %s (%s -> %s)' %
                  (trial, msg, true_label, pred_label))
            trial += 1

    if len(dir_detected) > 0:
        # write performance and log results
        fdir, _, _ = qc.parse_path_list(cfg.CLS_MI)
        logfile = time.strftime(fdir + "/online-%Y%m%d-%H%M%S.txt",
                                time.localtime())
        with open(logfile, 'w') as fout:
            fout.write('Ground-truth,Prediction\n')
            for gt, dt in zip(dir_seq, dir_detected):
                fout.write('%s,%s\n' % (gt, dt))
            cfmat, acc = qc.confusion_matrix(dir_seq, dir_detected)
            fout.write('\nAccuracy %.3f\nConfusion matrix\n' % acc)
            fout.write(cfmat)
            print('Log exported to %s' % logfile)
        print('\nAccuracy %.3f\nConfusion matrix\n' % acc)
        print(cfmat)

    visual.finish()
    if decoder:
        decoder.stop()
    '''
    # automatic thresholding
    if prob_history and len(bar_dirs) == 2:
        total = sum(len(prob_history[c]) for c in prob_history)
        fout = open(probs_logfile, 'a')
        msg = 'Automatic threshold optimization.\n'
        max_acc = 0
        max_bias = 0
        for bias in np.arange(-0.99, 1.00, 0.01):
            corrects = 0
            for p in prob_history[bar_dirs[0]]:
                p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1
                if p_biased >= 0.5:
                    corrects += 1
            for p in prob_history[bar_dirs[1]]:
                p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1
                if p_biased < 0.5:
                    corrects += 1
            acc = corrects / total
            msg += '%s%.2f: %.3f\n' % (bar_dirs[0], bias, acc)
            if acc > max_acc:
                max_acc = acc
                max_bias = bias
        msg += 'Max acc = %.3f at bias %.2f\n' % (max_acc, max_bias)
        fout.write(msg)
        fout.close()
        print(msg)
    '''

    print('Finished.')
Beispiel #12
0
    cfg = check_cfg(imp.load_source(cfg_module, cfg_module))
    if cfg.FAKE_CLS is None:
        # chooose amp
        if cfg.AMP_NAME is None and cfg.AMP_SERIAL is None:
            amp_name, amp_serial = pu.search_lsl(ignore_markers=True)
        else:
            amp_name = cfg.AMP_NAME
            amp_serial = cfg.AMP_SERIAL
        fake_dirs = None
    else:
        amp_name = None
        amp_serial = None
        fake_dirs = [v for (k, v) in cfg.DIRECTIONS]

    # events and triggers
    tdef = trigger_def(cfg.TRIGGER_DEF)

    if cfg.TRIGGER_DEVICE is None:
        input('\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.')
    trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        qc.print_c('\n** Error connecting to USB2LPT device. Use a mock trigger instead?', 'R')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # init classification
    qc.print_c('Initializing decoder.', 'W')

    decoder_UD = BCIDecoder(cfg.CLS_MI, buffer_size=10.0, fake=(cfg.FAKE_CLS is not None),
                            amp_name=amp_name, amp_serial=amp_serial, fake_dirs=fake_dirs)
Beispiel #13
0
    def getperf(APPLY_CAR, APPLY_PCA, APPLY_OVERSAMPLING, DO_CV, l_freq,
                h_freq, picks_feat, offset, tmin, tmax, baselineRange,
                reg_coeff, verbose):
        raw = loadedraw.copy()
        if APPLY_CAR:
            raw._data[1:] = raw._data[1:] - np.mean(raw._data[1:], axis=0)
            processing_steps.append('Car')

        tdef = trigger_def('triggerdef_errp.ini')
        sfreq = raw.info['sfreq']
        event_id = dict(correct=tdef.by_name['FEEDBACK_CORRECT'],
                        wrong=tdef.by_name['FEEDBACK_WRONG'])
        # %% simu online
        SIMULATE_ONLINE = False

        #    if SIMULATE_ONLINE is True:
        #        signal = mne.Epochs(raw, events=events, event_id=event_id, tmin=tmin, tmax=tmax,baseline=baselineRange, picks=picks_feat, preload=True, proj=False)
        #        signal_data = np.reshape(signal[0]._data,(signal[0]._data.shape[1],signal[0]._data.shape[2]))
        #        foo = compute_features(signal_data,sfreq,l_freq,h_freq,decim_factor)
        #        print('finalfeat')
        #        print(foo)

        # Spatial filter - Common Average Reference (CAR) # todo: reactivate

        # %% Dataset wide processing
        # Bandpass temporal filtering
        if SIMULATE_ONLINE is False:
            raw.filter(
                l_freq=l_freq,
                h_freq=h_freq,
                n_jobs=n_jobs,
                picks=picks_feat,
                method='iir',
                iir_params=None
            )  # method='iir'and irr_params=None -> filter with a 4th order Butterworth

        # %% Epoching and baselining
        # epochs = mne.Epochs(raw, events=events, event_id=event_id, tmin=tmin, tmax=tmax, baseline=baselineRange, picks=picks_feat, preload=True)
        epochs = mne.Epochs(raw,
                            events=events,
                            event_id=event_id,
                            tmin=tmin,
                            tmax=tmax,
                            baseline=baselineRange,
                            picks=picks_feat,
                            preload=True,
                            proj=False,
                            verbose=False)

        if (baselineRange):
            processing_steps.append('Baselining')
        if tmin != tmin_bkp:
            # if the baseline range was before the initial tmin, epochs was tricked to
            # to select this range (it expects that baseline is witin [tmin,tmax])
            # this part restore the initial tmin and prune the data
            epochs.tmin = tmin_bkp
            epochs._data = epochs._data[:, :, int((tmin_bkp - tmin) * sfreq):]

        # %% Fold creation
        # epochs.events contains the label that we want on the third column
        # We can then get the relevent data within a fold by doing epochs._data[test]
        # It will return an array with size ({test}L, [channel]L,{time}L)
        label = epochs.events[:, 2]

        cv = StratifiedShuffleSplit(label,
                                    n_iter=20,
                                    test_size=0.2,
                                    random_state=1337)

        if useLeaveOneOut is True:
            cv = LeaveOneOut(len(label))

        if APPLY_PCA:
            processing_steps.append('PCA')
        if APPLY_OVERSAMPLING:
            processing_steps.append('Oversampling')
        processing_steps.append('Normalization')
        processing_steps.append('Downsampling')

        # %% Fold processing
        def apply_cv(epochs):
            count = 1
            confusion_matrixes = []
            confusion_matrixes_percent = []
            tn_rates = []
            tp_rates = []
            predicted = ''
            test_label = ''
            firstIterCV = True
            probabilities = np.array([[]], ndmin=2)
            predictions = np.array([])
            my_tpr_cont = []
            my_fpr_cont = []
            my_aucs = []
            best_threshold = []
            cv_probabilities = []
            cv_probabilities_label = []
            for train, test in cv:
                ## Train Data processing ##
                train_data = epochs._data[train]
                train_label = label[train]

                # Online simulation flag
                if SIMULATE_ONLINE is True:  # epochs should have one epoch only
                    train_bp = mne.filter.band_pass_filter(
                        train_data, sfreq, l_freq, h_freq,
                        method='iir')  # bandpass on one epoch
                else:
                    train_bp = train_data

                # Normalization
                (train_normalized, trainShiftFactor,
                 trainScaleFactor) = normalizeAcrossEpoch(train_bp, 'MinMax')

                # Downsampling
                train_downsampling = train_normalized[:, :, ::decim_factor]

                # Merge (reshape) channel and time for the PCA
                train_reshaped = train_downsampling.reshape(
                    train_downsampling.shape[0], -1)

                # PCA initialisation
                if APPLY_PCA is False:
                    pca = None
                    train_pcaed = train_reshaped
                else:
                    pca = PCA(0.95)
                    pca.fit(train_reshaped)
                    pca.components_ = -pca.components_  # inversion of vector to be constistant with Inaki's code
                    train_pcaed = pca.transform(train_reshaped)

                # PCA
                #			train_pcaed = train_reshaped

                ## Test data processing ##
                test_data = epochs._data[test]
                test_label = label[test]

                # Compute_feature does the same steps as for train, but requires a computed PCA (that we got from train)
                # (bandpass, norm, ds, and merge channel and time)
                test_pcaed = compute_features(test_data, sfreq, l_freq, h_freq,
                                              decim_factor, trainShiftFactor,
                                              trainScaleFactor, pca)
                #			test_pcaed = compute_features(test_data,sfreq,l_freq,h_freq,decim_factor,trainShiftFactor,trainScaleFactor,pca=None)

                ## Test ##
                train_x = train_pcaed
                test_x = test_pcaed

                # oversampling the least present sample
                # if APPLY_OVERSAMPLING:
                #	idx_offset = balance_idx(train_label)
                #	oversampled_train_label = np.append(train_label,train_label[idx_offset])
                #	oversampled_train_x = np.concatenate((train_x,train_x[idx_offset]),0)
                #	train_label = oversampled_train_label
                #	train_x = oversampled_train_x

                # Classifier init
                RF = dict(trees=100, maxdepth=None)
                cls = RandomForestClassifier(n_estimators=RF['trees'],
                                             max_features='auto',
                                             max_depth=RF['maxdepth'],
                                             n_jobs=n_jobs)
                # cls = RandomForestClassifier(n_estimators=RF['trees'], max_features='auto', max_depth=RF['maxdepth'], class_weight="balanced", n_jobs=n_jobs)
                # cls = LDA(solver='eigen')
                #			cls = QDA(reg_param=0.3) # regularized LDA

                #			cls.fit( train_x, train_label )
                # Y_pred= cls.predict( test_x )
                # prediction = Y_pred

                # Fitting
                #				cls= rLDA(regcoeff)
                cls.fit(train_x, train_label)

                predicted = cls.predict(test_x)
                probs = cls.predict_proba(test_x)
                prediction = np.array(predicted)

                if useLeaveOneOut is True:
                    if firstIterCV is True:
                        probabilities = np.append(probabilities, probs, axis=1)
                        firstIterCV = False
                        predictions = np.append(predictions, prediction)
                    else:
                        probabilities = np.append(probabilities, probs, axis=0)
                        predictions = np.append(predictions, prediction)
                else:
                    predictions = np.append(predictions, prediction)
                    probabilities = np.append(probabilities, probs)

                # Performance
                if useLeaveOneOut is not True:
                    cm = np.array(confusion_matrix(test_label, prediction))
                    cm_normalized = cm.astype('float') / cm.sum(
                        axis=1)[:, np.newaxis]
                    confusion_matrixes.append(cm)
                    confusion_matrixes_percent.append(cm_normalized)
                    avg_confusion_matrixes = np.mean(
                        confusion_matrixes_percent, axis=0)

                # print('CV #'+str(count))
                #				print('Prediction: '+str(prediction))
                #				print('    Actual: '+str(test_label))

                # Append probs to the global list
                probs_np = np.array(probs)
                cv_probabilities.append(probs_np[:, 0])
                cv_probabilities_label.append(test_label)

                #			if useLeaveOneOut is not True:
                #				print('Confusion matrix')
                #				print(cm)
                #				print('Confusion matrix (normalized)')
                #				print(cm_normalized)
                #				print('---')
                #				print('True positive rate: '+str(cm_normalized[0][0]))
                #				print('True negative rate: '+str(cm_normalized[1][1]))
                #				print('===================')

                ## Manual ROC curve computation
                #			if useLeaveOneOut is not True:
                #				probs_np = np.array(probs)
                #				myfpr = []
                #				mytpr = []
                #				mythresh = []
                #				for thresh in np.linspace(0,1,100):
                #					newpred = [4 if x[0] < thresh else 3 for x in probs_np] #list comprehension to quickly go through the list. x[0] because hp_probs is shape (2,20)
                #					newcm = confusion_matrix(test_label,newpred)
                #					newcm_norm = newcm.astype('float') / newcm.sum(axis=1)[:, np.newaxis]
                #					mytpr.append(newcm_norm[0][0])
                #					myfpr.append(newcm_norm[1][0])
                #					mythresh.append(thresh)
                #
                #				my_tpr_cont.append(mytpr)
                #				my_fpr_cont.append(myfpr)
                #
                #				myroc_auc = auc(myfpr, mytpr)
                #				my_aucs.append(myroc_auc)

                ## One CV done, go to the next one
                count += 1

            # if useLeaveOneOut is not True:
            #			my_fpr_cont_np = np.array(my_fpr_cont)
            #			my_tpr_cont_np = np.array(my_tpr_cont)
            #
            #			my_fpr_cont_avg = np.mean(my_fpr_cont_np,axis=0)
            #			my_tpr_cont_avg = np.mean(my_tpr_cont_np,axis=0)
            #
            #
            #			plt.plot(my_fpr_cont_avg,my_tpr_cont_avg)
            #			plt.xlabel('false positive rate')
            #			plt.ylabel('true positive rate')
            #
            #
            #			auc_from_avg = auc(my_fpr_cont_avg,my_tpr_cont_avg)
            #			auc_from_my_aucs = np.mean(my_aucs)
            #
            #			# Make a subset of data where FPR < 0.2
            #			idx_below_fpr_0_2 = np.where(my_fpr_cont_avg < MAX_FPR)
            #			fpr_below_fpr_0_2 = my_fpr_cont_avg[idx_below_fpr_0_2]
            #			tpr_below_fpr_0_2 = my_tpr_cont_avg[idx_below_fpr_0_2]
            #
            #			# Look for the best (max value) FPR in that subset
            #			best_tpr_below_fpr_0_2 = np.max(tpr_below_fpr_0_2)
            #			# ... get its idx
            #			best_tpr_below_fpr_0_2_idx = np.array(np.where(my_tpr_cont_avg == best_tpr_below_fpr_0_2)).ravel()
            #
            #			# Get the associated TPRs
            #			best_tpr_below_fpr_0_2_associated_fpr = np.array(my_fpr_cont_avg)[best_tpr_below_fpr_0_2_idx]
            #			# Get the best (min value) in that subset
            #			best_associated_fpr = np.min(best_tpr_below_fpr_0_2_associated_fpr)
            #			# ... get its idx
            #			best_associated_fpr_idx = np.array(np.where(my_fpr_cont_avg == best_associated_fpr)).ravel()
            #
            #			# The best idx is the one that is on both set
            #			best_idx = best_tpr_below_fpr_0_2_idx[np.in1d(best_tpr_below_fpr_0_2_idx,best_associated_fpr_idx)]
            #			plt.xlabel('False positive rate')
            #			plt.ylabel('True positive rate')
            #			threshold_list = np.linspace(0,1,100)
            #			best_threshold = threshold_list[best_idx]
            #			print('Best treshold(s):'+str(best_threshold))
            #			print('Gives a TPR of '+str(best_tpr_below_fpr_0_2))
            #			print('And a FPR of '+str(best_associated_fpr))
            #
            #
            ##		from mpl_toolkits.mplot3d import Axes3D
            ##		fig = plt.figure()
            ##		ax = fig.add_subplot(111,projection='3d')
            ##		ax.plot(my_fpr_cont_avg,my_fpr_cont_avg,mythresh)
            ##		mean_tpr /= len(cv)
            ##		mean_tpr[-1] = 1.0
            ##		mean_auc = auc(mean_fpr, mean_tpr)
            #		#plt.plot(mean_fpr, mean_tpr, 'k--',label='Mean ROC (area = %0.2f)' % mean_auc, lw=2)
            #
            ##		if useLeaveOneOut is True:
            ##			finalCM = confusion_matrix(predictions,label)
            ##			#Used w/ one out
            ##			cms = []
            ##			cms_norm = []
            ##			#threshold search
            ##			probabilities_right = probabilities[0,:]
            ##			probabilities_wrong = probabilities[1,:]
            ##			for thresh in np.arange(0,1,0.05):
            ##				pred_tmp = np.array([])
            ##				for prob in probabilities:
            ##					if prob[0] < thresh:
            ##						pred_tmp = np.append(pred_tmp,4)
            ##					else:
            ##						pred_tmp = np.append(pred_tmp,3)
            ##				cm_tmp = confusion_matrix(pred_tmp,label)
            ##				cms.append(cm_tmp)
            ##				cm_tmp_norm = cm_tmp.astype('float') / cm_tmp.sum(axis=1)[:, np.newaxis]
            ##				cms_norm.append(cm_tmp_norm)
            #		if useLeaveOneOut is True:
            #			avg_confusion_matrixes = 0
            #			auc_from_avg = 0
            #			auc_from_my_aucs = 0
            #			label_np = np.array(label)
            #			lvofpr = []
            #			lvotpr = []
            #			lvothresh = []
            #			lvocms = []
            #			threshold_list = np.linspace(0,1,100)
            #			for thresh in threshold_list:
            #				lvopred = [4 if x[0] < thresh else 3 for x in probabilities] #list comprehension to quickly go through the list. x[0] because hp_probs is shape (2,20)
            #				lvocm = confusion_matrix(label_np,lvopred)
            #				lvocm_norm = lvocm.astype('float') / lvocm.sum(axis=1)[:, np.newaxis]
            #				lvocms.append(lvocm_norm)
            #				lvotpr.append(lvocm_norm[0][0])
            #				lvofpr.append(lvocm_norm[1][0])
            #				lvothresh.append(thresh)
            #
            #			lvo_auc = auc(lvofpr,lvotpr)
            #
            #			# Make a subset of data where FPR < 0.2
            #
            #			idx_below_fpr_0_2 = np.where(np.array(lvofpr) < MAX_FPR)
            #			fpr_below_fpr_0_2 = np.array(lvofpr)[idx_below_fpr_0_2[0]]
            #			tpr_below_fpr_0_2 = np.array(lvotpr)[idx_below_fpr_0_2[0]]
            #
            #			# Look for the best (max value) FPR in that subset
            #			best_tpr_below_fpr_0_2 = np.max(tpr_below_fpr_0_2)
            #			# ... get its idx
            #			best_tpr_below_fpr_0_2_idx = np.array(np.where(lvotpr == best_tpr_below_fpr_0_2)).ravel()
            #
            #			# Get the associated TPRs
            #			best_tpr_below_fpr_0_2_associated_fpr = np.array(lvofpr)[best_tpr_below_fpr_0_2_idx]
            #			# Get the best (min value) in that subset
            #			best_associated_fpr = np.min(best_tpr_below_fpr_0_2_associated_fpr)
            #			# ... get its idx
            #			best_associated_fpr_idx = np.array(np.where(lvofpr == best_associated_fpr)).ravel()
            #
            #			# The best idx is the one that is on both set
            #			best_idx = best_tpr_below_fpr_0_2_idx[np.in1d(best_tpr_below_fpr_0_2_idx,best_associated_fpr_idx)]
            #
            #			plt.plot(lvofpr,lvotpr)
            #			plt.xlabel('False positive rate')
            #			plt.ylabel('True positive rate')
            #			best_threshold = threshold_list[best_idx]
            #			print('Best treshold:'+str(best_threshold))
            #			print('Gives a TPR of '+str(best_tpr_below_fpr_0_2))
            #			print('And a FPR of '+str(best_associated_fpr))
            #			print('CM')
            #			print(lvocms[best_idx[0]])
            auc_from_avg = None
            auc_from_my_aucs = None
            best_threshold = None
            cv_probabilities_np = np.array(cv_probabilities)
            cv_prob_linear = np.ravel(cv_probabilities)
            cv_prob_label_np = np.array(cv_probabilities_label)
            cv_prob_label_linear = np.ravel(cv_prob_label_np)
            threshold_list = np.linspace(0, 1, 100)

            biglist_fpr = []
            biglist_tpr = []
            biglist_thresh = []
            biglist_cms = []

            for thresh in threshold_list:
                biglist_pred = [
                    4 if x < thresh else 3 for x in cv_prob_linear
                ]  # list comprehension to quickly go through the list.
                biglist_cm = confusion_matrix(cv_prob_label_linear,
                                              biglist_pred)
                biglist_cm_norm = biglist_cm.astype('float') / biglist_cm.sum(
                    axis=1)[:, np.newaxis]
                biglist_cms.append(biglist_cm_norm)
                biglist_tpr.append(biglist_cm_norm[0][0])
                biglist_fpr.append(biglist_cm_norm[1][0])
                biglist_thresh.append(thresh)
            biglist_auc = auc(biglist_fpr, biglist_tpr)

            # Make a subset of data where FPR < MAX_FPR
            idx_below_maxfpr = np.where(np.array(biglist_fpr) < MAX_FPR)
            fpr_below_maxfpr = np.array(biglist_fpr)[idx_below_maxfpr[0]]
            tpr_below_maxfpr = np.array(biglist_tpr)[idx_below_maxfpr[0]]

            # Look for the best (max value) FPR in that subset
            best_tpr_below_maxfpr = np.max(tpr_below_maxfpr)
            best_tpr_below_maxfpr_idx = np.array(
                np.where(biglist_tpr ==
                         best_tpr_below_maxfpr)).ravel()  # get its idx

            # Get the associated TPRs
            best_tpr_below_maxfpr_associated_fpr = np.array(
                biglist_fpr)[best_tpr_below_maxfpr_idx]
            # Get the best (min value) in that subset
            best_associated_fpr = np.min(best_tpr_below_maxfpr_associated_fpr)
            # ... get its idx
            best_associated_fpr_idx = np.array(
                np.where(biglist_fpr == best_associated_fpr)).ravel()

            # The best idx is the one that is on both set
            best_idx = best_tpr_below_maxfpr_idx[np.in1d(
                best_tpr_below_maxfpr_idx, best_associated_fpr_idx)]
            best_threshold = threshold_list[best_idx]

            if False:
                plt.plot(biglist_fpr, biglist_tpr)
                plt.xlabel('False positive rate')
                plt.ylabel('True positive rate')
                print('#################################')
                print('Best treshold:' + str(best_threshold))
                print('Gives a TPR of ' + str(best_tpr_below_maxfpr))
                print('And a FPR of ' + str(best_associated_fpr))
                print('CM')
                print(biglist_cms[best_idx[0]])
            return (biglist_auc, biglist_cms, best_threshold,
                    best_tpr_below_maxfpr)

        biglist_auc, biglist_cms, best_threshold = apply_cv(epochs)

        return (biglist_auc, biglist_cms, best_threshold,
                best_tpr_below_maxfpr)
Beispiel #14
0
def config_run(cfg_module):
    cfg = imp.load_source(cfg_module, cfg_module)
    tdef = trigger_def(cfg.TRIGGER_DEF)
    refresh_delay = 1.0 / cfg.REFRESH_RATE

    # visualizer
    keys = {'left':81, 'right':83, 'up':82, 'down':84, 'pgup':85, 'pgdn':86,
        'home':80, 'end':87, 'space':32, 'esc':27, ',':44, '.':46, 's':115, 'c':99,
        '[':91, ']':93, '1':49, '!':33, '2':50, '@':64, '3':51, '#':35}
    color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235),
        K=(0, 0, 0), w=(200, 200, 200))

    dir_sequence = []
    for x in range(cfg.TRIALS_EACH):
        dir_sequence.extend(cfg.DIRECTIONS)
    random.shuffle(dir_sequence)
    num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH

    event = 'start'
    trial = 1

    # Hardware trigger
    if cfg.TRIGGER_DEVICE is None:
        input('\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.')
    trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        print('\n** Error connecting to USB2LPT device. Use a mock trigger instead?')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    timer_trigger = qc.Timer()
    timer_dir = qc.Timer()
    timer_refresh = qc.Timer()

    bar = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE)
    bar.fill()
    bar.glass_draw_cue()

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based)
        if event == 'start' and timer_trigger.sec() > cfg.T_INIT:
            event = 'gap_s'
            bar.fill()
            timer_trigger.reset()
            trigger.signal(tdef.INIT)
        elif event == 'gap_s':
            bar.put_text('Trial %d / %d' % (trial, num_trials))
            event = 'gap'
        elif event == 'gap' and timer_trigger.sec() > cfg.T_GAP:
            event = 'cue'
            bar.fill()
            bar.draw_cue()
            trigger.signal(tdef.CUE)
            timer_trigger.reset()
        elif event == 'cue' and timer_trigger.sec() > cfg.T_CUE:
            event = 'dir_r'
            dir = dir_sequence[trial - 1]
            if dir == 'L':  # left
                bar.move('L', 100, overlay=True)
                trigger.signal(tdef.LEFT_READY)
            elif dir == 'R':  # right
                bar.move('R', 100, overlay=True)
                trigger.signal(tdef.RIGHT_READY)
            elif dir == 'U':  # up
                bar.move('U', 100, overlay=True)
                trigger.signal(tdef.UP_READY)
            elif dir == 'D':  # down
                bar.move('D', 100, overlay=True)
                trigger.signal(tdef.DOWN_READY)
            elif dir == 'B':  # both hands
                bar.move('L', 100, overlay=True)
                bar.move('R', 100, overlay=True)
                trigger.signal(tdef.BOTH_READY)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
            timer_trigger.reset()
        elif event == 'dir_r' and timer_trigger.sec() > cfg.T_DIR_READY:
            bar.fill()
            bar.draw_cue()
            event = 'dir'
            timer_trigger.reset()
            timer_dir.reset()
            if dir == 'L':  # left
                trigger.signal(tdef.LEFT_GO)
            elif dir == 'R':  # right
                trigger.signal(tdef.RIGHT_GO)
            elif dir == 'U':  # up
                trigger.signal(tdef.UP_GO)
            elif dir == 'D':  # down
                trigger.signal(tdef.DOWN_GO)
            elif dir == 'B':  # both
                trigger.signal(tdef.BOTH_GO)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
        elif event == 'dir' and timer_trigger.sec() > cfg.T_DIR:
            event = 'gap_s'
            bar.fill()
            trial += 1
            print('trial ' + str(trial - 1) + ' done')
            trigger.signal(tdef.BLANK)
            timer_trigger.reset()

        # protocol
        if event == 'dir':
            dx = min(100, int(100.0 * timer_dir.sec() / cfg.T_DIR) + 1)
            if dir == 'L':  # L
                bar.move('L', dx, overlay=True)
            elif dir == 'R':  # R
                bar.move('R', dx, overlay=True)
            elif dir == 'U':  # U
                bar.move('U', dx, overlay=True)
            elif dir == 'D':  # D
                bar.move('D', dx, overlay=True)
            elif dir == 'B':  # Both
                bar.move('L', dx, overlay=True)
                bar.move('R', dx, overlay=True)

        # wait for start
        if event == 'start':
            bar.put_text('Waiting to start')

        bar.update()
        key = 0xFF & cv2.waitKey(1)

        if key == keys['esc']:
            break
Beispiel #15
0
def run(cfg, state=mp.Value('i', 1), queue=None):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass
    #  Protocol start if equals to 1
    if not state.value:
        sys.exit()

    refresh_delay = 1.0 / cfg.REFRESH_RATE

    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)

    # visualizer
    keys = {
        'left': 81,
        'right': 83,
        'up': 82,
        'down': 84,
        'pgup': 85,
        'pgdn': 86,
        'home': 80,
        'end': 87,
        'space': 32,
        'esc': 27,
        ',': 44,
        '.': 46,
        's': 115,
        'c': 99,
        '[': 91,
        ']': 93,
        '1': 49,
        '!': 33,
        '2': 50,
        '@': 64,
        '3': 51,
        '#': 35
    }
    color = dict(G=(20, 140, 0),
                 B=(210, 0, 0),
                 R=(0, 50, 200),
                 Y=(0, 215, 235),
                 K=(0, 0, 0),
                 w=(200, 200, 200))

    dir_sequence = []
    for x in range(cfg.TRIALS_EACH):
        dir_sequence.extend(cfg.DIRECTIONS)
    random.shuffle(dir_sequence)
    num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH

    event = 'start'
    trial = 1

    # Hardware trigger
    if cfg.TRIGGER_DEVICE is None:
        logger.warning(
            'No trigger device set. Press Ctrl+C to stop or Enter to continue.'
        )
        #input()
    trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        logger.error(
            '\n** Error connecting to USB2LPT device. Use a mock trigger instead?'
        )
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # timers
    timer_trigger = qc.Timer()
    timer_dir = qc.Timer()
    timer_refresh = qc.Timer()
    t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'],
                                                cfg.TIMINGS['DIR_RANDOMIZE'])
    t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(
        -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

    bar = BarVisual(cfg.GLASS_USE,
                    screen_pos=cfg.SCREEN_POS,
                    screen_size=cfg.SCREEN_SIZE)
    bar.fill()
    bar.glass_draw_cue()

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based)
        if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']:
            event = 'gap_s'
            bar.fill()
            timer_trigger.reset()
            trigger.signal(cfg.tdef.INIT)
        elif event == 'gap_s':
            if cfg.TRIAL_PAUSE:
                bar.put_text('Press any key')
                bar.update()
                key = cv2.waitKey()
                if key == keys['esc'] or not state.value:
                    break
                bar.fill()
            bar.put_text('Trial %d / %d' % (trial, num_trials))
            event = 'gap'
            timer_trigger.reset()
        elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']:
            event = 'cue'
            bar.fill()
            bar.draw_cue()
            trigger.signal(cfg.tdef.CUE)
            timer_trigger.reset()
        elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']:
            event = 'dir_r'
            dir = dir_sequence[trial - 1]
            if dir == 'L':  # left
                bar.move('L', 100, overlay=True)
                trigger.signal(cfg.tdef.LEFT_READY)
            elif dir == 'R':  # right
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.RIGHT_READY)
            elif dir == 'U':  # up
                bar.move('U', 100, overlay=True)
                trigger.signal(cfg.tdef.UP_READY)
            elif dir == 'D':  # down
                bar.move('D', 100, overlay=True)
                trigger.signal(cfg.tdef.DOWN_READY)
            elif dir == 'B':  # both hands
                bar.move('L', 100, overlay=True)
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.BOTH_READY)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
            timer_trigger.reset()
        elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready:
            bar.fill()
            bar.draw_cue()
            event = 'dir'
            timer_trigger.reset()
            timer_dir.reset()
            if dir == 'L':  # left
                trigger.signal(cfg.tdef.LEFT_GO)
            elif dir == 'R':  # right
                trigger.signal(cfg.tdef.RIGHT_GO)
            elif dir == 'U':  # up
                trigger.signal(cfg.tdef.UP_GO)
            elif dir == 'D':  # down
                trigger.signal(cfg.tdef.DOWN_GO)
            elif dir == 'B':  # both
                trigger.signal(cfg.tdef.BOTH_GO)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
        elif event == 'dir' and timer_trigger.sec() > t_dir:
            event = 'gap_s'
            bar.fill()
            trial += 1
            logger.info('trial ' + str(trial - 1) + ' done')
            trigger.signal(cfg.tdef.BLANK)
            timer_trigger.reset()
            t_dir = cfg.TIMINGS['DIR'] + random.uniform(
                -cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
            t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(
                -cfg.TIMINGS['READY_RANDOMIZE'],
                cfg.TIMINGS['READY_RANDOMIZE'])

        # protocol
        if event == 'dir':
            dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1)
            if dir == 'L':  # L
                bar.move('L', dx, overlay=True)
            elif dir == 'R':  # R
                bar.move('R', dx, overlay=True)
            elif dir == 'U':  # U
                bar.move('U', dx, overlay=True)
            elif dir == 'D':  # D
                bar.move('D', dx, overlay=True)
            elif dir == 'B':  # Both
                bar.move('L', dx, overlay=True)
                bar.move('R', dx, overlay=True)

        # wait for start
        if event == 'start':
            bar.put_text('Waiting to start')

        bar.update()
        key = 0xFF & cv2.waitKey(1)
        if key == keys['esc'] or not state.value:
            break

    bar.finish()

    with state.get_lock():
        state.value = 0
Beispiel #16
0
    cfg = check_cfg(imp.load_source(cfg_module, cfg_module))
    if cfg.FAKE_CLS is None:
        # chooose amp
        if cfg.AMP_NAME is None and cfg.AMP_SERIAL is None:
            amp_name, amp_serial = pu.search_lsl(ignore_markers=True)
        else:
            amp_name = cfg.AMP_NAME
            amp_serial = cfg.AMP_SERIAL
        fake_dirs = None
    else:
        amp_name = None
        amp_serial = None
        fake_dirs = [v for (k, v) in cfg.DIRECTIONS]

    # events and triggers
    tdef = trigger_def(cfg.TRIGGER_FILE)

    if cfg.TRIGGER_DEVICE is None:
        input(
            '\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.'
        )
    trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        logger.error(
            'Cannot connect to USB2LPT device. Use a mock trigger instead?')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # init classification
    logger.info('Initializing decoder')
Beispiel #17
0
def stream_player(server_name,
                  fif_file,
                  chunk_size,
                  auto_restart=True,
                  wait_start=True,
                  repeat=np.float('inf'),
                  high_resolution=False,
                  trigger_file=None):
    """
    Input
    =====
    server_name: LSL server name.
    fif_file: fif file to replay.
    chunk_size: number of samples to send at once (usually 16-32 is good enough).
    auto_restart: play from beginning again after reaching the end.
    wait_start: wait for user to start in the beginning.
    repeat: number of loops to play.
    high_resolution: use perf_counter() instead of sleep() for higher time resolution
                     but uses much more cpu due to polling.
    trigger_file: used to convert event numbers into event strings for readability.
    
    Note: Run pycnbi.set_log_level('DEBUG') to print out the relative time stamps since started.
    
    """
    raw, events = pu.load_raw(fif_file)
    sfreq = raw.info['sfreq']  # sampling frequency
    n_channels = len(raw.ch_names)  # number of channels
    if trigger_file is not None:
        tdef = trigger_def(trigger_file)
    try:
        event_ch = raw.ch_names.index('TRIGGER')
    except ValueError:
        event_ch = None
    if raw is not None:
        logger.info_green('Successfully loaded %s' % fif_file)
        logger.info('Server name: %s' % server_name)
        logger.info('Sampling frequency %.3f Hz' % sfreq)
        logger.info('Number of channels : %d' % n_channels)
        logger.info('Chunk size : %d' % chunk_size)
        for i, ch in enumerate(raw.ch_names):
            logger.info('%d %s' % (i, ch))
        logger.info('Trigger channel : %s' % event_ch)
    else:
        raise RuntimeError('Error while loading %s' % fif_file)

    # set server information
    sinfo = pylsl.StreamInfo(server_name, channel_count=n_channels, channel_format='float32',\
        nominal_srate=sfreq, type='EEG', source_id=server_name)
    desc = sinfo.desc()
    channel_desc = desc.append_child("channels")
    for ch in raw.ch_names:
        channel_desc.append_child('channel').append_child_value('label', str(ch))\
            .append_child_value('type','EEG').append_child_value('unit','microvolts')
    desc.append_child('amplifier').append_child('settings').append_child_value(
        'is_slave', 'false')
    desc.append_child('acquisition').append_child_value(
        'manufacturer', 'PyCNBI').append_child_value('serial_number', 'N/A')
    outlet = pylsl.StreamOutlet(sinfo, chunk_size=chunk_size)

    if wait_start:
        input('Press Enter to start streaming.')
    logger.info('Streaming started')

    idx_chunk = 0
    t_chunk = chunk_size / sfreq
    finished = False
    if high_resolution:
        t_start = time.perf_counter()
    else:
        t_start = time.time()

    # start streaming
    played = 1
    while played < repeat:
        idx_current = idx_chunk * chunk_size
        chunk = raw._data[:, idx_current:idx_current + chunk_size]
        data = chunk.transpose().tolist()
        if idx_current >= raw._data.shape[1] - chunk_size:
            finished = True
        if high_resolution:
            # if a resolution over 2 KHz is needed
            t_sleep_until = t_start + idx_chunk * t_chunk
            while time.perf_counter() < t_sleep_until:
                pass
        else:
            # time.sleep() can have 500 us resolution using the tweak tool provided.
            t_wait = t_start + idx_chunk * t_chunk - time.time()
            if t_wait > 0.001:
                time.sleep(t_wait)
        outlet.push_chunk(data)
        logger.debug('[%8.3fs] sent %d samples (LSL %8.3f)' %
                     (time.perf_counter(), len(data), pylsl.local_clock()))
        if event_ch is not None:
            event_values = set(chunk[event_ch]) - set([0])
            if len(event_values) > 0:
                if trigger_file is None:
                    logger.info('Events: %s' % event_values)
                else:
                    for event in event_values:
                        if event in tdef.by_value:
                            logger.info('Events: %s (%s)' %
                                        (event, tdef.by_value[event]))
                        else:
                            logger.info('Events: %s (Undefined event)' % event)
        idx_chunk += 1

        if finished:
            if auto_restart is False:
                input(
                    'Reached the end of data. Press Enter to restart or Ctrl+C to stop.'
                )
            else:
                logger.info('Reached the end of data. Restarting.')
            idx_chunk = 0
            finished = False
            if high_resolution:
                t_start = time.perf_counter()
            else:
                t_start = time.time()
            played += 1
Beispiel #18
0
def get_tfr(cfg, recursive=False, n_jobs=1):
    '''
    @params:
    tfr_type: 'multitaper' or 'morlet'
    recursive: if True, load raw files in sub-dirs recursively
    export_path: path to save plots
    n_jobs: number of cores to run in parallel
    '''

    cfg = check_config(cfg)
    tfr_type = cfg.TFR_TYPE
    export_path = cfg.EXPORT_PATH
    t_buffer = cfg.T_BUFFER
    if tfr_type == 'multitaper':
        tfr = mne.time_frequency.tfr_multitaper
    elif tfr_type == 'morlet':
        tfr = mne.time_frequency.tfr_morlet
    elif tfr_type == 'butter':
        butter_order = 4 # TODO: parameterize
        tfr = lfilter
    elif tfr_type == 'fir':
        raise NotImplementedError
    else:
        raise ValueError('Wrong TFR type %s' % tfr_type)
    n_jobs = cfg.N_JOBS
    if n_jobs is None:
        n_jobs = mp.cpu_count()

    if hasattr(cfg, 'DATA_PATHS'):
        if export_path is None:
            raise ValueError('For multiple directories, cfg.EXPORT_PATH cannot be None')
        else:
            outpath = export_path
        # custom event file
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)
        file_prefix = 'grandavg'

        # load and merge files from all directories
        flist = []
        for ddir in cfg.DATA_PATHS:
            ddir = ddir.replace('\\', '/')
            if ddir[-1] != '/': ddir += '/'
            for f in qc.get_file_list(ddir, fullpath=True, recursive=recursive):
                if qc.parse_path(f).ext in ['fif', 'bdf', 'gdf']:
                    flist.append(f)
        raw, events = pu.load_multi(flist)
    else:
        logger.info('Loading %s' % cfg.DATA_FILE)
        raw, events = pu.load_raw(cfg.DATA_FILE)

        # custom events
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)

        if export_path is None:
            [outpath, file_prefix, _] = qc.parse_path_list(cfg.DATA_FILE)
        else:
            file_prefix = qc.parse_path(cfg.DATA_FILE).name
            outpath = export_path
            file_prefix = qc.parse_path(cfg.DATA_FILE).name

    # re-referencing
    if cfg.REREFERENCE is not None:
        pu.rereference(raw, cfg.REREFERENCE[1], cfg.REREFERENCE[0])
        assert cfg.REREFERENCE[0] in raw.ch_names

    sfreq = raw.info['sfreq']

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    raw = pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER,
                  spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks,
                  multiplier=cfg.MULTIPLIER, n_jobs=n_jobs)

    # Read epochs
    classes = {}
    event_set = set(events[:, -1])
    for t_name in cfg.TRIGGER_DEF:
        if cfg.TRIGGER_FILE is not None:
            tdef = trigger_def(cfg.TRIGGER_FILE)
            t_value = tdef.by_name[t_name]
            if t_value in event_set:
                classes[t_name] = t_value
        else:
            classes[str(t_name)] = t_name
    if len(classes) == 0:
        from IPython import embed; embed()
        raise ValueError('No desired event was found from the data.')

    try:
        tmin = cfg.EPOCH[0]
        tmin_buffer = tmin - t_buffer
        raw_tmax = raw._data.shape[1] / sfreq - 0.1
        if cfg.EPOCH[1] is None:
            if cfg.POWER_AVERAGED:
                raise ValueError('EPOCH value cannot have None for grand averaged TFR')
            else:
                if len(cfg.TRIGGERS) > 1:
                    raise ValueError('If the end time of EPOCH is None, only a single event can be defined.')
                t_ref = events[np.where(events[:,2] == list(cfg.TRIGGERS)[0])[0][0], 0] / sfreq
                tmax = raw_tmax - t_ref - t_buffer
        else:
            tmax = cfg.EPOCH[1]
        tmax_buffer = tmax + t_buffer
        if tmax_buffer > raw_tmax:
            raise ValueError('Epoch length with buffer (%.3f) is larger than signal length (%.3f)' % (tmax_buffer, raw_tmax))
        epochs_all = mne.Epochs(raw, events, classes, tmin=tmin_buffer, tmax=tmax_buffer,
                                proj=False, picks=picks, baseline=None, preload=True)
        if epochs_all.drop_log_stats() > 0:
            logger.error('\n** Bad epochs found. Dropping into a Python shell.')
            logger.error(epochs_all.drop_log)
            logger.error('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
                (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
            logger.error('\nType exit to continue.\n')
            pdb.set_trace()
    except:
        logger.critical('\n*** (tfr_export) Unknown error occurred while epoching ***')
        logger.critical('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
            (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
        pdb.set_trace()

    power = {}
    for evname in classes:
        export_dir = outpath
        qc.make_dirs(export_dir)
        logger.info('>> Processing %s' % evname)
        freqs = cfg.FREQ_RANGE  # define frequencies of interest
        n_cycles = freqs / 2.  # different number of cycle per frequency
        if cfg.POWER_AVERAGED:
            # grand-average TFR
            epochs = epochs_all[evname][:]
            if len(epochs) == 0:
                logger.WARNING('No %s epochs. Skipping.' % evname)
                continue

            if tfr_type == 'butter':
                b, a = butter_bandpass(cfg.FREQ_RANGE[0], cfg.FREQ_RANGE[-1], sfreq, order=butter_order)
                tfr_filtered = lfilter(b, a, epochs, axis=2)
                tfr_hilbert = hilbert(tfr_filtered)
                tfr_power = abs(tfr_hilbert)
                tfr_data = np.mean(tfr_power, axis=0)
            elif tfr_type == 'fir':
                raise NotImplementedError
            else:
                power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False,
                    return_itc=False, decim=1, n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                tfr_data = power[evname].data

            if cfg.EXPORT_MATLAB is True:
                # export all channels to MATLAB
                mout = '%s/%s-%s-%s.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname)
                scipy.io.savemat(mout, {'tfr':tfr_data, 'chs':epochs.ch_names,
                    'events':events, 'sfreq':sfreq, 'tmin':tmin, 'tmax':tmax, 'epochs':cfg.EPOCH, 'freqs':cfg.FREQ_RANGE})
                logger.info('Exported %s' % mout)
            if cfg.EXPORT_PNG is True:
                # Inspect power for each channel
                for ch in np.arange(len(picks)):
                    chname = raw.ch_names[picks[ch]]
                    title = 'Peri-event %s - Channel %s' % (evname, chname)

                    # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                    fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                        colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
                    fout = '%s/%s-%s-%s-%s.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname)
                    fig.savefig(fout)
                    plt.close()
                    logger.info('Exported to %s' % fout)
        else:
            # TFR per event
            for ep in range(len(epochs_all[evname])):
                epochs = epochs_all[evname][ep]
                if len(epochs) == 0:
                    logger.WARNING('No %s epochs. Skipping.' % evname)
                    continue
                power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False,
                    return_itc=False, decim=1, n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                if cfg.EXPORT_MATLAB is True:
                    # export all channels to MATLAB
                    mout = '%s/%s-%s-%s-ep%02d.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname, ep + 1)
                    scipy.io.savemat(mout, {'tfr':power[evname].data, 'chs':power[evname].ch_names,
                        'events':events, 'sfreq':sfreq, 'tmin':tmin, 'tmax':tmax, 'epochs':cfg.EPOCH, 'freqs':cfg.FREQ_RANGE})
                    logger.info('Exported %s' % mout)
                if cfg.EXPORT_PNG is True:
                    # Inspect power for each channel
                    for ch in np.arange(len(picks)):
                        chname = raw.ch_names[picks[ch]]
                        title = 'Peri-event %s - Channel %s, Trial %d' % (evname, chname, ep + 1)
                        # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                        fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                            colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
                        fout = '%s/%s-%s-%s-%s-ep%02d.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname, ep + 1)
                        fig.savefig(fout)
                        plt.close()
                        logger.info('Exported %s' % fout)

    if hasattr(cfg, 'POWER_DIFF'):
        export_dir = '%s/diff' % outpath
        qc.make_dirs(export_dir)
        labels = classes.keys()
        df = power[labels[0]] - power[labels[1]]
        df.data = np.log(np.abs(df.data))
        # Inspect power diff for each channel
        for ch in np.arange(len(picks)):
            chname = raw.ch_names[picks[ch]]
            title = 'Peri-event %s-%s - Channel %s' % (labels[0], labels[1], chname)

            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = df.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                          colorbar=True, title=title, vmin=3.0, vmax=-3.0, dB=False)
            fout = '%s/%s-%s-diff-%s-%s-%s.jpg' % (export_dir, file_prefix, cfg.SP_FILTER, labels[0], labels[1], chname)
            logger.info('Exporting to %s' % fout)
            fig.savefig(fout)
            plt.close()
    logger.info('Finished !')
Beispiel #19
0
        scaleFactor = np.max(epoch_data, 0) - np.min(epoch_data, 0)
    elif method == 'override':  # todo: find a better name
        shiftFactor = givenShiftFactor
        scaleFactor = givenScaleFactor

    for trial in range(new_epochs_data.shape[0]):
        new_epochs_data[trial, :, :] = (new_epochs_data[trial, :, :] -
                                        shiftFactor) / scaleFactor

    return (new_epochs_data, shiftFactor, scaleFactor)


if __name__ == '__main__':
    # load data
    raw, events = pu.load_multi(FLIST, spfilter='car')
    tdef = trigger_def('triggerdef_errp.ini')
    sfreq = raw.info['sfreq']

    # epoching
    tmin = 0
    tmax = 1
    event_id = dict(correct=tdef.by_key['FEEDBACK_CORRECT'],
                    wrong=tdef.by_key['FEEDBACK_WRONG'])

    # export plots: apply offline spectral filter
    if EXPORT_PLOTS == True:
        # apply filter on entire signal (for offline analysis)
        raw.filter(l_freq=l_freq,
                   h_freq=h_freq,
                   n_jobs=n_jobs,
                   picks=picks_feat,
Beispiel #20
0
def config_run(cfg_module):
    cfg = load_cfg(cfg_module)

    # visualizer
    keys = {
        'left': 81,
        'right': 83,
        'up': 82,
        'down': 84,
        'pgup': 85,
        'pgdn': 86,
        'home': 80,
        'end': 87,
        'space': 32,
        'esc': 27,
        ',': 44,
        '.': 46,
        's': 115,
        'c': 99,
        '[': 91,
        ']': 93,
        '1': 49,
        '!': 33,
        '2': 50,
        '@': 64,
        '3': 51,
        '#': 35
    }
    color = dict(G=(20, 140, 0),
                 B=(210, 0, 0),
                 R=(0, 50, 200),
                 Y=(0, 215, 235),
                 K=(0, 0, 0),
                 W=(255, 255, 255),
                 w=(200, 200, 200))

    dir_sequence = []
    for x in range(cfg.TRIALS_EACH):
        dir_sequence.extend(cfg.DIRECTIONS)
    random.shuffle(dir_sequence)
    num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH
    tdef = trigger_def(cfg.TRIGGER_DEF)
    refresh_delay = 1.0 / cfg.REFRESH_RATE
    state = 'start'
    trial = 1

    # STIMO protocol
    if cfg.WITH_STIMO is True:
        print('Opening STIMO serial port (%s / %d bps)' %
              (cfg.STIMO_COMPORT, cfg.STIMO_BAUDRATE))
        import serial
        ser = serial.Serial(cfg.STIMO_COMPORT, cfg.STIMO_BAUDRATE)
        print('STIMO serial port %s is_open = %s' %
              (cfg.STIMO_COMPORT, ser.is_open))

    # init trigger
    if cfg.TRIGGER_DEVICE is None:
        input(
            '\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.'
        )
    trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        print(
            '\n# Error connecting to USB2LPT device. Use a mock trigger instead?'
        )
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # visual feedback
    if cfg.FEEDBACK_TYPE == 'BAR':
        from pycnbi.protocols.viz_bars import BarVisual
        visual = BarVisual(cfg.GLASS_USE,
                           screen_pos=cfg.SCREEN_POS,
                           screen_size=cfg.SCREEN_SIZE)
    elif cfg.FEEDBACK_TYPE == 'BODY':
        if not hasattr(cfg, 'IMAGE_PATH'):
            raise ValueError('IMAGE_PATH is undefined in your config.')
        from pycnbi.protocols.viz_human import BodyVisual
        visual = BodyVisual(cfg.IMAGE_PATH,
                            use_glass=cfg.GLASS_USE,
                            screen_pos=cfg.SCREEN_POS,
                            screen_size=cfg.SCREEN_SIZE)
    visual.put_text('Waiting to start    ')

    timer_trigger = qc.Timer()
    timer_dir = qc.Timer()
    timer_refresh = qc.Timer()

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based)
        if state == 'start' and timer_trigger.sec() > cfg.T_INIT:
            state = 'gap_s'
            visual.fill()
            timer_trigger.reset()
            trigger.signal(tdef.INIT)
        elif state == 'gap_s':
            visual.put_text('Trial %d / %d' % (trial, num_trials))
            state = 'gap'
        elif state == 'gap' and timer_trigger.sec() > cfg.T_GAP:
            state = 'cue'
            visual.fill()
            visual.draw_cue()
            trigger.signal(tdef.CUE)
            timer_trigger.reset()
        elif state == 'cue' and timer_trigger.sec() > cfg.T_CUE:
            state = 'dir_r'
            dir = dir_sequence[trial - 1]
            if dir == 'L':  # left
                if cfg.FEEDBACK_TYPE == 'BAR':
                    visual.move('L', 100)
                else:
                    visual.put_text('LEFT')
                trigger.signal(tdef.LEFT_READY)
            elif dir == 'R':  # right
                if cfg.FEEDBACK_TYPE == 'BAR':
                    visual.move('R', 100)
                else:
                    visual.put_text('RIGHT')
                trigger.signal(tdef.RIGHT_READY)
            elif dir == 'U':  # up
                if cfg.FEEDBACK_TYPE == 'BAR':
                    visual.move('U', 100)
                else:
                    visual.put_text('UP')
                trigger.signal(tdef.UP_READY)
            elif dir == 'D':  # down
                if cfg.FEEDBACK_TYPE == 'BAR':
                    visual.move('D', 100)
                else:
                    visual.put_text('DOWN')
                trigger.signal(tdef.DOWN_READY)
            elif dir == 'B':  # both hands
                if cfg.FEEDBACK_TYPE == 'BAR':
                    visual.move('L', 100)
                    visual.move('R', 100)
                else:
                    visual.put_text('BOTH')
                trigger.signal(tdef.BOTH_READY)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
            gait_steps = 1
            timer_trigger.reset()
        elif state == 'dir_r' and timer_trigger.sec() > cfg.T_DIR_READY:
            visual.draw_cue()
            state = 'dir'
            timer_trigger.reset()
            timer_dir.reset()
            t_step = cfg.T_DIR + random.random() * cfg.RANDOMIZE_LENGTH
            if dir == 'L':  # left
                trigger.signal(tdef.LEFT_GO)
            elif dir == 'R':  # right
                trigger.signal(tdef.RIGHT_GO)
            elif dir == 'U':  # up
                trigger.signal(tdef.UP_GO)
            elif dir == 'D':  # down
                trigger.signal(tdef.DOWN_GO)
            elif dir == 'B':  # both
                trigger.signal(tdef.BOTH_GO)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
        elif state == 'dir':
            if timer_trigger.sec() > t_step:
                if cfg.FEEDBACK_TYPE == 'BODY':
                    if cfg.WITH_STIMO is True:
                        if dir == 'L':  # left
                            ser.write(b'1')
                            qc.print_c('STIMO: Sent 1', 'g')
                            trigger.signal(tdef.LEFT_STIMO)
                        elif dir == 'R':  # right
                            ser.write(b'2')
                            qc.print_c('STIMO: Sent 2', 'g')
                            trigger.signal(tdef.RIGHT_STIMO)
                    else:
                        if dir == 'L':  # left
                            trigger.signal(tdef.LEFT_RETURN)
                        elif dir == 'R':  # right
                            trigger.signal(tdef.RIGHT_RETURN)
                else:
                    trigger.signal(tdef.FEEDBACK)
                state = 'return'
                timer_trigger.reset()
            else:
                dx = min(100, int(100.0 * timer_dir.sec() / t_step) + 1)
                if dir == 'L':  # L
                    visual.move('L', dx, overlay=True)
                elif dir == 'R':  # R
                    visual.move('R', dx, overlay=True)
                elif dir == 'U':  # U
                    visual.move('U', dx, overlay=True)
                elif dir == 'D':  # D
                    visual.move('D', dx, overlay=True)
                elif dir == 'B':  # Both
                    visual.move('L', dx, overlay=True)
                    visual.move('R', dx, overlay=True)
        elif state == 'return':
            if timer_trigger.sec() > cfg.T_RETURN:
                if gait_steps < cfg.GAIT_STEPS:
                    gait_steps += 1
                    state = 'dir'
                    visual.move('L', 0)
                    if dir == 'L':
                        dir = 'R'
                        trigger.signal(tdef.RIGHT_GO)
                    else:
                        dir = 'L'
                        trigger.signal(tdef.LEFT_GO)
                    timer_dir.reset()
                    t_step = cfg.T_DIR + random.random() * cfg.RANDOMIZE_LENGTH
                else:
                    state = 'gap_s'
                    visual.fill()
                    trial += 1
                    print('trial ' + str(trial - 1) + ' done')
                    trigger.signal(tdef.BLANK)
                timer_trigger.reset()
            else:
                dx = max(
                    0,
                    int(100.0 * (cfg.T_RETURN - timer_trigger.sec()) /
                        cfg.T_RETURN))
                if dir == 'L':  # L
                    visual.move('L', dx, overlay=True)
                elif dir == 'R':  # R
                    visual.move('R', dx, overlay=True)
                elif dir == 'U':  # U
                    visual.move('U', dx, overlay=True)
                elif dir == 'D':  # D
                    visual.move('D', dx, overlay=True)
                elif dir == 'B':  # Both
                    visual.move('L', dx, overlay=True)
                    visual.move('R', dx, overlay=True)

        # wait for start
        if state == 'start':
            visual.put_text('Waiting to start    ')

        visual.update()
        key = 0xFF & cv2.waitKey(1)

        if key == keys['esc']:
            break

    # STIMO protocol
    if cfg.WITH_STIMO is True:
        ser.close()
        print('Closed STIMO serial port %s' % cfg.STIMO_COMPORT)
Beispiel #21
0
import pycnbi
import numpy as np
import pycnbi.utils.q_common as qc
from pycnbi.triggers.trigger_def import trigger_def

tdef = trigger_def('triggerdef_16.ini')

DATA_DIRS = [r'D:\data\MI\rx1\train']
CHANNEL_PICKS = [5, 6, 7, 11]

'''"""""""""""""""""""""""""""
 Epochs and events of interest
"""""""""""""""""""""""""""'''
TRIGGERS = {tdef.LEFT_GO, tdef.RIGHT_GO}
EPOCH = [-2.0, 4.0]
EVENT_FILE = None

'''"""""""""""""""""""""""""""
 Baseline relative to onset while plotting
 None in index 0: beginning of data
 None in index 1: end of data
"""""""""""""""""""""""""""'''
BS_TIMES = (None, 0)

'''"""""""""""""""""""""""""""
 PSD
"""""""""""""""""""""""""""'''
FREQ_RANGE = np.arange(1, 40, 1)

'''"""""""""""""""""""""""""""
 Unit conversion
Beispiel #22
0
    def disp_params(self, cfg_template_module, cfg_module):
        """
        Displays the parameters in the corresponding UI scrollArea.
        cfg = config module
        """

        self.clear_params()
        # Extract the parameters and their possible values from the template modules.
        params = inspect.getmembers(cfg_template_module)

        # Extract the chosen values from the subject's specific module.
        all_chosen_values = inspect.getmembers(cfg_module)

        filePath = self.ui.lineEdit_pathSearch.text()

        # Load channels
        if self.modality == 'train':
            subjectDataPath = '%s/%s/fif' % (os.environ['PYCNBI_DATA'],
                                             filePath.split('/')[-1])
            self.channels = read_params_from_txt(subjectDataPath,
                                                 'channelsList.txt')
        self.directions = ()

        # Iterates over the classes
        for par in range(2):
            param = inspect.getmembers(params[par][1])
            # Create layouts
            layout = QFormLayout()

            # Iterates over the list
            for p in param:
                # Remove useless attributes
                if '__' in p[0]:
                    continue

                # Iterates over the dict
                for key, values in p[1].items():
                    chosen_value = self.extract_value_from_module(
                        key, all_chosen_values)

                    # For the feedback directions [offline and online].
                    if 'DIRECTIONS' in key:
                        self.directions = values

                        if self.modality is 'offline':
                            nb_directions = 4
                            directions = Connect_Directions(
                                key, chosen_value, values, nb_directions)

                        elif self.modality is 'online':
                            cls_path = self.paramsWidgets[
                                'DECODER_FILE'].lineEdit_pathSearch.text()
                            cls = qc.load_obj(cls_path)
                            events = cls[
                                'cls'].classes_  # Finds the events on which the decoder has been trained on
                            events = list(map(int, events))
                            nb_directions = len(events)
                            chosen_events = [
                                event[1] for event in chosen_value
                            ]
                            chosen_value = [val[0] for val in chosen_value]

                            # Need tdef to convert int to str trigger values
                            try:
                                [tdef.by_value(i) for i in events]
                            except:
                                trigger_file = self.extract_value_from_module(
                                    'TRIGGER_FILE', all_chosen_values)
                                tdef = trigger_def(trigger_file)
                                # self.on_guichanges('tdef', tdef)
                                events = [tdef.by_value[i] for i in events]

                            directions = Connect_Directions_Online(
                                key, chosen_value, values, nb_directions,
                                chosen_events, events)

                        directions.signal_paramChanged.connect(
                            self.on_guichanges)
                        self.paramsWidgets.update({key: directions})
                        layout.addRow(key, directions.l)

                    # For providing a folder path.
                    elif 'PATH' in key:
                        pathfolderfinder = PathFolderFinder(
                            key, DEFAULT_PATH, chosen_value)
                        pathfolderfinder.signal_pathChanged.connect(
                            self.on_guichanges)
                        self.paramsWidgets.update({key: pathfolderfinder})
                        layout.addRow(key, pathfolderfinder.layout)
                        continue

                    # For providing a file path.
                    elif 'FILE' in key:
                        pathfilefinder = PathFileFinder(key, chosen_value)
                        pathfilefinder.signal_pathChanged.connect(
                            self.on_guichanges)
                        self.paramsWidgets.update({key: pathfilefinder})
                        layout.addRow(key, pathfilefinder.layout)
                        continue

                    # For the special case of choosing the trigger classes to train on
                    elif 'TRIGGER_DEF' in key:
                        trigger_file = self.extract_value_from_module(
                            'TRIGGER_FILE', all_chosen_values)
                        tdef = trigger_def(trigger_file)
                        # self.on_guichanges('tdef', tdef)
                        nb_directions = 4
                        #  Convert 'None' to real None (real None is removed when selected in the GUI)
                        tdef_values = [
                            None if i == 'None' else i
                            for i in list(tdef.by_name)
                        ]
                        directions = Connect_Directions(
                            key, chosen_value, tdef_values, nb_directions)
                        directions.signal_paramChanged.connect(
                            self.on_guichanges)
                        self.paramsWidgets.update({key: directions})
                        layout.addRow(key, directions.l)
                        continue

                    # To select specific electrodes
                    elif '_CHANNELS' in key or 'CHANNELS_' in key:
                        ch_select = Channel_Select(key, self.channels,
                                                   chosen_value)
                        ch_select.signal_paramChanged.connect(
                            self.on_guichanges)
                        self.paramsWidgets.update({key: ch_select})
                        layout.addRow(key, ch_select.layout)

                    elif 'BIAS' in key:
                        #  Add None to the list in case of no bias wanted
                        self.directions = tuple([None] + list(self.directions))
                        bias = Connect_Bias(key, self.directions, chosen_value)
                        bias.signal_paramChanged.connect(self.on_guichanges)
                        self.paramsWidgets.update({key: bias})
                        layout.addRow(key, bias.l)

                    # For all the int values.
                    elif values is int:
                        spinBox = Connect_SpinBox(key, chosen_value)
                        spinBox.signal_paramChanged.connect(self.on_guichanges)
                        self.paramsWidgets.update({key: spinBox})
                        layout.addRow(key, spinBox.w)
                        continue

                    # For all the float values.
                    elif values is float:
                        doublespinBox = Connect_DoubleSpinBox(
                            key, chosen_value)
                        doublespinBox.signal_paramChanged.connect(
                            self.on_guichanges)
                        self.paramsWidgets.update({key: doublespinBox})
                        layout.addRow(key, doublespinBox.w)
                        continue

                    # For parameters with multiple non-fixed values in a list (user can modify them)
                    elif values is list:
                        modifiable_list = Connect_Modifiable_List(
                            key, chosen_value)
                        modifiable_list.signal_paramChanged.connect(
                            self.on_guichanges)
                        self.paramsWidgets.update({key: modifiable_list})
                        layout.addRow(key, modifiable_list.frame)
                        continue

                    #  For parameters containing a string to modify
                    elif values is str:
                        lineEdit = Connect_LineEdit(key, chosen_value)
                        lineEdit.signal_paramChanged[str, str].connect(
                            self.on_guichanges)
                        lineEdit.signal_paramChanged[str, type(None)].connect(
                            self.on_guichanges)
                        self.paramsWidgets.update({key: lineEdit})
                        layout.addRow(key, lineEdit.w)
                        continue

                    # For parameters with multiple fixed values.
                    elif type(values) is tuple:
                        comboParams = Connect_ComboBox(key, chosen_value,
                                                       values)
                        comboParams.signal_paramChanged.connect(
                            self.on_guichanges)
                        comboParams.signal_additionalParamChanged.connect(
                            self.on_guichanges)
                        self.paramsWidgets.update({key: comboParams})
                        layout.addRow(key, comboParams.layout)
                        continue

                    # For parameters with multiple non-fixed values in a dict (user can modify them)
                    elif type(values) is dict:
                        try:
                            selection = chosen_value['selected']
                            comboParams = Connect_ComboBox(
                                key, chosen_value, values)
                            comboParams.signal_paramChanged.connect(
                                self.on_guichanges)
                            comboParams.signal_additionalParamChanged.connect(
                                self.on_guichanges)
                            self.paramsWidgets.update({key: comboParams})
                            layout.addRow(key, comboParams.layout)

                        except:
                            modifiable_dict = Connect_Modifiable_Dict(
                                key, chosen_value, values)
                            modifiable_dict.signal_paramChanged.connect(
                                self.on_guichanges)
                            self.paramsWidgets.update({key: modifiable_dict})
                            layout.addRow(key, modifiable_dict.frame)
                        continue

                # Add a horizontal line to separate parameters' type.
                if p != param[-1]:
                    separator = QFrame()
                    separator.setFrameShape(QFrame.HLine)
                    separator.setFrameShadow(QFrame.Sunken)
                    layout.addRow(separator)

                # Display the parameters according to their types.
                if params[par][0] == 'Basic':
                    self.ui.scrollAreaWidgetContents_Basics.setLayout(layout)
                elif params[par][0] == 'Advanced':
                    self.ui.scrollAreaWidgetContents_Adv.setLayout(layout)