Ejemplo n.º 1
0
def any2fif(filename, interactive=False, outdir=None, channel_file=None):
    """
    Generic file format converter
    """
    p = qc.parse_path(filename)
    if outdir is not None:
        qc.make_dirs(outdir)

    if p.ext == 'pcl':
        eve_file = '%s/%s.txt' % (p.dir, p.name.replace('raw', 'eve'))
        if os.path.exists(eve_file):
            logger.info('Adding events from %s' % eve_file)
        else:
            eve_file = None
        pcl2fif(filename,
                interactive=interactive,
                outdir=outdir,
                external_event=eve_file)
    elif p.ext == 'eeg':
        eeg2fif(filename, interactive=interactive, outdir=outdir)
    elif p.ext in ['edf', 'bdf']:
        bdf2fif(filename, interactive=interactive, outdir=outdir)
    elif p.ext == 'gdf':
        gdf2fif(filename,
                interactive=interactive,
                outdir=outdir,
                channel_file=channel_file)
    elif p.ext == 'xdf':
        xdf2fif(filename, interactive=interactive, outdir=outdir)
    else:  # unknown format
        logger.error(
            'Ignored unrecognized file extension %s. It should be [.pcl | .eeg | .gdf | .bdf]'
            % p.ext)
Ejemplo n.º 2
0
def record(state, amp_name, amp_serial, record_dir, eeg_only):
    # set data file name
    filename = time.strftime(record_dir + "/%Y%m%d-%H%M%S-raw.pcl",
                             time.localtime())
    qc.print_c('>> Output file: %s' % (filename), 'W')

    # test writability
    try:
        qc.make_dirs(record_dir)
        open(
            filename,
            'w').write('The data will written when the recording is finished.')
    except:
        raise RuntimeError('Problem writing to %s. Check permission.' %
                           filename)

    # start a server for sending out data filename when software trigger is used
    outlet = start_server('StreamRecorderInfo', channel_format='string',\
        source_id=filename, stype='Markers')

    # connect to EEG stream server
    sr = StreamReceiver(amp_name=amp_name,
                        amp_serial=amp_serial,
                        eeg_only=eeg_only)

    # start recording
    qc.print_c('\n>> Recording started (PID %d).' % os.getpid(), 'W')
    qc.print_c('\n>> Press Enter to stop recording', 'G')
    tm = qc.Timer(autoreset=True)
    next_sec = 1
    while state.value == 1:
        sr.acquire()
        if sr.get_buflen() > next_sec:
            duration = str(datetime.timedelta(seconds=int(sr.get_buflen())))
            print('RECORDING %s' % duration)
            next_sec += 1
        tm.sleep_atleast(0.01)

    # record stop
    qc.print_c('>> Stop requested. Copying buffer', 'G')
    buffers, times = sr.get_buffer()
    signals = buffers
    events = None

    # channels = total channels from amp, including trigger channel
    data = {
        'signals': signals,
        'timestamps': times,
        'events': events,
        'sample_rate': sr.get_sample_rate(),
        'channels': sr.get_num_channels(),
        'ch_names': sr.get_channel_names()
    }
    qc.print_c('Saving raw data ...', 'W')
    qc.save_obj(filename, data)
    print('Saved to %s\n' % filename)

    qc.print_c('Converting raw file into a fif format.', 'W')
    pcl2fif(filename)
Ejemplo n.º 3
0
def plot_grand_avg(epochs, outdir, picks=None):
    qc.make_dirs(outdir)
    for ev in epochs.event_id.keys():
        epavg = epochs[ev].average(picks)
        for chidx in range(len(epavg.ch_names)):
            ch = epavg.ch_names[chidx]
            epavg.plot(picks=[chidx], unit=True, ylim=dict(eeg=[-10, 10]), titles='%s-%s' % (ch, ev),\
                       show=False, scalings={'eeg':1}).savefig(outdir + '/%s-%s.png' % (ch, ev))
def record(recordState, amp_name, amp_serial, record_dir, eeg_only, recordLogger=logger, queue=None):

    redirect_stdout_to_queue(recordLogger, queue, 'INFO')

    # set data file name
    timestamp = time.strftime('%Y%m%d-%H%M%S', time.localtime())
    pcl_file = "%s/%s-raw.pcl" % (record_dir, timestamp)
    eve_file = '%s/%s-eve.txt' % (record_dir, timestamp)
    recordLogger.info('>> Output file: %s' % (pcl_file))

    # test writability
    try:
        qc.make_dirs(record_dir)
        open(pcl_file, 'w').write('The data will written when the recording is finished.')
    except:
        raise RuntimeError('Problem writing to %s. Check permission.' % pcl_file)

    # start a server for sending out data pcl_file when software trigger is used
    outlet = start_server('StreamRecorderInfo', channel_format='string',\
        source_id=eve_file, stype='Markers')

    # connect to EEG stream server
    sr = StreamReceiver(buffer_size=0, amp_name=amp_name, amp_serial=amp_serial, eeg_only=eeg_only)

    # start recording
    recordLogger.info('\n>> Recording started (PID %d).' % os.getpid())
    qc.print_c('\n>> Press Enter to stop recording', 'G')
    tm = qc.Timer(autoreset=True)
    next_sec = 1
    while recordState.value == 1:
        sr.acquire()
        if sr.get_buflen() > next_sec:
            duration = str(datetime.timedelta(seconds=int(sr.get_buflen())))
            recordLogger.info('RECORDING %s' % duration)
            next_sec += 1
        tm.sleep_atleast(0.001)

    # record stop
    recordLogger.info('>> Stop requested. Copying buffer')
    buffers, times = sr.get_buffer()
    signals = buffers
    events = None

    # channels = total channels from amp, including trigger channel
    data = {'signals':signals, 'timestamps':times, 'events':events,
            'sample_rate':sr.get_sample_rate(), 'channels':sr.get_num_channels(),
            'ch_names':sr.get_channel_names(), 'lsl_time_offset':sr.lsl_time_offset}
    recordLogger.info('Saving raw data ...')
    qc.save_obj(pcl_file, data)
    recordLogger.info('Saved to %s\n' % pcl_file)

    # automatically convert to fif and use event file if it exists (software trigger)
    if os.path.exists(eve_file):
        recordLogger.info('Found matching event file, adding events.')
    else:
        eve_file = None
    recordLogger.info('Converting raw file into fif.')
    pcl2fif(pcl_file, external_event=eve_file)
Ejemplo n.º 5
0
def get_tfr(fif_file, cfg, tfr, n_jobs=1):
    raw, events = pu.load_raw(fif_file)
    p = qc.parse_path(fif_file)
    fname = p.name
    outpath = p.dir

    export_dir = '%s/plot_%s' % (outpath, fname)
    qc.make_dirs(export_dir)

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER,
                  spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks,
                  multiplier=cfg.MULTIPLIER, n_jobs=n_jobs)

    # MNE TFR functions do not support Raw instances yet, so convert to Epoch
    if cfg.EVENT_START is None:
        raw._data[0][0] = 1
        events = np.array([[0, 0, 1]])
        classes = None
    else:
        classes = {'START':cfg.EVENT_START}
    tmax = (raw._data.shape[1] - 1) / raw.info['sfreq']
    epochs_all = mne.Epochs(raw, events, classes, tmin=0, tmax=tmax,
                    picks=picks, baseline=None, preload=True)
    print('\n>> Processing %s' % fif_file)
    freqs = cfg.FREQ_RANGE  # define frequencies of interest
    n_cycles = freqs / 2.  # different number of cycle per frequency
    power = tfr(epochs_all, freqs=freqs, n_cycles=n_cycles, use_fft=False,
        return_itc=False, decim=1, n_jobs=n_jobs)

    if cfg.EXPORT_MATLAB is True:
        # export all channels to MATLAB
        mout = '%s/%s-%s.mat' % (export_dir, fname, cfg.SP_FILTER)
        scipy.io.savemat(mout, {'tfr':power.data, 'chs':power.ch_names, 'events':events,
            'sfreq':raw.info['sfreq'], 'freqs':cfg.FREQ_RANGE})

    if cfg.EXPORT_PNG is True:
        # Plot power of each channel
        for ch in np.arange(len(picks)):
            ch_name = raw.ch_names[picks[ch]]
            title = 'Channel %s' % (ch_name)
            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = power.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
            fout = '%s/%s-%s-%s.png' % (export_dir, fname, cfg.SP_FILTER, ch_name)
            fig.savefig(fout)
            print('Exported %s' % fout)

    print('Finished !')
