def run(cfg, state=mp.Value('i', 1), queue=None):
    """
    Online protocol for Alpha/Theta neurofeedback.
    """
    redirect_stdout_to_queue(logger, queue, 'INFO')
    
    # Wait the recording to start (GUI)
    while state.value == 2: # 0: stop, 1:start, 2:wait
        pass

    # Protocol runs if state equals to 1
    if not state.value:
        sys.exit(-1)
    
    #----------------------------------------------------------------------
    # LSL stream connection
    #----------------------------------------------------------------------
    # chooose amp   
    amp_name, amp_serial = find_lsl_stream(cfg, state)
    
    # Connect to lsl stream
    sr = connect_lsl_stream(cfg, amp_name, amp_serial)
    
    # Get sampling rate
    sfreq = sr.get_sample_rate()
    
    # Get trigger channel
    trg_ch = sr.get_trigger_channel()
    
   
    #----------------------------------------------------------------------
    # Main
    #----------------------------------------------------------------------
    global_timer = qc.Timer(autoreset=False)
    internal_timer = qc.Timer(autoreset=True)
    
    while state.value == 1 and global_timer.sec() < cfg.GLOBAL_TIME:
        
        #----------------------------------------------------------------------
        # Data acquisition
        #----------------------------------------------------------------------        
        sr.acquire()
        window, tslist = sr.get_window()    # window = [samples x channels]
        window = window.T                   # window = [channels x samples]
               
        # Check if proper real-time acquisition
        tsnew = np.where(np.array(tslist) > last_ts)[0]
        if len(tsnew) == 0:
            logger.warning('There seems to be delay in receiving data.')
            time.sleep(1)
            continue
    
        #----------------------------------------------------------------------
        # ADD YOUR CODE HERE
        #----------------------------------------------------------------------
        
    
        
        last_ts = tslist[-1]
        internal_timer.sleep_atleast(cfg.TIMER_SLEEP)
Пример #2
0
    def __init__(self, cfg, viz, tdef, trigger, logfile=None):
        self.cfg = cfg
        self.tdef = tdef
        self.trigger = trigger
        self.viz = viz
        self.viz.fill()
        self.refresh_delay = 1.0 / self.cfg.REFRESH_RATE
        self.bar_step_left = self.cfg.BAR_STEP['left']
        self.bar_step_right = self.cfg.BAR_STEP['right']
        self.bar_step_up = self.cfg.BAR_STEP['up']
        self.bar_step_down = self.cfg.BAR_STEP['down']
        self.bar_step_both = self.cfg.BAR_STEP['both']

        if type(self.cfg.BAR_BIAS) is tuple:
            self.bar_bias = list(self.cfg.BAR_BIAS)
        else:
            self.bar_bias = self.cfg.BAR_BIAS

        # New decoder: already smoothed by the decoder so bias after.
        #self.alpha_old = self.cfg.PROB_ACC_ALPHA
        #self.alpha_new = 1.0 - self.cfg.PROB_ACC_ALPHA

        if hasattr(self.cfg,
                   'BAR_REACH_FINISH') and self.cfg.BAR_REACH_FINISH == True:
            self.premature_end = True
        else:
            self.premature_end = False

        self.tm_trigger = qc.Timer()
        self.tm_display = qc.Timer()
        self.tm_watchdog = qc.Timer()
        if logfile is not None:
            self.logf = open(logfile, 'w')
        else:
            self.logf = None

        # STIMO only
        if self.cfg.WITH_STIMO is True:
            if self.cfg.STIMO_COMPORT is None:
                atens = [x for x in serial.tools.list_ports.grep('ATEN')]
                if len(atens) == 0:
                    raise RuntimeError('No ATEN device found. Stop.')
                try:
                    self.stimo_port = atens[0].device
                except AttributeError:  # depends on Python distribution
                    self.stimo_port = atens[0][0]
            else:
                self.stimo_port = self.cfg.STIMO_COMPORT
            self.ser = serial.Serial(self.stimo_port, self.cfg.STIMO_BAUDRATE)
            logger.info('STIMO serial port %s is_open = %s' %
                        (self.stimo_port, self.ser.is_open))

        # FES only
        if self.cfg.WITH_FES is True:
            self.stim = fes.Motionstim8()
            self.stim.OpenSerialPort(self.cfg.FES_COMPORT)
            self.stim.InitializeChannelListMode()
            logger.info('Opened FES serial port')
Пример #3
0
def main():
    fmin = 1
    fmax = 40
    channels = 64
    wlen = 0.5  # window length in seconds
    sfreq = 512
    num_iterations = 500

    signal = np.random.rand(channels, int(np.round(sfreq * wlen)))
    psde = mne.decoding.PSDEstimator(sfreq=sfreq, fmin=fmin,\
        fmax=fmax, bandwidth=None, adaptive=False, low_bias=True,\
        n_jobs=1, normalization='length', verbose=None)

    tm = qc.Timer()
    times = []
    for i in range(num_iterations):
        tm.reset()
        psd = psde.transform(
            signal.reshape((1, signal.shape[0], signal.shape[1])))
        times.append(tm.msec())
        if i % 100 == 0:
            logger.info('%d / %d' % (i, num_iterations))
    ms = np.mean(times)
    fps = 1000 / ms
    logger.info('Average = %.1f ms (%.1f Hz)' % (ms, fps))
Пример #4
0
 def __init__(self, mock=False):
     self.BUFFER_SIZE = 1024
     self.last_dir = 'L'
     self.timer = qc.Timer(autoreset=True)
     self.mock = mock
     if self.mock:
         self.print('Using a fake, mock Glass control object.')
Пример #5
0
def fit_predict_thres(cls,
                      X_train,
                      Y_train,
                      X_test,
                      Y_test,
                      cnum,
                      label_list,
                      ignore_thres=None,
                      decision_thres=None):
    """
    Any likelihood lower than a threshold is not counted as classification score
    Confusion matrix, accuracy and F1 score (macro average) are computed.

    Params
    ======
    ignore_thres:
    if not None or larger than 0, likelihood values lower than ignore_thres will be ignored
    while computing confusion matrix.

    """
    timer = qc.Timer()
    cls.fit(X_train, Y_train)
    assert ignore_thres is None or ignore_thres >= 0
    if ignore_thres is None or ignore_thres == 0:
        Y_pred = cls.predict(X_test)
        score = skmetrics.accuracy_score(Y_test, Y_pred)
        cm = skmetrics.confusion_matrix(Y_test, Y_pred, label_list)
        f1 = skmetrics.f1_score(Y_test, Y_pred, average='macro')
    else:
        if decision_thres is not None:
            logger.error(
                'decision threshold and ignore_thres cannot be set at the same time.'
            )
            raise ValueError
        Y_pred = cls.predict_proba(X_test)
        Y_pred_labels = np.argmax(Y_pred, axis=1)
        Y_pred_maxes = np.array([x[i] for i, x in zip(Y_pred_labels, Y_pred)])
        Y_index_overthres = np.where(Y_pred_maxes >= ignore_thres)[0]
        Y_index_underthres = np.where(Y_pred_maxes < ignore_thres)[0]
        Y_pred_overthres = np.array(
            [cls.classes_[x] for x in Y_pred_labels[Y_index_overthres]])
        Y_pred_underthres = np.array(
            [cls.classes_[x] for x in Y_pred_labels[Y_index_underthres]])
        Y_pred_underthres_count = np.array(
            [np.count_nonzero(Y_pred_underthres == c) for c in label_list])
        Y_test_overthres = Y_test[Y_index_overthres]
        score = skmetrics.accuracy_score(Y_test_overthres, Y_pred_overthres)
        cm = skmetrics.confusion_matrix(Y_test_overthres, Y_pred_overthres,
                                        label_list)
        cm = np.concatenate((cm, Y_pred_underthres_count[:, np.newaxis]),
                            axis=1)
        f1 = skmetrics.f1_score(Y_test_overthres,
                                Y_pred_overthres,
                                average='macro')

    logger.info('Cross-validation %d (%.3f) - %.1f sec' %
                (cnum, score, timer.sec()))
    return score, cm, f1
Пример #6
0
    def init_timer(self):

        self.tm = qc.Timer()  # leeq

        QtCore.QCoreApplication.processEvents()
        QtCore.QCoreApplication.flush()
        self.timer = QtCore.QTimer(self)
        self.timer.timeout.connect(self.update_loop)
        self.timer.start(20)
Пример #7
0
    def init_timer(self):
        '''
        Initializes the QT timer, which will call the update function every 20 ms
        '''
        self.tm = qc.Timer()  # leeq

        QtCore.QCoreApplication.processEvents()
        QtCore.QCoreApplication.flush()
        self.timer = QtCore.QTimer(self)
        self.timer.timeout.connect(self.update_loop)
        self.timer.start(20)
Пример #8
0
def get_predict_proba(cls, X_train, Y_train, X_test, Y_test, cnum):
    """
    All likelihoods will be collected from every fold of a cross-validaiton. Based on these likelihoods,
    a threshold will be computed that will balance the true positive rate of each class.
    Available with binary classification scenario only.
    """
    timer = qc.Timer()
    cls.fit(X_train, Y_train)
    Y_pred = cls.predict_proba(X_test)
    logger.info('Cross-validation %d (%d tests) - %.1f sec' %
                (cnum, Y_pred.shape[0], timer.sec()))
    return Y_pred[:, 0]
