示例#1
0
def lsl_channel_list(inlet):
    """
    Reads XML description of LSL header and returns channel list

    Input:
        pylsl.StreamInlet object
    Returns:
        ch_list: [ name1, name2, ... ]
    """
    if not type(inlet) is pylsl.StreamInlet:
        logger.error('lsl_channel_list(): wrong input type %s' % type(inlet))
        raise TypeError
    root = ET.fromstring(inlet.info().as_xml())
    desc = root.find('desc')
    ch_list = []
    for ch in list(desc.find('channels')):
        ch_name = ch.find('label').text
        ch_list.append(ch_name)
    ''' This code may throw access violation error due to bug in pylsl.XMLElement
    # for some reason type(inlet) returns 'instance' type in Python 2.
    ch = inlet.info().desc().child('channels').first_child()
    ch_list = []
    for k in range(inlet.info().channel_count()):
        ch_name = ch.child_value('label')
        ch_list.append(ch_name)
        ch = ch.next_sibling()
    '''
    return ch_list
示例#2
0
def log_decoding_helper(state, event_queue, amp_name=None, amp_serial=None, autostop=False):
    """
    Helper function to run StreamReceiver object in background
    """
    logger.info('Event acquisition subprocess started.')

    # wait for the start signal
    while state.value == 0:
        time.sleep(0.01)
    
    # acquire event values and returns event times and event values
    sr = StreamReceiver(buffer_size=0, amp_name=amp_name, amp_serial=amp_serial)
    tm = qc.Timer(autoreset=True)
    started = False
    while state.value == 1:
        chunk, ts_list = sr.acquire()
        if autostop:
            if started is True:
                if len(ts_list) == 0:
                    state.value = 0
                    break
            elif len(ts_list) > 0:
                started = True
        tm.sleep_atleast(0.001)
    logger.info('Event acquisition subprocess finishing up ...')

    buffers, times = sr.get_buffer()
    events = buffers[:, 0] # first channel is the trigger channel
    event_index = np.where(events != 0)[0]
    event_times = times[event_index].reshape(-1).tolist()
    event_values = events[event_index].tolist()
    if len(event_times) != len(event_values):
        logger.error('event_times length (%d) is different from event_values length (%d)' % (len(event_times), len(event_values)))
    event_queue.put((event_times, event_values))
示例#3
0
def any2fif(filename, interactive=False, outdir=None, channel_file=None):
    """
    Generic file format converter
    """
    p = qc.parse_path(filename)
    if outdir is not None:
        qc.make_dirs(outdir)

    if p.ext == 'pcl':
        eve_file = '%s/%s.txt' % (p.dir, p.name.replace('raw', 'eve'))
        if os.path.exists(eve_file):
            logger.info('Adding events from %s' % eve_file)
        else:
            eve_file = None
        pcl2fif(filename,
                interactive=interactive,
                outdir=outdir,
                external_event=eve_file)
    elif p.ext == 'eeg':
        eeg2fif(filename, interactive=interactive, outdir=outdir)
    elif p.ext in ['edf', 'bdf']:
        bdf2fif(filename, interactive=interactive, outdir=outdir)
    elif p.ext == 'gdf':
        gdf2fif(filename,
                interactive=interactive,
                outdir=outdir,
                channel_file=channel_file)
    elif p.ext == 'xdf':
        xdf2fif(filename, interactive=interactive, outdir=outdir)
    else:  # unknown format
        logger.error(
            'Ignored unrecognized file extension %s. It should be [.pcl | .eeg | .gdf | .bdf]'
            % p.ext)
示例#4
0
def get_decoder_info(classifier):
    """
    Get only the classifier information without connecting to a server

    Params
    ------
        classifier: model file

    Returns
    -------
        info dictionary object
    """

    model = qc.load_obj(classifier)
    if model is None:
        logger.error('>> Error loading %s' % model)
        raise ValueError

    cls = model['cls']
    psde = model['psde']
    labels = list(cls.classes_)
    w_seconds = model['w_seconds']
    w_frames = model['w_frames']
    wstep = model['wstep']
    sfreq = model['sfreq']
    psd_temp = psde.transform(np.zeros((1, len(model['picks']), w_frames)))
    psd_shape = psd_temp.shape
    psd_size = psd_temp.size

    info = dict(labels=labels, cls=cls, psde=psde, w_seconds=w_seconds, w_frames=w_frames,\
                wstep=wstep, sfreq=sfreq, psd_shape=psd_shape, psd_size=psd_size)
    return info
示例#5
0
def get_timelags(epochs, wlen, wstep, downsample=1, picks=None):
    """
    (DEPRECATED FUNCTION)
    Get concatenated timelag features

    TODO: Unit test.

    Input
    =====
    epochs: input signals
    wlen: window length (# time points) in downsampled data
    wstep: window step in downsampled data
    downsample: downsample signal to be 1/downsample length
    picks: ignored for now

    Output
    ======
    X: [epochs] x [windows] x [channels*freqs]
    y: [epochs] x [labels]
    """
    '''
    wlen = int(wlen)
    wstep = int(wstep)
    downsample = int(downsample)
    X_data = None
    y_data = None
    labels = epochs.events[:, -1]  # every epoch must have event id
    epochs_data = epochs.get_data()
    n_channels = epochs_data.shape[1]
    # trim to the nearest divisible length
    epoch_ds_len = int(epochs_data.shape[2] / downsample)
    epoch_len = downsample * epoch_ds_len
    range_epochs = np.arange(epochs_data.shape[0])
    range_channels = np.arange(epochs_data.shape[1])
    range_windows = np.arange(epoch_ds_len - wlen, 0, -wstep)
    X_data = np.zeros((len(range_epochs), len(range_windows), wlen * n_channels))

    # for each epoch
    for ep in range_epochs:
        epoch = epochs_data[ep, :, :epoch_len]
        ds = qc.average_every_n(epoch.reshape(-1), downsample)  # flatten to 1-D, then downsample
        epoch_ds = ds.reshape(n_channels, -1)  # recover structure to channel x samples
        # for each window over all channels
        for i in range(len(range_windows)):
            w = range_windows[i]
            X = epoch_ds[:, w:w + wlen].reshape(1, -1)  # our feature vector
            X_data[ep, i, :] = X

        # fill labels
        y = np.empty((1, len(range_windows)))  # 1 x windows
        y.fill(labels[ep])
        if y_data is None:
            y_data = y
        else:
            y_data = np.concatenate((y_data, y), axis=0)

    return X_data, y_data
    '''
    logger.error('This function is deprecated.')
    raise NotImplementedError
