Пример #1
0
def epochs2mat(data_dir,
               channel_picks,
               event_id,
               tmin,
               tmax,
               merge_epochs=False,
               spfilter=None,
               spchannels=None):
    if merge_epochs:
        # load all raw files in the directory and merge epochs
        fiflist = []
        for data_file in qc.get_file_list(data_dir, fullpath=True):
            if data_file[-4:] != '.fif':
                continue
            fiflist.append(data_file)
        raw, events = pu.load_multi(fiflist,
                                    spfilter=spfilter,
                                    spchannels=spchannels)
        matfile = data_dir + '/epochs_all.mat'
        save_mat(raw, events, channel_picks, event_id, tmin, tmax, matfile)
    else:
        # process individual raw file separately
        for data_file in qc.get_file_list(data_dir, fullpath=True):
            if data_file[-4:] != '.fif':
                continue
            [base, fname, fext] = qc.parse_path_list(data_file)
            matfile = '%s/%s-epochs.mat' % (base, fname)
            raw, events = pu.load_raw(data_file)
            save_mat(raw, events, channel_picks, event_id, tmin, tmax, matfile)

    logger.info('Exported to %s' % matfile)
Пример #2
0
def compute_features(cfg):
    '''
    Compute features using config specification.

    Performs preprocessing, epcoching and feature computation.

    Input
    =====
    Config file object

    Output
    ======
    Feature data in dictionary
    - X_data: feature vectors
    - Y_data: feature labels
    - wlen: window length in seconds
    - w_frames: window length in frames
    - psde: MNE PSD estimator object
    - picks: channels used for feature computation
    - sfreq: sampling frequency
    - ch_names: channel names
    - times: feature timestamp (leading edge of a window)
    '''
    # Preprocessing, epoching and PSD computation
    ftrain = []
    for f in qc.get_file_list(cfg.DATA_PATH, fullpath=True):
        if f[-4:] in ['.fif', '.fiff']:
            ftrain.append(f)
    if len(ftrain) > 1 and cfg.PICKED_CHANNELS is not None and type(
            cfg.PICKED_CHANNELS[0]) == int:
        logger.error(
            'When loading multiple EEG files, PICKED_CHANNELS must be list of string, not integers because they may have different channel order.'
        )
        raise RuntimeError
    raw, events = pu.load_multi(ftrain)

    reref = cfg.REREFERENCE[cfg.REREFERENCE['selected']]
    if reref is not None:
        pu.rereference(raw, reref['New'], reref['Old'])

    if cfg.LOAD_EVENTS[cfg.LOAD_EVENTS['selected']] is not None:
        events = mne.read_events(cfg.LOAD_EVENTS[cfg.LOAD_EVENTS['selected']])

    trigger_def_int = set()
    for a in cfg.TRIGGER_DEF:
        trigger_def_int.add(getattr(cfg.tdef, a))
    triggers = {cfg.tdef.by_value[c]: c for c in trigger_def_int}

    # Pick channels
    if cfg.PICKED_CHANNELS is None:
        chlist = [int(x) for x in pick_types(raw.info, stim=False, eeg=True)]
    else:
        chlist = cfg.PICKED_CHANNELS
    picks = []
    for c in chlist:
        if type(c) == int:
            picks.append(c)
        elif type(c) == str:
            picks.append(raw.ch_names.index(c))
        else:
            logger.error(
                'PICKED_CHANNELS has a value of unknown type %s.\nPICKED_CHANNELS=%s'
                % (type(c), cfg.PICKED_CHANNELS))
            raise RuntimeError
    if cfg.EXCLUDED_CHANNELS is not None:
        for c in cfg.EXCLUDED_CHANNELS:
            if type(c) == str:
                if c not in raw.ch_names:
                    logger.warning(
                        'Exclusion channel %s does not exist. Ignored.' % c)
                    continue
                c_int = raw.ch_names.index(c)
            elif type(c) == int:
                c_int = c
            else:
                logger.error(
                    'EXCLUDED_CHANNELS has a value of unknown type %s.\nEXCLUDED_CHANNELS=%s'
                    % (type(c), cfg.EXCLUDED_CHANNELS))
                raise RuntimeError
            if c_int in picks:
                del picks[picks.index(c_int)]
    if max(picks) > len(raw.ch_names):
        logger.error(
            '"picks" has a channel index %d while there are only %d channels.'
            % (max(picks), len(raw.ch_names)))
        raise ValueError
    if hasattr(cfg, 'SP_CHANNELS') and cfg.SP_CHANNELS is not None:
        logger.warning(
            'SP_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.'
        )
    if hasattr(cfg, 'TP_CHANNELS') and cfg.TP_CHANNELS is not None:
        logger.warning(
            'TP_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.'
        )
    if hasattr(cfg, 'NOTCH_CHANNELS') and cfg.NOTCH_CHANNELS is not None:
        logger.warning(
            'NOTCH_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.'
        )
    if 'decim' not in cfg.FEATURES['PSD']:
        cfg.FEATURES['PSD']['decim'] = 1
        logger.warning('PSD["decim"] undefined. Set to 1.')

    # Read epochs
    try:
        # Experimental: multiple epoch ranges
        if type(cfg.EPOCH[0]) is list:
            epochs_train = []
            for ep in cfg.EPOCH:
                epoch = Epochs(raw,
                               events,
                               triggers,
                               tmin=ep[0],
                               tmax=ep[1],
                               proj=False,
                               picks=picks,
                               baseline=None,
                               preload=True,
                               verbose=False,
                               detrend=None)
                epochs_train.append(epoch)
        else:
            # Usual method: single epoch range
            epochs_train = Epochs(raw,
                                  events,
                                  triggers,
                                  tmin=cfg.EPOCH[0],
                                  tmax=cfg.EPOCH[1],
                                  proj=False,
                                  picks=picks,
                                  baseline=None,
                                  preload=True,
                                  verbose=False,
                                  detrend=None,
                                  on_missing='warning')
    except:
        logger.exception('Problem while epoching.')
        raise RuntimeError

    label_set = np.unique(triggers.values())

    # Compute features
    if cfg.FEATURES['selected'] == 'PSD':
        preprocess = dict(sfreq=epochs_train.info['sfreq'],
                          spatial=cfg.SP_FILTER,
                          spatial_ch=None,
                          spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']],
                          spectral_ch=None,
                          notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']],
                          notch_ch=None,
                          multiplier=cfg.MULTIPLIER,
                          ch_names=None,
                          rereference=None,
                          decim=cfg.FEATURES['PSD']['decim'],
                          n_jobs=cfg.N_JOBS)
        featdata = get_psd_feature(epochs_train,
                                   cfg.EPOCH,
                                   cfg.FEATURES['PSD'],
                                   picks=None,
                                   preprocess=preprocess,
                                   n_jobs=cfg.N_JOBS)
    elif cfg.FEATURES == 'TIMELAG':
        '''
        TODO: Implement multiple epochs for timelag feature
        '''
        logger.error(
            'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR TIMELAG FEATURE.')
        raise NotImplementedError
    elif cfg.FEATURES == 'WAVELET':
        '''
        TODO: Implement multiple epochs for wavelet feature
        '''
        logger.error(
            'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR WAVELET FEATURE.')
        raise NotImplementedError
    else:
        logger.error('%s feature type is not supported.' % cfg.FEATURES)
        raise NotImplementedError

    featdata['picks'] = picks
    featdata['sfreq'] = raw.info['sfreq']
    featdata['ch_names'] = raw.ch_names
    return featdata