Ejemplo n.º 6
0
def fif_resample(fif_dir, sfreq_target):
    out_dir = fif_dir + '/fif_resample%d' % sfreq_target
    qc.make_dirs(out_dir)
    for f in qc.get_file_list(fif_dir):
        pp = qc.parse_path(f)
        if pp.ext != 'fif':
            continue
        logger.info('Resampling %s' % f)
        raw, events = pu.load_raw(f)
        raw.resample(sfreq_target)
        fif_out = '%s/%s.fif' % (out_dir, pp.name)
        raw.save(fif_out)
        logger.info('Exported to %s' % fif_out)
Ejemplo n.º 7
0
def fif2mat(data_dir):
    out_dir = '%s/mat_files' % data_dir
    qc.make_dirs(out_dir)
    for rawfile in qc.get_file_list(data_dir, fullpath=True):
        if rawfile[-4:] != '.fif': continue
        raw, events = pu.load_raw(rawfile)
        events[:,0] += 1 # MATLAB uses 1-based indexing
        sfreq = raw.info['sfreq']
        data = dict(signals=raw._data, events=events, sfreq=sfreq, ch_names=raw.ch_names)
        fname = qc.parse_path(rawfile).name
        matfile = '%s/%s.mat' % (out_dir, fname)
        scipy.io.savemat(matfile, data)
        logger.info('Exported to %s' % matfile)
    logger.info('Finished exporting.')
Ejemplo n.º 8
0
def fif2mat(input_path):
    if os.path.isdir(input_path):
        out_dir = '%s/mat_files' % input_path
        qc.make_dirs(out_dir)
        num_processed = 0
        for rawfile in qc.get_file_list(input_path, fullpath=True):
            if rawfile[-4:] != '.fif':
                continue
            fif2mat_file(rawfile, out_dir)
            num_processed += 1
        if num_processed == 0:
            logger.warning('No fif files found in the path.')
    elif os.path.isfile(input_path):
        out_dir = '%s/mat_files' % qc.parse_path(input_path).dir
        qc.make_dirs(out_dir)
        fif2mat_file(input_path, out_dir)
    else:
        raise ValueError('Neither directory nor file: %s' % input_path)
    logger.info('Finished.')
Ejemplo n.º 9
0
def fix_channel_names(fif_dir, new_channel_names):
    '''
    Change channel names of fif files in a given directory.

    Input
    -----
    @fif_dir: path to fif files
    @new_channel_names: list of new channel names

    Output
    ------
    Modified fif files are saved in fif_dir/corrected/

    Kyuhwa Lee, 2019.
    '''

    flist = []
    for f in qc.get_file_list(fif_dir):
        if qc.parse_path(f).ext == 'fif':
            flist.append(f)

    if len(flist) > 0:
        qc.make_dirs('%s/corrected' % fif_dir)
        for f in qc.get_file_list(fif_dir):
            logger.info('\nLoading %s' % f)
            p = qc.parse_path(f)
            if p.ext == 'fif':
                raw, eve = pu.load_raw(f)
                if len(raw.ch_names) != len(new_channel_names):
                    raise RuntimeError(
                        'The number of new channels do not matach that of fif file.'
                    )
                raw.info['ch_names'] = new_channel_names
                for ch, new_ch in zip(raw.info['chs'], new_channel_names):
                    ch['ch_name'] = new_ch
                out_fif = '%s/corrected/%s.fif' % (p.dir, p.name)
                logger.info('Exporting to %s' % out_fif)
                raw.save(out_fif)
    else:
        logger.warning('No fif files found in %s' % fif_dir)
Ejemplo n.º 10
0
import trainer
from multiprocessing import cpu_count

mne.set_log_level('ERROR')

if __name__ == '__main__':
    rawlist = []
    for f in qc.get_file_list(DATA_PATH, fullpath=True):
        if f[-4:] == '.fif':
            rawlist.append(f)
    if len(rawlist) == 0:
        raise RuntimeError('No fif files found in the path.')

    # make output directory
    out_path = DATA_PATH + '/epochs'
    qc.make_dirs(out_path)

    # load data
    raw, events = pu.load_multi(rawlist, multiplier=MULTIPLIER)
    raw.pick_types(meg=False, eeg=True, stim=False)
    sfreq = raw.info['sfreq']
    if REF_CH_NEW is not None:
        pu.rereference(raw, REF_CH_NEW, REF_CH_OLD)

    # pick channels
    if CHANNEL_PICKS is None:
        picks = [
            raw.ch_names.index(c) for c in raw.ch_names if c not in EXCLUDES
        ]
    elif type(CHANNEL_PICKS[0]) == str:
        picks = [raw.ch_names.index(c) for c in CHANNEL_PICKS]
