def run_autoreject(subject): """Interpolate bad epochs/sensors using Autoreject. Parameters ---------- *subject: string The participant reference Save the resulting *-epo.fif file in the '4_autoreject' directory. Save .png of ERP difference and heatmap plots. References ---------- [1] Mainak Jas, Denis Engemann, Federico Raimondo, Yousra Bekhti, and Alexandre Gramfort, “Automated rejection and repair of bad trials in MEG/EEG.” In 6th International Workshop on Pattern Recognition in Neuroimaging (PRNI), 2016. [2] Mainak Jas, Denis Engemann, Yousra Bekhti, Federico Raimondo, and Alexandre Gramfort. 2017. “Autoreject: Automated artifact rejection for MEG and EEG data”. NeuroImage, 159, 417-429. """ # Import data input_path = root + '/4_ICA/' + subject + '-epo.fif' epochs = mne.read_epochs(input_path) # Autoreject ar = AutoReject(random_state=42, n_jobs=4) ar.fit_transform(epochs) epochs_clean = ar.transform(epochs) # Plot difference evoked = epochs.average() evoked_clean = epochs_clean.average() fig, axes = plt.subplots(2, 1, figsize=(6, 6)) for ax in axes: ax.tick_params(axis='x', which='both', bottom='off', top='off') ax.tick_params(axis='y', which='both', left='off', right='off') evoked.plot(exclude=[], axes=axes[0], ylim=[-30, 30], show=False) axes[0].set_title('Before autoreject') evoked_clean.plot(exclude=[], axes=axes[1], ylim=[-30, 30]) axes[1].set_title('After autoreject') plt.tight_layout() plt.savefig(root + '/5_autoreject/' + subject + '-autoreject.png') plt.close() # Plot heatmap ar.get_reject_log(epochs).plot() plt.savefig(root + '/5_autoreject/' + subject + '-heatmap.png') plt.close() # Save epoch data out_epoch = root + '/5_autoreject/' + subject + '-epo.fif' epochs_clean.save(out_epoch)
def epoch_and_clean_trials(subject, diagdir, bidsdir, datadir, derivdir, epochlength=3, eventid={'visualfix/fixCross': 10}): """ Chunk the data into epochs starting at the eventid specified per trial, lasting 7 seconds (which should include all trial elements). Do automatic artifact detection, rejection and fixing for eyeblinks, heartbeat, and high- and low-amplitude artifacts. :param subject: str, subject identifier. takes the form '001' :param diagdir: str, path to a directory where diagnostic plots can be saved :param bidsdir: str, path to a directory with BIDS data. Needed to load event logs from the experiment :param datadir: str, path to a directory with SSS-processed data :param derivdir: str, path to a directory where cleaned epochs can be saved :param epochlength: int, length of epoch :param eventid: dict, the event to start an Epoch from """ # construct name of the first split raw_fname = Path(datadir) / f'sub-{subject}/meg' / \ f'sub-{subject}_task-memento_proc-sss_meg.fif' logging.info(f"Reading in SSS-processed data from subject sub-{subject}. " f"Attempting the following path: {raw_fname}") raw = mne.io.read_raw_fif(raw_fname) events, event_dict = get_events(raw) # filter the data to remove high-frequency noise. Minimal high-pass filter # based on # https://www.sciencedirect.com/science/article/pii/S0165027021000157 # ensure the data is loaded prior to filtering raw.load_data() if subject == '017': logging.info('Setting additional bad channels for subject 17') raw.info['bads'] = ['MEG0313', 'MEG0513', 'MEG0523'] raw.interpolate_bads() # high-pass doesn't make sense, raw data has 0.1Hz high-pass filter already! _filter_data(raw, h_freq=100) # ICA to detect and repair artifacts logging.info('Removing eyeblink and heartbeat artifacts') rng = np.random.RandomState(28) remove_eyeblinks_and_heartbeat( raw=raw, subject=subject, figdir=diagdir, events=events, eventid=eventid, rng=rng, ) # get the actual epochs: chunk the trial into epochs starting from the # event ID. Do not baseline correct the data. logging.info(f'Creating epochs of length {epochlength}') if eventid == {'press/left': 1, 'press/right': 4}: # when centered on the response, move back in time epochs = mne.Epochs(raw, events, event_id=eventid, tmin=-epochlength, tmax=0, picks='meg', baseline=None) else: epochs = mne.Epochs(raw, events, event_id=eventid, tmin=0, tmax=epochlength, picks='meg', baseline=None) # ADD SUBJECT SPECIFIC TRIAL NUMBER TO THE EPOCH! ONLY THIS WAY WE CAN # LATER RECOVER WHICH TRIAL PARAMETERS WE'RE LOOKING AT BASED ON THE LOGS AS # THE EPOCH REJECTION WILL REMOVE TRIALS logging.info("Retrieving trial metadata.") from pymento_meg.proc.epoch import get_trial_features metadata = get_trial_features(bids_path=bidsdir, subject=subject, column='trial_no') # transform to integers metadata = metadata.astype(int) # this does not work if we start at fixation cross for subject 5, because 1 # fixation cross trigger is missing from the data, and it becomes impossible # to associate the trial metadata to the correct trials in the data epochs.metadata = metadata epochs.load_data() ## downsample the data to 200Hz #logging.info('Resampling epoched data down to 200 Hz') #epochs.resample(sfreq=200, verbose=True) # use autoreject to repair bad epochs ar = AutoReject( random_state=rng, n_interpolate=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) epochs_clean, reject_log = ar.fit_transform(epochs, return_log=True) # save the cleaned, epoched data to disk. outpath = _construct_path([ Path(derivdir), f"sub-{subject}", "meg", f"sub-{subject}_task-memento_cleaned_epo.fif", ]) logging.info(f"Saving cleaned, epoched data to {outpath}") epochs_clean.save(outpath, overwrite=True) # visualize the bad sensors for each trial fig = ar.get_reject_log(epochs).plot() fname = _construct_path([ Path(diagdir), f"sub-{subject}", "meg", f"epoch-rejectlog_sub-{subject}.png", ]) fig.savefig(fname) # plot the average of all cleaned epochs fig = epochs_clean.average().plot() fname = _construct_path([ Path(diagdir), f"sub-{subject}", "meg", f"clean-epoch_average_sub-{subject}.png", ]) fig.savefig(fname) # plot psd of cleaned epochs psd = epochs_clean.plot_psd() fname = _construct_path([ Path(diagdir), f"sub-{subject}", "meg", f"psd_cleaned-epochs-{subject}.png", ]) psd.savefig(fname)
def AR_local(cleaned_epochs_ICA, verbose=False): """ Applies local Autoreject to correct or reject bad epochs. Arguments: clean_epochs_ICA: list of Epochs after global Autoreject and ICA verbose: to plot data before and after AR, boolean set to False by default. Returns: cleaned_epochs_AR: list of Epochs after local Autoreject. """ bad_epochs_AR = [] # defaults values for n_interpolates and consensus_percs n_interpolates = np.array([1, 4, 32]) consensus_percs = np.linspace(0, 1.0, 11) for clean_epochs in cleaned_epochs_ICA: # per subj picks = mne.pick_types(clean_epochs[0].info, meg=False, eeg=True, stim=False, eog=False, exclude=[]) if verbose: ar_verbose = 'progressbar' else: ar_verbose = False ar = AutoReject(n_interpolates, consensus_percs, picks=picks, thresh_method='random_search', random_state=42, verbose=ar_verbose) # fitting AR to get bad epochs ar.fit(clean_epochs) reject_log = ar.get_reject_log(clean_epochs, picks=picks) bad_epochs_AR.append(reject_log) # taking bad epochs for min 1 subj (dyad) log1 = bad_epochs_AR[0] log2 = bad_epochs_AR[1] bad1 = np.where(log1.bad_epochs == True) bad2 = np.where(log2.bad_epochs == True) bad = list(set(bad1[0].tolist()).intersection(bad2[0].tolist())) if verbose: print('%s percent of bad epochs' % int(len(bad) / len(list(log1.bad_epochs)) * 100)) # picking good epochs for the two subj cleaned_epochs_AR = [] for clean_epochs in cleaned_epochs_ICA: # per subj clean_epochs_ep = clean_epochs.drop(indices=bad) # interpolating bads or removing epochs clean_epochs_AR = ar.transform(clean_epochs_ep) cleaned_epochs_AR.append(clean_epochs_AR) # equalizing epochs length between two subjects mne.epochs.equalize_epoch_counts(cleaned_epochs_AR) # Vizualisation before after AR evoked_before = [] for clean_epochs in cleaned_epochs_ICA: # per subj evoked_before.append(clean_epochs.average()) evoked_after_AR = [] for clean in cleaned_epochs_AR: evoked_after_AR.append(clean.average()) if verbose: for i, j in zip(evoked_before, evoked_after_AR): fig, axes = plt.subplots(2, 1, figsize=(6, 6)) for ax in axes: ax.tick_params(axis='x', which='both', bottom='off', top='off') ax.tick_params(axis='y', which='both', left='off', right='off') ylim = dict(grad=(-170, 200)) i.pick_types(eeg=True, exclude=[]) i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False) axes[0].set_title('Before autoreject') j.pick_types(eeg=True, exclude=[]) j.plot(exclude=[], axes=axes[1], ylim=ylim) # Problème titre ne s'affiche pas pour le deuxieme axe !!! axes[1].set_title('After autoreject') plt.tight_layout() return cleaned_epochs_AR
preload=True) ar = AutoReject(n_interpolates, consensus_percs, picks=picks, thresh_method='random_search', random_state=42) # Note that fitting and transforming can be done on different compatible # portions of data if needed. ar.fit(epochs) # epochs_ar, reject_log = ar.fit_transform(epochs, return_log=True) epochs_clean = ar.transform(epochs) reject_log = ar.get_reject_log(epochs) evoked_clean = epochs_clean.average() evoked = epochs.average() # visualize rejected epochs scalings = dict(eeg=40e-6) reject_log.plot_epochs(epochs, scalings=scalings) # epochs after cleaning epochs_clean.plot(scalings=scalings) # notify when done os.system('say "... I am ready for you Neil."') # Visualize repaired data ylim = dict(eeg=(-2, 2))
def test_autoreject(): """Test basic _AutoReject functionality.""" event_id = None tmin, tmax = -0.2, 0.5 events = mne.find_events(raw) ########################################################################## # picking epochs include = [u'EEG %03d' % i for i in range(1, 45, 3)] picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False, eog=True, include=include, exclude=[]) epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), decim=10, reject=None, preload=False)[:10] ar = _AutoReject() assert_raises(ValueError, ar.fit, epochs) epochs.load_data() ar.fit(epochs) assert_true(len(ar.picks_) == len(picks) - 1) # epochs with no picks. epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), decim=10, reject=None, preload=True)[:20] # let's drop some channels to speed up pre_picks = mne.pick_types(epochs.info, meg=True, eeg=True) pre_picks = np.r_[ mne.pick_types(epochs.info, meg='mag', eeg=False)[::15], mne.pick_types(epochs.info, meg='grad', eeg=False)[::60], mne.pick_types(epochs.info, meg=False, eeg=True)[::16], mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)] pick_ch_names = [epochs.ch_names[pp] for pp in pre_picks] bad_ch_names = [epochs.ch_names[ix] for ix in range(len(epochs.ch_names)) if ix not in pre_picks] epochs_with_bads = epochs.copy() epochs_with_bads.info['bads'] = bad_ch_names epochs.pick_channels(pick_ch_names) epochs_fit = epochs[:12] # make sure to use different size of epochs epochs_new = epochs[12:] epochs_with_bads_fit = epochs_with_bads[:12] X = epochs_fit.get_data() n_epochs, n_channels, n_times = X.shape X = X.reshape(n_epochs, -1) ar = _GlobalAutoReject() assert_raises(ValueError, ar.fit, X) ar = _GlobalAutoReject(n_channels=n_channels) assert_raises(ValueError, ar.fit, X) ar = _GlobalAutoReject(n_times=n_times) assert_raises(ValueError, ar.fit, X) ar_global = _GlobalAutoReject( n_channels=n_channels, n_times=n_times, thresh=40e-6) ar_global.fit(X) param_name = 'thresh' param_range = np.linspace(40e-6, 200e-6, 10) assert_raises(ValueError, validation_curve, X, None, param_name, param_range) ########################################################################## # picking AutoReject picks = mne.pick_types( epochs.info, meg='mag', eeg=True, stim=False, eog=False, include=[], exclude=[]) non_picks = mne.pick_types( epochs.info, meg='grad', eeg=False, stim=False, eog=False, include=[], exclude=[]) ch_types = ['mag', 'eeg'] ar = _AutoReject(picks=picks) # XXX : why do we need this?? ar = AutoReject(cv=3, picks=picks, random_state=42, n_interpolate=[1, 2], consensus=[0.5, 1]) assert_raises(AttributeError, ar.fit, X) assert_raises(ValueError, ar.transform, X) assert_raises(ValueError, ar.transform, epochs) ar.fit(epochs_fit) reject_log = ar.get_reject_log(epochs_fit) for ch_type in ch_types: # test that kappa & rho are selected assert_true( ar.n_interpolate_[ch_type] in ar.n_interpolate) assert_true( ar.consensus_[ch_type] in ar.consensus) assert_true( ar.n_interpolate_[ch_type] == ar.local_reject_[ch_type].n_interpolate_[ch_type]) assert_true( ar.consensus_[ch_type] == ar.local_reject_[ch_type].consensus_[ch_type]) # test complementarity of goods and bads assert_array_equal(len(reject_log.bad_epochs), len(epochs_fit)) # test that transform does not change state of ar epochs_clean = ar.transform(epochs_fit) # apply same data assert_true(repr(ar)) assert_true(repr(ar.local_reject_)) reject_log2 = ar.get_reject_log(epochs_fit) assert_array_equal(reject_log.labels, reject_log2.labels) assert_array_equal(reject_log.bad_epochs, reject_log2.bad_epochs) assert_array_equal(reject_log.ch_names, reject_log2.ch_names) epochs_new_clean = ar.transform(epochs_new) # apply to new data reject_log_new = ar.get_reject_log(epochs_new) assert_array_equal(len(reject_log_new.bad_epochs), len(epochs_new)) assert_true( len(reject_log_new.bad_epochs) != len(reject_log.bad_epochs)) picks_by_type = _get_picks_by_type(epochs.info, ar.picks) # test correct entries in fix log assert_true( np.isnan(reject_log_new.labels[:, non_picks]).sum() > 0) assert_true( np.isnan(reject_log_new.labels[:, picks]).sum() == 0) assert_equal(reject_log_new.labels.shape, (len(epochs_new), len(epochs_new.ch_names))) # test correct interpolations by type for ch_type, this_picks in picks_by_type: interp_counts = np.sum( reject_log_new.labels[:, this_picks] == 2, axis=1) labels = reject_log_new.labels.copy() not_this_picks = np.setdiff1d(np.arange(labels.shape[1]), this_picks) labels[:, not_this_picks] = np.nan interp_channels = _get_interp_chs( labels, reject_log.ch_names, this_picks) assert_array_equal( interp_counts, [len(cc) for cc in interp_channels]) is_same = epochs_new_clean.get_data() == epochs_new.get_data() if not np.isscalar(is_same): is_same = np.isscalar(is_same) assert_true(not is_same) # test that transform ignores bad channels epochs_with_bads_fit.pick_types(meg='mag', eeg=True, eog=True, exclude=[]) ar_bads = AutoReject(cv=3, random_state=42, n_interpolate=[1, 2], consensus=[0.5, 1]) ar_bads.fit(epochs_with_bads_fit) epochs_with_bads_clean = ar_bads.transform(epochs_with_bads_fit) good_w_bads_ix = mne.pick_types(epochs_with_bads_clean.info, meg='mag', eeg=True, eog=True, exclude='bads') good_wo_bads_ix = mne.pick_types(epochs_clean.info, meg='mag', eeg=True, eog=True, exclude='bads') assert_array_equal(epochs_with_bads_clean.get_data()[:, good_w_bads_ix, :], epochs_clean.get_data()[:, good_wo_bads_ix, :]) bad_ix = [epochs_with_bads_clean.ch_names.index(ch) for ch in epochs_with_bads_clean.info['bads']] epo_ix = ~ar_bads.get_reject_log(epochs_with_bads_fit).bad_epochs assert_array_equal( epochs_with_bads_clean.get_data()[:, bad_ix, :], epochs_with_bads_fit.get_data()[epo_ix, :, :][:, bad_ix, :]) assert_equal(epochs_clean.ch_names, epochs_fit.ch_names) assert_true(isinstance(ar.threshes_, dict)) assert_true(len(ar.picks) == len(picks)) assert_true(len(ar.threshes_.keys()) == len(ar.picks)) pick_eog = mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)[0] assert_true(epochs.ch_names[pick_eog] not in ar.threshes_.keys()) assert_raises( IndexError, ar.transform, epochs.copy().pick_channels( [epochs.ch_names[pp] for pp in picks[:3]])) epochs.load_data() assert_raises(ValueError, compute_thresholds, epochs, 'dfdfdf') index, ch_names = zip(*[(ii, epochs_fit.ch_names[pp]) for ii, pp in enumerate(picks)]) threshes_a = compute_thresholds( epochs_fit, picks=picks, method='random_search') assert_equal(set(threshes_a.keys()), set(ch_names)) threshes_b = compute_thresholds( epochs_fit, picks=picks, method='bayesian_optimization') assert_equal(set(threshes_b.keys()), set(ch_names))
def run_preproc(datadir='/data'): print('data directory: {}'.format(datadir)) conf_file_path = join(datadir, 'eegprep.conf') config = Configuration() config.setDefaults(defaults) if os.path.isfile(conf_file_path): with open(conf_file_path) as fh: conf_string = fh.read() config.updateFromString(conf_string) print('configuration:') print(config) bidsdir = join(datadir, 'BIDS') eegprepdir = join(bidsdir, 'derivatives', 'eegprep') subjectdirs = sorted(glob.glob(join(bidsdir, 'sub-*'))) for subjectdir in subjectdirs: assert os.path.isdir(subjectdir) sub = basename(subjectdir)[4:] # prepare derivatives directory derivdir = join(eegprepdir, 'sub-' + sub) os.makedirs(derivdir, exist_ok=True) reportsdir = join(eegprepdir, 'reports', 'sub-' + sub) os.makedirs(reportsdir, exist_ok=True) subject_epochs = {} rawtypes = {'.set': mne.io.read_raw_eeglab, '.bdf': mne.io.read_raw_edf} for fname in sorted(glob.glob(join(subjectdir, 'eeg', '*'))): _, ext = splitext(fname) if ext not in rawtypes.keys(): continue sub, ses, task, run = filename2tuple(basename(fname)) print('\nProcessing raw file: ' + basename(fname)) # read data raw = rawtypes[ext](fname, preload=True, verbose=False) events = mne.find_events(raw) #raw, consecutive=False, min_duration=0.005) # Set channel types and select reference channels channelFile = fname.replace('eeg' + ext, 'channels.tsv') channels = pandas.read_csv(channelFile, index_col='name', sep='\t') bids2mne = { 'MISC': 'misc', 'EEG': 'eeg', 'VEOG': 'eog', 'TRIG': 'stim', 'REF': 'eeg', } channels['mne'] = channels.type.replace(bids2mne) # the below fails if the specified channels are not in the data raw.set_channel_types(channels.mne.to_dict()) # set bad channels raw.info['bads'] = channels[channels.status=='bad'].index.tolist() # pick channels to use for epoching epoching_picks = mne.pick_types(raw.info, eeg=True, eog=False, stim=False, exclude='bads') # Filtering #raw.filter(l_freq=0.05, h_freq=40, fir_design='firwin') montage = mne.channels.read_montage(guess_montage(raw.ch_names)) print(montage) raw.set_montage(montage) # plot raw data nchans = len(raw.ch_names) pick_channels = numpy.arange(0, nchans, numpy.floor(nchans/20)).astype(int) start = numpy.round(raw.times.max()/2) fig = raw.plot(start=start, order=pick_channels) fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_raw.png'.format(sub, ses, task, run) fig.savefig(join(reportsdir, fname_plot)) # Set reference refChannels = channels[channels.type=='REF'].index.tolist() raw.set_eeg_reference(ref_channels=refChannels) ## epoching epochs_params = dict( events=events, tmin=-0.1, tmax=0.8, reject=None, # dict(eeg=250e-6, eog=150e-6) picks=epoching_picks, detrend=0, ) file_epochs = mne.Epochs(raw, preload=True, **epochs_params) file_epochs.drop_channels(refChannels) # autoreject (under development) ar = AutoReject(n_jobs=4) clean_epochs = ar.fit_transform(file_epochs) rejectlog = ar.get_reject_log(clean_epochs) fname_log = 'sub-{}_ses-{}_task-{}_run-{}_reject-log.npz'.format(sub, ses, task, run) save_rejectlog(join(reportsdir, fname_log), rejectlog) fig = plot_rejectlog(rejectlog) fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_bad-epochs.png'.format(sub, ses, task, run) fig.savefig(join(reportsdir, fname_plot)) # store for now subject_epochs[(ses, task, run)] = clean_epochs # create evoked plots conds = clean_epochs.event_id.keys() selected_conds = random.sample(conds, min(len(conds), 6)) picks = mne.pick_types(clean_epochs.info, eeg=True) for cond in selected_conds: evoked = clean_epochs[cond].average() fname_plot = 'sub-{}_ses-{}_task-{}_run-{}_evoked-{}.png'.format(sub, ses, task, run, cond) fig = evoked.plot_joint(picks=picks) fig.savefig(join(reportsdir, fname_plot)) sessSeg = 0 sessions = sorted(list(set([k[sessSeg] for k in subject_epochs.keys()]))) for session in sessions: taskSeg = 1 tasks = list(set([k[taskSeg] for k in subject_epochs.keys() if k[sessSeg]==session])) for task in tasks: print('\nGathering epochs for session {} task {}'.format(session, task)) epochs_selection = [v for (k, v) in subject_epochs.items() if k[:2]==(session, task)] task_epochs = mne.epochs.concatenate_epochs(epochs_selection) # downsample if configured to do so # important to do this after concatenation because # downsampling may cause rejection for 'TOOSHORT' if config['downsample'] < task_epochs.info['sfreq']: task_epochs = task_epochs.copy().resample(config['downsample'], npad='auto') ext = config['out_file_format'] fname = join(derivdir, 'sub-{}_ses-{}_task-{}_epo.{}'.format(sub, session, task, ext)) variables = { 'epochs': task_epochs.get_data(), 'events': task_epochs.events, 'timepoints': task_epochs.times } if ext == 'fif': task_epochs.save(fname) elif ext == 'mat': scipy.io.savemat(fname, mdict=variables) elif ext == 'npy': numpy.savez(fname, **variables)
evoked.info['bads'] = ['MEG 2443'] evoked_clean.info['bads'] = ['MEG 2443'] # %% # Let us plot the results. import matplotlib.pyplot as plt # noqa set_matplotlib_defaults(plt) fig, axes = plt.subplots(2, 1, figsize=(6, 6)) for ax in axes: ax.tick_params(axis='x', which='both', bottom='off', top='off') ax.tick_params(axis='y', which='both', left='off', right='off') ylim = dict(grad=(-170, 200)) evoked.pick_types(meg='grad', exclude=[]) evoked.plot(exclude=[], axes=axes[0], ylim=ylim, show=False) axes[0].set_title('Before autoreject') evoked_clean.pick_types(meg='grad', exclude=[]) evoked_clean.plot(exclude=[], axes=axes[1], ylim=ylim) axes[1].set_title('After autoreject') plt.tight_layout() # %% # To top things up, we can also visualize the bad sensors for each trial using # a heatmap. ar.get_reject_log(epochs['Auditory/Left']).plot()
def AR_local(cleaned_epochs_ICA: list, strategy:str = 'union', threshold:float = 50.0, verbose: bool = False) -> list: """ Applies local Autoreject to repair or reject bad epochs. Arguments: clean_epochs_ICA: list of Epochs after global Autoreject and ICA. strategy: more or less generous strategy to reject bad epochs: 'union' or 'intersection'. 'union' rejects bad epochs from subject 1 and subject 2 immediatly, whereas 'intersection' rejects shared bad epochs between subjects, tries to repare remaining bad epochs per subject, reject the non-reparable per subject and finally equalize epochs number between subjects. Set to 'union' by default. threshold: percentage of epochs removed that is accepted. Above this threshold, data are considered as a too shortened sample for further analyses. Set to 50.0 by default. verbose: option to plot data before and after AR, boolean, set to False by default. # use verbose = false until next Autoreject update Note: To reject or repair epochs, parameters are more or less conservative, see http://autoreject.github.io/generated/autoreject.AutoReject. Returns: cleaned_epochs_AR: list of Epochs after local Autoreject. dic_AR: dictionnary with the percentage of epochs rejection for each subject and for the intersection of the them. """ bad_epochs_AR = [] AR = [] dic_AR = {} dic_AR['strategy'] = strategy dic_AR['threshold'] = threshold # defaults values for n_interpolates and consensus_percs n_interpolates = np.array([1, 4, 32]) consensus_percs = np.linspace(0, 1.0, 11) # more generous values # n_interpolates = np.array([16, 32, 64]) # n_interpolates = np.array([1, 4, 8, 16, 32, 64]) # consensus_percs = np.linspace(0.5, 1.0, 11) for clean_epochs in cleaned_epochs_ICA: # per subj picks = mne.pick_types( clean_epochs[0].info, meg=False, eeg=True, stim=False, eog=False, exclude=[]) ar = AutoReject(n_interpolates, consensus_percs, picks=picks, thresh_method='random_search', random_state=42, verbose='tqdm_notebook') AR.append(ar) # fitting AR to get bad epochs ar.fit(clean_epochs) reject_log = ar.get_reject_log(clean_epochs, picks=picks) bad_epochs_AR.append(reject_log) # taking bad epochs for min 1 subj (dyad) log1 = bad_epochs_AR[0] log2 = bad_epochs_AR[1] bad1 = np.where(log1.bad_epochs == True) bad2 = np.where(log2.bad_epochs == True) if strategy == 'union': bad = list(set(bad1[0].tolist()).union(set(bad2[0].tolist()))) elif strategy == 'intersection': bad = list(set(bad1[0].tolist()).intersection(set(bad2[0].tolist()))) else: TypeError('not good strategy input!') # storing the percentage of epochs rejection dic_AR['S1'] = float((len(bad1[0].tolist())/len(cleaned_epochs_ICA[0]))*100) dic_AR['S2'] = float((len(bad2[0].tolist())/len(cleaned_epochs_ICA[1]))*100) # picking good epochs for the two subj cleaned_epochs_AR = [] for clean_epochs in cleaned_epochs_ICA: # per subj # keep a copy of the original data clean_epochs_ep = copy.deepcopy(clean_epochs) clean_epochs_ep = clean_epochs_ep.drop(indices=bad) # interpolating bads or removing epochs ar = AR[cleaned_epochs_ICA.index(clean_epochs)] clean_epochs_AR = ar.transform(clean_epochs_ep) cleaned_epochs_AR.append(clean_epochs_AR) if strategy == 'intersection': # equalizing epochs length between two participants mne.epochs.equalize_epoch_counts(cleaned_epochs_AR) dic_AR['dyad'] = float(((len(cleaned_epochs_ICA[0])-len(cleaned_epochs_AR[0]))/len(cleaned_epochs_ICA[0]))*100) if dic_AR['dyad'] >= threshold: TypeError('percentage of rejected epochs above threshold!') if verbose: print('%s percent of bad epochs' % dic_AR['dyad']) # Vizualisation before after AR evoked_before = [] for clean_epochs in cleaned_epochs_ICA: # per subj evoked_before.append(clean_epochs.average()) evoked_after_AR = [] for clean in cleaned_epochs_AR: evoked_after_AR.append(clean.average()) if verbose: for i, j in zip(evoked_before, evoked_after_AR): fig, axes = plt.subplots(2, 1, figsize=(6, 6)) for ax in axes: ax.tick_params(axis='x', which='both', bottom='off', top='off') ax.tick_params(axis='y', which='both', left='off', right='off') ylim = dict(grad=(-170, 200)) i.pick_types(eeg=True, exclude=[]) i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False) axes[0].set_title('Before autoreject') j.pick_types(eeg=True, exclude=[]) j.plot(exclude=[], axes=axes[1], ylim=ylim) # Problème titre ne s'affiche pas pour le deuxieme axe !!! axes[1].set_title('After autoreject') plt.tight_layout() return cleaned_epochs_AR, dic_AR