示例#6
0
def rereference(raw, ref_new, ref_old=None):
    """
    Reference to new channels. raw object is modified in-place for efficiency.

    raw: mne.io.RawArray

    ref_new: None | list of str (RawArray) | list of int (numpy array)
        Channel(s) to re-reference, e.g. M1, M2.
        Average of these channel values are substracted from all channel values.

    ref_old: None | str
        Channel to recover, assuming this channel was originally used as a reference.
    """

    # Re-reference and recover the original reference channel values if possible
    if type(raw) == np.ndarray:
        if raw_ch_old is not None:
            logger.error(
                'Recovering original reference channel is not yet supported for numpy arrays.'
            )
            raise NotImplementedError
        if type(raw_ch_new[0]) is not int:
            logger.error('Channels must be integer values for numpy arrays')
            raise ValueError
        raw -= np.mean(raw[ref_new], axis=0)
    else:
        if ref_old is not None:
            # Add a blank (zero-valued) channel
            mne.io.add_reference_channels(raw, ref_old, copy=False)
        # Re-reference
        mne.io.set_eeg_reference(raw, ref_new, copy=False)

    return True
示例#7
0
    def move(self,
             dir,
             dx,
             overlay=False,
             barcolor=None,
             caption='',
             caption_color='W'):
        if barcolor is None:
            if dx == self.xl2:
                c = 'G'
            else:
                c = 'R'
        else:
            c = barcolor

        self.glass.fullbar_color(c)
        color = self.color[c]

        if dir == 'L':
            if self.pc_feedback:
                self.img = self.left_images[dx]
            if self.glass_feedback:
                self.glass.move_bar(dir, dx, overlay)
        elif dir == 'R':
            if self.pc_feedback:
                self.img = self.right_images[dx]
            if self.glass_feedback:
                self.glass.move_bar(dir, dx, overlay)
        else:
            logger.error('Unknown direction %s' % dir)
        self.put_text(caption, caption_color)
        self.update()
示例#8
0
文件: trainer.py 项目: aizmeng/pycnbi
def fit_predict_thres(cls,
                      X_train,
                      Y_train,
                      X_test,
                      Y_test,
                      cnum,
                      label_list,
                      ignore_thres=None,
                      decision_thres=None):
    """
    Any likelihood lower than a threshold is not counted as classification score
    Confusion matrix, accuracy and F1 score (macro average) are computed.

    Params
    ======
    ignore_thres:
    if not None or larger than 0, likelihood values lower than ignore_thres will be ignored
    while computing confusion matrix.

    """
    timer = qc.Timer()
    cls.fit(X_train, Y_train)
    assert ignore_thres is None or ignore_thres >= 0
    if ignore_thres is None or ignore_thres == 0:
        Y_pred = cls.predict(X_test)
        score = skmetrics.accuracy_score(Y_test, Y_pred)
        cm = skmetrics.confusion_matrix(Y_test, Y_pred, label_list)
        f1 = skmetrics.f1_score(Y_test, Y_pred, average='macro')
    else:
        if decision_thres is not None:
            logger.error(
                'decision threshold and ignore_thres cannot be set at the same time.'
            )
            raise ValueError
        Y_pred = cls.predict_proba(X_test)
        Y_pred_labels = np.argmax(Y_pred, axis=1)
        Y_pred_maxes = np.array([x[i] for i, x in zip(Y_pred_labels, Y_pred)])
        Y_index_overthres = np.where(Y_pred_maxes >= ignore_thres)[0]
        Y_index_underthres = np.where(Y_pred_maxes < ignore_thres)[0]
        Y_pred_overthres = np.array(
            [cls.classes_[x] for x in Y_pred_labels[Y_index_overthres]])
        Y_pred_underthres = np.array(
            [cls.classes_[x] for x in Y_pred_labels[Y_index_underthres]])
        Y_pred_underthres_count = np.array(
            [np.count_nonzero(Y_pred_underthres == c) for c in label_list])
        Y_test_overthres = Y_test[Y_index_overthres]
        score = skmetrics.accuracy_score(Y_test_overthres, Y_pred_overthres)
        cm = skmetrics.confusion_matrix(Y_test_overthres, Y_pred_overthres,
                                        label_list)
        cm = np.concatenate((cm, Y_pred_underthres_count[:, np.newaxis]),
                            axis=1)
        f1 = skmetrics.f1_score(Y_test_overthres,
                                Y_pred_overthres,
                                average='macro')

    logger.info('Cross-validation %d (%.3f) - %.1f sec' %
                (cnum, score, timer.sec()))
    return score, cm, f1
示例#9
0
 def set_pin(self, pin):
     if self.lpttype == 'SOFTWARE':
         logger.error('set_pin() not supported for software trigger.')
         return False
     elif self.lpttype == 'FAKE':
         logger.info('FAKE trigger pin %s' % pin)
         return True
     else:
         self.set_data(2 ** (pin - 1))
示例#10
0
文件: trainer.py 项目: aizmeng/pycnbi
def balance_samples(X, Y, balance_type, verbose=False):
    if balance_type == 'OVER':
        """
        Oversample from classes that lack samples
        """
        label_set = np.unique(Y)
        max_set = []
        X_balanced = np.array(X)
        Y_balanced = np.array(Y)

        # find a class with maximum number of samples
        for c in label_set:
            yl = np.where(Y == c)[0]
            if len(max_set) == 0 or len(yl) > max_set[1]:
                max_set = [c, len(yl)]
        for c in label_set:
            if c == max_set[0]: continue
            yl = np.where(Y == c)[0]
            extra_samples = max_set[1] - len(yl)
            extra_idx = np.random.choice(yl, extra_samples)
            X_balanced = np.append(X_balanced, X[extra_idx], axis=0)
            Y_balanced = np.append(Y_balanced, Y[extra_idx], axis=0)
    elif balance_type == 'UNDER':
        """
        Undersample from classes that are excessive
        """
        label_set = np.unique(Y)
        min_set = []

        # find a class with minimum number of samples
        for c in label_set:
            yl = np.where(Y == c)[0]
            if len(min_set) == 0 or len(yl) < min_set[1]:
                min_set = [c, len(yl)]
        yl = np.where(Y == min_set[0])[0]
        X_balanced = np.array(X[yl])
        Y_balanced = np.array(Y[yl])
        for c in label_set:
            if c == min_set[0]: continue
            yl = np.where(Y == c)[0]
            reduced_idx = np.random.choice(yl, min_set[1])
            X_balanced = np.append(X_balanced, X[reduced_idx], axis=0)
            Y_balanced = np.append(Y_balanced, Y[reduced_idx], axis=0)
    elif balance_type is None or balance_type is False:
        return X, Y
    else:
        logger.error('Unknown balancing type %s' % balance_type)
        raise ValueError

    logger.info_green('\nNumber of samples after %ssampling' %
                      balance_type.lower())
    for c in label_set:
        logger.info(
            '%s: %d -> %d' %
            (c, len(np.where(Y == c)[0]), len(np.where(Y_balanced == c)[0])))

    return X_balanced, Y_balanced
