def epochs2mat(data_dir, channel_picks, event_id, tmin, tmax, merge_epochs=False, spfilter=None, spchannels=None): if merge_epochs: # load all raw files in the directory and merge epochs fiflist = [] for data_file in qc.get_file_list(data_dir, fullpath=True): if data_file[-4:] != '.fif': continue fiflist.append(data_file) raw, events = pu.load_multi(fiflist, spfilter=spfilter, spchannels=spchannels) matfile = data_dir + '/epochs_all.mat' save_mat(raw, events, channel_picks, event_id, tmin, tmax, matfile) else: # process individual raw file separately for data_file in qc.get_file_list(data_dir, fullpath=True): if data_file[-4:] != '.fif': continue [base, fname, fext] = qc.parse_path_list(data_file) matfile = '%s/%s-epochs.mat' % (base, fname) raw, events = pu.load_raw(data_file) save_mat(raw, events, channel_picks, event_id, tmin, tmax, matfile) logger.info('Exported to %s' % matfile)
def compute_features(cfg): ''' Compute features using config specification. Performs preprocessing, epcoching and feature computation. Input ===== Config file object Output ====== Feature data in dictionary - X_data: feature vectors - Y_data: feature labels - wlen: window length in seconds - w_frames: window length in frames - psde: MNE PSD estimator object - picks: channels used for feature computation - sfreq: sampling frequency - ch_names: channel names - times: feature timestamp (leading edge of a window) ''' # Preprocessing, epoching and PSD computation ftrain = [] for f in qc.get_file_list(cfg.DATA_PATH, fullpath=True): if f[-4:] in ['.fif', '.fiff']: ftrain.append(f) if len(ftrain) > 1 and cfg.PICKED_CHANNELS is not None and type( cfg.PICKED_CHANNELS[0]) == int: logger.error( 'When loading multiple EEG files, PICKED_CHANNELS must be list of string, not integers because they may have different channel order.' ) raise RuntimeError raw, events = pu.load_multi(ftrain) reref = cfg.REREFERENCE[cfg.REREFERENCE['selected']] if reref is not None: pu.rereference(raw, reref['New'], reref['Old']) if cfg.LOAD_EVENTS[cfg.LOAD_EVENTS['selected']] is not None: events = mne.read_events(cfg.LOAD_EVENTS[cfg.LOAD_EVENTS['selected']]) trigger_def_int = set() for a in cfg.TRIGGER_DEF: trigger_def_int.add(getattr(cfg.tdef, a)) triggers = {cfg.tdef.by_value[c]: c for c in trigger_def_int} # Pick channels if cfg.PICKED_CHANNELS is None: chlist = [int(x) for x in pick_types(raw.info, stim=False, eeg=True)] else: chlist = cfg.PICKED_CHANNELS picks = [] for c in chlist: if type(c) == int: picks.append(c) elif type(c) == str: picks.append(raw.ch_names.index(c)) else: logger.error( 'PICKED_CHANNELS has a value of unknown type %s.\nPICKED_CHANNELS=%s' % (type(c), cfg.PICKED_CHANNELS)) raise RuntimeError if cfg.EXCLUDED_CHANNELS is not None: for c in cfg.EXCLUDED_CHANNELS: if type(c) == str: if c not in raw.ch_names: logger.warning( 'Exclusion channel %s does not exist. Ignored.' % c) continue c_int = raw.ch_names.index(c) elif type(c) == int: c_int = c else: logger.error( 'EXCLUDED_CHANNELS has a value of unknown type %s.\nEXCLUDED_CHANNELS=%s' % (type(c), cfg.EXCLUDED_CHANNELS)) raise RuntimeError if c_int in picks: del picks[picks.index(c_int)] if max(picks) > len(raw.ch_names): logger.error( '"picks" has a channel index %d while there are only %d channels.' % (max(picks), len(raw.ch_names))) raise ValueError if hasattr(cfg, 'SP_CHANNELS') and cfg.SP_CHANNELS is not None: logger.warning( 'SP_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.' ) if hasattr(cfg, 'TP_CHANNELS') and cfg.TP_CHANNELS is not None: logger.warning( 'TP_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.' ) if hasattr(cfg, 'NOTCH_CHANNELS') and cfg.NOTCH_CHANNELS is not None: logger.warning( 'NOTCH_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.' ) if 'decim' not in cfg.FEATURES['PSD']: cfg.FEATURES['PSD']['decim'] = 1 logger.warning('PSD["decim"] undefined. Set to 1.') # Read epochs try: # Experimental: multiple epoch ranges if type(cfg.EPOCH[0]) is list: epochs_train = [] for ep in cfg.EPOCH: epoch = Epochs(raw, events, triggers, tmin=ep[0], tmax=ep[1], proj=False, picks=picks, baseline=None, preload=True, verbose=False, detrend=None) epochs_train.append(epoch) else: # Usual method: single epoch range epochs_train = Epochs(raw, events, triggers, tmin=cfg.EPOCH[0], tmax=cfg.EPOCH[1], proj=False, picks=picks, baseline=None, preload=True, verbose=False, detrend=None, on_missing='warning') except: logger.exception('Problem while epoching.') raise RuntimeError label_set = np.unique(triggers.values()) # Compute features if cfg.FEATURES['selected'] == 'PSD': preprocess = dict(sfreq=epochs_train.info['sfreq'], spatial=cfg.SP_FILTER, spatial_ch=None, spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']], spectral_ch=None, notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']], notch_ch=None, multiplier=cfg.MULTIPLIER, ch_names=None, rereference=None, decim=cfg.FEATURES['PSD']['decim'], n_jobs=cfg.N_JOBS) featdata = get_psd_feature(epochs_train, cfg.EPOCH, cfg.FEATURES['PSD'], picks=None, preprocess=preprocess, n_jobs=cfg.N_JOBS) elif cfg.FEATURES == 'TIMELAG': ''' TODO: Implement multiple epochs for timelag feature ''' logger.error( 'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR TIMELAG FEATURE.') raise NotImplementedError elif cfg.FEATURES == 'WAVELET': ''' TODO: Implement multiple epochs for wavelet feature ''' logger.error( 'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR WAVELET FEATURE.') raise NotImplementedError else: logger.error('%s feature type is not supported.' % cfg.FEATURES) raise NotImplementedError featdata['picks'] = picks featdata['sfreq'] = raw.info['sfreq'] featdata['ch_names'] = raw.ch_names return featdata
mne.set_log_level('ERROR') if __name__ == '__main__': rawlist = [] for f in qc.get_file_list(DATA_PATH, fullpath=True): if f[-4:] == '.fif': rawlist.append(f) if len(rawlist) == 0: raise RuntimeError('No fif files found in the path.') # make output directory out_path = DATA_PATH + '/epochs' qc.make_dirs(out_path) # load data raw, events = pu.load_multi(rawlist, multiplier=MULTIPLIER) raw.pick_types(meg=False, eeg=True, stim=False) sfreq = raw.info['sfreq'] if REF_CH_NEW is not None: pu.rereference(raw, REF_CH_NEW, REF_CH_OLD) # pick channels if CHANNEL_PICKS is None: picks = [raw.ch_names.index(c) for c in raw.ch_names if c not in EXCLUDES] elif type(CHANNEL_PICKS[0]) == str: picks = [raw.ch_names.index(c) for c in CHANNEL_PICKS] else: assert type(CHANNEL_PICKS[0]) is int picks = CHANNEL_PICKS # do epoching
def get_tfr(cfg, recursive=False, n_jobs=1): ''' @params: tfr_type: 'multitaper' or 'morlet' recursive: if True, load raw files in sub-dirs recursively export_path: path to save plots n_jobs: number of cores to run in parallel ''' cfg = check_config(cfg) tfr_type = cfg.TFR_TYPE export_path = cfg.EXPORT_PATH t_buffer = cfg.T_BUFFER if tfr_type == 'multitaper': tfr = mne.time_frequency.tfr_multitaper elif tfr_type == 'morlet': tfr = mne.time_frequency.tfr_morlet elif tfr_type == 'butter': butter_order = 4 # TODO: parameterize tfr = lfilter elif tfr_type == 'fir': raise NotImplementedError else: raise ValueError('Wrong TFR type %s' % tfr_type) n_jobs = cfg.N_JOBS if n_jobs is None: n_jobs = mp.cpu_count() if hasattr(cfg, 'DATA_PATHS'): if export_path is None: raise ValueError( 'For multiple directories, cfg.EXPORT_PATH cannot be None') else: outpath = export_path # custom event file if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None: events = mne.read_events(cfg.EVENT_FILE) file_prefix = 'grandavg' # load and merge files from all directories flist = [] for ddir in cfg.DATA_PATHS: ddir = ddir.replace('\\', '/') if ddir[-1] != '/': ddir += '/' for f in qc.get_file_list(ddir, fullpath=True, recursive=recursive): if qc.parse_path(f).ext in ['fif', 'bdf', 'gdf']: flist.append(f) raw, events = pu.load_multi(flist) else: logger.info('Loading %s' % cfg.DATA_FILE) raw, events = pu.load_raw(cfg.DATA_FILE) # custom events if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None: events = mne.read_events(cfg.EVENT_FILE) if export_path is None: [outpath, file_prefix, _] = qc.parse_path_list(cfg.DATA_FILE) else: file_prefix = qc.parse_path(cfg.DATA_FILE).name outpath = export_path file_prefix = qc.parse_path(cfg.DATA_FILE).name # re-referencing if cfg.REREFERENCE is not None: pu.rereference(raw, cfg.REREFERENCE[1], cfg.REREFERENCE[0]) assert cfg.REREFERENCE[0] in raw.ch_names sfreq = raw.info['sfreq'] # set channels of interest picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS) spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS) if max(picks) > len(raw.info['ch_names']): msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\ (max(picks), len(raw.info['ch_names'])) raise RuntimeError(msg) # Apply filters raw = pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER, spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks, multiplier=cfg.MULTIPLIER, n_jobs=n_jobs) # Read epochs classes = {} for t in cfg.TRIGGERS: if t in set(events[:, -1]): if hasattr(cfg, 'tdef'): classes[cfg.tdef.by_value[t]] = t else: classes[str(t)] = t if len(classes) == 0: raise ValueError('No desired event was found from the data.') try: tmin = cfg.EPOCH[0] tmin_buffer = tmin - t_buffer raw_tmax = raw._data.shape[1] / sfreq - 0.1 if cfg.EPOCH[1] is None: if cfg.POWER_AVERAGED: raise ValueError( 'EPOCH value cannot have None for grand averaged TFR') else: if len(cfg.TRIGGERS) > 1: raise ValueError( 'If the end time of EPOCH is None, only a single event can be defined.' ) t_ref = events[np.where( events[:, 2] == list(cfg.TRIGGERS)[0])[0][0], 0] / sfreq tmax = raw_tmax - t_ref - t_buffer else: tmax = cfg.EPOCH[1] tmax_buffer = tmax + t_buffer if tmax_buffer > raw_tmax: raise ValueError( 'Epoch length with buffer (%.3f) is larger than signal length (%.3f)' % (tmax_buffer, raw_tmax)) epochs_all = mne.Epochs(raw, events, classes, tmin=tmin_buffer, tmax=tmax_buffer, proj=False, picks=picks, baseline=None, preload=True) if epochs_all.drop_log_stats() > 0: logger.error( '\n** Bad epochs found. Dropping into a Python shell.') logger.error(epochs_all.drop_log) logger.error('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \ (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq)) logger.error('\nType exit to continue.\n') pdb.set_trace() except: logger.critical( '\n*** (tfr_export) Unknown error occurred while epoching ***') logger.critical('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \ (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq)) pdb.set_trace() power = {} for evname in classes: export_dir = outpath qc.make_dirs(export_dir) logger.info('>> Processing %s' % evname) freqs = cfg.FREQ_RANGE # define frequencies of interest n_cycles = freqs / 2. # different number of cycle per frequency if cfg.POWER_AVERAGED: # grand-average TFR epochs = epochs_all[evname][:] if len(epochs) == 0: logger.WARNING('No %s epochs. Skipping.' % evname) continue if tfr_type == 'butter': b, a = butter_bandpass(cfg.FREQ_RANGE[0], cfg.FREQ_RANGE[-1], sfreq, order=butter_order) tfr_filtered = lfilter(b, a, epochs, axis=2) tfr_hilbert = hilbert(tfr_filtered) tfr_power = abs(tfr_hilbert) tfr_data = np.mean(tfr_power, axis=0) elif tfr_type == 'fir': raise NotImplementedError else: power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False, return_itc=False, decim=1, n_jobs=n_jobs) power[evname] = power[evname].crop(tmin=tmin, tmax=tmax) tfr_data = power[evname].data if cfg.EXPORT_MATLAB is True: # export all channels to MATLAB mout = '%s/%s-%s-%s.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname) scipy.io.savemat( mout, { 'tfr': tfr_data, 'chs': epochs.ch_names, 'events': events, 'sfreq': sfreq, 'tmin': tmin, 'tmax': tmax, 'epochs': cfg.EPOCH, 'freqs': cfg.FREQ_RANGE }) logger.info('Exported %s' % mout) if cfg.EXPORT_PNG is True: # Inspect power for each channel for ch in np.arange(len(picks)): chname = raw.ch_names[picks[ch]] title = 'Peri-event %s - Channel %s' % (evname, chname) # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False, colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False) fout = '%s/%s-%s-%s-%s.png' % ( export_dir, file_prefix, cfg.SP_FILTER, evname, chname) fig.savefig(fout) plt.close() logger.info('Exported to %s' % fout) else: # TFR per event for ep in range(len(epochs_all[evname])): epochs = epochs_all[evname][ep] if len(epochs) == 0: logger.WARNING('No %s epochs. Skipping.' % evname) continue power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False, return_itc=False, decim=1, n_jobs=n_jobs) power[evname] = power[evname].crop(tmin=tmin, tmax=tmax) if cfg.EXPORT_MATLAB is True: # export all channels to MATLAB mout = '%s/%s-%s-%s-ep%02d.mat' % ( export_dir, file_prefix, cfg.SP_FILTER, evname, ep + 1) scipy.io.savemat( mout, { 'tfr': power[evname].data, 'chs': power[evname].ch_names, 'events': events, 'sfreq': sfreq, 'tmin': tmin, 'tmax': tmax, 'epochs': cfg.EPOCH, 'freqs': cfg.FREQ_RANGE }) logger.info('Exported %s' % mout) if cfg.EXPORT_PNG is True: # Inspect power for each channel for ch in np.arange(len(picks)): chname = raw.ch_names[picks[ch]] title = 'Peri-event %s - Channel %s, Trial %d' % ( evname, chname, ep + 1) # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False, colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False) fout = '%s/%s-%s-%s-%s-ep%02d.png' % ( export_dir, file_prefix, cfg.SP_FILTER, evname, chname, ep + 1) fig.savefig(fout) plt.close() logger.info('Exported %s' % fout) if hasattr(cfg, 'POWER_DIFF'): export_dir = '%s/diff' % outpath qc.make_dirs(export_dir) labels = classes.keys() df = power[labels[0]] - power[labels[1]] df.data = np.log(np.abs(df.data)) # Inspect power diff for each channel for ch in np.arange(len(picks)): chname = raw.ch_names[picks[ch]] title = 'Peri-event %s-%s - Channel %s' % (labels[0], labels[1], chname) # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' fig = df.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False, colorbar=True, title=title, vmin=3.0, vmax=-3.0, dB=False) fout = '%s/%s-%s-diff-%s-%s-%s.jpg' % (export_dir, file_prefix, cfg.SP_FILTER, labels[0], labels[1], chname) logger.info('Exporting to %s' % fout) fig.savefig(fout) plt.close() logger.info('Finished !')