示例#1
0
def epochs2mat(data_dir,
               channel_picks,
               event_id,
               tmin,
               tmax,
               merge_epochs=False,
               spfilter=None,
               spchannels=None):
    if merge_epochs:
        # load all raw files in the directory and merge epochs
        fiflist = []
        for data_file in qc.get_file_list(data_dir, fullpath=True):
            if data_file[-4:] != '.fif':
                continue
            fiflist.append(data_file)
        raw, events = pu.load_multi(fiflist,
                                    spfilter=spfilter,
                                    spchannels=spchannels)
        matfile = data_dir + '/epochs_all.mat'
        save_mat(raw, events, channel_picks, event_id, tmin, tmax, matfile)
    else:
        # process individual raw file separately
        for data_file in qc.get_file_list(data_dir, fullpath=True):
            if data_file[-4:] != '.fif':
                continue
            [base, fname, fext] = qc.parse_path_list(data_file)
            matfile = '%s/%s-epochs.mat' % (base, fname)
            raw, events = pu.load_raw(data_file)
            save_mat(raw, events, channel_picks, event_id, tmin, tmax, matfile)

    logger.info('Exported to %s' % matfile)
示例#2
0
def add_lsl_events(event_dir, offset=0, recursive=False, interactive=True):
    """
    Add events recorded with LSL timestamps to raw data files.
    Useful for software triggering.

    @params
    -------
    event_dir:
    Path to *-eve.txt files.

    offset:
    Timestamp offset (in seconds) in case the LSL server timestamps are shifted.
    Some OpenVibe acquisition servers send timestamps of their own running time (always
    starting from 0) instead of LSL timestamps. In this case, the only way to deal with
    this problem is to add an offset, a difference between LSL timestamp and OpenVibe
    server time stamp.

    recursive:
    Search sub-directories recursively.


    Kyuhwa Lee
    Swiss Federal Institute of Technology Lausanne (EPFL)
    2017
    """
    import pycnbi.utils.q_common as qc
    from pycnbi.utils.convert2fif import pcl2fif
    from builtins import input

    offset = float(offset)
    if offset != 0:
        logger.info_yellow('Time offset = %.3f' % offset)
    to_process = []
    logger.info('Files to be processed')
    if recursive:
        for d in qc.get_dir_list(event_dir):
            for f in qc.get_file_list(d, True):
                if f[-8:] == '-eve.txt':
                    to_process.append(f)
                    logger.info(f)
    else:
        for f in qc.get_file_list(event_dir, True):
            if f[-8:] == '-eve.txt':
                to_process.append(f)
                logger.info(f)

    if interactive:
        input('\nPress Enter to start')
    for f in to_process:
        pclfile = f.replace('-eve.txt', '-raw.pcl')
        pcl2fif(pclfile, external_event=f, offset=offset, overwrite=True)
示例#3
0
def get_tfr_each_file(cfg,
                      tfr_type='multitaper',
                      recursive=False,
                      export_path=None,
                      n_jobs=1):
    '''
    @params:
    tfr_type: 'multitaper' or 'morlet'
    recursive: if True, load raw files in sub-dirs recursively
    export_path: path to save plots
    n_jobs: number of cores to run in parallel
    '''

    cfg = check_cfg(cfg)

    t_buffer = cfg.T_BUFFER
    if tfr_type == 'multitaper':
        tfr = mne.time_frequency.tfr_multitaper
    elif tfr_type == 'morlet':
        tfr = mne.time_frequency.tfr_morlet
    else:
        raise ValueError('Wrong TFR type %s' % tfr_type)

    for fifdir in cfg.DATA_PATHS:
        for f in qc.get_file_list(fifdir, fullpath=True, recursive=recursive):
            [fdir, fname, fext] = qc.parse_path_list(f)
            if fext in ['fif', 'bdf', 'gdf']:
                get_tfr(f, cfg, tfr, cfg.N_JOBS)
示例#4
0
def main(input_dir, channel_file=None):
    count = 0
    for f in qc.get_file_list(input_dir, fullpath=True, recursive=True):
        p = qc.parse_path(f)
        outdir = p.dir + '/fif/'
        if p.ext in ['pcl', 'bdf', 'edf', 'gdf', 'eeg', 'xdf']:
            print('Converting %s' % f)
            any2fif(f, interactive=True, outdir=outdir, channel_file=channel_file)
            count += 1

    print('\n>> %d files converted.' % count)
示例#5
0
文件: trainer.py 项目: aizmeng/pycnbi
def cva_features(datadir):
    """
    (DEPRECATED FUNCTION)
    """
    for fin in qc.get_file_list(datadir, fullpath=True):
        if fin[-4:] != '.gdf': continue
        fout = fin + '.cva'
        if os.path.exists(fout):
            logger.info('Skipping', fout)
            continue
        logger.info("cva_features('%s')" % fin)
        qc.matlab("cva_features('%s')" % fin)