示例#11
0
    def move(self,
             dir,
             dx,
             overlay=False,
             barcolor=None,
             caption='',
             caption_color='W'):
        if not overlay:
            self.draw_cue()

        if barcolor is None:
            if dx == self.xl2:
                c = 'G'
            else:
                c = 'R'
        else:
            c = barcolor

        self.glass.fullbar_color(c)
        color = self.color[c]

        if dir == 'L':
            if self.pc_feedback:
                cv2.rectangle(self.img, (self.xl1 - dx, self.yl1),
                              (self.xl1, self.yr1), color, -1)
            if self.glass_feedback:
                self.glass.move_bar(dir, dx, overlay)
        elif dir == 'U':
            if self.pc_feedback:
                cv2.rectangle(self.img, (self.xl1, self.yl1 - dx),
                              (self.xr1, self.yl1), color, -1)
            if self.glass_feedback:
                self.glass.move_bar(dir, dx, overlay)
        elif dir == 'R':
            if self.pc_feedback:
                cv2.rectangle(self.img, (self.xr1, self.yl1),
                              (self.xr1 + dx, self.yr1), color, -1)
            if self.glass_feedback:
                self.glass.move_bar(dir, dx, overlay)
        elif dir == 'D':
            if self.pc_feedback:
                cv2.rectangle(self.img, (self.xl1, self.yr1),
                              (self.xr1, self.yr1 + dx), color, -1)
            if self.glass_feedback:
                self.glass.move_bar(dir, dx, overlay)
        elif dir == 'B':
            if self.pc_feedback:
                cv2.rectangle(self.img, (self.xl1 - dx, self.yl1),
                              (self.xl1, self.yr1), color, -1)
                cv2.rectangle(self.img, (self.xr1, self.yl1),
                              (self.xr1 + dx, self.yr1), color, -1)
            if self.glass_feedback:
                self.glass.move_bar('S', dx, overlay)
        else:
            logger.error('Unknown direction %s' % dir)
        self.put_text(caption, caption_color)
示例#12
0
 def check_connect(self):
     """
     Check connection and automatically connect if not connected
     """
     while not self.connected:
         logger.error(
             'LSL server not connected yet. Trying to connect automatically.'
         )
         self.connect()
         time.sleep(1)
示例#13
0
def get_index_max(seq):
    """
    Get the index of the maximum item in a list or dict
    """
    if type(seq) == list:
        return max(range(len(seq)), key=seq.__getitem__)
    elif type(seq) == dict:
        return max(seq, key=seq.__getitem__)
    else:
        logger.error('Unsupported input %s' % type(seq))
        return None
示例#14
0
def make_dirs(dirname, delete=False):
    """
    Recusively create directories.
    if delete=true, directory will be deleted first if exists.
    """
    if os.path.exists(dirname) and delete == True:
        try:
            shutil.rmtree(dirname)
        except OSError:
            logger.error(
                'Directory was not completely removed. (Perhaps a Dropbox folder?). Continuing.'
            )
    if not os.path.exists(dirname):
        os.makedirs(dirname)
def load_raw(rawfile, spfilter=None, spchannels=None, events_ext=None, multiplier=1, verbose='ERROR'):
    """
    Loads data from a fif-format file.
    You can convert non-fif files (.eeg, .bdf, .gdf, .pcl) to fif format.

    Parameters:
    rawfile: (absolute) data file path
    spfilter: 'car' | 'laplacian' | None
    spchannels: None | list (for CAR) | dict (for LAPLACIAN)
        'car': channel indices used for CAR filtering. If None, use all channels except
               the trigger channel (index 0).
        'laplacian': {channel:[neighbor1, neighbor2, ...], ...}
        *** Note ***
        Since PyCNBI puts trigger channel as index 0, data channel starts from index 1.
    events_ext: Add externally recorded events.
                [ [sample_index1, 0, event_value1],... ]
    multiplier: Multiply all values except triggers (to convert unit).

    Returns:
    raw: mne.io.RawArray object. First channel (index 0) is always trigger channel.
    events: mne-compatible events numpy array object (N x [frame, 0, type])
    spfilter= {None | 'car' | 'laplacian'}

    """

    if not os.path.exists(rawfile):
        logger.error('File %s not found' % rawfile)
        raise IOError
    if not os.path.isfile(rawfile):
        logger.error('%s is not a file' % rawfile)
        raise IOError

    extension = qc.parse_path(rawfile).ext
    assert extension in ['fif', 'fiff'], 'only fif format is supported'
    raw = mne.io.Raw(rawfile, preload=True, verbose=verbose)
    if spfilter is not None or multiplier is not 1:
        preprocess(raw, spatial=spfilter, spatial_ch=spchannels, multiplier=multiplier)
    if events_ext is not None:
        events = mne.read_events(events_ext)
    else:
        tch = find_event_channel(raw)
        if tch is not None:
            events = mne.find_events(raw, stim_channel=raw.ch_names[tch], shortest_event=1, uint_cast=True, consecutive='increasing')
            # MNE's annoying hidden cockroach: first_samp
            events[:, 0] -= raw.first_samp
        else:
            events = np.array([], dtype=np.int64)

    return raw, events
示例#16
0
 def start(self):
     """
     Start the daemon
     """
     if self.is_running() > 0:
         msg = 'Cannot start. Daemon already running. (PID' + ', '.join(['%d' % proc.pid for proc in self.procs]) + ')'
         logger.error(msg)
         return
     for proc in self.procs:
         proc.start()
     if self.wait_init:
         for running in self.running:
             while running.value == 0:
                 time.sleep(0.001)
     logger.info(self.startmsg)