Ejemplo n.º 11
0
        # train a model
        RF = dict(trees=1000, maxdepth=100)
        cls = RandomForestClassifier(n_estimators=RF['trees'],
                                     max_features='auto',
                                     max_depth=RF['maxdepth'],
                                     n_jobs=n_jobs)
        cls.fit(X, Y)
        cls.n_jobs = 1  # n_jobs should be 1 for online decoding
        print(
            'Trained a Random Forest classifer with %d trees and %d maxdepth' %
            (RF['trees'], RF['maxdepth']))
        ch_names = [raw.info['ch_names'][c] for c in picks_feat]
        data = dict(sfreq=raw.info['sfreq'], ch_names=ch_names, picks=picks_feat,\
                    cls=cls, l_freq=l_freq, h_freq=h_freq, decim_factor=decim_factor)
        outdir = DATADIR + '/errp_classifier'
        qc.make_dirs(outdir)
        clsfile = outdir + '/errp_classifier.pcl'
        qc.save_obj(clsfile, data)
        print('Saved as %s' % clsfile)

    if True:
        # hoang's code
        label = epochs.events[:, 2]

        cls = rLDA_binary(0.3)

        train_data = epochs._data
        train_label = label

        ### Normalization
        (train_data_normalized, trainShiftFactor,
Ejemplo n.º 12
0
        '''
        data= dict(sfreq=raw.info['sfreq'], ch_names=ch_names, picks=picks_feat, \
                w=w,b=b, l_freq=l_freq, h_freq=h_freq, decim_factor=decim_factor, pca=pca,
                shiftFactor=trainShiftFactor, scaleFactor=trainScaleFactor)
        '''
        # remember line 195:
        # t_lower = tmin-paddingLength
        # t_upper = tmax+paddingLength

        ##########################################################################
        data = dict(cls=cls, sfreq=raw.info['sfreq'], ch_names=ch_names, picks=picks_feat,\
                    l_freq=l_freq, h_freq=h_freq, decim_factor=decim_factor,\
                    shiftFactor=trainShiftFactor, scaleFactor=trainScaleFactor, pca=pca, threshold=best_threshold[0],
                    tmin=tmin, tmax=tmax, paddingIdx=paddingIdx, iir_params=dict(a=a, b=b))
        outdir = DATADIR + '/errp_classifier'
        qc.make_dirs(outdir)
        clsfile = outdir + '/errp_classifier.pcl'
        qc.save_obj(clsfile, data)
        print('Saved as %s' % clsfile)
print('Done')

#    def balance_idx(label):
#        labelsetWrong = np.where(label==3)[0]
#        labelsetCorrect = np.where(label==4)[0]
#
#        diff = len(labelsetCorrect) - len(labelsetWrong)
#
#        if diff > 0:
#            smallestSet = labelsetWrong
#            largestSet = labelsetCorrect
#        elif diff<0:
Ejemplo n.º 13
0
def train_decoder(cfg, featdata, feat_file=None):
    """
    Train the final decoder using all data
    """
    # Init a classifier
    selected_classifier = cfg.CLASSIFIER['selected']
    if selected_classifier == 'GB':
        cls = GradientBoostingClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER[selected_classifier]['learning_rate'],
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            random_state=cfg.CLASSIFIER[selected_classifier]['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'XGB':
        cls = XGBClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER[selected_classifier]['learning_rate'],
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            random_state=cfg.GB['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'RF':
        cls = RandomForestClassifier(
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            max_features='auto',
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            n_jobs=cfg.N_JOBS,
            random_state=cfg.CLASSIFIER[selected_classifier]['seed'],
            oob_score=False,
            class_weight='balanced_subsample')
    elif selected_classifier == 'LDA':
        cls = LDA()
    elif selected_classifier == 'rLDA':
        cls = rLDA(cfg.CLASSIFIER[selected_classifier][r_coeff])
    else:
        logger.error('Unknown classifier %s' % selected_classifier)
        raise ValueError

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']
    if cfg.FEATURES['PSD']['wlen'] is None:
        cfg.FEATURES['PSD']['wlen'] = wlen
    w_frames = featdata['w_frames']
    ch_names = featdata['ch_names']
    X_data_merged = np.concatenate(X_data)
    Y_data_merged = np.concatenate(Y_data)
    if cfg.CV['BALANCE_SAMPLES']:
        X_data_merged, Y_data_merged = balance_samples(
            X_data_merged,
            Y_data_merged,
            cfg.CV['BALANCE_SAMPLES'],
            verbose=True)

    # Start training the decoder
    logger.info_green('Training the decoder')
    timer = qc.Timer()
    cls.n_jobs = cfg.N_JOBS
    cls.fit(X_data_merged, Y_data_merged)
    logger.info('Trained %d samples x %d dimension in %.1f sec' %\
          (X_data_merged.shape[0], X_data_merged.shape[1], timer.sec()))
    cls.n_jobs = 1  # always set n_jobs=1 for testing

    # Export the decoder
    classes = {c: cfg.tdef.by_value[c] for c in np.unique(Y_data)}
    if cfg.FEATURES['selected'] == 'PSD':
        data = dict(cls=cls,
                    ch_names=ch_names,
                    psde=featdata['psde'],
                    sfreq=featdata['sfreq'],
                    picks=featdata['picks'],
                    classes=classes,
                    epochs=cfg.EPOCH,
                    w_frames=w_frames,
                    w_seconds=cfg.FEATURES['PSD']['wlen'],
                    wstep=cfg.FEATURES['PSD']['wstep'],
                    spatial=cfg.SP_FILTER,
                    spatial_ch=featdata['picks'],
                    spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']],
                    spectral_ch=featdata['picks'],
                    notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']],
                    notch_ch=featdata['picks'],
                    multiplier=cfg.MULTIPLIER,
                    ref_ch=cfg.REREFERENCE[cfg.REREFERENCE['selected']],
                    decim=cfg.FEATURES['PSD']['decim'])
    clsfile = '%s/classifier/classifier-%s.pkl' % (cfg.DATA_PATH,
                                                   platform.architecture()[0])
    qc.make_dirs('%s/classifier' % cfg.DATA_PATH)
    qc.save_obj(clsfile, data)
    logger.info('Decoder saved to %s' % clsfile)

    # Reverse-lookup frequency from FFT
    fq = 0
    if type(cfg.FEATURES['PSD']['wlen']) == list:
        fq_res = 1.0 / cfg.FEATURES['PSD']['wlen'][0]
    else:
        fq_res = 1.0 / cfg.FEATURES['PSD']['wlen']
    fqlist = []
    while fq <= cfg.FEATURES['PSD']['fmax']:
        if fq >= cfg.FEATURES['PSD']['fmin']:
            fqlist.append(fq)
        fq += fq_res

    # Show top distinctive features
    if cfg.FEATURES['selected'] == 'PSD':
        logger.info_green('Good features ordered by importance')
        if selected_classifier in ['RF', 'GB', 'XGB']:
            keys, values = qc.sort_by_value(list(cls.feature_importances_),
                                            rev=True)
        elif selected_classifier in ['LDA', 'rLDA']:
            keys, values = qc.sort_by_value(cls.w, rev=True)
        keys = np.array(keys)
        values = np.array(values)

        if cfg.EXPORT_GOOD_FEATURES:
            if feat_file is None:
                gfout = open('%s/classifier/good_features.txt' % cfg.DATA_PATH,
                             'w')
            else:
                gfout = open(feat_file, 'w')

        if type(wlen) is not list:
            ch_names = [ch_names[c] for c in featdata['picks']]
        else:
            ch_names = []
            for w in range(len(wlen)):
                for c in featdata['picks']:
                    ch_names.append('w%d-%s' % (w, ch_names[c]))

        chlist, hzlist = features.feature2chz(keys, fqlist, ch_names=ch_names)
        valnorm = values[:cfg.FEAT_TOPN].copy()
        valsum = np.sum(valnorm)
        if valsum == 0:
            valsum = 1
        valnorm = valnorm / valsum * 100.0

        # show top-N features
        for i, (ch, hz) in enumerate(zip(chlist, hzlist)):
            if i >= cfg.FEAT_TOPN:
                break
            txt = '%-3s %5.1f Hz  normalized importance %-6s  raw importance %-6s  feature %-5d' %\
                  (ch, hz, '%.2f%%' % valnorm[i], '%.2f%%' % (values[i] * 100.0), keys[i])
            logger.info(txt)

        if cfg.EXPORT_GOOD_FEATURES:
            gfout.write('Importance(%) Channel Frequency Index\n')
            for i, (ch, hz) in enumerate(zip(chlist, hzlist)):
                gfout.write('%.3f\t%s\t%s\t%d\n' %
                            (values[i] * 100.0, ch, hz, keys[i]))
            gfout.close()
Ejemplo n.º 14
0
def cross_validate(cfg, featdata, cv_file=None):
    """
    Perform cross validation
    """
    # Init a classifier
    selected_classifier = cfg.CLASSIFIER['selected']
    if selected_classifier == 'GB':
        cls = GradientBoostingClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER['GB']['learning_rate'],
            presort='auto',
            n_estimators=cfg.CLASSIFIER['GB']['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER['GB']['depth'],
            random_state=cfg.CLASSIFIER['GB']['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False)
    elif selected_classifier == 'XGB':
        cls = XGBClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER['XGB']['learning_rate'],
            presort='auto',
            n_estimators=cfg.CLASSIFIER['XGB']['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER['XGB']['depth'],
            random_state=cfg.CLASSIFIER['XGB'],
            max_features='sqrt',
            verbose=0,
            warm_start=False)
    elif selected_classifier == 'RF':
        cls = RandomForestClassifier(
            n_estimators=cfg.CLASSIFIER['RF']['trees'],
            max_features='auto',
            max_depth=cfg.CLASSIFIER['RF']['depth'],
            n_jobs=cfg.N_JOBS,
            random_state=cfg.CLASSIFIER['RF']['seed'],
            oob_score=False,
            class_weight='balanced_subsample')
    elif selected_classifier == 'LDA':
        cls = LDA()
    elif selected_classifier == 'rLDA':
        cls = rLDA(cfg.CLASSIFIER['rLDA']['r_coeff'])
    else:
        logger.error('Unknown classifier type %s' % selected_classifier)
        raise ValueError

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']

    # Choose CV type
    ntrials, nsamples, fsize = X_data.shape
    selected_cv = cfg.CV_PERFORM['selected']
    if selected_cv == 'LeaveOneOut':
        logger.info_green('%d-fold leave-one-out cross-validation' % ntrials)
        if SKLEARN_OLD:
            cv = LeaveOneOut(len(Y_data))
        else:
            cv = LeaveOneOut()
    elif selected_cv == 'StratifiedShuffleSplit':
        logger.info_green(
            '%d-fold stratified cross-validation with test set ratio %.2f' %
            (cfg.CV_PERFORM[selected_cv]['folds'],
             cfg.CV_PERFORM[selected_cv]['test_ratio']))
        if SKLEARN_OLD:
            cv = StratifiedShuffleSplit(
                Y_data[:, 0],
                cfg.CV_PERFORM[selected_cv]['folds'],
                test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'],
                random_state=cfg.CV_PERFORM[selected_cv]['seed'])
        else:
            cv = StratifiedShuffleSplit(
                n_splits=cfg.CV_PERFORM[selected_cv]['folds'],
                test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'],
                random_state=cfg.CV_PERFORM[selected_cv]['seed'])
    else:
        logger.error('%s is not supported yet. Sorry.' %
                     cfg.CV_PERFORM[cfg.CV_PERFORM['selected']])
        raise NotImplementedError
    logger.info('%d trials, %d samples per trial, %d feature dimension' %
                (ntrials, nsamples, fsize))

    # Do it!
    timer_cv = qc.Timer()
    scores, cm_txt = crossval_epochs(cv,
                                     X_data,
                                     Y_data,
                                     cls,
                                     cfg.tdef.by_value,
                                     cfg.CV['BALANCE_SAMPLES'],
                                     n_jobs=cfg.N_JOBS,
                                     ignore_thres=cfg.CV['IGNORE_THRES'],
                                     decision_thres=cfg.CV['DECISION_THRES'])
    t_cv = timer_cv.sec()

    # Export results
    txt = 'Cross validation took %d seconds.\n' % t_cv
    txt += '\n- Class information\n'
    txt += '%d epochs, %d samples per epoch, %d feature dimension (total %d samples)\n' %\
        (ntrials, nsamples, fsize, ntrials * nsamples)
    for ev in np.unique(Y_data):
        txt += '%s: %d trials\n' % (cfg.tdef.by_value[ev],
                                    len(np.where(Y_data[:, 0] == ev)[0]))
    if cfg.CV['BALANCE_SAMPLES']:
        txt += 'The number of samples was balanced using %ssampling.\n' % cfg.BALANCE_SAMPLES.lower(
        )
    txt += '\n- Experiment condition\n'
    txt += 'Sampling frequency: %.3f Hz\n' % featdata['sfreq']
    txt += 'Spatial filter: %s (channels: %s)\n' % (cfg.SP_FILTER,
                                                    cfg.SP_CHANNELS)
    txt += 'Spectral filter: %s\n' % cfg.TP_FILTER[cfg.TP_FILTER['selected']]
    txt += 'Notch filter: %s\n' % cfg.NOTCH_FILTER[
        cfg.NOTCH_FILTER['selected']]
    txt += 'Channels: ' + ','.join(
        [str(featdata['ch_names'][p]) for p in featdata['picks']]) + '\n'
    txt += 'PSD range: %.1f - %.1f Hz\n' % (cfg.FEATURES['PSD']['fmin'],
                                            cfg.FEATURES['PSD']['fmax'])
    txt += 'Window step: %.2f msec\n' % (
        1000.0 * cfg.FEATURES['PSD']['wstep'] / featdata['sfreq'])
    if type(wlen) is list:
        for i, w in enumerate(wlen):
            txt += 'Window size: %.1f msec\n' % (w * 1000.0)
            txt += 'Epoch range: %s sec\n' % (cfg.EPOCH[i])
    else:
        txt += 'Window size: %.1f msec\n' % (cfg.FEATURES['PSD']['wlen'] *
                                             1000.0)
        txt += 'Epoch range: %s sec\n' % (cfg.EPOCH)
    txt += 'Decimation factor: %d\n' % cfg.FEATURES['PSD']['decim']

    # Compute stats
    cv_mean, cv_std = np.mean(scores), np.std(scores)
    txt += '\n- Average CV accuracy over %d epochs (random seed=%s)\n' % (
        ntrials, cfg.CV_PERFORM[cfg.CV_PERFORM['selected']]['seed'])
    if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] in [
            'LeaveOneOut', 'StratifiedShuffleSplit'
    ]:
        txt += "mean %.3f, std: %.3f\n" % (cv_mean, cv_std)
    txt += 'Classifier: %s, ' % selected_classifier
    if selected_classifier == 'RF':
        txt += '%d trees, %s max depth, random state %s\n' % (
            cfg.CLASSIFIER['RF']['trees'], cfg.CLASSIFIER['RF']['depth'],
            cfg.CLASSIFIER['RF']['seed'])
    elif selected_classifier == 'GB' or selected_classifier == 'XGB':
        txt += '%d trees, %s max depth, %s learing_rate, random state %s\n' % (
            cfg.CLASSIFIER['GB']['trees'], cfg.CLASSIFIER['GB']['depth'],
            cfg.CLASSIFIER['GB']['learning_rate'],
            cfg.CLASSIFIER['GB']['seed'])
    elif selected_classifier == 'rLDA':
        txt += 'regularization coefficient %.2f\n' % cfg.CLASSIFIER['rLDA'][
            'r_coeff']
    if cfg.CV['IGNORE_THRES'] is not None:
        txt += 'Decision threshold: %.2f\n' % cfg.CV['IGNORE_THRES']
    txt += '\n- Confusion Matrix\n' + cm_txt
    logger.info(txt)

    # Export to a file
    if 'export_result' in cfg.CV_PERFORM[selected_cv] and cfg.CV_PERFORM[
            selected_cv]['export_result'] is True:
        if cv_file is None:
            if cfg.EXPORT_CLS is True:
                qc.make_dirs('%s/classifier' % cfg.DATA_PATH)
                fout = open('%s/classifier/cv_result.txt' % cfg.DATA_PATH, 'w')
            else:
                fout = open('%s/cv_result.txt' % cfg.DATA_PATH, 'w')
        else:
            fout = open(cv_file, 'w')
        fout.write(txt)
        fout.close()
Ejemplo n.º 15
0
def pcl2fif(filename,
            interactive=False,
            outdir=None,
            external_event=None,
            offset=0,
            overwrite=False,
            precision='single'):
    """
    PyCNBI Python pickle file

    Params
    --------
    outdir: If None, it will be the subdirectory of the fif file.
    external_event: Event file in text format. Each row should be: "SAMPLE_INDEX 0 EVENT_TYPE"
    precision: Data matrix format. 'single' improves backward compatability.
    """
    fdir, fname, fext = qc.parse_path_list(filename)
    if outdir is None:
        outdir = fdir + 'fif/'
    elif outdir[-1] != '/':
        outdir += '/'

    data = qc.load_obj(filename)

    if type(data['signals']) == list:
        signals_raw = np.array(data['signals'][0]).T  # to channels x samples
    else:
        signals_raw = data['signals'].T  # to channels x samples
    sample_rate = data['sample_rate']

    if 'ch_names' not in data:
        ch_names = ['CH%d' % (x + 1) for x in range(signals_raw.shape[0])]
    else:
        ch_names = data['ch_names']

    # search for event channel
    trig_ch = pu.find_event_channel(signals_raw, ch_names)
    ''' TODO: REMOVE
    # exception
    if trig_ch is None:
        logger.warning('Inferred event channel is None.')
        if interactive:
            logger.warning('If you are sure everything is alright, press Enter.')
            input()

    # fix wrong event channel
    elif trig_ch_guess != trig_ch:
        logger.warning('Specified event channel (%d) != inferred event channel (%d).' % (trig_ch, trig_ch_guess))
        if interactive: input('Press Enter to fix. Event channel will be set to %d.' % trig_ch_guess)
        ch_names.insert(trig_ch_guess, ch_names.pop(trig_ch))
        trig_ch = trig_ch_guess
        logger.info('New channel list:')
        for c in ch_names:
            logger.info('%s' % c)
        logger.info('Event channel is now set to %d' % trig_ch)
    '''

    # move trigger channel to index 0
    if trig_ch is None:
        # assuming no event channel exists, add a event channel to index 0 for consistency.
        logger.warning(
            'No event channel was not found. Adding a blank event channel to index 0.'
        )
        eventch = np.zeros([1, signals_raw.shape[1]])
        signals = np.concatenate((eventch, signals_raw), axis=0)
        num_eeg_channels = signals_raw.shape[
            0]  # data['channels'] is not reliable any more
        trig_ch = 0
        ch_names = ['TRIGGER'
                    ] + ['CH%d' % (x + 1) for x in range(num_eeg_channels)]
    elif trig_ch == 0:
        signals = signals_raw
        num_eeg_channels = data['channels'] - 1
    else:
        # move event channel to 0
        logger.info('Moving event channel %d to 0.' % trig_ch)
        signals = np.concatenate(
            (signals_raw[[trig_ch]], signals_raw[:trig_ch],
             signals_raw[trig_ch + 1:]),
            axis=0)
        assert signals_raw.shape == signals.shape
        num_eeg_channels = data['channels'] - 1
        ch_names.pop(trig_ch)
        trig_ch = 0
        ch_names.insert(trig_ch, 'TRIGGER')
        logger.info('New channel list:')
        for c in ch_names:
            logger.info('%s' % c)

    ch_info = ['stim'] + ['eeg'] * num_eeg_channels
    info = mne.create_info(ch_names, sample_rate, ch_info)

    # create Raw object
    raw = mne.io.RawArray(signals, info)
    raw._times = data['timestamps']  # seems to have no effect

    if external_event is not None:
        raw._data[0] = 0  # erase current events
        events_index = event_timestamps_to_indices(filename, external_event,
                                                   offset)
        if len(events_index) == 0:
            logger.warning('No events were found in the event file')
        else:
            logger.info('Found %d events' % len(events_index))
            raw.add_events(events_index, stim_channel='TRIGGER')

    qc.make_dirs(outdir)
    fiffile = outdir + fname + '.fif'

    raw.save(fiffile, verbose=False, overwrite=overwrite, fmt=precision)
    logger.info('Saved to %s' % fiffile)

    saveChannels2txt(outdir, ch_names)

    return True
Ejemplo n.º 16
0
def get_tfr(cfg, recursive=False, n_jobs=1):
    '''
    @params:
    tfr_type: 'multitaper' or 'morlet'
    recursive: if True, load raw files in sub-dirs recursively
    export_path: path to save plots
    n_jobs: number of cores to run in parallel
    '''

    cfg = check_cfg(cfg)
    tfr_type = cfg.TFR_TYPE
    export_path = cfg.EXPORT_PATH
    t_buffer = cfg.T_BUFFER
    if tfr_type == 'multitaper':
        tfr = mne.time_frequency.tfr_multitaper
    elif tfr_type == 'morlet':
        tfr = mne.time_frequency.tfr_morlet
    elif tfr_type == 'butter':
        butter_order = 4 # TODO: parameterize
        tfr = lfilter
    elif tfr_type == 'fir':
        raise NotImplementedError
    else:
        raise ValueError('Wrong TFR type %s' % tfr_type)
    n_jobs = cfg.N_JOBS
    if n_jobs is None:
        n_jobs = mp.cpu_count()

    if hasattr(cfg, 'DATA_DIRS'):
        if export_path is None:
            raise ValueError('For multiple directories, cfg.EXPORT_PATH cannot be None')
        else:
            outpath = export_path
        # custom event file
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)
        file_prefix = 'grandavg'

        # load and merge files from all directories
        flist = []
        for ddir in cfg.DATA_DIRS:
            ddir = ddir.replace('\\', '/')
            if ddir[-1] != '/': ddir += '/'
            for f in qc.get_file_list(ddir, fullpath=True, recursive=recursive):
                if qc.parse_path(f).ext in ['fif', 'bdf', 'gdf']:
                    flist.append(f)
        raw, events = pu.load_multi(flist)
    else:
        print('Loading', cfg.DATA_FILE)
        raw, events = pu.load_raw(cfg.DATA_FILE)

        # custom events
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)

        if export_path is None:
            [outpath, file_prefix, _] = qc.parse_path_list(cfg.DATA_FILE)
        else:
            outpath = export_path

    # re-referencing
    if cfg.REREFERENCE is not None:
        pu.rereference(raw, cfg.REREFERENCE[1], cfg.REREFERENCE[0])

    sfreq = raw.info['sfreq']

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER,
                  spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks,
                  multiplier=cfg.MULTIPLIER, n_jobs=n_jobs)

    # Read epochs
    classes = {}
    for t in cfg.TRIGGERS:
        if t in set(events[:, -1]):
            if hasattr(cfg, 'tdef'):
                classes[cfg.tdef.by_value[t]] = t
            else:
                classes[str(t)] = t
    if len(classes) == 0:
        raise ValueError('No desired event was found from the data.')

    try:
        tmin = cfg.EPOCH[0]
        tmin_buffer = tmin - t_buffer
        raw_tmax = raw._data.shape[1] / sfreq - 0.1
        if cfg.EPOCH[1] is None:
            if cfg.POWER_AVERAGED:
                raise ValueError('EPOCH value cannot have None for grand averaged TFR')
            else:
                if len(cfg.TRIGGERS) > 1:
                    raise ValueError('If the end time of EPOCH is None, only a single event can be defined.')
                t_ref = events[np.where(events[:,2] == list(cfg.TRIGGERS)[0])[0][0], 0] / sfreq
                tmax = raw_tmax - t_ref - t_buffer
        else:
            tmax = cfg.EPOCH[1]
        tmax_buffer = tmax + t_buffer
        if tmax_buffer > raw_tmax:
            raise ValueError('Epoch length with buffer (%.3f) is larger than signal length (%.3f)' % (tmax_buffer, raw_tmax))

        #print('Epoch tmin = %.1f, tmax = %.1f, raw length = %.1f' % (tmin, tmax, raw_tmax))
        epochs_all = mne.Epochs(raw, events, classes, tmin=tmin_buffer, tmax=tmax_buffer,
                                proj=False, picks=picks, baseline=None, preload=True)
        if epochs_all.drop_log_stats() > 0:
            print('\n** Bad epochs found. Dropping into a Python shell.')
            print(epochs_all.drop_log)
            print('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
                (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
            print('\nType exit to continue.\n')
            pdb.set_trace()
    except:
        print('\n*** (tfr_export) ERROR OCCURRED WHILE EPOCHING ***')
        traceback.print_exc()
        print('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
            (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
        pdb.set_trace()

    power = {}
    for evname in classes:
        #export_dir = '%s/plot_%s' % (outpath, evname)
        export_dir = outpath
        qc.make_dirs(export_dir)
        print('\n>> Processing %s' % evname)
        freqs = cfg.FREQ_RANGE  # define frequencies of interest
        n_cycles = freqs / 2.  # different number of cycle per frequency
        if cfg.POWER_AVERAGED:
            # grand-average TFR
            epochs = epochs_all[evname][:]
            if len(epochs) == 0:
                print('No %s epochs. Skipping.' % evname)
                continue

            if tfr_type == 'butter':
                b, a = butter_bandpass(cfg.FREQ_RANGE[0], cfg.FREQ_RANGE[-1], sfreq, order=butter_order)
                tfr_filtered = lfilter(b, a, epochs, axis=2)
                tfr_hilbert = hilbert(tfr_filtered)
                tfr_power = abs(tfr_hilbert)
                tfr_data = np.mean(tfr_power, axis=0)
            elif tfr_type == 'fir':
                raise NotImplementedError
            else:
                power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False,
                    return_itc=False, decim=1, n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                tfr_data = power[evname].data

            if cfg.EXPORT_MATLAB is True:
                # export all channels to MATLAB
                mout = '%s/%s-%s-%s.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname)
                scipy.io.savemat(mout, {'tfr':tfr_data, 'chs':epochs.ch_names,
                    'events':events, 'sfreq':sfreq, 'epochs':cfg.EPOCH, 'freqs':cfg.FREQ_RANGE})
                print('Exported %s' % mout)
            if cfg.EXPORT_PNG is True:
                # Inspect power for each channel
                for ch in np.arange(len(picks)):
                    chname = raw.ch_names[picks[ch]]
                    title = 'Peri-event %s - Channel %s' % (evname, chname)

                    # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                    fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                        colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
                    fout = '%s/%s-%s-%s-%s.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname)
                    fig.savefig(fout)
                    fig.clf()
                    print('Exported to %s' % fout)
        else:
            # TFR per event
            for ep in range(len(epochs_all[evname])):
                epochs = epochs_all[evname][ep]
                if len(epochs) == 0:
                    print('No %s epochs. Skipping.' % evname)
                    continue
                power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False,
                    return_itc=False, decim=1, n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                if cfg.EXPORT_MATLAB is True:
                    # export all channels to MATLAB
                    mout = '%s/%s-%s-%s-ep%02d.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname, ep + 1)
                    scipy.io.savemat(mout, {'tfr':power[evname].data, 'chs':power[evname].ch_names,
                        'events':events, 'sfreq':sfreq, 'tmin':tmin, 'tmax':tmax, 'freqs':cfg.FREQ_RANGE})
                    print('Exported %s' % mout)
                if cfg.EXPORT_PNG is True:
                    # Inspect power for each channel
                    for ch in np.arange(len(picks)):
                        chname = raw.ch_names[picks[ch]]
                        title = 'Peri-event %s - Channel %s, Trial %d' % (evname, chname, ep + 1)
                        # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                        fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                            colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
                        fout = '%s/%s-%s-%s-%s-ep%02d.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname, ep + 1)
                        fig.savefig(fout)
                        fig.clf()
                        print('Exported %s' % fout)

    if hasattr(cfg, 'POWER_DIFF'):
        export_dir = '%s/diff' % outpath
        qc.make_dirs(export_dir)
        labels = classes.keys()
        df = power[labels[0]] - power[labels[1]]
        df.data = np.log(np.abs(df.data))
        # Inspect power diff for each channel
        for ch in np.arange(len(picks)):
            chname = raw.ch_names[picks[ch]]
            title = 'Peri-event %s-%s - Channel %s' % (labels[0], labels[1], chname)

            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = df.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                          colorbar=True, title=title, vmin=3.0, vmax=-3.0, dB=False)
            fout = '%s/%s-%s-diff-%s-%s-%s.jpg' % (export_dir, file_prefix, cfg.SP_FILTER, labels[0], labels[1], chname)
            print('Exporting to %s' % fout)
            fig.savefig(fout)
            fig.clf()
    print('Finished !')
Ejemplo n.º 17
0
def cross_validate(cfg, featdata, cv_file=None):
    """
    Perform cross validation
    """
    # Init a classifier
    if cfg.CLASSIFIER == 'GB':
        cls = GradientBoostingClassifier(loss='deviance',
                                         learning_rate=cfg.GB['learning_rate'],
                                         n_estimators=cfg.GB['trees'],
                                         subsample=1.0,
                                         max_depth=cfg.GB['max_depth'],
                                         random_state=cfg.GB['seed'],
                                         max_features='sqrt',
                                         verbose=0,
                                         warm_start=False,
                                         presort='auto')
    elif cfg.CLASSIFIER == 'XGB':
        cls = XGBClassifier(loss='deviance',
                            learning_rate=cfg.GB['learning_rate'],
                            n_estimators=cfg.GB['trees'],
                            subsample=1.0,
                            max_depth=cfg.GB['max_depth'],
                            random_state=cfg.GB['seed'],
                            max_features='sqrt',
                            verbose=0,
                            warm_start=False,
                            presort='auto')
    elif cfg.CLASSIFIER == 'RF':
        cls = RandomForestClassifier(n_estimators=cfg.RF['trees'],
                                     max_features='auto',
                                     max_depth=cfg.RF['max_depth'],
                                     n_jobs=cfg.N_JOBS,
                                     random_state=cfg.RF['seed'],
                                     oob_score=True,
                                     class_weight='balanced_subsample')
    elif cfg.CLASSIFIER == 'LDA':
        cls = LDA()
    elif cfg.CLASSIFIER == 'rLDA':
        cls = rLDA(cfg.RLDA_REGULARIZE_COEFF)
    else:
        raise ValueError('Unknown classifier type %s' % cfg.CLASSIFIER)

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']
    if cfg.PSD['wlen'] is None:
        cfg.PSD['wlen'] = wlen

    # Choose CV type
    ntrials, nsamples, fsize = X_data.shape
    if cfg.CV_PERFORM == 'LeaveOneOut':
        print('\n>> %d-fold leave-one-out cross-validation' % ntrials)
        if SKLEARN_OLD:
            cv = LeaveOneOut(len(Y_data))
        else:
            cv = LeaveOneOut()
    elif cfg.CV_PERFORM == 'StratifiedShuffleSplit':
        print(
            '\n>> %d-fold stratified cross-validation with test set ratio %.2f'
            % (cfg.CV_FOLDS, cfg.CV_TEST_RATIO))
        if SKLEARN_OLD:
            cv = StratifiedShuffleSplit(Y_data[:, 0],
                                        cfg.CV_FOLDS,
                                        test_size=cfg.CV_TEST_RATIO,
                                        random_state=cfg.CV_RANDOM_SEED)
        else:
            cv = StratifiedShuffleSplit(n_splits=cfg.CV_FOLDS,
                                        test_size=cfg.CV_TEST_RATIO,
                                        random_state=cfg.CV_RANDOM_SEED)
    else:
        raise NotImplementedError('%s is not supported yet. Sorry.' %
                                  cfg.CV_PERFORM)
    print('%d trials, %d samples per trial, %d feature dimension' %
          (ntrials, nsamples, fsize))

    # Do it!
    timer_cv = qc.Timer()
    scores, cm_txt = crossval_epochs(cv,
                                     X_data,
                                     Y_data,
                                     cls,
                                     cfg.tdef.by_value,
                                     cfg.BALANCE_SAMPLES,
                                     n_jobs=cfg.N_JOBS,
                                     ignore_thres=cfg.CV_IGNORE_THRES,
                                     decision_thres=cfg.CV_DECISION_THRES)
    t_cv = timer_cv.sec()

    # Export results
    txt = '\n>> Cross validation took %d seconds.\n' % t_cv
    txt += '\n- Class information\n'
    txt += '%d epochs, %d samples per epoch, %d feature dimension (total %d samples)\n' %\
        (ntrials, nsamples, fsize, ntrials * nsamples)
    for ev in np.unique(Y_data):
        txt += '%s: %d trials\n' % (cfg.tdef.by_value[ev],
                                    len(np.where(Y_data[:, 0] == ev)[0]))
    if cfg.BALANCE_SAMPLES:
        txt += 'The number of samples was balanced across classes. Method: %s\n' % cfg.BALANCE_SAMPLES
    txt += '\n- Experiment conditions\n'
    txt += 'Spatial filter: %s (channels: %s)\n' % (cfg.SP_FILTER,
                                                    cfg.SP_FILTER)
    txt += 'Spectral filter: %s\n' % cfg.TP_FILTER
    txt += 'Notch filter: %s\n' % cfg.NOTCH_FILTER
    txt += 'Channels: ' + ','.join(
        [str(featdata['ch_names'][p]) for p in featdata['picks']]) + '\n'
    txt += 'PSD range: %.1f - %.1f Hz\n' % (cfg.PSD['fmin'], cfg.PSD['fmax'])
    txt += 'Window step: %.2f msec\n' % (1000.0 * cfg.PSD['wstep'] /
                                         featdata['sfreq'])
    if type(wlen) is list:
        for i, w in enumerate(wlen):
            txt += 'Window size: %.1f msec\n' % (w * 1000.0)
            txt += 'Epoch range: %s sec\n' % (cfg.EPOCH[i])
    else:
        txt += 'Window size: %.1f msec\n' % (cfg.PSD['wlen'] * 1000.0)
        txt += 'Epoch range: %s sec\n' % (cfg.EPOCH)

    # Compute stats
    cv_mean, cv_std = np.mean(scores), np.std(scores)
    txt += '\n- Average CV accuracy over %d epochs (random seed=%s)\n' % (
        ntrials, cfg.CV_RANDOM_SEED)
    if cfg.CV_PERFORM in ['LeaveOneOut', 'StratifiedShuffleSplit']:
        txt += "mean %.3f, std: %.3f\n" % (cv_mean, cv_std)
    txt += 'Classifier: %s, ' % cfg.CLASSIFIER
    if cfg.CLASSIFIER == 'RF':
        txt += '%d trees, %s max depth, random state %s\n' % (
            cfg.RF['trees'], cfg.RF['max_depth'], cfg.RF['seed'])
    elif cfg.CLASSIFIER == 'GB' or cfg.CLASSIFIER == 'XGB':
        txt += '%d trees, %s max depth, %s learing_rate, random state %s\n' % (
            cfg.GB['trees'], cfg.GB['max_depth'], cfg.GB['learning_rate'],
            cfg.GB['seed'])
    elif cfg.CLASSIFIER == 'rLDA':
        txt += 'regularization coefficient %.2f\n' % cfg.RLDA_REGULARIZE_COEFF
    if cfg.CV_IGNORE_THRES is not None:
        txt += 'Decision threshold: %.2f\n' % cfg.CV_IGNORE_THRES
    txt += '\n- Confusion Matrix\n' + cm_txt
    print(txt)

    # Export to a file
    if hasattr(
            cfg, 'CV_EXPORT_RESULT'
    ) and cfg.CV_EXPORT_RESULT is True and cfg.CV_PERFORM is not None:
        if cv_file is None:
            if cfg.EXPORT_CLS is True:
                qc.make_dirs('%s/classifier' % cfg.DATADIR)
                fout = open('%s/classifier/cv_result.txt' % cfg.DATADIR, 'w')
            else:
                fout = open('%s/cv_result.txt' % cfg.DATADIR, 'w')
        else:
            fout = open(cv_file, 'w')
        fout.write(txt)
        fout.close()
Ejemplo n.º 18
0
def createClassifier(loadedraw,\
                     events,\
                     tmin,\
                     tmax,\
                     tlow,\
                     thigh,\
                     regcoeff,\
                     useLeaveOneOut,\
                     APPLY_CAR,\
                     APPLY_PCA,\
                     l_freq,\
                     h_freq,\
                     MAX_FPR,\
                     picks_feat,\
                     baselineRange,\
                     decim_factor,\
                     cv_container,\
                     FILTER_METHOD,\
                     best_threshold,\
                     verbose=False):
    tdef, sfreq, event_id, b, a, zi, t_lower, t_upper, epochs, wframes = preprocess(loadedraw=loadedraw,\
                                                                                    events=events,\
                                                                                    APPLY_CAR=APPLY_CAR,\
                                                                                    l_freq=l_freq,\
                                                                                    h_freq=h_freq,\
                                                                                    filter_method=FILTER_METHOD,\
                                                                                    tmin=tmin,\
                                                                                    tmax=tmax,\
                                                                                    tlow=tlow,\
                                                                                    thigh=thigh,\
                                                                                    n_jobs=n_jobs,\
                                                                                    picks_feat=picks_feat,\
                                                                                    baselineRange=baselineRange,
                                                                                    verbose=False)
    train_pcaed, pca, trainShiftFactor, trainScaleFactor = compute_features(signals=epochs._data,\
                                                                            dataset_type='train',\
                                                                            sfreq=sfreq,\
                                                                            l_freq=l_freq,\
                                                                            h_freq=h_freq,\
                                                                            decim_factor=decim_factor,\
                                                                            shiftFactor=None,\
                                                                            scaleFactor=None,\
                                                                            pca=None,\
                                                                            tmin=tmin,\
                                                                            tmax=tmax,\
                                                                            tlow=tlow,\
                                                                            thigh=thigh,\
                                                                            filter_method=FILTER_METHOD)

    cls = rLDA(regcoeff)
    label = epochs.events[:, 2]
    cls.fit(train_pcaed, label)
    ch_names = [loadedraw.info['ch_names'][c] for c in picks_feat]
    data = dict(apply_car=APPLY_CAR,
                sfreq=loadedraw.info['sfreq'],\
                picks=picks_feat,\
                decim_factor=decim_factor,\
                ch_names=ch_names,\
                tmin=tmin,\
                tmax=tmax,\
                tlow=tlow,\
                thigh=thigh,\
                l_freq=l_freq,\
                h_freq=h_freq,\
                baselineRange=baselineRange,\
                shiftFactor=trainShiftFactor,\
                scaleFactor=trainScaleFactor,\
                cls=cls,\
                pca=pca,\
                threshold=best_threshold[0],\
                filter_method=FILTER_METHOD,\
                wframes=wframes)
    outdir = DATADIR + '/errp_classifier'
    qc.make_dirs(outdir)
    clsfile = outdir + '/errp_classifier.pcl'
    qc.save_obj(clsfile, data)
    print('Saved as %s' % clsfile)
    print('Using ' + str(epochs._data.shape[0]) + ' epochs')
Ejemplo n.º 19
0
def run(file_flag):
    # Shimer MAC addresses
    # shimm_addr= ["00:06:66:46:9A:67", "00:06:66:46:B6:4A"]#, "00:06:66:46:BD:8D", "00:06:66:46:9A:1A", "00:06:66:46:BD:BF"]
    shimm_addr = ["00:06:66:46:B7:D4"]
    emg_addr = []  # ["00:06:66:46:9A:1A", "00:06:66:46:BD:BF"]

    # Configuration parameters
    scan_flag = 1
    plot_flag = 0

    sock_port = 1
    nodes = []
    plt_axx = 500
    plt_ylim = 4000
    plt_rate = 20

    rng_size = 50

    # rng_acc_x=RingBuffer(50)
    # Add sample to ringbuffer
    # rng_acc_x.append(pack_0)
    # buff1= np.zeros((n_nodes,10,rng_size),dtype=np.int)
    # buff2= np.zeros((n_nodes,10,rng_size),dtype=np.int)
    # buff_flag= 1
    # buff= [[[0 for x in range(10)] for y in range(2)] for z in range(rng_size)]
    # buff_idx= 0

    if plot_flag == 1:
        # plot parameters
        sample_idx = 0
        analogData = AnalogData(plt_axx)

    # Get the list of available nodes
    if scan_flag == 0:
        target_addr = shimm_addr
    else:
        try:
            target_addr = []
            print("Scanning bluetooth devices...")
            nearby_devices = bluetooth.discover_devices()
            for bdaddr in nearby_devices:
                print("			" + str(bdaddr) + " - " +
                      bluetooth.lookup_name(bdaddr))
                if bdaddr in shimm_addr:
                    target_addr.append(bdaddr)
        except:
            print("[Error] Problem while scanning bluetooth")
            sys.exit(1)

    n_nodes = len(target_addr)
    if n_nodes > 0:
        print(("Found %d target Shimmer nodes") % (len(target_addr)))
    else:
        print("Could not find target bluetooth device nearby. Exiting")
        sys.exit(1)

    print("Configuring the nodes...")
    for node_idx, bdaddr in enumerate(target_addr):
        try:
            # Connecting to the sensors
            sock = bluetooth.BluetoothSocket(bluetooth.RFCOMM)
            if bdaddr in emg_addr:
                n = shimmer_node(bdaddr, sock, 0x2)
            else:
                n = shimmer_node(bdaddr, sock, 0x1)
            nodes.append(n)

            print((bdaddr, sock_port), end=' ')
            nodes[-1].sock.connect((bdaddr, sock_port))
            print(" Shimmer %d (" % (node_idx) +
                  bluetooth.lookup_name(bdaddr) + ") [Connected]")
            # send the set sensors command
            nodes[-1].sock.send(
                struct.pack('BBB', 0x08, nodes[-1].senscfg_hi,
                            nodes[-1].senscfg_lo))
            nodes[-1].wait_for_ack()

            # send the set sampling rate command
            nodes[-1].sock.send(struct.pack('BB', 0x05, 0x14))  # 51.2Hz
            nodes[-1].wait_for_ack()

            # Inquiry command
            print("	Shimmer %d (" % (node_idx) +
                  bluetooth.lookup_name(bdaddr) + ") [Configured]")
            nodes[-1].sock.send(struct.pack('B', 0x01))
            nodes[-1].wait_for_ack()
            inq = nodes[-1].read_inquiry()
        except bluetooth.btcommon.BluetoothError as e:
            print(("BluetoothError during read_data: {0}".format(e.strerror)))
            print("Unable to connect to the nodes. Exiting")
            sys.exit(1)

    # Create file and plot
    try:
        if file_flag == 1:
            # Create buffer
            now = datetime.datetime.now()
            qc.make_dirs('../DATA')
            logname = "../DATA/IMU_" + now.strftime("%Y%m%d%H%M") + ".log"
            print("[cnbi_shimmer] Creating file: %s" % (logname))
            outfile = open(logname, "w")
            for node_idx, shim in enumerate(nodes):
                outfile.write(str(node_idx) + ": " + str(shim.addr) + "\n")
            outfile.close()

            fname = "../DATA/IMU_" + now.strftime("%Y%m%d%H%M") + ".dat"
            print("[cnbi_shimmer] Creating file: %s" % (fname))
            outfile = open(fname, "w")

        # Create plot
        if plot_flag == 1:
            analogPlot = AnalogPlot(analogData)
            plt.axis([0, plt_axx, 0, plt_ylim])
            plt.ion()
            plt.show()
    except:
        print("[Error]: Error creating file/plot!! Exiting")
        # close the socket
        print("Closing nodes")
        for node_idx, shim in enumerate(nodes):
            shim.sock.close()
            print("	Shimmer %d [Ok]" % (node_idx))
        sys.exit(1)

    print(
        "[cnbi_shimmer] Recording started. Press Ctrl+C to finish recording.")
    # send start streaming command
    for shim in nodes:
        shim.sock.send(struct.pack('B', 0x07))

    for node_idx, shim in enumerate(nodes):
        shim.wait_for_ack()
        shim.up = 1
        print("	Shimmer %d [Ok]" % (node_idx))

    # Main acquisition loop
    while True:
        try:
            sample = []
            sample_lslclock = []
            for shim in nodes:
                if shim.up == 1:
                    sample.append(shim.read_data())
                else:
                    sample.append([0] * (shim.n_fields))

            for samp in sample:
                sample_lslclock.append([pylsl.local_clock()] + list(samp[1:]))

            if file_flag == 1:
                simplejson.dump(sample_lslclock,
                                outfile,
                                separators=(',', ';'))
                outfile.write('\n')

            # print sample
            # plt.title(str(sample[0][0]))

            # leeq
            if plot_flag == 1:
                analogData.add([sample[0][1], sample[0][2]])
                sample_idx = sample_idx + 1
                if sample_idx % plt_rate == 0:
                    analogPlot.update(analogData)

            if file_flag == 0:
                print(qc.list2string(sample_lslclock[1], '%9.1f', ' '))

        # Exit if key is pressed
        except KeyboardInterrupt:
            print("\n[cnbi_shimmer] Stopping acquisition....")
            break
        except bluetooth.btcommon.BluetoothError as e:
            print(("[Error] BluetoothError during read_data: {0}".format(
                e.strerror)))

    # send stop streaming command
    print("[cnbi_shimmer] Stopping streaming")
    try:
        for shim in nodes:
            shim.sock.send(struct.pack('B', 0x20))
        for node_idx, shim in enumerate(nodes):
            shim.wait_for_ack()
            print("	Shimmer %d [Ok]" % (node_idx))
    except bluetooth.btcommon.BluetoothError as e:
        print(("[Error] BluetoothError during read_data: {0}".format(
            e.strerror)))
    '''
        n_nodes =	len(target_addr)
        while n_nodes>0:
        sample= []
        for node_idx,shim in enumerate(nodes):
        pckt= shim.wait_stop_streaming()
        print "	Shimmer %d [waiting]" % (node_idx)
        if len(pckt) != 1:
            sample.append(pckt)
        else:
        sample.append(str("0"*(shim.samplesize)))
        nodes.remove(shim)
        n_nodes= n_nodes-1
        print "	Shimmer %d [Ok]" % (node_idx)
        simplejson.dump(sample, outfile, separators=(',',';'))
        analogData.add([sample[0][1],sample[1][1]])
        analogPlot.update(analogData)
    '''

    # Closing	file
    if file_flag == 1:
        print("[cnbi_shimmer] Closing file: %s" % (fname))
        try:
            outfile.close()
        except:
            print("			[Error] Problem closing file!")

    # close the socket
    print("[cnbi_shimmer] Closing nodes")
    for node_idx, shim in enumerate(nodes):
        shim.sock.close()
        print("	Shimmer %d [Ok]" % (node_idx))

    print("[cnbi_shimmer] Recording Finished. Please close this window.")
    getch()