Пример #3
0
mne.set_log_level('ERROR')

if __name__ == '__main__':
    rawlist = []
    for f in qc.get_file_list(DATA_PATH, fullpath=True):
        if f[-4:] == '.fif':
            rawlist.append(f)
    if len(rawlist) == 0:
        raise RuntimeError('No fif files found in the path.')

    # make output directory
    out_path = DATA_PATH + '/epochs'
    qc.make_dirs(out_path)

    # load data
    raw, events = pu.load_multi(rawlist, multiplier=MULTIPLIER)
    raw.pick_types(meg=False, eeg=True, stim=False)
    sfreq = raw.info['sfreq']
    if REF_CH_NEW is not None:
        pu.rereference(raw, REF_CH_NEW, REF_CH_OLD)

    # pick channels
    if CHANNEL_PICKS is None:
        picks = [raw.ch_names.index(c) for c in raw.ch_names if c not in EXCLUDES]
    elif type(CHANNEL_PICKS[0]) == str:
        picks = [raw.ch_names.index(c) for c in CHANNEL_PICKS]
    else:
        assert type(CHANNEL_PICKS[0]) is int
        picks = CHANNEL_PICKS

    # do epoching
Пример #4
0
def get_tfr(cfg, recursive=False, n_jobs=1):
    '''
    @params:
    tfr_type: 'multitaper' or 'morlet'
    recursive: if True, load raw files in sub-dirs recursively
    export_path: path to save plots
    n_jobs: number of cores to run in parallel
    '''

    cfg = check_config(cfg)
    tfr_type = cfg.TFR_TYPE
    export_path = cfg.EXPORT_PATH
    t_buffer = cfg.T_BUFFER
    if tfr_type == 'multitaper':
        tfr = mne.time_frequency.tfr_multitaper
    elif tfr_type == 'morlet':
        tfr = mne.time_frequency.tfr_morlet
    elif tfr_type == 'butter':
        butter_order = 4  # TODO: parameterize
        tfr = lfilter
    elif tfr_type == 'fir':
        raise NotImplementedError
    else:
        raise ValueError('Wrong TFR type %s' % tfr_type)
    n_jobs = cfg.N_JOBS
    if n_jobs is None:
        n_jobs = mp.cpu_count()

    if hasattr(cfg, 'DATA_PATHS'):
        if export_path is None:
            raise ValueError(
                'For multiple directories, cfg.EXPORT_PATH cannot be None')
        else:
            outpath = export_path
        # custom event file
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)
        file_prefix = 'grandavg'

        # load and merge files from all directories
        flist = []
        for ddir in cfg.DATA_PATHS:
            ddir = ddir.replace('\\', '/')
            if ddir[-1] != '/': ddir += '/'
            for f in qc.get_file_list(ddir, fullpath=True,
                                      recursive=recursive):
                if qc.parse_path(f).ext in ['fif', 'bdf', 'gdf']:
                    flist.append(f)
        raw, events = pu.load_multi(flist)
    else:
        logger.info('Loading %s' % cfg.DATA_FILE)
        raw, events = pu.load_raw(cfg.DATA_FILE)

        # custom events
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)

        if export_path is None:
            [outpath, file_prefix, _] = qc.parse_path_list(cfg.DATA_FILE)
        else:
            file_prefix = qc.parse_path(cfg.DATA_FILE).name
            outpath = export_path
            file_prefix = qc.parse_path(cfg.DATA_FILE).name

    # re-referencing
    if cfg.REREFERENCE is not None:
        pu.rereference(raw, cfg.REREFERENCE[1], cfg.REREFERENCE[0])
        assert cfg.REREFERENCE[0] in raw.ch_names

    sfreq = raw.info['sfreq']

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    raw = pu.preprocess(raw,
                        spatial=cfg.SP_FILTER,
                        spatial_ch=spchannels,
                        spectral=cfg.TP_FILTER,
                        spectral_ch=picks,
                        notch=cfg.NOTCH_FILTER,
                        notch_ch=picks,
                        multiplier=cfg.MULTIPLIER,
                        n_jobs=n_jobs)

    # Read epochs
    classes = {}
    for t in cfg.TRIGGERS:
        if t in set(events[:, -1]):
            if hasattr(cfg, 'tdef'):
                classes[cfg.tdef.by_value[t]] = t
            else:
                classes[str(t)] = t
    if len(classes) == 0:
        raise ValueError('No desired event was found from the data.')

    try:
        tmin = cfg.EPOCH[0]
        tmin_buffer = tmin - t_buffer
        raw_tmax = raw._data.shape[1] / sfreq - 0.1
        if cfg.EPOCH[1] is None:
            if cfg.POWER_AVERAGED:
                raise ValueError(
                    'EPOCH value cannot have None for grand averaged TFR')
            else:
                if len(cfg.TRIGGERS) > 1:
                    raise ValueError(
                        'If the end time of EPOCH is None, only a single event can be defined.'
                    )
                t_ref = events[np.where(
                    events[:, 2] == list(cfg.TRIGGERS)[0])[0][0], 0] / sfreq
                tmax = raw_tmax - t_ref - t_buffer
        else:
            tmax = cfg.EPOCH[1]
        tmax_buffer = tmax + t_buffer
        if tmax_buffer > raw_tmax:
            raise ValueError(
                'Epoch length with buffer (%.3f) is larger than signal length (%.3f)'
                % (tmax_buffer, raw_tmax))
        epochs_all = mne.Epochs(raw,
                                events,
                                classes,
                                tmin=tmin_buffer,
                                tmax=tmax_buffer,
                                proj=False,
                                picks=picks,
                                baseline=None,
                                preload=True)
        if epochs_all.drop_log_stats() > 0:
            logger.error(
                '\n** Bad epochs found. Dropping into a Python shell.')
            logger.error(epochs_all.drop_log)
            logger.error('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
                (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
            logger.error('\nType exit to continue.\n')
            pdb.set_trace()
    except:
        logger.critical(
            '\n*** (tfr_export) Unknown error occurred while epoching ***')
        logger.critical('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
            (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
        pdb.set_trace()

    power = {}
    for evname in classes:
        export_dir = outpath
        qc.make_dirs(export_dir)
        logger.info('>> Processing %s' % evname)
        freqs = cfg.FREQ_RANGE  # define frequencies of interest
        n_cycles = freqs / 2.  # different number of cycle per frequency
        if cfg.POWER_AVERAGED:
            # grand-average TFR
            epochs = epochs_all[evname][:]
            if len(epochs) == 0:
                logger.WARNING('No %s epochs. Skipping.' % evname)
                continue

            if tfr_type == 'butter':
                b, a = butter_bandpass(cfg.FREQ_RANGE[0],
                                       cfg.FREQ_RANGE[-1],
                                       sfreq,
                                       order=butter_order)
                tfr_filtered = lfilter(b, a, epochs, axis=2)
                tfr_hilbert = hilbert(tfr_filtered)
                tfr_power = abs(tfr_hilbert)
                tfr_data = np.mean(tfr_power, axis=0)
            elif tfr_type == 'fir':
                raise NotImplementedError
            else:
                power[evname] = tfr(epochs,
                                    freqs=freqs,
                                    n_cycles=n_cycles,
                                    use_fft=False,
                                    return_itc=False,
                                    decim=1,
                                    n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                tfr_data = power[evname].data

            if cfg.EXPORT_MATLAB is True:
                # export all channels to MATLAB
                mout = '%s/%s-%s-%s.mat' % (export_dir, file_prefix,
                                            cfg.SP_FILTER, evname)
                scipy.io.savemat(
                    mout, {
                        'tfr': tfr_data,
                        'chs': epochs.ch_names,
                        'events': events,
                        'sfreq': sfreq,
                        'tmin': tmin,
                        'tmax': tmax,
                        'epochs': cfg.EPOCH,
                        'freqs': cfg.FREQ_RANGE
                    })
                logger.info('Exported %s' % mout)
            if cfg.EXPORT_PNG is True:
                # Inspect power for each channel
                for ch in np.arange(len(picks)):
                    chname = raw.ch_names[picks[ch]]
                    title = 'Peri-event %s - Channel %s' % (evname, chname)

                    # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                    fig = power[evname].plot([ch],
                                             baseline=cfg.BS_TIMES,
                                             mode=cfg.BS_MODE,
                                             show=False,
                                             colorbar=True,
                                             title=title,
                                             vmin=cfg.VMIN,
                                             vmax=cfg.VMAX,
                                             dB=False)
                    fout = '%s/%s-%s-%s-%s.png' % (
                        export_dir, file_prefix, cfg.SP_FILTER, evname, chname)
                    fig.savefig(fout)
                    plt.close()
                    logger.info('Exported to %s' % fout)
        else:
            # TFR per event
            for ep in range(len(epochs_all[evname])):
                epochs = epochs_all[evname][ep]
                if len(epochs) == 0:
                    logger.WARNING('No %s epochs. Skipping.' % evname)
                    continue
                power[evname] = tfr(epochs,
                                    freqs=freqs,
                                    n_cycles=n_cycles,
                                    use_fft=False,
                                    return_itc=False,
                                    decim=1,
                                    n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                if cfg.EXPORT_MATLAB is True:
                    # export all channels to MATLAB
                    mout = '%s/%s-%s-%s-ep%02d.mat' % (
                        export_dir, file_prefix, cfg.SP_FILTER, evname, ep + 1)
                    scipy.io.savemat(
                        mout, {
                            'tfr': power[evname].data,
                            'chs': power[evname].ch_names,
                            'events': events,
                            'sfreq': sfreq,
                            'tmin': tmin,
                            'tmax': tmax,
                            'epochs': cfg.EPOCH,
                            'freqs': cfg.FREQ_RANGE
                        })
                    logger.info('Exported %s' % mout)
                if cfg.EXPORT_PNG is True:
                    # Inspect power for each channel
                    for ch in np.arange(len(picks)):
                        chname = raw.ch_names[picks[ch]]
                        title = 'Peri-event %s - Channel %s, Trial %d' % (
                            evname, chname, ep + 1)
                        # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                        fig = power[evname].plot([ch],
                                                 baseline=cfg.BS_TIMES,
                                                 mode=cfg.BS_MODE,
                                                 show=False,
                                                 colorbar=True,
                                                 title=title,
                                                 vmin=cfg.VMIN,
                                                 vmax=cfg.VMAX,
                                                 dB=False)
                        fout = '%s/%s-%s-%s-%s-ep%02d.png' % (
                            export_dir, file_prefix, cfg.SP_FILTER, evname,
                            chname, ep + 1)
                        fig.savefig(fout)
                        plt.close()
                        logger.info('Exported %s' % fout)

    if hasattr(cfg, 'POWER_DIFF'):
        export_dir = '%s/diff' % outpath
        qc.make_dirs(export_dir)
        labels = classes.keys()
        df = power[labels[0]] - power[labels[1]]
        df.data = np.log(np.abs(df.data))
        # Inspect power diff for each channel
        for ch in np.arange(len(picks)):
            chname = raw.ch_names[picks[ch]]
            title = 'Peri-event %s-%s - Channel %s' % (labels[0], labels[1],
                                                       chname)

            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = df.plot([ch],
                          baseline=cfg.BS_TIMES,
                          mode=cfg.BS_MODE,
                          show=False,
                          colorbar=True,
                          title=title,
                          vmin=3.0,
                          vmax=-3.0,
                          dB=False)
            fout = '%s/%s-%s-diff-%s-%s-%s.jpg' % (export_dir, file_prefix,
                                                   cfg.SP_FILTER, labels[0],
                                                   labels[1], chname)
            logger.info('Exporting to %s' % fout)
            fig.savefig(fout)
            plt.close()
    logger.info('Finished !')