示例#6
0
def fix_channel_names(fif_dir, new_channel_names):
    '''
    Change channel names of fif files in a given directory.

    Input
    -----
    @fif_dir: path to fif files
    @new_channel_names: list of new channel names

    Output
    ------
    Modified fif files are saved in fif_dir/corrected/

    Kyuhwa Lee, 2019.
    '''

    flist = []
    for f in qc.get_file_list(fif_dir):
        if qc.parse_path(f).ext == 'fif':
            flist.append(f)

    if len(flist) > 0:
        qc.make_dirs('%s/corrected' % fif_dir)
        for f in qc.get_file_list(fif_dir):
            logger.info('\nLoading %s' % f)
            p = qc.parse_path(f)
            if p.ext == 'fif':
                raw, eve = pu.load_raw(f)
                if len(raw.ch_names) != len(new_channel_names):
                    raise RuntimeError(
                        'The number of new channels do not matach that of fif file.'
                    )
                raw.info['ch_names'] = new_channel_names
                for ch, new_ch in zip(raw.info['chs'], new_channel_names):
                    ch['ch_name'] = new_ch
                out_fif = '%s/corrected/%s.fif' % (p.dir, p.name)
                logger.info('Exporting to %s' % out_fif)
                raw.save(out_fif)
    else:
        logger.warning('No fif files found in %s' % fif_dir)
示例#7
0
def fif_resample(fif_dir, sfreq_target):
    out_dir = fif_dir + '/fif_resample%d' % sfreq_target
    qc.make_dirs(out_dir)
    for f in qc.get_file_list(fif_dir):
        pp = qc.parse_path(f)
        if pp.ext != 'fif':
            continue
        logger.info('Resampling %s' % f)
        raw, events = pu.load_raw(f)
        raw.resample(sfreq_target)
        fif_out = '%s/%s.fif' % (out_dir, pp.name)
        raw.save(fif_out)
        logger.info('Exported to %s' % fif_out)
示例#8
0
文件: fif2mat.py 项目: aizmeng/pycnbi
def fif2mat(data_dir):
    out_dir = '%s/mat_files' % data_dir
    qc.make_dirs(out_dir)
    for rawfile in qc.get_file_list(data_dir, fullpath=True):
        if rawfile[-4:] != '.fif': continue
        raw, events = pu.load_raw(rawfile)
        events[:,0] += 1 # MATLAB uses 1-based indexing
        sfreq = raw.info['sfreq']
        data = dict(signals=raw._data, events=events, sfreq=sfreq, ch_names=raw.ch_names)
        fname = qc.parse_path(rawfile).name
        matfile = '%s/%s.mat' % (out_dir, fname)
        scipy.io.savemat(matfile, data)
        logger.info('Exported to %s' % matfile)
    logger.info('Finished exporting.')
示例#9
0
def fif2mat(input_path):
    if os.path.isdir(input_path):
        out_dir = '%s/mat_files' % input_path
        qc.make_dirs(out_dir)
        num_processed = 0
        for rawfile in qc.get_file_list(input_path, fullpath=True):
            if rawfile[-4:] != '.fif':
                continue
            fif2mat_file(rawfile, out_dir)
            num_processed += 1
        if num_processed == 0:
            logger.warning('No fif files found in the path.')
    elif os.path.isfile(input_path):
        out_dir = '%s/mat_files' % qc.parse_path(input_path).dir
        qc.make_dirs(out_dir)
        fif2mat_file(input_path, out_dir)
    else:
        raise ValueError('Neither directory nor file: %s' % input_path)
    logger.info('Finished.')
示例#10
0
def read_images(img_path, screen_size=None):
    pnglist = []
    for f in qc.get_file_list(img_path):
        if f[-4:] != '.png':
            continue

        img = cv2.imread(f)
        # fit to screen size if image is larger
        if screen_size is not None:
            screen_width, screen_height = screen_size
            rx = img.shape[1] / screen_width
            ry = img.shape[0] / screen_height
            if max(rx, ry) > 1:
                if rx > ry:
                    target_w = screen_width
                    target_h = int(img.shape[0] / rx)
                elif rx < ry:
                    target_w = int(img.shape[1] / ry)
                    target_h = screen_height
                else:
                    target_w = screen_width
                    target_h = screen_height
            else:
                target_w = img.shape[1]
                target_h = img.shape[0]
            dsize = (int(target_w), int(target_h))
            img_res = cv2.resize(img, dsize, interpolation=cv2.INTER_LANCZOS4)
            img_out = np.zeros((screen_height, screen_width, img.shape[2]),
                               dtype=img.dtype)
            ox = int((screen_width - target_w) / 2)
            oy = int((screen_height - target_h) / 2)
            img_out[oy:oy + target_h, ox:ox + target_w, :] = img_res
        else:
            img_out = img
        pnglist.append(img_out)
        print('.', end='')
    print('Done')
    return pnglist