Пример #9
0
    def __init__(self,
                 window_size=1,
                 buffer_size=1,
                 amp_serial=None,
                 eeg_only=False,
                 amp_name=None):
        """
        Params:
            window_size (in seconds): keep the latest window_size seconds of the buffer.
            buffer_size (in seconds): 1-day is the maximum size. Large buffer may lead to a delay if not pulled frequently.
            amp_name: connect to a server named 'amp_name'. None: no constraint.
            amp_serial: connect to a server with serial number 'amp_serial'. None: no constraint.
            eeg_only: ignore non-EEG servers
        """
        _MAX_BUFFER_SIZE = 86400  # max buffer size allowed by StreamReceiver (24 hours)
        _MAX_PYLSL_STREAM_BUFSIZE = 360  # max buffer size for pylsl.StreamInlet

        if window_size <= 0:
            logger.error('Wrong window_size %d.' % window_size)
            raise ValueError()
        self.winsec = window_size
        if buffer_size == 0:
            buffer_size = _MAX_BUFFER_SIZE
        elif buffer_size < 0 or buffer_size > _MAX_BUFFER_SIZE:
            logger.error('Improper buffer size %.1f. Setting to %d.' %
                         (buffer_size, _MAX_BUFFER_SIZE))
            buffer_size = _MAX_BUFFER_SIZE
        elif buffer_size < self.winsec:
            logger.error(
                'Buffer size %.1f is smaller than window size. Setting to %.1f.'
                % (buffer_size, self.winsec))
            buffer_size = self.winsec
        self.bufsec = buffer_size
        self.bufsize = 0  # to be calculated using sampling rate
        self.stream_bufsec = int(
            math.ceil(min(_MAX_PYLSL_STREAM_BUFSIZE, self.bufsec)))
        self.stream_bufsize = 0  # to be calculated using sampling rate
        self.amp_serial = amp_serial
        self.eeg_only = eeg_only
        self.amp_name = amp_name
        self.tr_channel = None  # trigger indx used by StreamReceiver class
        self.eeg_channels = []  # signal indx used by StreamReceiver class
        self._lsl_tr_channel = None  # raw trigger indx in pylsl.pull_chunk()
        self._lsl_eeg_channels = []  # raw signal indx in pylsl.pull_chunk()
        self.ready = False  # False until the buffer is filled for the first time
        self.connected = False
        self.buffers = []
        self.timestamps = []
        self.watchdog = qc.Timer()
        self.multiplier = 1  # 10**6 for uV unit (automatically updated for openvibe servers)

        self.connect()
Пример #10
0
def test_receiver():
    import mne
    import os

    CH_INDEX = [1] # channel to monitor
    TIME_INDEX = None # integer or None. None = average of raw values of the current window
    SHOW_PSD = False
    mne.set_log_level('ERROR')
    os.environ['OMP_NUM_THREADS'] = '1' # actually improves performance for multitaper

    # connect to LSL server
    amp_name, amp_serial = pu.search_lsl()
    sr = StreamReceiver(window_size=1, buffer_size=1, amp_serial=amp_serial, eeg_only=False, amp_name=amp_name)
    sfreq = sr.get_sample_rate()
    trg_ch = sr.get_trigger_channel()
    logger.info('Trigger channel = %d' % trg_ch)

    # PSD init
    if SHOW_PSD:
        psde = mne.decoding.PSDEstimator(sfreq=sfreq, fmin=1, fmax=50, bandwidth=None, \
            adaptive=False, low_bias=True, n_jobs=1, normalization='length', verbose=None)

    watchdog = qc.Timer()
    tm = qc.Timer(autoreset=True)
    last_ts = 0
    while True:
        sr.acquire()
        window, tslist = sr.get_window() # window = [samples x channels]
        window = window.T # chanel x samples

        qc.print_c('LSL Diff = %.3f' % (pylsl.local_clock() - tslist[-1]), 'G')

        # print event values
        tsnew = np.where(np.array(tslist) > last_ts)[0]
        if len(tsnew) == 0:
            logger.warning('There seems to be delay in receiving data.')
            time.sleep(1)
            continue
        trigger = np.unique(window[trg_ch, tsnew[0]:])

        # for Biosemi
        # if sr.amp_name=='BioSemi':
        #    trigger= set( [255 & int(x-1) for x in trigger ] )

        if len(trigger) > 0:
            logger.info('Triggers: %s' % np.array(trigger))

        logger.info('[%.1f] Receiving data...' % watchdog.sec())

        if TIME_INDEX is None:
            datatxt = qc.list2string(np.mean(window[CH_INDEX, :], axis=1), '%-15.6f')
            print('[%.3f : %.3f]' % (tslist[0], tslist[-1]) + ' data: %s' % datatxt)
        else:
            datatxt = qc.list2string(window[CH_INDEX, TIME_INDEX], '%-15.6f')
            print('[%.3f]' % tslist[TIME_INDEX] + ' data: %s' % datatxt)

        # show PSD
        if SHOW_PSD:
            psd = psde.transform(window.reshape((1, window.shape[0], window.shape[1])))
            psd = psd.reshape((psd.shape[1], psd.shape[2]))
            psdmean = np.mean(psd, axis=1)
            for p in psdmean:
                print('%.1f' % p, end=' ')

        last_ts = tslist[-1]
        tm.sleep_atleast(0.05)
Пример #11
0
def run(cfg, amp_name, amp_serial, state=mp.Value('i', 1), experiment_mode=True, baseline=False):
    """
    Online protocol for Alpha/Theta neurofeedback.
    """
    #----------------------------------------------------------------------
    # LSL stream connection
    #----------------------------------------------------------------------

    sr = protocol_utils.connect_lsl_stream(amp_name=amp_name,
                                           amp_serial=amp_serial,
                                           window_size=cfg['window_size'],
                                           buffer_size=cfg['buffer_size'])

    sfreq = sr.get_sample_rate()
    trg_ch = sr.get_trigger_channel()

    #----------------------------------------------------------------------
    # PSD estimators initialization
    #----------------------------------------------------------------------
    psde_alpha = protocol_utils.init_psde(*list(cfg['alpha_band_freq'].values()),
                                          sampling_frequency=cfg['sampling_frequency'],
                                          n_jobs=cfg['n_jobs'])
    psde_theta = protocol_utils.init_psde(*list(cfg['theta_band_freq'].values()),
                                          sampling_frequency=cfg['sampling_frequency'],
                                          n_jobs=cfg['n_jobs'])
    #----------------------------------------------------------------------
    # Initialize the feedback sounds
    #----------------------------------------------------------------------
    sound_1, sound_2 = protocol_utils.init_feedback_sounds(cfg['music_state_1_path'],
                                                           cfg['music_state_2_path'])


    #----------------------------------------------------------------------
    # Main
    #----------------------------------------------------------------------
    global_timer   = qc.Timer(autoreset=False)
    internal_timer = qc.Timer(autoreset=True)

    pgmixer.init()
    if experiment_mode:
        # Init trigger communication
        trigger_signals = trigger_def(cfg['trigger_file'])
        trigger = pyLptControl.Trigger(state, cfg['trigger_device'])
        if trigger.init(50) == False:
            logger.error('\n** Error connecting to trigger device.')
            raise RuntimeError


        # Preload the starting voice
        print(cfg['start_voice_file'])
        pgmixer.music.load(cfg['start_voice_file'])

        # Init feedback
        viz = BarVisual(False,
                        screen_pos=cfg['screen_pos'],
                        screen_size=cfg['screen_size'])

        viz.fill()
        viz.put_text('Close your eyes and relax')
        viz.update()

        pgmixer.music.play()

        # Wait a key press
        key = 0xFF & cv2.waitKey(0)
        if key == KEYS['esc'] or not state.value:
            sys.exit(-1)

        print('recording started')

        trigger.signal(trigger_signals.INIT)

    state = 'RATIO_FEEDBACK'

    if not baseline:
        sound_1.play(loops=-1)
        sound_2.play(loops=-1)

    current_max = 0
    last_ts = None

    last_ratio = None
    measured_psd_ratios = np.full(cfg['window_size_psd_max'], np.nan)

    while global_timer.sec() < cfg['global_time']:

        #----------------------------------------------------------------------
        # Data acquisition
        #----------------------------------------------------------------------
        #  Pz = 8
        sr.acquire()
        window, tslist = sr.get_window()    # window = [samples x channels]
        window = window.T                   # window = [channels x samples]

        # Check if proper real-time acquisition
        if last_ts is not None:
            tsnew = np.where(np.array(tslist) > last_ts)[0]
            if len(tsnew) == 0:
                logger.warning('There seems to be delay in receiving data.')
                time.sleep(1)
                continue


        # Spatial filtering
        window = pu.preprocess(window,
                               sfreq=sfreq,
                               spatial=cfg.get('spatial_filter'),
                               spatial_ch=cfg.get('spatial_channels'))

        #----------------------------------------------------------------------
        # Computing the Power Spectrum Densities using multitapers
        #----------------------------------------------------------------------
        # PSD
        if not baseline:
            if cfg['feature_type'] == FeatureType.THETA:
                psd_theta = protocol_utils.compute_psd(window, psde_theta)

                feature = psd_theta
            elif cfg['feature_type'] == FeatureType.ALPHA_THETA:
                psd_alpha = protocol_utils.compute_psd(window, psde_alpha)
                psd_theta = protocol_utils.compute_psd(window, psde_theta)

                feature = psd_alpha / psd_theta

            measured_psd_ratios = add_to_queue(measured_psd_ratios, feature)
            current_music_ratio  = feature / np.max(measured_psd_ratios[~np.isnan(measured_psd_ratios)])

            #current_music_ratio  = feature / np.max(measured_psd_ratios)

            if last_ratio is not None:
                applied_music_ratio = last_ratio + (current_music_ratio - last_ratio) * 0.25
            else:
                applied_music_ratio = current_music_ratio

            mix_sounds(style=cfg['music_mix_style'],
                       sounds=(sound_1, sound_2),
                       feature_value=applied_music_ratio)

            print((f"{cfg['feature_type']}: {feature:0.3f}"
                   f"\t, current_music_ratio: {current_music_ratio:0.3f}"
                   f"\t, applied music ratio: {applied_music_ratio:0.3f}"
                   ))

            last_ratio = applied_music_ratio


        last_ts = tslist[-1]
        internal_timer.sleep_atleast(cfg['timer_sleep'])


    if not baseline:
        sound_1.fadeout(3)
        sound_2.fadeout(3)


    if experiment_mode:
        trigger.signal(trigger_signals.END)

        # Remove the text
        viz.fill()
        viz.put_text('Recording is finished')
        viz.update()

        # Ending voice
        pgmixer.music.load(cfg['end_voice_file'])
        pgmixer.music.play()
        time.sleep(5)

        # Close cv2 window
        viz.finish()

    print('done')
