def test_io(): """Test IO functionality.""" event_id = None tmin, tmax = -0.2, 0.5 events = mne.find_events(raw) savedir = _TempDir() fname = op.join(savedir, 'autoreject.hdf5') include = [u'EEG %03d' % i for i in range(1, 45, 3)] picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False, eog=True, include=include, exclude=[]) # raise error if preload is false epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), decim=4, reject=None, preload=True)[:10] ar = AutoReject(cv=2, random_state=42, n_interpolate=[1], consensus=[0.5], verbose=False) ar.save(fname) # save without fitting # check that fit after saving is the same as fit # without saving ar2 = read_auto_reject(fname) ar.fit(epochs) ar2.fit(epochs) assert np.sum([ar.threshes_[k] - ar2.threshes_[k] for k in ar.threshes_.keys()]) == 0. pytest.raises(ValueError, ar.save, fname) ar.save(fname, overwrite=True) ar3 = read_auto_reject(fname) epochs_clean1, reject_log1 = ar.transform(epochs, return_log=True) epochs_clean2, reject_log2 = ar3.transform(epochs, return_log=True) assert_array_equal(epochs_clean1.get_data(), epochs_clean2.get_data()) assert_array_equal(reject_log1.labels, reject_log2.labels)
def autore(epo_eeg_cust): """ This function is used for artifact correction/rejection ---------- epo_eeg_cust: MNE.Epochs Epochs data Returns ------- clean: MNE.Epochs Artifact-free epochs data """ ar = AutoReject(n_jobs=4) ar.fit(epo_eeg_cust) epo_ar, reject_log = ar.transform(epo_eeg_cust, return_log=True) clean = epo_ar.copy() # Used for plotting #scalings = dict(eeg=50) # reject_log.plot_epochs(epo_eeg_cust, scalings=scalings) # epo_ar.average().plot() return clean
def main(): ################################################# ## SETUP ## Get list of subject files subj_files = listdir(DAT_PATH) subj_files = [file for file in subj_files if EXT.lower() in file.lower()] ## Set up FOOOF Objects # Initialize FOOOF settings & objects objects fooof_settings = FOOOFSettings(peak_width_limits=PEAK_WIDTH_LIMITS, max_n_peaks=MAX_N_PEAKS, min_peak_amplitude=MIN_PEAK_AMP, peak_threshold=PEAK_THRESHOLD, aperiodic_mode=APERIODIC_MODE) fm = FOOOF(*fooof_settings, verbose=False) fg = FOOOFGroup(*fooof_settings, verbose=False) # Save out a settings file fg.save('0-FOOOF_Settings', pjoin(RES_PATH, 'FOOOF'), save_settings=True) # Set up the dictionary to store all the FOOOF results fg_dict = dict() for load_label in LOAD_LABELS: fg_dict[load_label] = dict() for side_label in SIDE_LABELS: fg_dict[load_label][side_label] = dict() for seg_label in SEG_LABELS: fg_dict[load_label][side_label][seg_label] = [] ## Initialize group level data stores n_subjs, n_conds, n_times = len(subj_files), 3, N_TIMES group_fooofed_alpha_freqs = np.zeros(shape=[n_subjs]) dropped_components = np.ones(shape=[n_subjs, 50]) * 999 dropped_trials = np.ones(shape=[n_subjs, 1500]) * 999 canonical_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times]) fooofed_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times]) # Set channel types ch_types = {'LHor' : 'eog', 'RHor' : 'eog', 'IVer' : 'eog', 'SVer' : 'eog', 'LMas' : 'misc', 'RMas' : 'misc', 'Nose' : 'misc', 'EXG8' : 'misc'} ################################################# ## RUN ACROSS ALL SUBJECTS # Run analysis across each subject for s_ind, subj_file in enumerate(subj_files): # Get subject label and print status subj_label = subj_file.split('.')[0] print('\nCURRENTLY RUNNING SUBJECT: ', subj_label, '\n') ################################################# ## LOAD / ORGANIZE / SET-UP DATA # Load subject of data, apply apply fixes for channels, etc eeg_dat = mne.io.read_raw_edf(pjoin(DAT_PATH, subj_file), preload=True, verbose=False) # Fix channel name labels eeg_dat.info['ch_names'] = [chl[2:] for chl in \ eeg_dat.ch_names[:-1]] + [eeg_dat.ch_names[-1]] for ind, chi in enumerate(eeg_dat.info['chs']): eeg_dat.info['chs'][ind]['ch_name'] = eeg_dat.info['ch_names'][ind] # Update channel types eeg_dat.set_channel_types(ch_types) # Set reference - average reference eeg_dat = eeg_dat.set_eeg_reference(ref_channels='average', projection=False, verbose=False) # Set channel montage chs = mne.channels.read_montage('standard_1020', eeg_dat.ch_names) eeg_dat.set_montage(chs) # Get event information & check all used event codes evs = mne.find_events(eeg_dat, shortest_event=1, verbose=False) # Pull out sampling rate srate = eeg_dat.info['sfreq'] ################################################# ## Pre-Processing: ICA # High-pass filter data for running ICA eeg_dat.filter(l_freq=1., h_freq=None, fir_design='firwin') if RUN_ICA: print("\nICA: CALCULATING SOLUTION\n") # ICA settings method = 'fastica' n_components = 0.99 random_state = 47 reject = {'eeg': 20e-4} # Initialize ICA object ica = ICA(n_components=n_components, method=method, random_state=random_state) # Fit ICA ica.fit(eeg_dat, reject=reject) # Save out ICA solution ica.save(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif')) # Otherwise: load previously saved ICA to apply else: print("\nICA: USING PRECOMPUTED\n") ica = read_ica(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif')) # Find components to drop, based on correlation with EOG channels drop_inds = [] for chi in EOG_CHS: inds, _ = ica.find_bads_eog(eeg_dat, ch_name=chi, threshold=2.5, l_freq=1, h_freq=10, verbose=False) drop_inds.extend(inds) drop_inds = list(set(drop_inds)) # Set which components to drop, and collect record of this ica.exclude = drop_inds dropped_components[s_ind, 0:len(drop_inds)] = drop_inds # Apply ICA to data eeg_dat = ica.apply(eeg_dat) ################################################# ## SORT OUT EVENT CODES # Extract a list of all the event labels all_trials = [it for it2 in EV_DICT.values() for it in it2] # Create list of new event codes to be used to label correct trials (300s) all_trials_new = [it + 100 for it in all_trials] # This is an annoying way to collapse across the doubled event markers from above all_trials_new = [it - 1 if not ind%2 == 0 else it for ind, it in enumerate(all_trials_new)] # Get labelled dictionary of new event names ev_dict2 = {k:v for k, v in zip(EV_DICT.keys(), set(all_trials_new))} # Initialize variables to store new event definitions evs2 = np.empty(shape=[0, 3], dtype='int64') lags = np.array([]) # Loop through, creating new events for all correct trials t_min, t_max = -0.4, 3.0 for ref_id, targ_id, new_id in zip(all_trials, CORR_CODES * 6, all_trials_new): t_evs, t_lags = mne.event.define_target_events(evs, ref_id, targ_id, srate, t_min, t_max, new_id) if len(t_evs) > 0: evs2 = np.vstack([evs2, t_evs]) lags = np.concatenate([lags, t_lags]) ################################################# ## FOOOF # Set channel of interest ch_ind = eeg_dat.ch_names.index(CHL) # Calculate PSDs over ~ first 2 minutes of data, for specified channel fmin, fmax = 1, 50 tmin, tmax = 5, 125 psds, freqs = mne.time_frequency.psd_welch(eeg_dat, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, n_fft=int(2*srate), n_overlap=int(srate), n_per_seg=int(2*srate), verbose=False) # Fit FOOOF across all channels fg.fit(freqs, psds, FREQ_RANGE, n_jobs=-1) # Save out FOOOF results fg.save(subj_label + '_fooof', pjoin(RES_PATH, 'FOOOF'), save_results=True) # Extract individualized CF from specified channel, add to group collection fm = fg.get_fooof(ch_ind, False) fooof_freq, _, _ = get_band_peak(fm.peak_params_, [7, 14]) group_fooofed_alpha_freqs[s_ind] = fooof_freq # If not FOOOF alpha extracted, reset to 10 if np.isnan(fooof_freq): fooof_freq = 10 ################################################# ## ALPHA FILTERING # CANONICAL: Filter data to canonical alpha band: 8-12 Hz alpha_dat = eeg_dat.copy() alpha_dat.filter(8, 12, fir_design='firwin', verbose=False) alpha_dat.apply_hilbert(envelope=True, verbose=False) # FOOOF: Filter data to FOOOF derived alpha band fooof_dat = eeg_dat.copy() fooof_dat.filter(fooof_freq-2, fooof_freq+2, fir_design='firwin') fooof_dat.apply_hilbert(envelope=True) ################################################# ## EPOCH TRIALS # Set epoch timings tmin, tmax = -0.85, 1.1 # Epoch trials - raw data for trial rejection epochs = mne.Epochs(eeg_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose=False) # Epoch trials - filtered version epochs_alpha = mne.Epochs(alpha_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax, baseline=(-0.5, -0.35), preload=True, verbose=False) epochs_fooof = mne.Epochs(fooof_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax, baseline=(-0.5, -0.35), preload=True, verbose=False) ################################################# ## PRE-PROCESSING: AUTO-REJECT if RUN_AUTOREJECT: print('\nAUTOREJECT: CALCULATING SOLUTION\n') # Initialize and run autoreject across epochs ar = AutoReject(n_jobs=4, verbose=False) ar.fit(epochs) # Save out AR solution ar.save(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'), overwrite=True) # Otherwise: load & apply previously saved AR solution else: print('\nAUTOREJECT: USING PRECOMPUTED\n') ar = read_auto_reject(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5')) ar.verbose = 'tqdm' # Apply autoreject to the original epochs object it was learnt on epochs, rej_log = ar.transform(epochs, return_log=True) # Apply autoreject to the copies of the data - apply interpolation, then drop same epochs _apply_interp(rej_log, epochs_alpha, ar.threshes_, ar.picks_, ar.verbose) epochs_alpha.drop(rej_log.bad_epochs) _apply_interp(rej_log, epochs_fooof, ar.threshes_, ar.picks_, ar.verbose) epochs_fooof.drop(rej_log.bad_epochs) # Collect which epochs were dropped dropped_trials[s_ind, 0:sum(rej_log.bad_epochs)] = np.where(rej_log.bad_epochs)[0] ################################################# ## SET UP CHANNEL CLUSTERS # Set channel clusters - take channels contralateral to stimulus presentation # Note: channels will be used to extract data contralateral to stimulus presentation le_chs = ['P3', 'P5', 'P7', 'P9', 'O1', 'PO3', 'PO7'] # Left Side Channels le_inds = [epochs.ch_names.index(chn) for chn in le_chs] ri_chs = ['P4', 'P6', 'P8', 'P10', 'O2', 'PO4', 'PO8'] # Right Side Channels ri_inds = [epochs.ch_names.index(chn) for chn in ri_chs] ################################################# ## TRIAL-RELATED ANALYSIS: CANONICAL vs. FOOOF ## Pull out channels of interest for each load level # Channels extracted are those contralateral to stimulus presentation # Canonical Data lo1_a = np.concatenate([epochs_alpha['LeLo1']._data[:, ri_inds, :], epochs_alpha['RiLo1']._data[:, le_inds, :]], 0) lo2_a = np.concatenate([epochs_alpha['LeLo2']._data[:, ri_inds, :], epochs_alpha['RiLo2']._data[:, le_inds, :]], 0) lo3_a = np.concatenate([epochs_alpha['LeLo3']._data[:, ri_inds, :], epochs_alpha['RiLo3']._data[:, le_inds, :]], 0) # FOOOFed data lo1_f = np.concatenate([epochs_fooof['LeLo1']._data[:, ri_inds, :], epochs_fooof['RiLo1']._data[:, le_inds, :]], 0) lo2_f = np.concatenate([epochs_fooof['LeLo2']._data[:, ri_inds, :], epochs_fooof['RiLo2']._data[:, le_inds, :]], 0) lo3_f = np.concatenate([epochs_fooof['LeLo3']._data[:, ri_inds, :], epochs_fooof['RiLo3']._data[:, le_inds, :]], 0) ## Calculate average across trials and channels - add to group data collection # Canonical data canonical_group_avg_dat[s_ind, 0, :] = np.mean(lo1_a, 1).mean(0) canonical_group_avg_dat[s_ind, 1, :] = np.mean(lo2_a, 1).mean(0) canonical_group_avg_dat[s_ind, 2, :] = np.mean(lo3_a, 1).mean(0) # FOOOFed data fooofed_group_avg_dat[s_ind, 0, :] = np.mean(lo1_f, 1).mean(0) fooofed_group_avg_dat[s_ind, 1, :] = np.mean(lo2_f, 1).mean(0) fooofed_group_avg_dat[s_ind, 2, :] = np.mean(lo3_f, 1).mean(0) ################################################# ## FOOOFING TRIAL AVERAGED DATA # Loop loop loads & trials segments for seg_label, seg_time in zip(SEG_LABELS, SEG_TIMES): tmin, tmax = seg_time[0], seg_time[1] # Calculate PSDs across trials, fit FOOOF models to averages for le_label, ri_label, load_label in zip(['LeLo1', 'LeLo2', 'LeLo3'], ['RiLo1', 'RiLo2', 'RiLo3'], LOAD_LABELS): ## Calculate trial wise PSDs for left & right side trials trial_freqs, le_trial_psds = periodogram( epochs[le_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)], srate, window='hann', nfft=4*srate) trial_freqs, ri_trial_psds = periodogram( epochs[ri_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)], srate, window='hann', nfft=4*srate) ## FIT ALL CHANNELS VERSION if FIT_ALL_CHANNELS: ## Average spectra across trials within a given load & side le_avg_psd_contra = avg_func(le_trial_psds[:, ri_inds, :], 0) le_avg_psd_ipsi = avg_func(le_trial_psds[:, le_inds, :], 0) ri_avg_psd_contra = avg_func(ri_trial_psds[:, le_inds, :], 0) ri_avg_psd_ipsi = avg_func(ri_trial_psds[:, ri_inds, :], 0) ## Combine spectra across left & right trials for given load ch_psd_contra = np.vstack([le_avg_psd_contra, ri_avg_psd_contra]) ch_psd_ipsi = np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi]) ## Fit FOOOFGroup to all channels, average & and collect results fg.fit(trial_freqs, ch_psd_contra, FREQ_RANGE) fm = avg_fg(fg) fg_dict[load_label]['Contra'][seg_label].append(fm.copy()) fg.fit(trial_freqs, ch_psd_ipsi, FREQ_RANGE) fm = avg_fg(fg) fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy()) ## COLLAPSE ACROSS CHANNELS VERSION else: ## Average spectra across trials and channels within a given load & side le_avg_psd_contra = avg_func(avg_func(le_trial_psds[:, ri_inds, :], 0), 0) le_avg_psd_ipsi = avg_func(avg_func(le_trial_psds[:, le_inds, :], 0), 0) ri_avg_psd_contra = avg_func(avg_func(ri_trial_psds[:, le_inds, :], 0), 0) ri_avg_psd_ipsi = avg_func(avg_func(ri_trial_psds[:, ri_inds, :], 0), 0) ## Collapse spectra across left & right trials for given load avg_psd_contra = avg_func(np.vstack([le_avg_psd_contra, ri_avg_psd_contra]), 0) avg_psd_ipsi = avg_func(np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi]), 0) ## Fit FOOOF, and collect results fm.fit(trial_freqs, avg_psd_contra, FREQ_RANGE) fg_dict[load_label]['Contra'][seg_label].append(fm.copy()) fm.fit(trial_freqs, avg_psd_ipsi, FREQ_RANGE) fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy()) ################################################# ## SAVE OUT RESULTS # Save out group data np.save(pjoin(RES_PATH, 'Group', 'alpha_freqs_group'), group_fooofed_alpha_freqs) np.save(pjoin(RES_PATH, 'Group', 'canonical_group'), canonical_group_avg_dat) np.save(pjoin(RES_PATH, 'Group', 'fooofed_group'), fooofed_group_avg_dat) np.save(pjoin(RES_PATH, 'Group', 'dropped_trials'), dropped_trials) np.save(pjoin(RES_PATH, 'Group', 'dropped_components'), dropped_components) # Save out second round of FOOOFing for load_label in LOAD_LABELS: for side_label in SIDE_LABELS: for seg_label in SEG_LABELS: fg = combine_fooofs(fg_dict[load_label][side_label][seg_label]) fg.save('Group_' + load_label + '_' + side_label + '_' + seg_label, pjoin(RES_PATH, 'FOOOF'), save_results=True)
def run_autoreject(subject): """Interpolate bad epochs/sensors using Autoreject. Parameters ---------- *subject: string The participant reference Save the resulting *-epo.fif file in the '4_autoreject' directory. Save .png of ERP difference and heatmap plots. References ---------- [1] Mainak Jas, Denis Engemann, Federico Raimondo, Yousra Bekhti, and Alexandre Gramfort, “Automated rejection and repair of bad trials in MEG/EEG.” In 6th International Workshop on Pattern Recognition in Neuroimaging (PRNI), 2016. [2] Mainak Jas, Denis Engemann, Yousra Bekhti, Federico Raimondo, and Alexandre Gramfort. 2017. “Autoreject: Automated artifact rejection for MEG and EEG data”. NeuroImage, 159, 417-429. """ # Import data input_path = root + '/4_ICA/' + subject + '-epo.fif' epochs = mne.read_epochs(input_path) # Autoreject ar = AutoReject(random_state=42, n_jobs=4) ar.fit_transform(epochs) epochs_clean = ar.transform(epochs) # Plot difference evoked = epochs.average() evoked_clean = epochs_clean.average() fig, axes = plt.subplots(2, 1, figsize=(6, 6)) for ax in axes: ax.tick_params(axis='x', which='both', bottom='off', top='off') ax.tick_params(axis='y', which='both', left='off', right='off') evoked.plot(exclude=[], axes=axes[0], ylim=[-30, 30], show=False) axes[0].set_title('Before autoreject') evoked_clean.plot(exclude=[], axes=axes[1], ylim=[-30, 30]) axes[1].set_title('After autoreject') plt.tight_layout() plt.savefig(root + '/5_autoreject/' + subject + '-autoreject.png') plt.close() # Plot heatmap ar.get_reject_log(epochs).plot() plt.savefig(root + '/5_autoreject/' + subject + '-heatmap.png') plt.close() # Save epoch data out_epoch = root + '/5_autoreject/' + subject + '-epo.fif' epochs_clean.save(out_epoch)
def AR_local(cleaned_epochs_ICA, verbose=False): """ Applies local Autoreject to correct or reject bad epochs. Arguments: clean_epochs_ICA: list of Epochs after global Autoreject and ICA verbose: to plot data before and after AR, boolean set to False by default. Returns: cleaned_epochs_AR: list of Epochs after local Autoreject. """ bad_epochs_AR = [] # defaults values for n_interpolates and consensus_percs n_interpolates = np.array([1, 4, 32]) consensus_percs = np.linspace(0, 1.0, 11) for clean_epochs in cleaned_epochs_ICA: # per subj picks = mne.pick_types(clean_epochs[0].info, meg=False, eeg=True, stim=False, eog=False, exclude=[]) if verbose: ar_verbose = 'progressbar' else: ar_verbose = False ar = AutoReject(n_interpolates, consensus_percs, picks=picks, thresh_method='random_search', random_state=42, verbose=ar_verbose) # fitting AR to get bad epochs ar.fit(clean_epochs) reject_log = ar.get_reject_log(clean_epochs, picks=picks) bad_epochs_AR.append(reject_log) # taking bad epochs for min 1 subj (dyad) log1 = bad_epochs_AR[0] log2 = bad_epochs_AR[1] bad1 = np.where(log1.bad_epochs == True) bad2 = np.where(log2.bad_epochs == True) bad = list(set(bad1[0].tolist()).intersection(bad2[0].tolist())) if verbose: print('%s percent of bad epochs' % int(len(bad) / len(list(log1.bad_epochs)) * 100)) # picking good epochs for the two subj cleaned_epochs_AR = [] for clean_epochs in cleaned_epochs_ICA: # per subj clean_epochs_ep = clean_epochs.drop(indices=bad) # interpolating bads or removing epochs clean_epochs_AR = ar.transform(clean_epochs_ep) cleaned_epochs_AR.append(clean_epochs_AR) # equalizing epochs length between two subjects mne.epochs.equalize_epoch_counts(cleaned_epochs_AR) # Vizualisation before after AR evoked_before = [] for clean_epochs in cleaned_epochs_ICA: # per subj evoked_before.append(clean_epochs.average()) evoked_after_AR = [] for clean in cleaned_epochs_AR: evoked_after_AR.append(clean.average()) if verbose: for i, j in zip(evoked_before, evoked_after_AR): fig, axes = plt.subplots(2, 1, figsize=(6, 6)) for ax in axes: ax.tick_params(axis='x', which='both', bottom='off', top='off') ax.tick_params(axis='y', which='both', left='off', right='off') ylim = dict(grad=(-170, 200)) i.pick_types(eeg=True, exclude=[]) i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False) axes[0].set_title('Before autoreject') j.pick_types(eeg=True, exclude=[]) j.plot(exclude=[], axes=axes[1], ylim=ylim) # Problème titre ne s'affiche pas pour le deuxieme axe !!! axes[1].set_title('After autoreject') plt.tight_layout() return cleaned_epochs_AR
detrend=0, preload=True) ar = AutoReject(n_interpolates, consensus_percs, picks=picks, thresh_method='random_search', random_state=42) # Note that fitting and transforming can be done on different compatible # portions of data if needed. ar.fit(epochs) # epochs_ar, reject_log = ar.fit_transform(epochs, return_log=True) epochs_clean = ar.transform(epochs) reject_log = ar.get_reject_log(epochs) evoked_clean = epochs_clean.average() evoked = epochs.average() # visualize rejected epochs scalings = dict(eeg=40e-6) reject_log.plot_epochs(epochs, scalings=scalings) # epochs after cleaning epochs_clean.plot(scalings=scalings) # notify when done os.system('say "... I am ready for you Neil."') # Visualize repaired data
epochs.load_data() epochs = epochs.drop(epochs_2_drop, reason="bad behaviour") epochs.save(op.join(sub_path, "clean-" + epo.split(sep)[-1]), overwrite=True) print("AMOUNT OF EPOCHS AFTER MATCHING WITH BEH:", len(epochs)) print("DOES IT MATCH?", len(beh_ixs) == len(epochs)) print("\n") if len(beh_ixs) == len(epochs): ar = AutoReject(consensus=np.linspace(0, 1.0, 27), n_interpolate=np.array([1, 4, 32]), thresh_method="bayesian_optimization", cv=10, n_jobs=-1, random_state=42, verbose="progressbar") ar.fit(epochs) epo_type = epo.split(sep)[-1].split("-")[3] name = "{}-{}-{}".format(subject_id, numero, epo_type) ar_fname = op.join(qc_folder, "{}-autoreject.h5".format(name)) ar.save(ar_fname, overwrite=True) epochs_ar, rej_log = ar.transform(epochs, return_log=True) rej_log.plot(show=False) plt.savefig(op.join(qc_folder, "{}-autoreject-log.png".format(name))) plt.close("all") epo.split(sep)[-1] cleaned = op.join(sub_path, "autoreject-" + epo.split(sep)[-1]) epochs.save(op.join(sub_path, "autoreject-" + epo.split(sep)[-1]), overwrite=True) print("CLEANED EPOCHS SAVED:", cleaned)
def test_autoreject(): """Test basic _AutoReject functionality.""" event_id = None tmin, tmax = -0.2, 0.5 events = mne.find_events(raw) ########################################################################## # picking epochs include = [u'EEG %03d' % i for i in range(1, 45, 3)] picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False, eog=True, include=include, exclude=[]) epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), decim=10, reject=None, preload=False)[:10] ar = _AutoReject() assert_raises(ValueError, ar.fit, epochs) epochs.load_data() ar.fit(epochs) assert_true(len(ar.picks_) == len(picks) - 1) # epochs with no picks. epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), decim=10, reject=None, preload=True)[:20] # let's drop some channels to speed up pre_picks = mne.pick_types(epochs.info, meg=True, eeg=True) pre_picks = np.r_[ mne.pick_types(epochs.info, meg='mag', eeg=False)[::15], mne.pick_types(epochs.info, meg='grad', eeg=False)[::60], mne.pick_types(epochs.info, meg=False, eeg=True)[::16], mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)] pick_ch_names = [epochs.ch_names[pp] for pp in pre_picks] bad_ch_names = [epochs.ch_names[ix] for ix in range(len(epochs.ch_names)) if ix not in pre_picks] epochs_with_bads = epochs.copy() epochs_with_bads.info['bads'] = bad_ch_names epochs.pick_channels(pick_ch_names) epochs_fit = epochs[:12] # make sure to use different size of epochs epochs_new = epochs[12:] epochs_with_bads_fit = epochs_with_bads[:12] X = epochs_fit.get_data() n_epochs, n_channels, n_times = X.shape X = X.reshape(n_epochs, -1) ar = _GlobalAutoReject() assert_raises(ValueError, ar.fit, X) ar = _GlobalAutoReject(n_channels=n_channels) assert_raises(ValueError, ar.fit, X) ar = _GlobalAutoReject(n_times=n_times) assert_raises(ValueError, ar.fit, X) ar_global = _GlobalAutoReject( n_channels=n_channels, n_times=n_times, thresh=40e-6) ar_global.fit(X) param_name = 'thresh' param_range = np.linspace(40e-6, 200e-6, 10) assert_raises(ValueError, validation_curve, X, None, param_name, param_range) ########################################################################## # picking AutoReject picks = mne.pick_types( epochs.info, meg='mag', eeg=True, stim=False, eog=False, include=[], exclude=[]) non_picks = mne.pick_types( epochs.info, meg='grad', eeg=False, stim=False, eog=False, include=[], exclude=[]) ch_types = ['mag', 'eeg'] ar = _AutoReject(picks=picks) # XXX : why do we need this?? ar = AutoReject(cv=3, picks=picks, random_state=42, n_interpolate=[1, 2], consensus=[0.5, 1]) assert_raises(AttributeError, ar.fit, X) assert_raises(ValueError, ar.transform, X) assert_raises(ValueError, ar.transform, epochs) ar.fit(epochs_fit) reject_log = ar.get_reject_log(epochs_fit) for ch_type in ch_types: # test that kappa & rho are selected assert_true( ar.n_interpolate_[ch_type] in ar.n_interpolate) assert_true( ar.consensus_[ch_type] in ar.consensus) assert_true( ar.n_interpolate_[ch_type] == ar.local_reject_[ch_type].n_interpolate_[ch_type]) assert_true( ar.consensus_[ch_type] == ar.local_reject_[ch_type].consensus_[ch_type]) # test complementarity of goods and bads assert_array_equal(len(reject_log.bad_epochs), len(epochs_fit)) # test that transform does not change state of ar epochs_clean = ar.transform(epochs_fit) # apply same data assert_true(repr(ar)) assert_true(repr(ar.local_reject_)) reject_log2 = ar.get_reject_log(epochs_fit) assert_array_equal(reject_log.labels, reject_log2.labels) assert_array_equal(reject_log.bad_epochs, reject_log2.bad_epochs) assert_array_equal(reject_log.ch_names, reject_log2.ch_names) epochs_new_clean = ar.transform(epochs_new) # apply to new data reject_log_new = ar.get_reject_log(epochs_new) assert_array_equal(len(reject_log_new.bad_epochs), len(epochs_new)) assert_true( len(reject_log_new.bad_epochs) != len(reject_log.bad_epochs)) picks_by_type = _get_picks_by_type(epochs.info, ar.picks) # test correct entries in fix log assert_true( np.isnan(reject_log_new.labels[:, non_picks]).sum() > 0) assert_true( np.isnan(reject_log_new.labels[:, picks]).sum() == 0) assert_equal(reject_log_new.labels.shape, (len(epochs_new), len(epochs_new.ch_names))) # test correct interpolations by type for ch_type, this_picks in picks_by_type: interp_counts = np.sum( reject_log_new.labels[:, this_picks] == 2, axis=1) labels = reject_log_new.labels.copy() not_this_picks = np.setdiff1d(np.arange(labels.shape[1]), this_picks) labels[:, not_this_picks] = np.nan interp_channels = _get_interp_chs( labels, reject_log.ch_names, this_picks) assert_array_equal( interp_counts, [len(cc) for cc in interp_channels]) is_same = epochs_new_clean.get_data() == epochs_new.get_data() if not np.isscalar(is_same): is_same = np.isscalar(is_same) assert_true(not is_same) # test that transform ignores bad channels epochs_with_bads_fit.pick_types(meg='mag', eeg=True, eog=True, exclude=[]) ar_bads = AutoReject(cv=3, random_state=42, n_interpolate=[1, 2], consensus=[0.5, 1]) ar_bads.fit(epochs_with_bads_fit) epochs_with_bads_clean = ar_bads.transform(epochs_with_bads_fit) good_w_bads_ix = mne.pick_types(epochs_with_bads_clean.info, meg='mag', eeg=True, eog=True, exclude='bads') good_wo_bads_ix = mne.pick_types(epochs_clean.info, meg='mag', eeg=True, eog=True, exclude='bads') assert_array_equal(epochs_with_bads_clean.get_data()[:, good_w_bads_ix, :], epochs_clean.get_data()[:, good_wo_bads_ix, :]) bad_ix = [epochs_with_bads_clean.ch_names.index(ch) for ch in epochs_with_bads_clean.info['bads']] epo_ix = ~ar_bads.get_reject_log(epochs_with_bads_fit).bad_epochs assert_array_equal( epochs_with_bads_clean.get_data()[:, bad_ix, :], epochs_with_bads_fit.get_data()[epo_ix, :, :][:, bad_ix, :]) assert_equal(epochs_clean.ch_names, epochs_fit.ch_names) assert_true(isinstance(ar.threshes_, dict)) assert_true(len(ar.picks) == len(picks)) assert_true(len(ar.threshes_.keys()) == len(ar.picks)) pick_eog = mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)[0] assert_true(epochs.ch_names[pick_eog] not in ar.threshes_.keys()) assert_raises( IndexError, ar.transform, epochs.copy().pick_channels( [epochs.ch_names[pp] for pp in picks[:3]])) epochs.load_data() assert_raises(ValueError, compute_thresholds, epochs, 'dfdfdf') index, ch_names = zip(*[(ii, epochs_fit.ch_names[pp]) for ii, pp in enumerate(picks)]) threshes_a = compute_thresholds( epochs_fit, picks=picks, method='random_search') assert_equal(set(threshes_a.keys()), set(ch_names)) threshes_b = compute_thresholds( epochs_fit, picks=picks, method='bayesian_optimization') assert_equal(set(threshes_b.keys()), set(ch_names))
picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, include=[], exclude=[]) # Make epochs from the raw data epochs = mne.Epochs(raw, picks=picks, events=events, event_id=event_id, tmin=tmin, tmax=tmax, preload=True, reject=None) # Setup AutoReject ar = AutoReject(n_interpolates, consensus_percs, thresh_method='random_search', random_state=seed) # Fit, i.e. calculate AutoReject ar.fit(epochs) epochs_clean = ar.transform(epochs) # Clean the epochs epochs_clean.save(data_path + '%s-epo.fif' % subject) # Save the epochs
# %% # Note that :class:`autoreject.AutoReject` by design supports # multiple channels. # If no picks are passed, separate solutions will be computed for each channel # type and internally combined. This then readily supports cleaning # unseen epochs from the different channel types used during fit. # Here we only use a subset of channels to save time. ar = AutoReject(n_interpolates, consensus_percs, picks=picks, thresh_method='random_search', random_state=42) # Note that fitting and transforming can be done on different compatible # portions of data if needed. ar.fit(epochs['Auditory/Left']) epochs_clean = ar.transform(epochs['Auditory/Left']) evoked_clean = epochs_clean.average() evoked = epochs['Auditory/Left'].average() # %% # Now, we will manually mark the bad channels just for plotting. evoked.info['bads'] = ['MEG 2443'] evoked_clean.info['bads'] = ['MEG 2443'] # %% # Let us plot the results. import matplotlib.pyplot as plt # noqa set_matplotlib_defaults(plt)
def run_epochs(subject, autoreject=True): raw_fname = op.join(meg_dir, subject, f'{subject}_audvis-filt_raw_sss.fif') annot_fname = op.join(meg_dir, subject, f'{subject}_audvis-annot.fif') raw = mne.io.read_raw_fif(raw_fname, preload=False) annot = mne.read_annotations(annot_fname) raw.set_annotations(annot) if autoreject: epo_fname = op.join(meg_dir, subject, f'{subject}_audvis-filt-sss-ar-epo.fif') else: epo_fname = op.join(meg_dir, subject, f'{subject}_audvis-filt-sss-epo.fif') # ICA ica_fname = op.join(meg_dir, subject, f'{subject}_audvis-ica.fif') ica = mne.preprocessing.read_ica(ica_fname) # ICA ica = mne.preprocessing.read_ica(ica_fname) try: # ECG ecg_epochs = mne.preprocessing.create_ecg_epochs(raw, l_freq=10, h_freq=20, baseline=(None, None), preload=True) ecg_inds, scores_ecg = ica.find_bads_ecg(ecg_epochs, method='ctps', threshold='auto', verbose='INFO') except ValueError: # not found pass else: print(f'Found {len(ecg_inds)} ({ecg_inds}) ECG indices for {subject}') if len(ecg_inds) != 0: ica.exclude.extend(ecg_inds[:n_max_ecg]) # for future inspection ecg_epochs.average().save( op.join(meg_dir, subject, f'{subject}_audvis-ecg-ave.fif')) # release memory del ecg_epochs, ecg_inds, scores_ecg try: # EOG eog_epochs = mne.preprocessing.create_eog_epochs(raw, baseline=(None, None), preload=True) eog_inds, scores_eog = ica.find_bads_eog(eog_epochs) except ValueError: # not found pass else: print(f'Found {len(eog_inds)} ({eog_inds}) EOG indices for {subject}') if len(eog_inds) != 0: ica.exclude.extend(eog_inds[:n_max_eog]) # for future inspection eog_epochs.average().save( op.join(meg_dir, subject, f'{subject}_audvis-eog-ave.fif')) del eog_epochs, eog_inds, scores_eog # release memory # applying ICA on Raw raw.load_data() ica.apply(raw) # extract events for epoching # modify stim_channel for your need events = mne.find_events(raw, stim_channel="STI 014") picks = mne.pick_types(raw.info, meg=True) epochs = mne.Epochs( raw, events=events, picks=picks, event_id=event_id, tmin=tmin, tmax=tmax, baseline=baseline, decim=4, # raw sampling rate is 600 Hz, subsample to 150 Hz preload=True, # for autoreject reject_tmax=reject_tmax, reject_by_annotation=True) del raw, annot # autoreject (local) if autoreject: # local reject # keep the bad sensors/channels because autoreject can repair it via # interpolation picks = mne.pick_types(epochs.info, meg=True, exclude=[]) ar = AutoReject(picks=picks, n_jobs=n_jobs, verbose=False) print(f'Run autoreject (local) for {subject} (it takes a long time)') ar.fit(epochs) print(f'Drop bad epochs and interpolate bad sensors for {subject}') epochs = ar.transform(epochs) print(f'Dropped {round(epochs.drop_log_stats(), 2)}% epochs for {subject}') epochs.save(epo_fname, overwrite=True)
def AR_local(cleaned_epochs_ICA: list, strategy:str = 'union', threshold:float = 50.0, verbose: bool = False) -> list: """ Applies local Autoreject to repair or reject bad epochs. Arguments: clean_epochs_ICA: list of Epochs after global Autoreject and ICA. strategy: more or less generous strategy to reject bad epochs: 'union' or 'intersection'. 'union' rejects bad epochs from subject 1 and subject 2 immediatly, whereas 'intersection' rejects shared bad epochs between subjects, tries to repare remaining bad epochs per subject, reject the non-reparable per subject and finally equalize epochs number between subjects. Set to 'union' by default. threshold: percentage of epochs removed that is accepted. Above this threshold, data are considered as a too shortened sample for further analyses. Set to 50.0 by default. verbose: option to plot data before and after AR, boolean, set to False by default. # use verbose = false until next Autoreject update Note: To reject or repair epochs, parameters are more or less conservative, see http://autoreject.github.io/generated/autoreject.AutoReject. Returns: cleaned_epochs_AR: list of Epochs after local Autoreject. dic_AR: dictionnary with the percentage of epochs rejection for each subject and for the intersection of the them. """ bad_epochs_AR = [] AR = [] dic_AR = {} dic_AR['strategy'] = strategy dic_AR['threshold'] = threshold # defaults values for n_interpolates and consensus_percs n_interpolates = np.array([1, 4, 32]) consensus_percs = np.linspace(0, 1.0, 11) # more generous values # n_interpolates = np.array([16, 32, 64]) # n_interpolates = np.array([1, 4, 8, 16, 32, 64]) # consensus_percs = np.linspace(0.5, 1.0, 11) for clean_epochs in cleaned_epochs_ICA: # per subj picks = mne.pick_types( clean_epochs[0].info, meg=False, eeg=True, stim=False, eog=False, exclude=[]) ar = AutoReject(n_interpolates, consensus_percs, picks=picks, thresh_method='random_search', random_state=42, verbose='tqdm_notebook') AR.append(ar) # fitting AR to get bad epochs ar.fit(clean_epochs) reject_log = ar.get_reject_log(clean_epochs, picks=picks) bad_epochs_AR.append(reject_log) # taking bad epochs for min 1 subj (dyad) log1 = bad_epochs_AR[0] log2 = bad_epochs_AR[1] bad1 = np.where(log1.bad_epochs == True) bad2 = np.where(log2.bad_epochs == True) if strategy == 'union': bad = list(set(bad1[0].tolist()).union(set(bad2[0].tolist()))) elif strategy == 'intersection': bad = list(set(bad1[0].tolist()).intersection(set(bad2[0].tolist()))) else: TypeError('not good strategy input!') # storing the percentage of epochs rejection dic_AR['S1'] = float((len(bad1[0].tolist())/len(cleaned_epochs_ICA[0]))*100) dic_AR['S2'] = float((len(bad2[0].tolist())/len(cleaned_epochs_ICA[1]))*100) # picking good epochs for the two subj cleaned_epochs_AR = [] for clean_epochs in cleaned_epochs_ICA: # per subj # keep a copy of the original data clean_epochs_ep = copy.deepcopy(clean_epochs) clean_epochs_ep = clean_epochs_ep.drop(indices=bad) # interpolating bads or removing epochs ar = AR[cleaned_epochs_ICA.index(clean_epochs)] clean_epochs_AR = ar.transform(clean_epochs_ep) cleaned_epochs_AR.append(clean_epochs_AR) if strategy == 'intersection': # equalizing epochs length between two participants mne.epochs.equalize_epoch_counts(cleaned_epochs_AR) dic_AR['dyad'] = float(((len(cleaned_epochs_ICA[0])-len(cleaned_epochs_AR[0]))/len(cleaned_epochs_ICA[0]))*100) if dic_AR['dyad'] >= threshold: TypeError('percentage of rejected epochs above threshold!') if verbose: print('%s percent of bad epochs' % dic_AR['dyad']) # Vizualisation before after AR evoked_before = [] for clean_epochs in cleaned_epochs_ICA: # per subj evoked_before.append(clean_epochs.average()) evoked_after_AR = [] for clean in cleaned_epochs_AR: evoked_after_AR.append(clean.average()) if verbose: for i, j in zip(evoked_before, evoked_after_AR): fig, axes = plt.subplots(2, 1, figsize=(6, 6)) for ax in axes: ax.tick_params(axis='x', which='both', bottom='off', top='off') ax.tick_params(axis='y', which='both', left='off', right='off') ylim = dict(grad=(-170, 200)) i.pick_types(eeg=True, exclude=[]) i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False) axes[0].set_title('Before autoreject') j.pick_types(eeg=True, exclude=[]) j.plot(exclude=[], axes=axes[1], ylim=ylim) # Problème titre ne s'affiche pas pour le deuxieme axe !!! axes[1].set_title('After autoreject') plt.tight_layout() return cleaned_epochs_AR, dic_AR
############################################################################### # Note that:class:`autoreject.AutoReject` by design supports # multiple channels. # If no picks are passed separate solutions will be computed for each channel # type and internally combines. This then readily supports cleaning # unseen epochs from the different channel types used during fit. # Here we only use a subset of channels to save time. ar = AutoReject(n_interpolates, consensus_percs, picks=picks, thresh_method='random_search', random_state=42) # Note that fitting and transforming can be done on different compatible # portions of data if needed. ar.fit(epochs['Auditory/Left']) epochs_clean = ar.transform(epochs['Auditory/Left']) evoked_clean = epochs_clean.average() evoked = epochs['Auditory/Left'].average() ############################################################################### # Now, we will manually mark the bad channels just for plotting. evoked.info['bads'] = ['MEG 2443'] evoked_clean.info['bads'] = ['MEG 2443'] ############################################################################### # Let us plot the results. import matplotlib.pyplot as plt # noqa set_matplotlib_defaults(plt)