示例#11
0
    n_jobs = 1
    fmin = 1
    fmax = 40
    wlen = 0.5
    wstep = 32
    tmin = 0.0
    tmax = 30
    channel_picks = None
    excludes = ['TRIGGER', 'M1', 'M2', 'EOG']

    if n_jobs > 1:
        import multiprocessing as mp

        pool = mp.Pool(n_jobs)
        procs = []
        for rawfile in qc.get_file_list(data_dir):
            if rawfile[-8:] != '-raw.fif': continue
            cmd = [
                rawfile, fmin, fmax, wlen, wstep, tmin, tmax, channel_picks,
                excludes
            ]
            procs.append(pool.apply_async(raw2psd, cmd))
        for proc in procs:
            proc.get()
        pool.close()
        pool.join()
    else:
        for rawfile in qc.get_file_list(data_dir):
            if rawfile[-8:] != '-raw.fif': continue
            raw2psd(rawfile,
                    fmin=fmin,
示例#12
0
import pycnbi
import pycnbi.utils.q_common as qc
from epochs2psd import epochs2psd

# parameters
data_dir = r'D:\data\MI\rx1\offline\gait-pulling\20161104\test'
channel_picks = None
tmin = 0.0
tmax = 3.0
fmin = 1
fmax = 40
w_len = 0.5
w_step = 16
from pycnbi.triggers.trigger_def import trigger_def

tdef = trigger_def('triggerdef_16.ini')
events = {'left': tdef.LEFT_GO, 'right': tdef.RIGHT_GO}

if __name__ == '__main__':
    for f in qc.get_file_list(data_dir):
        if f[-4:] != '.fif': continue
        print(f)
        epochs2psd(f, channel_picks, events, tmin, tmax, fmin, fmax, w_len,
                   w_step)
示例#13
0
    # reset trigger channel
    raw._data[0] *= 0
    raw.add_events(eve, 'TRIGGER')
    raw.save(rawfile_out, overwrite=True)

    logger.info('=== After merging ===')
    for key in np.unique(eve[:, 2]):
        if key in tdef.by_value:
            logger.info(
                '%s: %d events' %
                (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
        else:
            logger.info('%s: %d events' %
                        (key, len(np.where(eve[:, 2] == key)[0])))


# sample code
if __name__ == '__main__':
    fif_dir = r'D:\data\STIMO\GO004\offline\all'
    trigger_file = 'triggerdef_gait_chuv.ini'
    events = {'BOTH_GO': ['LEFT_GO', 'RIGHT_GO']}

    out_dir = fif_dir + '/merged'
    qc.make_dirs(out_dir)
    for rawfile_in in qc.get_file_list(fif_dir):
        p = qc.parse_path(rawfile_in)
        if p.ext != 'fif':
            continue
        rawfile_out = '%s/%s.%s' % (out_dir, p.name, p.ext)
        merge_events(trigger_file, events, rawfile_in, rawfile_out)
示例#14
0
def get_tfr(cfg, recursive=False, n_jobs=1):
    '''
    @params:
    tfr_type: 'multitaper' or 'morlet'
    recursive: if True, load raw files in sub-dirs recursively
    export_path: path to save plots
    n_jobs: number of cores to run in parallel
    '''

    cfg = check_cfg(cfg)
    tfr_type = cfg.TFR_TYPE
    export_path = cfg.EXPORT_PATH
    t_buffer = cfg.T_BUFFER
    if tfr_type == 'multitaper':
        tfr = mne.time_frequency.tfr_multitaper
    elif tfr_type == 'morlet':
        tfr = mne.time_frequency.tfr_morlet
    elif tfr_type == 'butter':
        butter_order = 4 # TODO: parameterize
        tfr = lfilter
    elif tfr_type == 'fir':
        raise NotImplementedError
    else:
        raise ValueError('Wrong TFR type %s' % tfr_type)
    n_jobs = cfg.N_JOBS
    if n_jobs is None:
        n_jobs = mp.cpu_count()

    if hasattr(cfg, 'DATA_DIRS'):
        if export_path is None:
            raise ValueError('For multiple directories, cfg.EXPORT_PATH cannot be None')
        else:
            outpath = export_path
        # custom event file
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)
        file_prefix = 'grandavg'

        # load and merge files from all directories
        flist = []
        for ddir in cfg.DATA_DIRS:
            ddir = ddir.replace('\\', '/')
            if ddir[-1] != '/': ddir += '/'
            for f in qc.get_file_list(ddir, fullpath=True, recursive=recursive):
                if qc.parse_path(f).ext in ['fif', 'bdf', 'gdf']:
                    flist.append(f)
        raw, events = pu.load_multi(flist)
    else:
        print('Loading', cfg.DATA_FILE)
        raw, events = pu.load_raw(cfg.DATA_FILE)

        # custom events
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)

        if export_path is None:
            [outpath, file_prefix, _] = qc.parse_path_list(cfg.DATA_FILE)
        else:
            outpath = export_path

    # re-referencing
    if cfg.REREFERENCE is not None:
        pu.rereference(raw, cfg.REREFERENCE[1], cfg.REREFERENCE[0])

    sfreq = raw.info['sfreq']

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER,
                  spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks,
                  multiplier=cfg.MULTIPLIER, n_jobs=n_jobs)

    # Read epochs
    classes = {}
    for t in cfg.TRIGGERS:
        if t in set(events[:, -1]):
            if hasattr(cfg, 'tdef'):
                classes[cfg.tdef.by_value[t]] = t
            else:
                classes[str(t)] = t
    if len(classes) == 0:
        raise ValueError('No desired event was found from the data.')

    try:
        tmin = cfg.EPOCH[0]
        tmin_buffer = tmin - t_buffer
        raw_tmax = raw._data.shape[1] / sfreq - 0.1
        if cfg.EPOCH[1] is None:
            if cfg.POWER_AVERAGED:
                raise ValueError('EPOCH value cannot have None for grand averaged TFR')
            else:
                if len(cfg.TRIGGERS) > 1:
                    raise ValueError('If the end time of EPOCH is None, only a single event can be defined.')
                t_ref = events[np.where(events[:,2] == list(cfg.TRIGGERS)[0])[0][0], 0] / sfreq
                tmax = raw_tmax - t_ref - t_buffer
        else:
            tmax = cfg.EPOCH[1]
        tmax_buffer = tmax + t_buffer
        if tmax_buffer > raw_tmax:
            raise ValueError('Epoch length with buffer (%.3f) is larger than signal length (%.3f)' % (tmax_buffer, raw_tmax))

        #print('Epoch tmin = %.1f, tmax = %.1f, raw length = %.1f' % (tmin, tmax, raw_tmax))
        epochs_all = mne.Epochs(raw, events, classes, tmin=tmin_buffer, tmax=tmax_buffer,
                                proj=False, picks=picks, baseline=None, preload=True)
        if epochs_all.drop_log_stats() > 0:
            print('\n** Bad epochs found. Dropping into a Python shell.')
            print(epochs_all.drop_log)
            print('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
                (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
            print('\nType exit to continue.\n')
            pdb.set_trace()
    except:
        print('\n*** (tfr_export) ERROR OCCURRED WHILE EPOCHING ***')
        traceback.print_exc()
        print('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
            (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
        pdb.set_trace()

    power = {}
    for evname in classes:
        #export_dir = '%s/plot_%s' % (outpath, evname)
        export_dir = outpath
        qc.make_dirs(export_dir)
        print('\n>> Processing %s' % evname)
        freqs = cfg.FREQ_RANGE  # define frequencies of interest
        n_cycles = freqs / 2.  # different number of cycle per frequency
        if cfg.POWER_AVERAGED:
            # grand-average TFR
            epochs = epochs_all[evname][:]
            if len(epochs) == 0:
                print('No %s epochs. Skipping.' % evname)
                continue

            if tfr_type == 'butter':
                b, a = butter_bandpass(cfg.FREQ_RANGE[0], cfg.FREQ_RANGE[-1], sfreq, order=butter_order)
                tfr_filtered = lfilter(b, a, epochs, axis=2)
                tfr_hilbert = hilbert(tfr_filtered)
                tfr_power = abs(tfr_hilbert)
                tfr_data = np.mean(tfr_power, axis=0)
            elif tfr_type == 'fir':
                raise NotImplementedError
            else:
                power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False,
                    return_itc=False, decim=1, n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                tfr_data = power[evname].data

            if cfg.EXPORT_MATLAB is True:
                # export all channels to MATLAB
                mout = '%s/%s-%s-%s.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname)
                scipy.io.savemat(mout, {'tfr':tfr_data, 'chs':epochs.ch_names,
                    'events':events, 'sfreq':sfreq, 'epochs':cfg.EPOCH, 'freqs':cfg.FREQ_RANGE})
                print('Exported %s' % mout)
            if cfg.EXPORT_PNG is True:
                # Inspect power for each channel
                for ch in np.arange(len(picks)):
                    chname = raw.ch_names[picks[ch]]
                    title = 'Peri-event %s - Channel %s' % (evname, chname)

                    # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                    fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                        colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
                    fout = '%s/%s-%s-%s-%s.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname)
                    fig.savefig(fout)
                    fig.clf()
                    print('Exported to %s' % fout)
        else:
            # TFR per event
            for ep in range(len(epochs_all[evname])):
                epochs = epochs_all[evname][ep]
                if len(epochs) == 0:
                    print('No %s epochs. Skipping.' % evname)
                    continue
                power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False,
                    return_itc=False, decim=1, n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                if cfg.EXPORT_MATLAB is True:
                    # export all channels to MATLAB
                    mout = '%s/%s-%s-%s-ep%02d.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname, ep + 1)
                    scipy.io.savemat(mout, {'tfr':power[evname].data, 'chs':power[evname].ch_names,
                        'events':events, 'sfreq':sfreq, 'tmin':tmin, 'tmax':tmax, 'freqs':cfg.FREQ_RANGE})
                    print('Exported %s' % mout)
                if cfg.EXPORT_PNG is True:
                    # Inspect power for each channel
                    for ch in np.arange(len(picks)):
                        chname = raw.ch_names[picks[ch]]
                        title = 'Peri-event %s - Channel %s, Trial %d' % (evname, chname, ep + 1)
                        # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                        fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                            colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
                        fout = '%s/%s-%s-%s-%s-ep%02d.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname, ep + 1)
                        fig.savefig(fout)
                        fig.clf()
                        print('Exported %s' % fout)

    if hasattr(cfg, 'POWER_DIFF'):
        export_dir = '%s/diff' % outpath
        qc.make_dirs(export_dir)
        labels = classes.keys()
        df = power[labels[0]] - power[labels[1]]
        df.data = np.log(np.abs(df.data))
        # Inspect power diff for each channel
        for ch in np.arange(len(picks)):
            chname = raw.ch_names[picks[ch]]
            title = 'Peri-event %s-%s - Channel %s' % (labels[0], labels[1], chname)

            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = df.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                          colorbar=True, title=title, vmin=3.0, vmax=-3.0, dB=False)
            fout = '%s/%s-%s-diff-%s-%s-%s.jpg' % (export_dir, file_prefix, cfg.SP_FILTER, labels[0], labels[1], chname)
            print('Exporting to %s' % fout)
            fig.savefig(fout)
            fig.clf()
    print('Finished !')
示例#15
0
文件: trainer.py 项目: LSYhhhh/pycnbi
def compute_features(cfg):
    # Load file list
    ftrain = []
    for f in qc.get_file_list(cfg.DATADIR, fullpath=True):
        if f[-4:] in ['.fif', '.fiff']:
            ftrain.append(f)

    # Preprocessing, epoching and PSD computation
    if len(ftrain) > 1 and cfg.CHANNEL_PICKS is not None and type(
            cfg.CHANNEL_PICKS[0]) == int:
        raise RuntimeError(
            'When loading multiple EEG files, CHANNEL_PICKS must be list of string, not integers because they may have different channel order.'
        )
    raw, events = pu.load_multi(ftrain)
    if cfg.REF_CH is not None:
        pu.rereference(raw, cfg.REF_CH[1], cfg.REF_CH[0])
    if cfg.LOAD_EVENTS_FILE is not None:
        events = mne.read_events(cfg.LOAD_EVENTS_FILE)
    triggers = {cfg.tdef.by_value[c]: c for c in set(cfg.TRIGGER_DEF)}

    # Pick channels
    if cfg.CHANNEL_PICKS is None:
        chlist = [int(x) for x in pick_types(raw.info, stim=False, eeg=True)]
    else:
        chlist = cfg.CHANNEL_PICKS
    picks = []
    for c in chlist:
        if type(c) == int:
            picks.append(c)
        elif type(c) == str:
            picks.append(raw.ch_names.index(c))
        else:
            raise RuntimeError(
                'CHANNEL_PICKS has a value of unknown type %s.\nCHANNEL_PICKS=%s'
                % (type(c), cfg.CHANNEL_PICKS))
    if cfg.EXCLUDES is not None:
        for c in cfg.EXCLUDES:
            if type(c) == str:
                if c not in raw.ch_names:
                    qc.print_c(
                        'Warning: Exclusion channel %s does not exist. Ignored.'
                        % c, 'Y')
                    continue
                c_int = raw.ch_names.index(c)
            elif type(c) == int:
                c_int = c
            else:
                raise RuntimeError(
                    'EXCLUDES has a value of unknown type %s.\nEXCLUDES=%s' %
                    (type(c), cfg.EXCLUDES))
            if c_int in picks:
                del picks[picks.index(c_int)]
    if max(picks) > len(raw.ch_names):
        raise ValueError(
            '"picks" has a channel index %d while there are only %d channels.'
            % (max(picks), len(raw.ch_names)))
    if hasattr(cfg, 'SP_CHANNELS') and cfg.SP_CHANNELS is not None:
        qc.print_c(
            'compute_features(): SP_CHANNELS parameter is not supported yet. Will be set to CHANNEL_PICKS.',
            'Y')
    if hasattr(cfg, 'TP_CHANNELS') and cfg.TP_CHANNELS is not None:
        qc.print_c(
            'compute_features(): TP_CHANNELS parameter is not supported yet. Will be set to CHANNEL_PICKS.',
            'Y')
    if hasattr(cfg, 'NOTCH_CHANNELS') and cfg.NOTCH_CHANNELS is not None:
        qc.print_c(
            'compute_features(): NOTCH_CHANNELS parameter is not supported yet. Will be set to CHANNEL_PICKS.',
            'Y')

    # Read epochs
    try:
        # Experimental: multiple epoch ranges
        if type(cfg.EPOCH[0]) is list:
            epochs_train = []
            for ep in cfg.EPOCH:
                epoch = Epochs(raw,
                               events,
                               triggers,
                               tmin=ep[0],
                               tmax=ep[1],
                               proj=False,
                               picks=picks,
                               baseline=None,
                               preload=True,
                               verbose=False,
                               detrend=None)
                # Channels are already selected by 'picks' param so use all channels.
                pu.preprocess(epoch,
                              spatial=cfg.SP_FILTER,
                              spatial_ch=None,
                              spectral=cfg.TP_FILTER,
                              spectral_ch=None,
                              notch=cfg.NOTCH_FILTER,
                              notch_ch=None,
                              multiplier=cfg.MULTIPLIER,
                              n_jobs=cfg.N_JOBS)
                epochs_train.append(epoch)
        else:
            # Usual method: single epoch range
            epochs_train = Epochs(raw,
                                  events,
                                  triggers,
                                  tmin=cfg.EPOCH[0],
                                  tmax=cfg.EPOCH[1],
                                  proj=False,
                                  picks=picks,
                                  baseline=None,
                                  preload=True,
                                  verbose=False,
                                  detrend=None)
            # Channels are already selected by 'picks' param so use all channels.
            pu.preprocess(epochs_train,
                          spatial=cfg.SP_FILTER,
                          spatial_ch=None,
                          spectral=cfg.TP_FILTER,
                          spectral_ch=None,
                          notch=cfg.NOTCH_FILTER,
                          notch_ch=None,
                          multiplier=cfg.MULTIPLIER,
                          n_jobs=cfg.N_JOBS)
    except:
        qc.print_c('\n*** (trainer.py) ERROR OCCURRED WHILE EPOCHING ***\n',
                   'R')
        # Catch and throw errors from child processes
        traceback.print_exc()
        if interactive:
            print('Dropping into a shell.\n')
            embed()
        raise RuntimeError

    label_set = np.unique(triggers.values())

    # Compute features
    if cfg.FEATURES == 'PSD':
        featdata = get_psd_feature(epochs_train,
                                   cfg.EPOCH,
                                   cfg.PSD,
                                   feat_picks=None,
                                   n_jobs=cfg.N_JOBS)
    elif cfg.FEATURES == 'TIMELAG':
        '''
        TODO: Implement multiple epochs for timelag feature
        '''
        raise NotImplementedError(
            'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR TIMELAG FEATURE.')
    elif cfg.FEATURES == 'WAVELET':
        '''
        TODO: Implement multiple epochs for wavelet feature
        '''
        raise NotImplementedError(
            'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR WAVELET FEATURE.')
    else:
        raise NotImplementedError('%s feature type is not supported.' %
                                  cfg.FEATURES)

    featdata['picks'] = picks
    featdata['sfreq'] = raw.info['sfreq']
    featdata['ch_names'] = raw.ch_names
    return featdata
示例#16
0
TMIN = 0.0
TMAX = 2.0

import pycnbi.utils.pycnbi_utils as pu
import pycnbi.utils.q_common as qc
import scipy.io
import mne
import numpy as np
import trainer
from multiprocessing import cpu_count

mne.set_log_level('ERROR')

if __name__ == '__main__':
    rawlist = []
    for f in qc.get_file_list(DATA_PATH, fullpath=True):
        if f[-4:] == '.fif':
            rawlist.append(f)
    if len(rawlist) == 0:
        raise RuntimeError('No fif files found in the path.')

    # make output directory
    out_path = DATA_PATH + '/epochs'
    qc.make_dirs(out_path)

    # load data
    raw, events = pu.load_multi(rawlist, multiplier=MULTIPLIER)
    raw.pick_types(meg=False, eeg=True, stim=False)
    sfreq = raw.info['sfreq']
    if REF_CH_NEW is not None:
        pu.rereference(raw, REF_CH_NEW, REF_CH_OLD)
示例#17
0
from __future__ import print_function, division
"""
Compute confusion matrix and accuracy from online result logs.

Kyuhwa Lee ([email protected])
Swiss Federal Institute of Technology of Lausanne (EPFL)
"""

LOG_DIR = r'D:\data\MI\rx1\classifier\gait-ULR-250ms'

import pycnbi
import pycnbi.utils.q_common as qc

dtlist = []
gtlist = []
for f in qc.get_file_list(LOG_DIR):
    [basedir, fname, fext] = qc.parse_path_list(f)
    if 'online' not in fname or fext != 'txt':
        continue
    print(f)

    for l in open(f):
        if len(l.strip()) == 0: break
        gt, dt = l.strip().split(',')
        gtlist.append(gt)
        dtlist.append(dt)

print('Ground-truth: %s' % ''.join(gtlist))
print('Detected as : %s' % ''.join(dtlist))
cfmat, acc = qc.confusion_matrix(gtlist, dtlist)
print('\nAverage accuracy: %.3f' % acc)
示例#18
0
    dups = np.where(0 == np.diff(eve[:, 0]))[0]
    assert len(dups) == 0
    assert max(eve[:, 2]) <= max(tdef.by_value.keys())

    # reset trigger channel
    raw._data[0] *= 0
    raw.add_events(eve, 'TRIGGER')
    raw.save(eeg_out, overwrite=True)

    print('\nResulting events')
    for key in np.unique(eve[:, 2]):
        print('%s: %d' %
              (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))


if __name__ == '__main__':
    fif_dir = r'D:\data\STIMO\GO004\offline\all'
    trigger_file = 'triggerdef_gait_chuv.ini'
    events = {'BOTH_GO': ['LEFT_GO', 'RIGHT_GO']}

    fiflist = []
    out_dir = fif_dir + '/merged'
    qc.make_dirs(out_dir)
    for f in qc.get_file_list(fif_dir):
        p = qc.parse_path(f)
        if p.ext != 'fif':
            continue
        eeg_in = f
        eeg_out = '%s/%s.%s' % (out_dir, p.name, p.ext)
        merge_events(trigger_file, events, eeg_in, eeg_out)
示例#19
0
import pycnbi
import os, sys, random
import pycnbi.utils.q_common as qc
import numpy as np
import mne
import matplotlib.pyplot as plt
import multiprocessing as mp
from sklearn.cross_validation import StratifiedShuffleSplit, LeaveOneOut
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
from sklearn.decomposition import PCA
from pycnbi.decoder.rlda import rLDA_binary
from pycnbi.triggers.trigger_def import trigger_def

FLIST = qc.get_file_list(DATADIR, fullpath=True)
n_jobs = mp.cpu_count()


# get grand averages of epochs
def plot_grand_avg(epochs, outdir, picks=None):
    qc.make_dirs(outdir)
    for ev in epochs.event_id.keys():
        epavg = epochs[ev].average(picks)
        for chidx in range(len(epavg.ch_names)):
            ch = epavg.ch_names[chidx]
            epavg.plot(picks=[chidx], unit=True, ylim=dict(eeg=[-10, 10]), titles='%s-%s' % (ch, ev),\
                       show=False, scalings={'eeg':1}).savefig(outdir + '/%s-%s.png' % (ch, ev))


# compute features for the classifier
示例#20
0
def hdf5_to_python(data_dir):
    for rawfile in qc.get_file_list(data_dir, fullpath=True):
        # rawfile = 'D:/Hoang/My Documents/Python/artifact_data/arm_move2015.09.30_18.48.16.hdf5'
        # rawfile= 'D:/data/Artifact/eyeroll2015.09.30_18.38.23.hdf5'
        if rawfile.split('.')[-1] != 'hdf5': continue

        f = h5py.File(rawfile)

        # Reading xml properties
        """
        The HDF5 file is structured as folloging
        - AsynchronData
        |- AsynchronSignalType: XML
        |- Time: time index of the trigger data
        |- TypeID: corresponding trigger data, in a shifted 8-bit pattern (see below for explanation)
        |- Value: No idea !

        -RawData
        |- AcquisitionTaskDescription: XML, one node per channel
        |- DAQDeviceCapabilities: XML, one node per channel
        |- DAQDeviceDescription: XML, mostly useless
        |- Sample: Array of channel x time indexes
        |- SessionDescription: XML, mostly useless
        |- SubjectDescription: XML, mostly useless

        -SavedFeatures
        |- NumberOfFeatures: empty ?

        -Version
        |- Float, useless.
        """
        # Read properties from XML
        tree = XET.fromstring(f['RawData']['AcquisitionTaskDescription'][0])
        samplingFreq = int(tree.findall('SamplingFrequency')[0].text)  # todo: typecheck ?

        # Decode the trigger channel
        triggerDataRaw = f['AsynchronData']['TypeID'].value.ravel()  # Get the bit trigger data, flatten it
        bitoffset = min(triggerDataRaw)  # we're looking for the smallest value
        triggerDataExp = np.ravel(triggerDataRaw - bitoffset)

        timestamps = f['AsynchronData']['Time'].value.ravel()  # timestamps

        # We're looking at the indexes where there's a change
        timestampsOffset = np.insert(timestamps, 0, 0)  # offset by one
        timestamps = np.append(timestamps, 0)  # set to the same length
        diff_idx = np.ravel(np.nonzero(
            timestamps - timestampsOffset))  # non-zero element of this are the one we seek,[0] because it's stupid to have a nparray inside a tuple

        # Iterate each bit indexes and convert the 8-bits to decimal values
        triggerData = np.array([])
        for index, current in enumerate(diff_idx):
            tmp = 0
            if index < len(diff_idx) - 1:  # most elements
                for i in range(diff_idx[index], diff_idx[index + 1]):
                    tmp = tmp + 2 ** triggerDataExp[i]
                triggerData = np.append(triggerData, tmp)

        # Get the index, remove the last element to match the size of triggerData
        triggerIndexes = timestamps[diff_idx]
        triggerIndexes = np.delete(triggerIndexes, len(triggerIndexes) - 1)

        # triggerData and triggerIndexes are ready to be inputted to mne.create_event

        logger.info('%s\n%d events found. Event types: %s' % (rawfile, len(triggerIndexes), set(triggerData)))
        merged = np.vstack((triggerIndexes, triggerData)).T
        matfile = rawfile.replace('.hdf5', '.mat')
        matdata = dict(events=merged)
        scipy.io.savemat(matfile, matdata)
        logger.info('Data exported to %s' % matfile)
示例#21
0
def load_multi(src, spfilter=None, spchannels=None, multiplier=1):
    """
    Load multiple data files and concatenate them into a single series

    - Assumes all files have the same sampling rate and channel order.
    - Event locations are updated accordingly with new offset.

    @params:
        src: directory or list of files.
        spfilter: apply spatial filter while loading.
        spchannels: list of channel names to apply spatial filter.
        multiplier: to change units for better numerical stability.

    See load_raw() for more low-level details.

    """

    if type(src) == str:
        if not os.path.isdir(src):
            logger.error('%s is not a directory or does not exist.' % src)
            raise IOError
        flist = []
        for f in qc.get_file_list(src):
            if qc.parse_path_list(f)[2] == 'fif':
                flist.append(f)
    elif type(src) in [list, tuple]:
        flist = src
    else:
        logger.error('Unknown input type %s' % type(src))
        raise TypeError

    if len(flist) == 0:
        logger.error('load_multi(): No fif files found in %s.' % src)
        raise RuntimeError
    elif len(flist) == 1:
        return load_raw(flist[0],
                        spfilter=spfilter,
                        spchannels=spchannels,
                        multiplier=multiplier)

    # load raw files
    rawlist = []
    for f in flist:
        logger.info('Loading %s' % f)
        raw, _ = load_raw(f,
                          spfilter=spfilter,
                          spchannels=spchannels,
                          multiplier=multiplier)
        rawlist.append(raw)

    # concatenate signals
    signals = None
    for raw in rawlist:
        if signals is None:
            signals = raw._data
        else:
            signals = np.concatenate((signals, raw._data),
                                     axis=1)  # append samples

    # create a concatenated raw object and update channel names
    raw = rawlist[0]
    trigch = find_event_channel(raw)
    ch_types = ['eeg'] * len(raw.ch_names)
    if trigch is not None:
        ch_types[trigch] = 'stim'
    info = mne.create_info(raw.ch_names, raw.info['sfreq'], ch_types)
    raw_merged = mne.io.RawArray(signals, info)

    # re-calculate event positions
    events = mne.find_events(raw_merged,
                             stim_channel='TRIGGER',
                             shortest_event=1,
                             uint_cast=True,
                             consecutive='increasing',
                             output='onset',
                             initial_event=True)

    return raw_merged, events
示例#22
0
def compute_features(cfg):
    '''
    Compute features using config specification.

    Performs preprocessing, epcoching and feature computation.

    Input
    =====
    Config file object

    Output
    ======
    Feature data in dictionary
    - X_data: feature vectors
    - Y_data: feature labels
    - wlen: window length in seconds
    - w_frames: window length in frames
    - psde: MNE PSD estimator object
    - picks: channels used for feature computation
    - sfreq: sampling frequency
    - ch_names: channel names
    - times: feature timestamp (leading edge of a window)
    '''
    # Preprocessing, epoching and PSD computation
    ftrain = []
    for f in qc.get_file_list(cfg.DATA_PATH, fullpath=True):
        if f[-4:] in ['.fif', '.fiff']:
            ftrain.append(f)
    if len(ftrain) > 1 and cfg.PICKED_CHANNELS is not None and type(
            cfg.PICKED_CHANNELS[0]) == int:
        logger.error(
            'When loading multiple EEG files, PICKED_CHANNELS must be list of string, not integers because they may have different channel order.'
        )
        raise RuntimeError
    raw, events = pu.load_multi(ftrain)

    reref = cfg.REREFERENCE[cfg.REREFERENCE['selected']]
    if reref is not None:
        pu.rereference(raw, reref['New'], reref['Old'])

    if cfg.LOAD_EVENTS[cfg.LOAD_EVENTS['selected']] is not None:
        events = mne.read_events(cfg.LOAD_EVENTS[cfg.LOAD_EVENTS['selected']])

    trigger_def_int = set()
    for a in cfg.TRIGGER_DEF:
        trigger_def_int.add(getattr(cfg.tdef, a))
    triggers = {cfg.tdef.by_value[c]: c for c in trigger_def_int}

    # Pick channels
    if cfg.PICKED_CHANNELS is None:
        chlist = [int(x) for x in pick_types(raw.info, stim=False, eeg=True)]
    else:
        chlist = cfg.PICKED_CHANNELS
    picks = []
    for c in chlist:
        if type(c) == int:
            picks.append(c)
        elif type(c) == str:
            picks.append(raw.ch_names.index(c))
        else:
            logger.error(
                'PICKED_CHANNELS has a value of unknown type %s.\nPICKED_CHANNELS=%s'
                % (type(c), cfg.PICKED_CHANNELS))
            raise RuntimeError
    if cfg.EXCLUDED_CHANNELS is not None:
        for c in cfg.EXCLUDED_CHANNELS:
            if type(c) == str:
                if c not in raw.ch_names:
                    logger.warning(
                        'Exclusion channel %s does not exist. Ignored.' % c)
                    continue
                c_int = raw.ch_names.index(c)
            elif type(c) == int:
                c_int = c
            else:
                logger.error(
                    'EXCLUDED_CHANNELS has a value of unknown type %s.\nEXCLUDED_CHANNELS=%s'
                    % (type(c), cfg.EXCLUDED_CHANNELS))
                raise RuntimeError
            if c_int in picks:
                del picks[picks.index(c_int)]
    if max(picks) > len(raw.ch_names):
        logger.error(
            '"picks" has a channel index %d while there are only %d channels.'
            % (max(picks), len(raw.ch_names)))
        raise ValueError
    if hasattr(cfg, 'SP_CHANNELS') and cfg.SP_CHANNELS is not None:
        logger.warning(
            'SP_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.'
        )
    if hasattr(cfg, 'TP_CHANNELS') and cfg.TP_CHANNELS is not None:
        logger.warning(
            'TP_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.'
        )
    if hasattr(cfg, 'NOTCH_CHANNELS') and cfg.NOTCH_CHANNELS is not None:
        logger.warning(
            'NOTCH_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.'
        )
    if 'decim' not in cfg.FEATURES['PSD']:
        cfg.FEATURES['PSD']['decim'] = 1
        logger.warning('PSD["decim"] undefined. Set to 1.')

    # Read epochs
    try:
        # Experimental: multiple epoch ranges
        if type(cfg.EPOCH[0]) is list:
            epochs_train = []
            for ep in cfg.EPOCH:
                epoch = Epochs(raw,
                               events,
                               triggers,
                               tmin=ep[0],
                               tmax=ep[1],
                               proj=False,
                               picks=picks,
                               baseline=None,
                               preload=True,
                               verbose=False,
                               detrend=None)
                epochs_train.append(epoch)
        else:
            # Usual method: single epoch range
            epochs_train = Epochs(raw,
                                  events,
                                  triggers,
                                  tmin=cfg.EPOCH[0],
                                  tmax=cfg.EPOCH[1],
                                  proj=False,
                                  picks=picks,
                                  baseline=None,
                                  preload=True,
                                  verbose=False,
                                  detrend=None,
                                  on_missing='warning')
    except:
        logger.exception('Problem while epoching.')
        raise RuntimeError

    label_set = np.unique(triggers.values())

    # Compute features
    if cfg.FEATURES['selected'] == 'PSD':
        preprocess = dict(sfreq=epochs_train.info['sfreq'],
                          spatial=cfg.SP_FILTER,
                          spatial_ch=None,
                          spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']],
                          spectral_ch=None,
                          notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']],
                          notch_ch=None,
                          multiplier=cfg.MULTIPLIER,
                          ch_names=None,
                          rereference=None,
                          decim=cfg.FEATURES['PSD']['decim'],
                          n_jobs=cfg.N_JOBS)
        featdata = get_psd_feature(epochs_train,
                                   cfg.EPOCH,
                                   cfg.FEATURES['PSD'],
                                   picks=None,
                                   preprocess=preprocess,
                                   n_jobs=cfg.N_JOBS)
    elif cfg.FEATURES == 'TIMELAG':
        '''
        TODO: Implement multiple epochs for timelag feature
        '''
        logger.error(
            'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR TIMELAG FEATURE.')
        raise NotImplementedError
    elif cfg.FEATURES == 'WAVELET':
        '''
        TODO: Implement multiple epochs for wavelet feature
        '''
        logger.error(
            'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR WAVELET FEATURE.')
        raise NotImplementedError
    else:
        logger.error('%s feature type is not supported.' % cfg.FEATURES)
        raise NotImplementedError

    featdata['picks'] = picks
    featdata['sfreq'] = raw.info['sfreq']
    featdata['ch_names'] = raw.ch_names
    return featdata