Пример #12
0
def raw2psd(rawfile=None,
            fmin=1,
            fmax=40,
            wlen=0.5,
            wstep=1,
            tmin=0.0,
            tmax=None,
            channel_picks=None,
            excludes=[],
            n_jobs=1):
    """
    Compute PSD features over a sliding window on the entire raw file.
    Leading edge of the window is the time reference, i.e. do not use future data.

    Input
    =====
    rawfile: fif file.
    channel_picks: None or list of channel names
    tmin (sec): start time of the PSD window relative to the event onset.
    tmax (sec): end time of the PSD window relative to the event onset. None = until the end.
    fmin (Hz): minimum PSD frequency
    fmax (Hz): maximum PSD frequency
    wlen (sec): sliding window length for computing PSD (sec)
    wstep (int): sliding window step (time samples)
    excludes (list): list of channels to exclude
    """

    raw, eve = pu.load_raw(rawfile)
    sfreq = raw.info['sfreq']
    wframes = int(round(sfreq * wlen))
    raw_eeg = raw.pick_types(meg=False, eeg=True, stim=False, exclude=excludes)
    if channel_picks is None:
        rawdata = raw_eeg._data
        chlist = raw.ch_names
    else:
        chlist = []
        for ch in channel_picks:
            chlist.append(raw.ch_names.index(ch))
        rawdata = raw_eeg._data[np.array(chlist)]

    if tmax is None:
        t_end = rawdata.shape[1]
    else:
        t_end = int(round(tmax * sfreq))
    t_start = int(round(tmin * sfreq)) + wframes
    psde = mne.decoding.PSDEstimator(sfreq, fmin=fmin, fmax=fmax, n_jobs=1,\
        bandwidth=None, low_bias=True, adaptive=False, normalization='length',
        verbose=None)
    print('[PID %d] %s' % (os.getpid(), rawfile))
    psd_all = []
    evelist = []
    times = []
    t_len = t_end - t_start
    last_eve = 0
    y_i = 0
    t_last = t_start
    tm = qc.Timer()
    for t in range(t_start, t_end, wstep):
        # compute PSD
        window = rawdata[:, t - wframes:t]
        psd = psde.transform(
            window.reshape((1, window.shape[0], window.shape[1])))
        psd = psd.reshape(psd.shape[1], psd.shape[2])
        psd_all.append(psd)
        times.append(t)

        # matching events at the current window
        if y_i < eve.shape[0] and t >= eve[y_i][0]:
            last_eve = eve[y_i][2]
            y_i += 1
        evelist.append(last_eve)

        if tm.sec() >= 1:
            perc = (t - t_start) / t_len
            fps = (t - t_last) / wstep
            est = (t_end - t) / wstep / fps
            logger.info('[PID %d] %.1f%% (%.1f FPS, %ds left)' %
                        (os.getpid(), perc * 100.0, fps, est))
            t_last = t
            tm.reset()
    logger.info('Finished.')

    # export data
    try:
        chnames = [raw.ch_names[ch] for ch in chlist]
        psd_all = np.array(psd_all)
        [basedir, fname, fext] = qc.parse_path_list(rawfile)
        fout_header = '%s/psd-%s-header.pkl' % (basedir, fname)
        fout_psd = '%s/psd-%s-data.npy' % (basedir, fname)
        header = {
            'psdfile': fout_psd,
            'times': np.array(times),
            'sfreq': sfreq,
            'channels': chnames,
            'wframes': wframes,
            'events': evelist
        }
        qc.save_obj(fout_header, header)
        np.save(fout_psd, psd_all)
        logger.info('Exported to:\n(header) %s\n(numpy array) %s' %
                    (fout_header, fout_psd))
    except:
        logger.exception('(%s) Unexpected error occurred while exporting data. Dropping you into a shell for recovery.' %\
            os.path.basename(__file__))
        embed()
Пример #13
0
def get_psd(epochs,
            psde,
            wlen,
            wstep,
            picks=None,
            flatten=True,
            preprocess=None,
            decim=1,
            n_jobs=1):
    """
    Compute multi-taper PSDs over a sliding window

    Input
    =====
    epochs: MNE Epochs object
    psde: MNE PSDEstimator object
    wlen: window length in frames
    wstep: window step in frames
    picks: channels to be used; use all if None
    flatten: boolean, see Returns section
    n_jobs: nubmer of cores to use, None = use all cores

    Output
    ======
    if flatten==True:
        X_data: [epochs] x [windows] x [channels*freqs]
    else:
        X_data: [epochs] x [windows] x [channels] x [freqs]
    y_data: [epochs] x [windows]

    TODO:
        Accept input as numpy array as well, in addition to Epochs object
    """

    tm = qc.Timer()

    if n_jobs is None:
        n_jobs = mp.cpu_count()
    if n_jobs > 1:
        logger.info('Opening a pool of %d workers' % n_jobs)
        pool = mp.Pool(n_jobs)

    # compute PSD from sliding windows of each epoch
    labels = epochs.events[:, -1]
    epochs_data = epochs.get_data()
    w_starts = np.arange(0, epochs_data.shape[2] - wlen, wstep)
    X_data = None
    y_data = None
    results = []
    for ep in np.arange(len(labels)):
        title = 'Epoch %d / %d, Frames %d-%d' % (
            ep + 1, len(labels), w_starts[0], w_starts[-1] + wlen - 1)
        if n_jobs == 1:
            # no multiprocessing
            results.append(
                slice_win(epochs_data[ep], w_starts, wlen, psde, picks, title,
                          True, preprocess))
        else:
            # parallel psd computation
            results.append(
                pool.apply_async(slice_win, [
                    epochs_data[ep], w_starts, wlen, psde, picks, title, True,
                    preprocess
                ]))

    for ep in range(len(results)):
        if n_jobs == 1:
            r = results[ep]
        else:
            r = results[ep].get()  # windows x features
        X = r.reshape((1, r.shape[0], r.shape[1]))  # 1 x windows x features
        if X_data is None:
            X_data = X
        else:
            X_data = np.concatenate((X_data, X), axis=0)

        # speed comparison: http://stackoverflow.com/questions/5891410/numpy-array-initialization-fill-with-identical-values
        y = np.empty((1, r.shape[0]))  # 1 x windows
        y.fill(labels[ep])
        if y_data is None:
            y_data = y
        else:
            y_data = np.concatenate((y_data, y), axis=0)

    # close pool
    if n_jobs > 1:
        pool.close()
        pool.join()

    logger.info('Feature computation took %d seconds.' % tm.sec())

    if flatten:
        return X_data, y_data.astype(np.int)
    else:
        xs = X_data.shape
        nch = len(epochs.ch_names)
        return X_data.reshape(xs[0], xs[1], nch,
                              int(xs[2] / nch)), y_data.astype(np.int)