示例#17
0
def convert2mat(filename, matfile):
    """
    Convert to mat using MATLAB BioSig sload().
    """
    basename = '.'.join(filename.split('.')[:-1])
    # extension= filename.split('.')[-1]
    matfile = basename + '.mat'
    if not os.path.exists(matfile):
        logger.info('Converting input to mat file')
        run = "[sig,header]=sload('%s'); save('%s.mat','sig','header');" % (
            filename, basename)
        qc.matlab(run)
        if not os.path.exists(matfile):
            logger.error('mat file convertion error.')
            sys.exit()
示例#18
0
 def set_data(self, value):
     if self.lpttype == 'SOFTWARE':
         logger.error('set_data() not supported for software trigger.')
         return False
     elif self.lpttype == 'FAKE':
         logger.info('FAKE trigger value %s' % value)
         return True
     else:
         if self.lpttype == 'USB2LPT':
             self.lpt.setdata(value)
         elif self.lpttype == 'DESKTOP':
             self.lpt.setdata(self.portaddr, value)
         elif self.lpttype == 'ARDUINO':
             self.ser.write(bytes([value]))
         else:
             raise RuntimeError('Wrong trigger device')
示例#19
0
    def reset(self):
        """
        Reset classifier to the initial state
        """
        # share numpy array self.psd between processes.
        # to compute the shared memory size, we need to create a temporary decoder object.
        if self.fake == True:
            psd_size = None
            psd_shape = None
            psd_ctypes = None
            self.psd = None
        else:
            info = get_decoder_info(self.classifier)
            psd_size = info['psd_size']
            psd_shape = info['psd_shape'][1:]  # we get only the last window
            psd_ctypes = sharedctypes.RawArray('d', np.zeros(psd_size))
            self.psd = np.frombuffer(psd_ctypes, dtype=np.float64, count=psd_size)

        self.probs = mp.Array('d', [1.0 / len(self.labels)] * len(self.labels))
        self.probs_smooth = mp.Array('d', [1.0 / len(self.labels)] * len(self.labels))
        self.pread = mp.Value('i', 1)
        self.t_problast = mp.Value('d', 0)
        self.return_psd = mp.Value('i', 0)
        self.procs = []
        mp.freeze_support()

        if self.parallel:
            logger.error('Parallel decoding is under a rigorous test. Please do not use it for now.')
            raise NotImplementedError
            num_strides = self.parallel['num_strides']
            period = self.parallel['period']
            self.running = [mp.Value('i', 0)] * num_strides
            if num_strides > 1:
                stride = period / num_strides
            else:
                stride = 0
            t_start = time.time()
            for i in range(num_strides):
                self.procs.append(mp.Process(target=self.daemon, args=\
                    [self.classifier, self.probs, self.probs_smooth, self.pread, self.t_problast,\
                     self.running[i], self.return_psd, psd_ctypes, self.psdlock,\
                     dict(t_start=(t_start+i*stride), period=period)]))
        else:
            self.running = [mp.Value('i', 0)]
            self.procs = [mp.Process(target=self.daemon, args=\
                [self.classifier, self.probs, self.probs_smooth, self.pread, self.t_problast,\
                 self.running[0], self.return_psd, psd_ctypes, self.psdlock, None])]
示例#20
0
    def init(self, duration):
        if self.lpttype == 'SOFTWARE':
            logger.info('Ignoring delay parameter for software trigger.')
            return True
        elif self.lpttype == 'FAKE':
            return True
        else:
            self.delay = duration / 1000.0

            if self.lpttype in ['DESKTOP', 'USB2LPT']:
                if self.lpt.init() == -1:
                    logger.error('Connecting to LPT port failed. Check the driver status.')
                    self.lpt = None
                    return False

            self.action = False
            self.offtimer = threading.Timer(self.delay, self.signal_off)
            return True
示例#21
0
def check_config(cfg):
    critical_vars = {
        'COMMON': [
            'TRIGGER_DEVICE', 'TRIGGER_FILE', 'SCREEN_SIZE', 'DIRECTIONS',
            'DIR_RANDOM', 'TRIALS_EACH'
        ],
        'TIMINGS': [
            'INIT', 'GAP', 'CUE', 'READY', 'READY_RANDOMIZE', 'DIR',
            'DIR_RANDOMIZE'
        ]
    }
    optional_vars = {
        'FEEDBACK_TYPE': 'BAR',
        'FEEDBACK_IMAGE_PATH': None,
        'SCREEN_POS': (0, 0),
        'DIR_RANDOM': True,
        'GLASS_USE': False,
        'TRIAL_PAUSE': False,
        'REFRESH_RATE': 30
    }

    for key in critical_vars['COMMON']:
        if not hasattr(cfg, key):
            raise RuntimeError('%s is a required parameter' % key)

    if not hasattr(cfg, 'TIMINGS'):
        logger.error('"TIMINGS" not defined in config.')
        raise RuntimeError
    for v in critical_vars['TIMINGS']:
        if v not in cfg.TIMINGS:
            logger.error('%s not defined in config.' % v)
            raise RuntimeError

    for key in optional_vars:
        if not hasattr(cfg, key):
            setattr(cfg, key, optional_vars[key])
            logger.warning('Setting undefined %s=%s' % (key, optional[key]))

    if getattr(cfg, 'TRIGGER_DEVICE') == None:
        logger.warning(
            'The trigger device is set to None! No events will be saved.')
        raise RuntimeError(
            'The trigger device is set to None! No events will be saved.')
示例#22
0
def bdf2fif(filename, interactive=False, outdir=None):
    """
    EDF or BioSemi BDF format
    """
    # convert to mat using MATLAB (MNE's edf reader has an offset bug)
    fdir, fname, fext = qc.parse_path_list(filename)
    if outdir is None:
        outdir = fdir
    elif outdir[-1] != '/':
        outdir += '/'

    fiffile = outdir + fname + '.fif'
    raw = mne.io.read_raw_edf(filename, preload=True)

    # process event channel
    if raw.info['chs'][-1]['ch_name'] != 'STI 014':
        logger.error(
            "The last channel (%s) doesn't seem to be an event channel. Entering debugging mode."
            % raw.info['chs'][-1]['ch_name'])
        pdb.set_trace()
    raw.info['chs'][-1]['ch_name'] = 'TRIGGER'
    events = mne.find_events(raw,
                             stim_channel='TRIGGER',
                             shortest_event=1,
                             uint_cast=True,
                             consecutive=True)
    events[:, 2] -= events[:, 1]  # set offset to 0
    events[:, 1] = 0
    # move the event channel to index 0 (for consistency)
    raw._data = np.concatenate(
        (raw._data[-1, :].reshape(1, -1), raw._data[:-1, :]))
    raw._data[0] *= 0  # init the event channel
    raw.info['chs'] = [raw.info['chs'][-1]] + raw.info['chs'][:-1]

    # add events
    raw.add_events(events, 'TRIGGER')

    # save and close
    raw.save(fiffile, verbose=False, overwrite=True, fmt='double')
    logger.info('Saved to %s' % fiffile)

    saveChannels2txt(outdir, ch_names)