Пример #14
0
def run(cfg, state=mp.Value('i', 1), queue=None):
    """
    Offline protocol
    """

    # visualizer
    keys = {
        'left': 81,
        'right': 83,
        'up': 82,
        'down': 84,
        'pgup': 85,
        'pgdn': 86,
        'home': 80,
        'end': 87,
        'space': 32,
        'esc': 27,
        ',': 44,
        '.': 46,
        's': 115,
        'c': 99,
        '[': 91,
        ']': 93,
        '1': 49,
        '!': 33,
        '2': 50,
        '@': 64,
        '3': 51,
        '#': 35
    }

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass

    # Protocol runs if state equals to 1
    if not state.value:
        sys.exit(-1)

    global_timer = qc.Timer(autoreset=False)

    # Init trigger communication
    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)
    trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        logger.error('\n** Error connecting to trigger device.')
        raise RuntimeError

    # Preload the starting voice
    pgmixer.init()
    pgmixer.music.load(cfg.START_VOICE)

    # Init feedback
    viz = BarVisual(cfg.GLASS_USE,
                    screen_pos=cfg.SCREEN_POS,
                    screen_size=cfg.SCREEN_SIZE)
    viz.fill()
    viz.put_text('Close your eyes and relax')
    viz.update()

    # PLay the start voice
    pgmixer.music.play()

    # Wait a key press
    key = 0xFF & cv2.waitKey(0)
    if key == keys['esc'] or not state.value:
        sys.exit(-1)

    viz.fill()
    viz.put_text('Recording in progress')
    viz.update()

    #----------------------------------------------------------------------
    # Main
    #----------------------------------------------------------------------
    trigger.signal(cfg.tdef.INIT)

    while state.value == 1 and global_timer.sec() < cfg.GLOBAL_TIME:

        key = cv2.waitKey(1)
        if key == keys['esc']:
            with state.get_lock():
                state.value = 0

    trigger.signal(cfg.tdef.END)

    # Remove the text
    viz.fill()
    viz.put_text('Recording is finished')
    viz.update()

    # Ending voice
    pgmixer.music.load(cfg.END_VOICE)
    pgmixer.music.play()
    time.sleep(5)

    # Close cv2 window
    viz.finish()
Пример #15
0
def record(recordState,
           amp_name,
           amp_serial,
           record_dir,
           eeg_only,
           recordLogger=logger,
           queue=None):

    redirect_stdout_to_queue(recordLogger, queue, 'INFO')

    # set data file name
    timestamp = time.strftime('%Y%m%d-%H%M%S', time.localtime())
    pcl_file = "%s/%s-raw.pcl" % (record_dir, timestamp)
    eve_file = '%s/%s-eve.txt' % (record_dir, timestamp)
    recordLogger.info('>> Output file: %s' % (pcl_file))

    # test writability
    try:
        qc.make_dirs(record_dir)
        open(
            pcl_file,
            'w').write('The data will written when the recording is finished.')
    except:
        raise RuntimeError('Problem writing to %s. Check permission.' %
                           pcl_file)

    # start a server for sending out data pcl_file when software trigger is used
    outlet = start_server('StreamRecorderInfo', channel_format='string',\
        source_id=eve_file, stype='Markers')

    # connect to EEG stream server
    sr = StreamReceiver(buffer_size=0,
                        amp_name=amp_name,
                        amp_serial=amp_serial,
                        eeg_only=eeg_only)

    # start recording
    recordLogger.info('\n>> Recording started (PID %d).' % os.getpid())

    with recordState.get_lock():
        recordState.value = 1

    tm = qc.Timer(autoreset=True)
    next_sec = 1
    while recordState.value == 1:
        sr.acquire()
        if sr.get_buflen() > next_sec:
            duration = str(datetime.timedelta(seconds=int(sr.get_buflen())))
            recordLogger.info('RECORDING %s' % duration)
            next_sec += 1
        tm.sleep_atleast(0.001)

    # record stop
    recordLogger.info('>> Stop requested. Copying buffer')
    buffers, times = sr.get_buffer()
    signals = buffers
    events = None

    # channels = total channels from amp, including trigger channel
    data = {
        'signals': signals,
        'timestamps': times,
        'events': events,
        'sample_rate': sr.get_sample_rate(),
        'channels': sr.get_num_channels(),
        'ch_names': sr.get_channel_names(),
        'lsl_time_offset': sr.lsl_time_offset
    }
    recordLogger.info('Saving raw data ...')
    qc.save_obj(pcl_file, data)
    recordLogger.info('Saved to %s\n' % pcl_file)

    # automatically convert to fif and use event file if it exists (software trigger)
    if os.path.exists(eve_file):
        recordLogger.info('Found matching event file, adding events.')
    else:
        eve_file = None
    recordLogger.info('Converting raw file into fif.')
    pcl2fif(pcl_file, external_event=eve_file)
Пример #16
0
    def __init__(self,
                 image_path,
                 use_glass=False,
                 glass_feedback=True,
                 pc_feedback=True,
                 screen_pos=None,
                 screen_size=None):
        """
        Input:
            use_glass: if False, mock Glass will be used
            glass_feedback: show feedback to the user?
            pc_feedback: show feedback on the pc screen?
            screen_pos: screen position in (x,y)
            screen_size: screen size in (x,y)
        """
        # screen size and message setting
        if screen_size is None:
            if sys.platform.startswith('win'):
                from win32api import GetSystemMetrics
                screen_width = GetSystemMetrics(0)
                screen_height = GetSystemMetrics(1)
            else:
                screen_width = 1024
                screen_height = 768
            screen_size = (screen_width, screen_height)
        else:
            screen_width, screen_height = screen_size
        if screen_pos is None:
            screen_x, screen_y = (0, 0)
        else:
            screen_x, screen_y = screen_pos

        self.text_size = 2
        self.img = np.zeros((screen_height, screen_width, 3), np.uint8)
        self.glass = bgi_client.GlassControl(mock=not use_glass)
        self.glass.connect('127.0.0.1', 59900)
        self.set_glass_feedback(glass_feedback)
        self.set_pc_feedback(pc_feedback)
        self.set_cue_color(boxcol='B', crosscol='W')
        self.width = self.img.shape[1]
        self.height = self.img.shape[0]
        hw = int(self.barwidth / 2)
        self.cx = int(self.width / 2)
        self.cy = int(self.height / 2)
        self.xl1 = self.cx - hw
        self.xl2 = self.xl1 - self.barwidth
        self.xr1 = self.cx + hw
        self.xr2 = self.xr1 + self.barwidth
        self.yl1 = self.cy - hw
        self.yl2 = self.yl1 - self.barwidth
        self.yr1 = self.cy + hw
        self.yr2 = self.yr1 + self.barwidth

        if os.path.isdir(image_path):
            # load images
            left_image_path = '%s/left' % image_path
            right_image_path = '%s/right' % image_path
            tm = qc.Timer()
            logger.info('Reading images from %s' % left_image_path)
            self.left_images = read_images(left_image_path, screen_size)
            logger.info('Reading images from %s' % right_image_path)
            self.right_images = read_images(right_image_path, screen_size)
            logger.info('Took %.1f s' % tm.sec())
        else:
            # load pickled images
            # note: this is painfully slow in Pytohn 2 even with cPickle (3s vs 27s)
            assert image_path[-4:] == '.pkl', 'The file must be of .pkl format'
            logger.info('Loading image binary file %s ...' % image_path)
            tm = qc.Timer()
            with gzip.open(image_path, 'rb') as fp:
                image_data = pickle.load(fp)
            self.left_images = image_data['left_images']
            self.right_images = image_data['right_images']
            feedback_w = self.left_images[0].shape[1] / 2
            feedback_h = self.left_images[0].shape[0] / 2
            loc_x = [int(self.cx - feedback_w), int(self.cx + feedback_w)]
            loc_y = [int(self.cy - feedback_h), int(self.cy + feedback_h)]
            img_fit = np.zeros((screen_height, screen_width, 3), np.uint8)

            # adjust to the current screen size
            logger.info('Fitting images into the current screen size')
            for i, img in enumerate(self.left_images):
                img_fit = np.zeros((screen_height, screen_width, 3), np.uint8)
                img_fit[loc_y[0]:loc_y[1], loc_x[0]:loc_x[1]] = img
                self.left_images[i] = img_fit
            for i, img in enumerate(self.right_images):
                img_fit = np.zeros((screen_height, screen_width, 3), np.uint8)
                img_fit[loc_y[0]:loc_y[1], loc_x[0]:loc_x[1]] = img
                self.right_images[i] = img_fit

            logger.info('Took %.1f s' % tm.sec())
            logger.info('Done.')
        cv2.namedWindow("Protocol", cv2.WND_PROP_FULLSCREEN)
        cv2.moveWindow("Protocol", screen_x, screen_y)
        cv2.setWindowProperty("Protocol", cv2.WND_PROP_FULLSCREEN,
                              cv2.WINDOW_FULLSCREEN)
Пример #17
0
 with intel-mkl library:
  sklearn 0.17(pip) = 25.0 Hz, 0.19.1(conda) = 24.0 Hz

512 Hz, 64 channels, 256-sample window:
  sklearn 0.17(pip) = 27.9 Hz, 0.19.1(conda) = 24.0 Hz