def check_config(cfg):
    """
    Ensure that the config file contains the parameters
    """
    critical_vars = {
        'COMMON': ['DATA_PATH'],
    }
    optional_vars = {}

    for key in critical_vars['COMMON']:
        if not hasattr(cfg, key):
            logger.error('%s is a required parameter' % key)
            raise RuntimeError

    for key in optional_vars:
        if not hasattr(cfg, key):
            setattr(cfg, key, optional_vars[key])
            logger.warning('Setting undefined parameter %s=%s' %
                           (key, getattr(cfg, key)))

    return cfg
示例#24
0
    def print_c(msg, color=None, end='\n'):
        """
        Colored print using colorama.

        Fullset:
            https://pypi.python.org/pypi/colorama
            Fore: BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, RESET.
            Back: BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, RESET.
            Style: DIM, NORMAL, BRIGHT, RESET_ALL

        TODO:
            Make it using *args and **kwargs to support fully featured print().

        """
        if color is None:
            print(str(msg), end=end)
            return
        color = str(color)
        if len(color) != 1:
            raise ValueError(
                'Color parameter must be a single color code, not %s' %
                type(color))
        if color.upper() == 'B':
            c = colorama.Fore.BLUE
        elif color.upper() == 'R':
            c = colorama.Fore.RED
        elif color.upper() == 'G':
            c = colorama.Fore.GREEN
        elif color.upper() == 'Y':
            c = colorama.Fore.YELLOW
        elif color.upper() == 'W':
            c = colorama.Fore.WHITE
        elif color.upper() == 'C':
            c = colorama.Fore.CYAN
        else:
            logger.error('print_c(): Unknown color code %s' % color)
            raise ValueError
        print(colorama.Style.BRIGHT + c + str(msg) + colorama.Style.RESET_ALL,
              end=end)
示例#25
0
def run(record_dir, amp_name, amp_serial, eeg_only=False):
    logger.info('\nOutput directory: %s' % (record_dir))

    # spawn the recorder as a child process
    logger.info('\n>> Press Enter to start recording.')
    key = input()
    state = mp.Value('i', 1)
    proc = mp.Process(target=record,
                      args=[state, amp_name, amp_serial, record_dir, eeg_only])
    proc.start()

    # clean up
    time.sleep(1)  # required on some Python distribution
    input()
    state.value = 0
    logger.info('(main) Waiting for recorder process to finish.')
    proc.join(10)
    if proc.is_alive():
        logger.error(
            'Recorder process not finishing. Are you running from Spyder?')
        logger.error('Dropping into a shell')
        qc.shell()
    sys.stdout.flush()
    logger.info('Recording finished.')
示例#26
0
    def __init__(self,
                 window_size=1,
                 buffer_size=1,
                 amp_serial=None,
                 eeg_only=False,
                 amp_name=None):
        """
        Params:
            window_size (in seconds): keep the latest window_size seconds of the buffer.
            buffer_size (in seconds): 1-day is the maximum size. Large buffer may lead to a delay if not pulled frequently.
            amp_name: connect to a server named 'amp_name'. None: no constraint.
            amp_serial: connect to a server with serial number 'amp_serial'. None: no constraint.
            eeg_only: ignore non-EEG servers
        """
        _MAX_BUFFER_SIZE = 86400  # max buffer size allowed by StreamReceiver (24 hours)
        _MAX_PYLSL_STREAM_BUFSIZE = 360  # max buffer size for pylsl.StreamInlet

        if window_size <= 0:
            logger.error('Wrong window_size %d.' % window_size)
            raise ValueError()
        self.winsec = window_size
        if buffer_size == 0:
            buffer_size = _MAX_BUFFER_SIZE
        elif buffer_size < 0 or buffer_size > _MAX_BUFFER_SIZE:
            logger.error('Improper buffer size %.1f. Setting to %d.' %
                         (buffer_size, _MAX_BUFFER_SIZE))
            buffer_size = _MAX_BUFFER_SIZE
        elif buffer_size < self.winsec:
            logger.error(
                'Buffer size %.1f is smaller than window size. Setting to %.1f.'
                % (buffer_size, self.winsec))
            buffer_size = self.winsec
        self.bufsec = buffer_size
        self.bufsize = 0  # to be calculated using sampling rate
        self.stream_bufsec = int(
            math.ceil(min(_MAX_PYLSL_STREAM_BUFSIZE, self.bufsec)))
        self.stream_bufsize = 0  # to be calculated using sampling rate
        self.amp_serial = amp_serial
        self.eeg_only = eeg_only
        self.amp_name = amp_name
        self.tr_channel = None  # trigger indx used by StreamReceiver class
        self.eeg_channels = []  # signal indx used by StreamReceiver class
        self._lsl_tr_channel = None  # raw trigger indx in pylsl.pull_chunk()
        self._lsl_eeg_channels = []  # raw signal indx in pylsl.pull_chunk()
        self.ready = False  # False until the buffer is filled for the first time
        self.connected = False
        self.buffers = []
        self.timestamps = []
        self.watchdog = qc.Timer()
        self.multiplier = 1  # 10**6 for uV unit (automatically updated for openvibe servers)

        self.connect()
示例#27
0
def load_multi(src, spfilter=None, spchannels=None, multiplier=1):
    """
    Load multiple data files and concatenate them into a single series

    - Assumes all files have the same sampling rate and channel order.
    - Event locations are updated accordingly with new offset.

    @params:
        src: directory or list of files.
        spfilter: apply spatial filter while loading.
        spchannels: list of channel names to apply spatial filter.
        multiplier: to change units for better numerical stability.

    See load_raw() for more low-level details.

    """

    if type(src) == str:
        if not os.path.isdir(src):
            logger.error('%s is not a directory or does not exist.' % src)
            raise IOError
        flist = []
        for f in qc.get_file_list(src):
            if qc.parse_path_list(f)[2] == 'fif':
                flist.append(f)
    elif type(src) in [list, tuple]:
        flist = src
    else:
        logger.error('Unknown input type %s' % type(src))
        raise TypeError

    if len(flist) == 0:
        logger.error('load_multi(): No fif files found in %s.' % src)
        raise RuntimeError
    elif len(flist) == 1:
        return load_raw(flist[0],
                        spfilter=spfilter,
                        spchannels=spchannels,
                        multiplier=multiplier)

    # load raw files
    rawlist = []
    for f in flist:
        logger.info('Loading %s' % f)
        raw, _ = load_raw(f,
                          spfilter=spfilter,
                          spchannels=spchannels,
                          multiplier=multiplier)
        rawlist.append(raw)

    # concatenate signals
    signals = None
    for raw in rawlist:
        if signals is None:
            signals = raw._data
        else:
            signals = np.concatenate((signals, raw._data),
                                     axis=1)  # append samples

    # create a concatenated raw object and update channel names
    raw = rawlist[0]
    trigch = find_event_channel(raw)
    ch_types = ['eeg'] * len(raw.ch_names)
    if trigch is not None:
        ch_types[trigch] = 'stim'
    info = mne.create_info(raw.ch_names, raw.info['sfreq'], ch_types)
    raw_merged = mne.io.RawArray(signals, info)

    # re-calculate event positions
    events = mne.find_events(raw_merged,
                             stim_channel='TRIGGER',
                             shortest_event=1,
                             uint_cast=True,
                             consecutive='increasing',
                             output='onset',
                             initial_event=True)

    return raw_merged, events
示例#28
0
def preprocess(raw,
               sfreq=None,
               spatial=None,
               spatial_ch=None,
               spectral=None,
               spectral_ch=None,
               notch=None,
               notch_ch=None,
               multiplier=1,
               ch_names=None,
               rereference=None,
               decim=None,
               n_jobs=1):
    """
    Apply spatial, spectral, notch filters and convert unit.
    raw is modified in-place.

    Input
    ------
    raw: mne.io.Raw | mne.io.RawArray | mne.Epochs | numpy.array (n_channels x n_samples)
         numpy.array type assumes the data has only pure EEG channnels without event channels

    sfreq: required only if raw is numpy array.

    spatial: None | 'car' | 'laplacian'
        Spatial filter type.

    spatial_ch: None | list (for CAR) | dict (for LAPLACIAN)
        Reference channels for spatial filtering. May contain channel names.
        'car': channel indices used for CAR filtering. If None, use all channels except
               the trigger channel (index 0).
        'laplacian': {channel:[neighbor1, neighbor2, ...], ...}
        *** Note ***
        Since PyCNBI puts trigger channel as index 0, data channel starts from index 1.

    spectral: None | [l_freq, h_freq]
        Spectral filter.
        if l_freq is None: lowpass filter is applied.
        if h_freq is None: highpass filter is applied.
        if l_freq < h_freq: bandpass filter is applied.
        if l_freq > h_freq: band-stop filter is applied.

    spectral_ch: None | list
        Channel picks for spectra filtering. May contain channel names.

    notch: None | float | list of frequency in floats
        Notch filter.

    notch_ch: None | list
        Channel picks for notch filtering. May contain channel names.

    multiplier: float
        If not 1, multiply data values excluding trigger values.

    ch_names: None | list
        If raw is numpy array and channel picks are list of strings, ch_names will
        be used as a look-up table to convert channel picks to channel numbers.

    rereference: Not supported yet.

    decim: None | int
        Apply low-pass filter and decimate (downsample). sfreq must be given. Ignored if 1.

    Output
    ------
    Same input data structure.

    Note: To save computation time, input data may be modified in-place.
    TODO: Add an option to disable in-place modification.
    """

    # Check datatype
    if type(raw) == np.ndarray:
        # Numpy array: assume we don't have event channel
        data = raw
        assert sfreq is not None and sfreq > 0, 'Wrong sfreq value.'
        assert 2 <= len(
            data.shape
        ) <= 3, 'Unknown data shape. The dimension must be 2 or 3.'
        if len(data.shape) == 3:
            n_channels = data.shape[1]
        elif len(data.shape) == 2:
            n_channels = data.shape[0]
        eeg_channels = list(range(n_channels))
        if decim is not None and decim != 1:
            if sfreq is None:
                logger.error('Decimation cannot be applied if sfreq is None.')
                raise ValueError
    else:
        # MNE Raw object: exclude event channel
        ch_names = raw.ch_names
        data = raw._data
        sfreq = raw.info['sfreq']
        assert 2 <= len(
            data.shape
        ) <= 3, 'Unknown data shape. The dimension must be 2 or 3.'
        if len(data.shape) == 3:
            # assert type(raw) is mne.epochs.Epochs
            n_channels = data.shape[1]
        elif len(data.shape) == 2:
            n_channels = data.shape[0]
        eeg_channels = list(range(n_channels))
        tch = find_event_channel(raw)
        if tch is None:
            logger.warning('No trigger channel found. Using all channels.')
        else:
            tch_name = ch_names[tch]
            eeg_channels.pop(tch)

    # Re-reference channels
    if rereference is not None:
        logger.error('re-referencing not implemented yet. Sorry.')
        raise NotImplementedError

    # Do unit conversion
    if multiplier != 1:
        data[eeg_channels] *= multiplier

    # Apply spatial filter
    if spatial is None:
        pass
    elif spatial == 'car':
        if spatial_ch is None:
            spatial_ch = eeg_channels

        if type(spatial_ch[0]) == str:
            assert ch_names is not None, 'preprocess(): ch_names must not be None'
            spatial_ch_i = [ch_names.index(c) for c in spatial_ch]
        else:
            spatial_ch_i = spatial_ch

        if len(spatial_ch_i) > 1:
            if len(data.shape) == 2:
                data[spatial_ch_i] -= np.mean(data[spatial_ch_i], axis=0)
            elif len(data.shape) == 3:
                means = np.mean(data[:, spatial_ch_i, :], axis=1)
                data[:, spatial_ch_i, :] -= means[:, np.newaxis, :]
            else:
                logger.error('Unknown data shape %s' % str(data.shape))
                raise ValueError
    elif spatial == 'laplacian':
        if type(spatial_ch) is not dict:
            logger.error(
                'preprocess(): For Lapcacian, spatial_ch must be of form {CHANNEL:[NEIGHBORS], ...}'
            )
            raise TypeError
        if type(spatial_ch.keys()[0]) == str:
            spatial_ch_i = {}
            for c in spatial_ch:
                ref_ch = ch_names.index(c)
                spatial_ch_i[ref_ch] = [
                    ch_names.index(n) for n in spatial_ch[c]
                ]
        else:
            spatial_ch_i = spatial_ch

        if len(spatial_ch_i) > 1:
            rawcopy = data.copy()
            for src in spatial_ch:
                nei = spatial_ch[src]
                if len(data.shape) == 2:
                    data[src] = rawcopy[src] - np.mean(rawcopy[nei], axis=0)
                elif len(data.shape) == 3:
                    data[:, src, :] = rawcopy[:, src, :] - np.mean(
                        rawcopy[:, nei, :], axis=1)
                else:
                    logger.error('preprocess(): Unknown data shape %s' %
                                 str(data.shape))
                    raise ValueError
    else:
        logger.error('preprocess(): Unknown spatial filter %s' % spatial)
        raise ValueError

    # Downsample
    if decim is not None and decim != 1:
        if type(raw) == np.ndarray:
            data = mne.filter.resample(data,
                                       down=decim,
                                       npad='auto',
                                       window='boxcar',
                                       n_jobs=1)
        else:
            # resample() of Raw* and Epochs object internally calls mne.filter.resample()
            raw = raw.resample(raw.info['sfreq'] / decim,
                               npad='auto',
                               window='boxcar',
                               n_jobs=1)
            data = raw._data
        sfreq /= decim

    # Apply spectral filter
    if spectral is not None:
        if spectral_ch is None:
            spectral_ch = eeg_channels

        if type(spectral_ch[0]) == str:
            assert ch_names is not None, 'preprocess(): ch_names must not be None'
            spectral_ch_i = [ch_names.index(c) for c in spectral_ch]
        else:
            spectral_ch_i = spectral_ch

        # fir_design='firwin' is especially important for ICA analysis. See:
        # http://martinos.org/mne/dev/generated/mne.preprocessing.ICA.html?highlight=score_sources#mne.preprocessing.ICA.score_sources
        mne.filter.filter_data(data,
                               sfreq,
                               spectral[0],
                               spectral[1],
                               picks=spectral_ch_i,
                               filter_length='auto',
                               l_trans_bandwidth='auto',
                               h_trans_bandwidth='auto',
                               n_jobs=n_jobs,
                               method='fir',
                               iir_params=None,
                               copy=False,
                               phase='zero',
                               fir_window='hamming',
                               fir_design='firwin',
                               verbose='ERROR')

    # Apply notch filter
    if notch is not None:
        if notch_ch is None:
            notch_ch = eeg_channels

        if type(notch_ch[0]) == str:
            assert ch_names is not None, 'preprocess(): ch_names must not be None'
            notch_ch_i = [ch_names.index(c) for c in notch_ch]
        else:
            notch_ch_i = notch_ch

        mne.filter.notch_filter(data,
                                Fs=sfreq,
                                freqs=notch,
                                notch_widths=5,
                                picks=notch_ch_i,
                                method='fft',
                                n_jobs=n_jobs,
                                copy=False)

    if type(raw) == np.ndarray:
        raw = data
    return raw