@author: leeq
"""

import neurodecode
import numpy as np
import neurodecode.utils.q_common as qc
from neurodecode.decoder.decoder import BCIDecoder, BCIDecoderDaemon

if __name__ == '__main__':
    decoder_file = 'PATH_TO_CLASSIFIER_FILE'
    decoder = BCIDecoder(decoder_file, buffer_size=1.0)
    num_decode = 200
    tm = qc.Timer()
    times = []
    while len(times) < num_decode:
        tm.reset()
        prob = decoder.get_prob()
        times.append(tm.msec())
        if len(times) % 10 == 0:
            print(len(times), end=' ')

    ms = np.mean(times)
    fps = 1000 / ms
    print('\nAverage = %.1f ms (%.1f Hz)' % (ms, fps))
Пример #18
0
def run(cfg, state=mp.Value('i', 1), queue=None):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass
    #  Protocol start if equals to 1
    if not state.value:
        sys.exit()

    refresh_delay = 1.0 / cfg.REFRESH_RATE

    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)

    # visualizer
    keys = {
        'left': 81,
        'right': 83,
        'up': 82,
        'down': 84,
        'pgup': 85,
        'pgdn': 86,
        'home': 80,
        'end': 87,
        'space': 32,
        'esc': 27,
        ',': 44,
        '.': 46,
        's': 115,
        'c': 99,
        '[': 91,
        ']': 93,
        '1': 49,
        '!': 33,
        '2': 50,
        '@': 64,
        '3': 51,
        '#': 35
    }
    color = dict(G=(20, 140, 0),
                 B=(210, 0, 0),
                 R=(0, 50, 200),
                 Y=(0, 215, 235),
                 K=(0, 0, 0),
                 w=(200, 200, 200))

    dir_sequence = []
    for x in range(cfg.TRIALS_EACH):
        dir_sequence.extend(cfg.DIRECTIONS)
    random.shuffle(dir_sequence)
    num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH

    event = 'start'
    trial = 1

    # Hardware trigger
    if cfg.TRIGGER_DEVICE is None:
        logger.warning(
            'No trigger device set. Press Ctrl+C to stop or Enter to continue.'
        )
        #input()
    trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        logger.error(
            '\n** Error connecting to USB2LPT device. Use a mock trigger instead?'
        )
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # timers
    timer_trigger = qc.Timer()
    timer_dir = qc.Timer()
    timer_refresh = qc.Timer()
    t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'],
                                                cfg.TIMINGS['DIR_RANDOMIZE'])
    t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(
        -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

    bar = BarVisual(cfg.GLASS_USE,
                    screen_pos=cfg.SCREEN_POS,
                    screen_size=cfg.SCREEN_SIZE)
    bar.fill()
    bar.glass_draw_cue()

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based)
        if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']:
            event = 'gap_s'
            bar.fill()
            timer_trigger.reset()
            trigger.signal(cfg.tdef.INIT)
        elif event == 'gap_s':
            if cfg.TRIAL_PAUSE:
                bar.put_text('Press any key')
                bar.update()
                key = cv2.waitKey()
                if key == keys['esc'] or not state.value:
                    break
                bar.fill()
            bar.put_text('Trial %d / %d' % (trial, num_trials))
            event = 'gap'
            timer_trigger.reset()
        elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']:
            event = 'cue'
            bar.fill()
            bar.draw_cue()
            trigger.signal(cfg.tdef.CUE)
            timer_trigger.reset()
        elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']:
            event = 'dir_r'
            dir = dir_sequence[trial - 1]
            if dir == 'L':  # left
                bar.move('L', 100, overlay=True)
                trigger.signal(cfg.tdef.LEFT_READY)
            elif dir == 'R':  # right
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.RIGHT_READY)
            elif dir == 'U':  # up
                bar.move('U', 100, overlay=True)
                trigger.signal(cfg.tdef.UP_READY)
            elif dir == 'D':  # down
                bar.move('D', 100, overlay=True)
                trigger.signal(cfg.tdef.DOWN_READY)
            elif dir == 'B':  # both hands
                bar.move('L', 100, overlay=True)
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.BOTH_READY)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
            timer_trigger.reset()
        elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready:
            bar.fill()
            bar.draw_cue()
            event = 'dir'
            timer_trigger.reset()
            timer_dir.reset()
            if dir == 'L':  # left
                trigger.signal(cfg.tdef.LEFT_GO)
            elif dir == 'R':  # right
                trigger.signal(cfg.tdef.RIGHT_GO)
            elif dir == 'U':  # up
                trigger.signal(cfg.tdef.UP_GO)
            elif dir == 'D':  # down
                trigger.signal(cfg.tdef.DOWN_GO)
            elif dir == 'B':  # both
                trigger.signal(cfg.tdef.BOTH_GO)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
        elif event == 'dir' and timer_trigger.sec() > t_dir:
            event = 'gap_s'
            bar.fill()
            trial += 1
            logger.info('trial ' + str(trial - 1) + ' done')
            trigger.signal(cfg.tdef.BLANK)
            timer_trigger.reset()
            t_dir = cfg.TIMINGS['DIR'] + random.uniform(
                -cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
            t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(
                -cfg.TIMINGS['READY_RANDOMIZE'],
                cfg.TIMINGS['READY_RANDOMIZE'])

        # protocol
        if event == 'dir':
            dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1)
            if dir == 'L':  # L
                bar.move('L', dx, overlay=True)
            elif dir == 'R':  # R
                bar.move('R', dx, overlay=True)
            elif dir == 'U':  # U
                bar.move('U', dx, overlay=True)
            elif dir == 'D':  # D
                bar.move('D', dx, overlay=True)
            elif dir == 'B':  # Both
                bar.move('L', dx, overlay=True)
                bar.move('R', dx, overlay=True)

        # wait for start
        if event == 'start':
            bar.put_text('Waiting to start')

        bar.update()
        key = 0xFF & cv2.waitKey(1)
        if key == keys['esc'] or not state.value:
            break

    bar.finish()

    with state.get_lock():
        state.value = 0
Пример #19
0
def print_perc(percent):
    white = (255, 255, 255)
    # create a font object.
    # 1st parameter is the font file
    # which is present in pygame.
    # 2nd parameter is size of the font
    font = pygame.font.Font('freesansbold.ttf', 22)

    # create a text suface object,
    # on which text is drawn on it.
    text = font.render(str(percent / 10) + '%', True, white)

    # create a rectangular object for the
    # text surface object
    textRect = text.get_rect()

    # set the center of the rectangular object.
    textRect.center = (w // 4, h // 4)

    screen.blit(text, textRect)

    #----------------------------------------------------------------------
    # LSL stream connection
    #----------------------------------------------------------------------
    # chooose amp
    amp_name, amp_serial = find_lsl_stream(cfg, state)

    # Connect to lsl stream
    sr = connect_lsl_stream(cfg, amp_name, amp_serial)

    # Get sampling rate
    sfreq = sr.get_sample_rate()

    # Get trigger channel
    trg_ch = sr.get_trigger_channel()

    #----------------------------------------------------------------------
    # Main
    #----------------------------------------------------------------------
    global_timer = qc.Timer(autoreset=False)
    internal_timer = qc.Timer(autoreset=True)

    while state.value == 1 and global_timer.sec() < cfg.GLOBAL_TIME:

        #----------------------------------------------------------------------
        # Data acquisition
        #----------------------------------------------------------------------
        sr.acquire()
        raw, tslist = sr.get_window()  # [samples x channels]
        raw = raw.T  # [channels x samples]

        # Check if proper real-time acquisition
        tsnew = np.where(np.array(tslist) > last_ts)[0]
        if len(tsnew) == 0:
            logger.warning('There seems to be delay in receiving data.')
            time.sleep(1)
            continue

        #----------------------------------------------------------------------
        # Data processing
        #----------------------------------------------------------------------

        # Compute the GFP
        gfp = np.abs(detrend(np.mean(raw, 0)))  # [1 x samples]

        # Find GFP's peak
        gfp_peaks = argrelextrema(gfp, np.greater,
                                  order=15)  # Order needs to be optimized

        # Assign dominant microstate
        count = 0

        micro_template = np.loadtxt("./Maps_4states_s2.txt", dtype=float)

        #  Missing first and last microstate --> TO change
        for p in range(1, len(gfp_peaks[0]) - 1):
            correletion = np.array()

            for i in range(len(micro_template)):
                correletion.append(np.corrcoef(raw[:, p], micro_template[i]))

            if np.argmax(correletion) == cfg.MICRO2REGULATE:
                start = gfp_peaks[0][p - 1] + (gfp_peaks[0][p] -
                                               gfp_peaks[0][p - 1]) / 2
                end = gfp_peaks[0][p] + (gfp_peaks[0][p + 1] -
                                         gfp_peaks[0][p]) / 2
                count = count + len(range(start, end))

        # Percentage of the microstate of interest
        percent = count / raw.shape[1] * 100

        # Feedback

        pygame.init()

        pygame.display.set_caption('EEG microstate')

        background = pygame.image.load('stars2.png')

        background_size = background.get_size()
        background_rect = background.get_rect()
        screen = pygame.display.set_mode(background_size)
        w, h = background_size

        x = 0
        y = 0

        x1 = 0
        y1 = -h

        ship = pygame.image.load("space.png")
        shiprect = ship.get_rect()
        shiprect.center = (w // 2, h // 2)
        running = True
        i = 0

        # set the pygame window name
        pygame.display.set_caption('EEG microstate')
        screen.blit(background, background_rect)

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
        y1 += percent
        y += percent
        screen.blit(background, (x, y))
        screen.blit(background, (x1, y1))
        if y > h:
            y = -h
        if y1 > h:
            y1 = -h

        screen.blit(ship, shiprect)
        print_perc(percent)
        pygame.display.flip()
        pygame.display.update()
Пример #20
0
def cross_validate(cfg, featdata, cv_file=None):
    """
    Perform cross validation
    """
    # Init a classifier
    selected_classifier = cfg.CLASSIFIER['selected']
    if selected_classifier == 'GB':
        cls = GradientBoostingClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER['GB']['learning_rate'],
            presort='auto',
            n_estimators=cfg.CLASSIFIER['GB']['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER['GB']['depth'],
            random_state=cfg.CLASSIFIER['GB']['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False)
    elif selected_classifier == 'XGB':
        cls = XGBClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER['XGB']['learning_rate'],
            presort='auto',
            n_estimators=cfg.CLASSIFIER['XGB']['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER['XGB']['depth'],
            random_state=cfg.CLASSIFIER['XGB'],
            max_features='sqrt',
            verbose=0,
            warm_start=False)
    elif selected_classifier == 'RF':
        cls = RandomForestClassifier(
            n_estimators=cfg.CLASSIFIER['RF']['trees'],
            max_features='auto',
            max_depth=cfg.CLASSIFIER['RF']['depth'],
            n_jobs=cfg.N_JOBS,
            random_state=cfg.CLASSIFIER['RF']['seed'],
            oob_score=False,
            class_weight='balanced_subsample')
    elif selected_classifier == 'LDA':
        cls = LDA()
    elif selected_classifier == 'rLDA':
        cls = rLDA(cfg.CLASSIFIER['rLDA']['r_coeff'])
    else:
        logger.error('Unknown classifier type %s' % selected_classifier)
        raise ValueError

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']

    # Choose CV type
    ntrials, nsamples, fsize = X_data.shape
    selected_cv = cfg.CV_PERFORM['selected']
    if selected_cv == 'LeaveOneOut':
        logger.info_green('%d-fold leave-one-out cross-validation' % ntrials)
        if SKLEARN_OLD:
            cv = LeaveOneOut(len(Y_data))
        else:
            cv = LeaveOneOut()
    elif selected_cv == 'StratifiedShuffleSplit':
        logger.info_green(
            '%d-fold stratified cross-validation with test set ratio %.2f' %
            (cfg.CV_PERFORM[selected_cv]['folds'],
             cfg.CV_PERFORM[selected_cv]['test_ratio']))
        if SKLEARN_OLD:
            cv = StratifiedShuffleSplit(
                Y_data[:, 0],
                cfg.CV_PERFORM[selected_cv]['folds'],
                test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'],
                random_state=cfg.CV_PERFORM[selected_cv]['seed'])
        else:
            cv = StratifiedShuffleSplit(
                n_splits=cfg.CV_PERFORM[selected_cv]['folds'],
                test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'],
                random_state=cfg.CV_PERFORM[selected_cv]['seed'])
    else:
        logger.error('%s is not supported yet. Sorry.' %
                     cfg.CV_PERFORM[cfg.CV_PERFORM['selected']])
        raise NotImplementedError
    logger.info('%d trials, %d samples per trial, %d feature dimension' %
                (ntrials, nsamples, fsize))

    # Do it!
    timer_cv = qc.Timer()
    scores, cm_txt = crossval_epochs(cv,
                                     X_data,
                                     Y_data,
                                     cls,
                                     cfg.tdef.by_value,
                                     cfg.CV['BALANCE_SAMPLES'],
                                     n_jobs=cfg.N_JOBS,
                                     ignore_thres=cfg.CV['IGNORE_THRES'],
                                     decision_thres=cfg.CV['DECISION_THRES'])
    t_cv = timer_cv.sec()

    # Export results
    txt = 'Cross validation took %d seconds.\n' % t_cv
    txt += '\n- Class information\n'
    txt += '%d epochs, %d samples per epoch, %d feature dimension (total %d samples)\n' %\
        (ntrials, nsamples, fsize, ntrials * nsamples)
    for ev in np.unique(Y_data):
        txt += '%s: %d trials\n' % (cfg.tdef.by_value[ev],
                                    len(np.where(Y_data[:, 0] == ev)[0]))
    if cfg.CV['BALANCE_SAMPLES']:
        txt += 'The number of samples was balanced using %ssampling.\n' % cfg.BALANCE_SAMPLES.lower(
        )
    txt += '\n- Experiment condition\n'
    txt += 'Sampling frequency: %.3f Hz\n' % featdata['sfreq']
    txt += 'Spatial filter: %s (channels: %s)\n' % (cfg.SP_FILTER,
                                                    cfg.SP_CHANNELS)
    txt += 'Spectral filter: %s\n' % cfg.TP_FILTER[cfg.TP_FILTER['selected']]
    txt += 'Notch filter: %s\n' % cfg.NOTCH_FILTER[
        cfg.NOTCH_FILTER['selected']]
    txt += 'Channels: ' + ','.join(
        [str(featdata['ch_names'][p]) for p in featdata['picks']]) + '\n'
    txt += 'PSD range: %.1f - %.1f Hz\n' % (cfg.FEATURES['PSD']['fmin'],
                                            cfg.FEATURES['PSD']['fmax'])
    txt += 'Window step: %.2f msec\n' % (
        1000.0 * cfg.FEATURES['PSD']['wstep'] / featdata['sfreq'])
    if type(wlen) is list:
        for i, w in enumerate(wlen):
            txt += 'Window size: %.1f msec\n' % (w * 1000.0)
            txt += 'Epoch range: %s sec\n' % (cfg.EPOCH[i])
    else:
        txt += 'Window size: %.1f msec\n' % (cfg.FEATURES['PSD']['wlen'] *
                                             1000.0)
        txt += 'Epoch range: %s sec\n' % (cfg.EPOCH)
    txt += 'Decimation factor: %d\n' % cfg.FEATURES['PSD']['decim']

    # Compute stats
    cv_mean, cv_std = np.mean(scores), np.std(scores)
    txt += '\n- Average CV accuracy over %d epochs (random seed=%s)\n' % (
        ntrials, cfg.CV_PERFORM[cfg.CV_PERFORM['selected']]['seed'])
    if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] in [
            'LeaveOneOut', 'StratifiedShuffleSplit'
    ]:
        txt += "mean %.3f, std: %.3f\n" % (cv_mean, cv_std)
    txt += 'Classifier: %s, ' % selected_classifier
    if selected_classifier == 'RF':
        txt += '%d trees, %s max depth, random state %s\n' % (
            cfg.CLASSIFIER['RF']['trees'], cfg.CLASSIFIER['RF']['depth'],
            cfg.CLASSIFIER['RF']['seed'])
    elif selected_classifier == 'GB' or selected_classifier == 'XGB':
        txt += '%d trees, %s max depth, %s learing_rate, random state %s\n' % (
            cfg.CLASSIFIER['GB']['trees'], cfg.CLASSIFIER['GB']['depth'],
            cfg.CLASSIFIER['GB']['learning_rate'],
            cfg.CLASSIFIER['GB']['seed'])
    elif selected_classifier == 'rLDA':
        txt += 'regularization coefficient %.2f\n' % cfg.CLASSIFIER['rLDA'][
            'r_coeff']
    if cfg.CV['IGNORE_THRES'] is not None:
        txt += 'Decision threshold: %.2f\n' % cfg.CV['IGNORE_THRES']
    txt += '\n- Confusion Matrix\n' + cm_txt
    logger.info(txt)

    # Export to a file
    if 'export_result' in cfg.CV_PERFORM[selected_cv] and cfg.CV_PERFORM[
            selected_cv]['export_result'] is True:
        if cv_file is None:
            if cfg.EXPORT_CLS is True:
                qc.make_dirs('%s/classifier' % cfg.DATA_PATH)
                fout = open('%s/classifier/cv_result.txt' % cfg.DATA_PATH, 'w')
            else:
                fout = open('%s/cv_result.txt' % cfg.DATA_PATH, 'w')
        else:
            fout = open(cv_file, 'w')
        fout.write(txt)
        fout.close()
Пример #21
0
    def classify(self,
                 decoder,
                 true_label,
                 title_text,
                 bar_dirs,
                 state='start',
                 prob_history=None):
        """
        Run a single trial
        """
        true_label_index = bar_dirs.index(true_label)
        self.tm_trigger.reset()
        if self.bar_bias is not None:
            bias_idx = bar_dirs.index(self.bar_bias[0])

        if self.logf is not None:
            self.logf.write('True label: %s\n' % true_label)

        tm_classify = qc.Timer(autoreset=True)
        self.stimo_timer = qc.Timer()
        while True:
            self.tm_display.sleep_atleast(self.refresh_delay)
            self.tm_display.reset()
            if state == 'start' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['INIT']:
                state = 'gap_s'
                if self.cfg.TRIALS_PAUSE:
                    self.viz.put_text('Press any key')
                    self.viz.update()
                    key = cv2.waitKeyEx()
                    if key == KEYS['esc']:
                        return
                self.viz.fill()
                self.tm_trigger.reset()
                self.trigger.signal(self.tdef.INIT)

            elif state == 'gap_s':
                if self.cfg.TIMINGS['GAP'] > 0:
                    self.viz.put_text(title_text)
                state = 'gap'
                self.tm_trigger.reset()

            elif state == 'gap' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['GAP']:
                state = 'cue'
                self.viz.fill()
                self.viz.draw_cue()
                self.viz.glass_draw_cue()
                self.trigger.signal(self.tdef.CUE)
                self.tm_trigger.reset()

            elif state == 'cue' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['READY']:
                state = 'dir_r'

                if self.cfg.SHOW_CUE is True:
                    if self.cfg.FEEDBACK_TYPE == 'BAR':
                        self.viz.move(true_label,
                                      100,
                                      overlay=False,
                                      barcolor='G')
                    elif self.cfg.FEEDBACK_TYPE == 'BODY':
                        self.viz.put_text(DIRS[true_label], 'R')
                    if true_label == 'L':  # left
                        self.trigger.signal(self.tdef.LEFT_READY)
                    elif true_label == 'R':  # right
                        self.trigger.signal(self.tdef.RIGHT_READY)
                    elif true_label == 'U':  # up
                        self.trigger.signal(self.tdef.UP_READY)
                    elif true_label == 'D':  # down
                        self.trigger.signal(self.tdef.DOWN_READY)
                    elif true_label == 'B':  # both hands
                        self.trigger.signal(self.tdef.BOTH_READY)
                    else:
                        raise RuntimeError('Unknown direction %s' % true_label)
                self.tm_trigger.reset()
                '''
                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    self.viz.set_pc_feedback(False)
                self.viz.move(true_label, 100, overlay=False, barcolor='G')

                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    self.viz.set_pc_feedback(True)
                    if self.cfg.SHOW_CUE is True:
                        self.viz.put_text(dirs[true_label], 'R')

                if true_label == 'L':  # left
                    self.trigger.signal(self.tdef.LEFREADY)
                elif true_label == 'R':  # right
                    self.trigger.signal(self.tdef.RIGHT_READY)
                elif true_label == 'U':  # up
                    self.trigger.signal(self.tdef.UP_READY)
                elif true_label == 'D':  # down
                    self.trigger.signal(self.tdef.DOWN_READY)
                elif true_label == 'B':  # both hands
                    self.trigger.signal(self.tdef.BOTH_READY)
                else:
                    raise RuntimeError('Unknown direction %s' % true_label)
                self.tm_trigger.reset()
                '''
            elif state == 'dir_r' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['DIR_CUE']:
                self.viz.fill()
                self.viz.draw_cue()
                self.viz.glass_draw_cue()
                state = 'dir'

                # initialize bar scores
                bar_label = bar_dirs[0]
                bar_score = 0
                probs = [1.0 / len(bar_dirs)] * len(bar_dirs)
                self.viz.move(bar_label, bar_score, overlay=False)
                probs_acc = np.zeros(len(probs))

                if true_label == 'L':  # left
                    self.trigger.signal(self.tdef.LEFT_GO)
                elif true_label == 'R':  # right
                    self.trigger.signal(self.tdef.RIGHT_GO)
                elif true_label == 'U':  # up
                    self.trigger.signal(self.tdef.UP_GO)
                elif true_label == 'D':  # down
                    self.trigger.signal(self.tdef.DOWN_GO)
                elif true_label == 'B':  # both
                    self.trigger.signal(self.tdef.BOTH_GO)
                else:
                    raise RuntimeError('Unknown truedirection %s' % true_label)

                self.tm_watchdog.reset()
                self.tm_trigger.reset()

            elif state == 'dir':
                if self.tm_trigger.sec() > self.cfg.TIMINGS['CLASSIFY'] or (
                        self.premature_end and bar_score >= 100):
                    if not hasattr(
                            self.cfg,
                            'SHOW_RESULT') or self.cfg.SHOW_RESULT is True:
                        # show classfication result
                        if self.cfg.WITH_STIMO is True:
                            if self.cfg.STIMO_FULLGAIT_CYCLE is not None and bar_label == 'U':
                                res_color = 'G'
                            elif self.cfg.TRIALS_RETRY is False or bar_label == true_label:
                                res_color = 'G'
                            else:
                                res_color = 'Y'
                        else:
                            res_color = 'Y'
                        if self.cfg.FEEDBACK_TYPE == 'BODY':
                            self.viz.move(bar_label,
                                          bar_score,
                                          overlay=False,
                                          barcolor=res_color,
                                          caption=DIRS[bar_label],
                                          caption_color=res_color)
                        else:
                            self.viz.move(bar_label,
                                          100,
                                          overlay=False,
                                          barcolor=res_color)
                    else:
                        if self.cfg.FEEDBACK_TYPE == 'BODY':
                            self.viz.move(bar_label,
                                          bar_score,
                                          overlay=False,
                                          barcolor=res_color,
                                          caption='TRIAL END',
                                          caption_color=res_color)
                        else:
                            self.viz.move(bar_label,
                                          0,
                                          overlay=False,
                                          barcolor=res_color)
                    self.trigger.signal(self.tdef.FEEDBACK)

                    # STIMO
                    if self.cfg.WITH_STIMO is True and self.cfg.STIMO_CONTINUOUS is False:
                        if self.cfg.STIMO_FULLGAIT_CYCLE is not None:
                            if bar_label == 'U':
                                self.ser.write(
                                    self.cfg.STIMO_FULLGAIT_PATTERN[0])
                                logger.info('STIMO: Sent 1')
                                time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE)
                                self.ser.write(
                                    self.cfg.STIMO_FULLGAIT_PATTERN[1])
                                logger.info('STIMO: Sent 2')
                                time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE)
                        elif self.cfg.TRIALS_RETRY is False or bar_label == true_label:
                            if bar_label == 'L':
                                self.ser.write(b'1')
                                logger.info('STIMO: Sent 1')
                            elif bar_label == 'R':
                                self.ser.write(b'2')
                                logger.info('STIMO: Sent 2')

                    # FES event mode mode
                    if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is False:
                        if bar_label == 'L':
                            stim_code = [0, 30, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)
                            logger.info('FES: Sent Left')
                            time.sleep(0.5)
                            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)

                        elif bar_label == 'R':
                            stim_code = [30, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)
                            time.sleep(0.5)
                            logger.info('FES: Sent Right')
                            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)

                    if self.cfg.DEBUG_PROBS:
                        msg = 'DEBUG: Accumulated probabilities = %s' % qc.list2string(
                            probs_acc, '%.3f')
                        logger.info(msg)
                        if self.logf is not None:
                            self.logf.write(msg + '\n')
                    if self.logf is not None:
                        self.logf.write('%s detected as %s (%d)\n\n' %
                                        (true_label, bar_label, bar_score))
                        self.logf.flush()

                    # end of trial
                    state = 'feedback'
                    self.tm_trigger.reset()
                else:
                    # classify
                    probs_new = decoder.get_prob_smooth_unread()
                    if probs_new is None:
                        if self.tm_watchdog.sec() > 3:
                            logger.warning(
                                'No classification being done. Are you receiving data streams?'
                            )
                            self.tm_watchdog.reset()
                    else:
                        self.tm_watchdog.reset()

                        if prob_history is not None:
                            prob_history[true_label].append(
                                probs_new[true_label_index])

                        probs_acc += np.array(probs_new)
                        '''
                        New decoder: already smoothed by the decoder so bias after.
                        '''
                        probs = list(probs_new)
                        if self.bar_bias is not None:
                            probs[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs)
                            probs = [p / newsum for p in probs]
                        '''
                        # Method 2: bias and smoothen
                        if self.bar_bias is not None:
                            # print('BEFORE: %.3f %.3f'% (probs_new[0], probs_new[1]) )
                            probs_new[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs_new)
                            probs_new = [p / newsum for p in probs_new]
                        # print('AFTER: %.3f %.3f'% (probs_new[0], probs_new[1]) )
                        for i in range(len(probs_new)):
                            probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new
                        '''
                        ''' Original method
                        # Method 1: smoothen and bias
                        for i in range( len(probs_new) ):
                            probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new
                        # bias bar
                        if self.bar_bias is not None:
                            probs[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs)
                            probs = [p/newsum for p in probs]
                        '''

                        # determine the direction
                        # TODO: np.argmax(probs)
                        max_pidx = qc.get_index_max(probs)
                        max_label = bar_dirs[max_pidx]

                        if self.cfg.POSITIVE_FEEDBACK is False or \
                                (self.cfg.POSITIVE_FEEDBACK and true_label == max_label):
                            dx = probs[max_pidx]
                            if max_label == 'R':
                                dx *= self.bar_step_right
                            elif max_label == 'L':
                                dx *= self.bar_step_left
                            elif max_label == 'U':
                                dx *= self.bar_step_up
                            elif max_label == 'D':
                                dx *= self.bar_step_down
                            elif max_label == 'B':
                                dx *= self.bar_step_both
                            else:
                                logger.debug('Direction %s using bar step %d' %
                                             (max_label, self.bar_step_left))
                                dx *= self.bar_step_left

                            # slow start
                            selected = self.cfg.BAR_SLOW_START['selected']
                            if self.cfg.BAR_SLOW_START[
                                    selected] and self.tm_trigger.sec(
                                    ) < self.cfg.BAR_SLOW_START[selected]:
                                dx *= self.tm_trigger.sec(
                                ) / self.cfg.BAR_SLOW_START[selected][0]

                            # add likelihoods
                            if max_label == bar_label:
                                bar_score += dx
                            else:
                                bar_score -= dx
                                # change of direction
                                if bar_score < 0:
                                    bar_score = -bar_score
                                    bar_label = max_label
                            bar_score = int(bar_score)
                            if bar_score > 100:
                                bar_score = 100
                            if self.cfg.FEEDBACK_TYPE == 'BODY':
                                if self.cfg.SHOW_CUE:
                                    self.viz.move(bar_label,
                                                  bar_score,
                                                  overlay=False,
                                                  caption=DIRS[true_label],
                                                  caption_color='G')
                                else:
                                    self.viz.move(bar_label,
                                                  bar_score,
                                                  overlay=False)
                            else:
                                self.viz.move(bar_label,
                                              bar_score,
                                              overlay=False)

                            # send the confidence value continuously
                            if self.cfg.WITH_STIMO and self.cfg.STIMO_CONTINUOUS:
                                if self.stimo_timer.sec(
                                ) >= self.cfg.STIMO_COOLOFF:
                                    if bar_label == 'U':
                                        stimo_code = bar_score
                                    else:
                                        stimo_code = 0
                                    self.ser.write(bytes([stimo_code]))
                                    logger.info('Sent STIMO code %d' %
                                                stimo_code)
                                    self.stimo_timer.reset()
                            # with FES
                            if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is True:
                                if self.stimo_timer.sec(
                                ) >= self.cfg.STIMO_COOLOFF:
                                    if bar_label == 'L':
                                        stim_code = [
                                            bar_score, 0, 0, 0, 0, 0, 0, 0
                                        ]
                                    else:
                                        stim_code = [
                                            0, bar_score, 0, 0, 0, 0, 0, 0
                                        ]
                                    self.stim.UpdateChannelSettings(stim_code)
                                    logger.info('Sent FES code %d' % bar_score)
                                    self.stimo_timer.reset()

                        if self.cfg.DEBUG_PROBS:
                            if self.bar_bias is not None:
                                biastxt = '[Bias=%s%.3f]  ' % (
                                    self.bar_bias[0], self.bar_bias[1])
                            else:
                                biastxt = ''
                            msg = '%s%s  prob %s   acc %s   bar %s%d  (%.1f ms)' % \
                                  (biastxt, bar_dirs, qc.list2string(probs_new, '%.2f'), qc.list2string(probs, '%.2f'),
                                   bar_label, bar_score, tm_classify.msec())
                            logger.info(msg)
                            if self.logf is not None:
                                self.logf.write(msg + '\n')

            elif state == 'feedback' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['FEEDBACK']:
                self.trigger.signal(self.tdef.BLANK)
                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    state = 'return'
                    self.tm_trigger.reset()
                else:
                    state = 'gap_s'
                    self.viz.fill()
                    self.viz.update()
                    return bar_label

            elif state == 'return':
                self.viz.set_glass_feedback(False)
                if self.cfg.WITH_STIMO:
                    self.viz.move(bar_label,
                                  bar_score,
                                  overlay=False,
                                  barcolor='B')
                else:
                    self.viz.move(bar_label,
                                  bar_score,
                                  overlay=False,
                                  barcolor='Y')
                self.viz.set_glass_feedback(True)
                bar_score -= 5
                if bar_score <= 0:
                    state = 'gap_s'
                    self.viz.fill()
                    self.viz.update()
                    return bar_label

            self.viz.update()
            key = cv2.waitKeyEx(1)
            if key == KEYS['esc']:
                return
            elif key == KEYS['space']:
                dx = 0
                bar_score = 0
                probs = [1.0 / len(bar_dirs)] * len(bar_dirs)
                self.viz.move(bar_dirs[0], bar_score, overlay=False)
                self.viz.update()
                logger.info('probs and dx reset.')
                self.tm_trigger.reset()
            elif key in ARROW_KEYS and ARROW_KEYS[key] in bar_dirs:
                # change bias on the fly
                if self.bar_bias is None:
                    self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT]
                else:
                    if ARROW_KEYS[key] == self.bar_bias[0]:
                        self.bar_bias[1] += BIAS_INCREMENT
                    elif self.bar_bias[1] >= BIAS_INCREMENT:
                        self.bar_bias[1] -= BIAS_INCREMENT
                    else:
                        self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT]
                if self.bar_bias[1] == 0:
                    self.bar_bias = None
                else:
                    bias_idx = bar_dirs.index(self.bar_bias[0])
Пример #22
0
def train_decoder(cfg, featdata, feat_file=None):
    """
    Train the final decoder using all data
    """
    # Init a classifier
    selected_classifier = cfg.CLASSIFIER['selected']
    if selected_classifier == 'GB':
        cls = GradientBoostingClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER[selected_classifier]['learning_rate'],
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            random_state=cfg.CLASSIFIER[selected_classifier]['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'XGB':
        cls = XGBClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER[selected_classifier]['learning_rate'],
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            random_state=cfg.GB['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'RF':
        cls = RandomForestClassifier(
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            max_features='auto',
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            n_jobs=cfg.N_JOBS,
            random_state=cfg.CLASSIFIER[selected_classifier]['seed'],
            oob_score=False,
            class_weight='balanced_subsample')
    elif selected_classifier == 'LDA':
        cls = LDA()
    elif selected_classifier == 'rLDA':
        cls = rLDA(cfg.CLASSIFIER[selected_classifier][r_coeff])
    else:
        logger.error('Unknown classifier %s' % selected_classifier)
        raise ValueError

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']
    if cfg.FEATURES['PSD']['wlen'] is None:
        cfg.FEATURES['PSD']['wlen'] = wlen
    w_frames = featdata['w_frames']
    ch_names = featdata['ch_names']
    X_data_merged = np.concatenate(X_data)
    Y_data_merged = np.concatenate(Y_data)
    if cfg.CV['BALANCE_SAMPLES']:
        X_data_merged, Y_data_merged = balance_samples(
            X_data_merged,
            Y_data_merged,
            cfg.CV['BALANCE_SAMPLES'],
            verbose=True)

    # Start training the decoder
    logger.info_green('Training the decoder')
    timer = qc.Timer()
    cls.n_jobs = cfg.N_JOBS
    cls.fit(X_data_merged, Y_data_merged)
    logger.info('Trained %d samples x %d dimension in %.1f sec' %\
          (X_data_merged.shape[0], X_data_merged.shape[1], timer.sec()))
    cls.n_jobs = 1  # always set n_jobs=1 for testing

    # Export the decoder
    classes = {c: cfg.tdef.by_value[c] for c in np.unique(Y_data)}
    if cfg.FEATURES['selected'] == 'PSD':
        data = dict(cls=cls,
                    ch_names=ch_names,
                    psde=featdata['psde'],
                    sfreq=featdata['sfreq'],
                    picks=featdata['picks'],
                    classes=classes,
                    epochs=cfg.EPOCH,
                    w_frames=w_frames,
                    w_seconds=cfg.FEATURES['PSD']['wlen'],
                    wstep=cfg.FEATURES['PSD']['wstep'],
                    spatial=cfg.SP_FILTER,
                    spatial_ch=featdata['picks'],
                    spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']],
                    spectral_ch=featdata['picks'],
                    notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']],
                    notch_ch=featdata['picks'],
                    multiplier=cfg.MULTIPLIER,
                    ref_ch=cfg.REREFERENCE[cfg.REREFERENCE['selected']],
                    decim=cfg.FEATURES['PSD']['decim'])
    clsfile = '%s/classifier/classifier-%s.pkl' % (cfg.DATA_PATH,
                                                   platform.architecture()[0])
    qc.make_dirs('%s/classifier' % cfg.DATA_PATH)
    qc.save_obj(clsfile, data)
    logger.info('Decoder saved to %s' % clsfile)

    # Reverse-lookup frequency from FFT
    fq = 0
    if type(cfg.FEATURES['PSD']['wlen']) == list:
        fq_res = 1.0 / cfg.FEATURES['PSD']['wlen'][0]
    else:
        fq_res = 1.0 / cfg.FEATURES['PSD']['wlen']
    fqlist = []
    while fq <= cfg.FEATURES['PSD']['fmax']:
        if fq >= cfg.FEATURES['PSD']['fmin']:
            fqlist.append(fq)
        fq += fq_res

    # Show top distinctive features
    if cfg.FEATURES['selected'] == 'PSD':
        logger.info_green('Good features ordered by importance')
        if selected_classifier in ['RF', 'GB', 'XGB']:
            keys, values = qc.sort_by_value(list(cls.feature_importances_),
                                            rev=True)
        elif selected_classifier in ['LDA', 'rLDA']:
            keys, values = qc.sort_by_value(cls.w, rev=True)
        keys = np.array(keys)
        values = np.array(values)

        if cfg.EXPORT_GOOD_FEATURES:
            if feat_file is None:
                gfout = open('%s/classifier/good_features.txt' % cfg.DATA_PATH,
                             'w')
            else:
                gfout = open(feat_file, 'w')

        if type(wlen) is not list:
            ch_names = [ch_names[c] for c in featdata['picks']]
        else:
            ch_names = []
            for w in range(len(wlen)):
                for c in featdata['picks']:
                    ch_names.append('w%d-%s' % (w, ch_names[c]))

        chlist, hzlist = features.feature2chz(keys, fqlist, ch_names=ch_names)
        valnorm = values[:cfg.FEAT_TOPN].copy()
        valsum = np.sum(valnorm)
        if valsum == 0:
            valsum = 1
        valnorm = valnorm / valsum * 100.0

        # show top-N features
        for i, (ch, hz) in enumerate(zip(chlist, hzlist)):
            if i >= cfg.FEAT_TOPN:
                break
            txt = '%-3s %5.1f Hz  normalized importance %-6s  raw importance %-6s  feature %-5d' %\
                  (ch, hz, '%.2f%%' % valnorm[i], '%.2f%%' % (values[i] * 100.0), keys[i])
            logger.info(txt)

        if cfg.EXPORT_GOOD_FEATURES:
            gfout.write('Importance(%) Channel Frequency Index\n')
            for i, (ch, hz) in enumerate(zip(chlist, hzlist)):
                gfout.write('%.3f\t%s\t%s\t%d\n' %
                            (values[i] * 100.0, ch, hz, keys[i]))
            gfout.close()