示例#29
0
文件: trainer.py 项目: aizmeng/pycnbi
def train_decoder(cfg, featdata, feat_file=None):
    """
    Train the final decoder using all data
    """
    # Init a classifier
    selected_classifier = cfg.CLASSIFIER['selected']
    if selected_classifier == 'GB':
        cls = GradientBoostingClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER[selected_classifier]['learning_rate'],
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            random_state=cfg.CLASSIFIER[selected_classifier]['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'XGB':
        cls = XGBClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER[selected_classifier]['learning_rate'],
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            random_state=cfg.GB['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'RF':
        cls = RandomForestClassifier(
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            max_features='auto',
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            n_jobs=cfg.N_JOBS,
            random_state=cfg.CLASSIFIER[selected_classifier]['seed'],
            oob_score=False,
            class_weight='balanced_subsample')
    elif selected_classifier == 'LDA':
        cls = LDA()
    elif selected_classifier == 'rLDA':
        cls = rLDA(cfg.CLASSIFIER[selected_classifier][r_coeff])
    else:
        logger.error('Unknown classifier %s' % selected_classifier)
        raise ValueError

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']
    if cfg.FEATURES['PSD']['wlen'] is None:
        cfg.FEATURES['PSD']['wlen'] = wlen
    w_frames = featdata['w_frames']
    ch_names = featdata['ch_names']
    X_data_merged = np.concatenate(X_data)
    Y_data_merged = np.concatenate(Y_data)
    if cfg.CV['BALANCE_SAMPLES']:
        X_data_merged, Y_data_merged = balance_samples(
            X_data_merged,
            Y_data_merged,
            cfg.CV['BALANCE_SAMPLES'],
            verbose=True)

    # Start training the decoder
    logger.info_green('Training the decoder')
    timer = qc.Timer()
    cls.n_jobs = cfg.N_JOBS
    cls.fit(X_data_merged, Y_data_merged)
    logger.info('Trained %d samples x %d dimension in %.1f sec' %\
          (X_data_merged.shape[0], X_data_merged.shape[1], timer.sec()))
    cls.n_jobs = 1  # always set n_jobs=1 for testing

    # Export the decoder
    classes = {c: cfg.tdef.by_value[c] for c in np.unique(Y_data)}
    if cfg.FEATURES['selected'] == 'PSD':
        data = dict(cls=cls,
                    ch_names=ch_names,
                    psde=featdata['psde'],
                    sfreq=featdata['sfreq'],
                    picks=featdata['picks'],
                    classes=classes,
                    epochs=cfg.EPOCH,
                    w_frames=w_frames,
                    w_seconds=cfg.FEATURES['PSD']['wlen'],
                    wstep=cfg.FEATURES['PSD']['wstep'],
                    spatial=cfg.SP_FILTER,
                    spatial_ch=featdata['picks'],
                    spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']],
                    spectral_ch=featdata['picks'],
                    notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']],
                    notch_ch=featdata['picks'],
                    multiplier=cfg.MULTIPLIER,
                    ref_ch=cfg.REREFERENCE[cfg.REREFERENCE['selected']],
                    decim=cfg.FEATURES['PSD']['decim'])
    clsfile = '%s/classifier/classifier-%s.pkl' % (cfg.DATA_PATH,
                                                   platform.architecture()[0])
    qc.make_dirs('%s/classifier' % cfg.DATA_PATH)
    qc.save_obj(clsfile, data)
    logger.info('Decoder saved to %s' % clsfile)

    # Reverse-lookup frequency from FFT
    fq = 0
    if type(cfg.FEATURES['PSD']['wlen']) == list:
        fq_res = 1.0 / cfg.FEATURES['PSD']['wlen'][0]
    else:
        fq_res = 1.0 / cfg.FEATURES['PSD']['wlen']
    fqlist = []
    while fq <= cfg.FEATURES['PSD']['fmax']:
        if fq >= cfg.FEATURES['PSD']['fmin']:
            fqlist.append(fq)
        fq += fq_res

    # Show top distinctive features
    if cfg.FEATURES['selected'] == 'PSD':
        logger.info_green('Good features ordered by importance')
        if selected_classifier in ['RF', 'GB', 'XGB']:
            keys, values = qc.sort_by_value(list(cls.feature_importances_),
                                            rev=True)
        elif selected_classifier in ['LDA', 'rLDA']:
            keys, values = qc.sort_by_value(cls.w, rev=True)
        keys = np.array(keys)
        values = np.array(values)

        if cfg.EXPORT_GOOD_FEATURES:
            if feat_file is None:
                gfout = open('%s/classifier/good_features.txt' % cfg.DATA_PATH,
                             'w')
            else:
                gfout = open(feat_file, 'w')

        if type(wlen) is not list:
            ch_names = [ch_names[c] for c in featdata['picks']]
        else:
            ch_names = []
            for w in range(len(wlen)):
                for c in featdata['picks']:
                    ch_names.append('w%d-%s' % (w, ch_names[c]))

        chlist, hzlist = features.feature2chz(keys, fqlist, ch_names=ch_names)
        valnorm = values[:cfg.FEAT_TOPN].copy()
        valsum = np.sum(valnorm)
        if valsum == 0:
            valsum = 1
        valnorm = valnorm / valsum * 100.0

        # show top-N features
        for i, (ch, hz) in enumerate(zip(chlist, hzlist)):
            if i >= cfg.FEAT_TOPN:
                break
            txt = '%-3s %5.1f Hz  normalized importance %-6s  raw importance %-6s  feature %-5d' %\
                  (ch, hz, '%.2f%%' % valnorm[i], '%.2f%%' % (values[i] * 100.0), keys[i])
            logger.info(txt)

        if cfg.EXPORT_GOOD_FEATURES:
            gfout.write('Importance(%) Channel Frequency Index\n')
            for i, (ch, hz) in enumerate(zip(chlist, hzlist)):
                gfout.write('%.3f\t%s\t%s\t%d\n' %
                            (values[i] * 100.0, ch, hz, keys[i]))
            gfout.close()
示例#30
0
文件: trainer.py 项目: aizmeng/pycnbi
def check_config(cfg):
    critical_vars = {
        'COMMON': [
            'TRIGGER_FILE', 'TRIGGER_DEF', 'EPOCH', 'DATA_PATH',
            'PICKED_CHANNELS', 'SP_FILTER', 'SP_CHANNELS', 'TP_FILTER',
            'NOTCH_FILTER', 'FEATURES', 'CLASSIFIER', 'CV_PERFORM'
        ],
        'RF': ['trees', 'depth', 'seed'],
        'GB': ['trees', 'learning_rate', 'depth', 'seed'],
        'LDA': [],
        'rLDA': ['r_coeff'],
        'StratifiedShuffleSplit':
        ['test_ratio', 'folds', 'seed', 'export_result'],
        'LeaveOneOut': ['export_result']
    }

    # optional variables with default values
    optional_vars = {
        'MULTIPLIER': 1,
        'EXPORT_GOOD_FEATURES': False,
        'FEAT_TOPN': 10,
        'EXPORT_CLS': False,
        'REREFERENCE': None,
        'N_JOBS': None,
        'EXCLUDED_CHANNELS': None,
        'LOAD_EVENTS': None,
        'CV': {
            'IGNORE_THRES': None,
            'DECISION_THRES': None,
            'BALANCE_SAMPLES': False
        },
    }

    for v in critical_vars['COMMON']:
        if not hasattr(cfg, v):
            logger.error('%s not defined in config.' % v)
            raise RuntimeError

    for key in optional_vars:
        if not hasattr(cfg, key):
            setattr(cfg, key, optional_vars[key])
            logger.warning('Setting undefined parameter %s=%s' %
                           (key, getattr(cfg, key)))

    if 'decim' not in cfg.FEATURES['PSD']:
        cfg.FEATURES['PSD']['decim'] = 1

    # classifier parameters check
    selected_classifier = cfg.CLASSIFIER[cfg.CLASSIFIER['selected']]

    if selected_classifier == 'RF':
        if 'RF' not in cfg.CLASSIFIER:
            logger.error('"RF" not defined in config.')
            raise RuntimeError
        for v in critical_vars['RF']:
            if v not in cfg.CLASSIFIER['RF']:
                logger.error('%s not defined in config.' % v)
                raise RuntimeError

    elif selected_classifier == 'GB' or selected_classifier == 'XGB':
        if 'GB' not in cfg.CLASSIFIER:
            logger.error('"GB" not defined in config.')
            raise RuntimeError
        for v in critical_vars['GB']:
            if v not in cfg.CLASSIFIER[selected_classifier]:
                logger.error('%s not defined in config.' % v)
                raise RuntimeError

    elif selected_classifier == 'rLDA':
        if 'rLDA' not in cfg.CLASSIFIER:
            logger.error('"rLDA" not defined in config.')
            raise RuntimeError
        for v in critical_vars['rLDA']:
            if v not in cfg.CLASSIFIER['rLDA']:
                logger.error('%s not defined in config.' % v)
                raise RuntimeError

    cv_selected = cfg.CV_PERFORM['selected']
    if cfg.CV_PERFORM[cv_selected] is not None:
        if cv_selected == 'StratifiedShuffleSplit':
            if 'StratifiedShuffleSplit' not in cfg.CV_PERFORM:
                logger.error('"StratifiedShuffleSplit" not defined in config.')
                raise RuntimeError
            for v in critical_vars['StratifiedShuffleSplit']:
                if v not in cfg.CV_PERFORM[cv_selected]:
                    logger.error('%s not defined in config.' % v)
                    raise RuntimeError

        elif cv_selected == 'LeaveOneOut':
            if 'LeaveOneOut' not in cfg.CV_PERFORM:
                logger.error('"LeaveOneOut" not defined in config.')
                raise RuntimeError
            for v in critical_vars['LeaveOneOut']:
                if v not in cfg.CV_PERFORM[cv_selected]:
                    logger.error('%s not defined in config.' % v)
                    raise RuntimeError

    if cfg.N_JOBS is None:
        cfg.N_JOBS = mp.cpu_count()

    return cfg