def test_ransac(): """Some basic tests for ransac.""" event_id = {'Visual/Left': 3} tmin, tmax = -0.2, 0.5 events = mne.find_events(raw) epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), decim=8, reject=None, preload=True) # normal case picks = mne.pick_types(epochs.info, meg='mag', eeg=False, stim=False, eog=False, exclude=[]) ransac = Ransac(picks=picks) epochs_clean = ransac.fit_transform(epochs) assert_true(len(epochs_clean) == len(epochs)) # Pass numpy instead of epochs X = epochs.get_data() assert_raises(AttributeError, ransac.fit, X) # # should not contain both channel types picks = mne.pick_types(epochs.info, meg=True, eeg=False, stim=False, eog=False, exclude=[]) ransac = Ransac(picks=picks) assert_raises(ValueError, ransac.fit, epochs) # # should not contain other channel types. picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=True, eog=False, exclude=[]) ransac = Ransac(picks=picks) assert_raises(ValueError, ransac.fit, epochs)
def test_ransac(): """Some basic tests for ransac.""" ransac = Ransac() event_id = {'Visual/Left': 3} tmin, tmax = -0.2, 0.5 events = mne.find_events(raw) include = [u'EEG %03d' % i for i in range(1, 15)] picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False, eog=False, include=include, exclude=[]) epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), decim=8, reject=None, add_eeg_ref=False) X = epochs.get_data() assert_raises(ValueError, ransac.fit, X) # should not contain both channel types assert_raises(ValueError, ransac.fit, epochs) picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=True, eog=False, include=include, exclude=[]) # should not contain other channel types epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), decim=8, reject=None, add_eeg_ref=False) assert_raises(ValueError, ransac.fit, epochs) # now with only one channel type picks = mne.pick_types(raw.info, meg=False, eeg=True) epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), decim=8, reject=None, add_eeg_ref=False, preload=True) epochs_clean = ransac.fit_transform(epochs) assert_true(len(epochs_clean) == len(epochs))
def mne_ransac_bad_channels(raw, overwrite=False): bids_chan_file = tb.fileparts(raw.filenames[0], '_channels.tsv', -4) ransacfile = tb.fileparts(raw.filenames[0], '_channels_ransac.tsv') if not pathlib.Path.is_file(pathlib.Path(ransacfile)) or overwrite: epochs = mne_epoch(raw).drop_bad() epochs.load_data() ransac = Ransac(random_state=999) ransac.fit(epochs) raw.info['bads'] = ransac.bad_chs_ chans = pd.read_csv(bids_chan_file, delimiter='\t') chans.loc[chans.name.isin(ransac.bad_chs_), 'status'] = 'bad' pd.DataFrame.to_csv(chans, ransacfile, sep='\t') else: chans = pd.read_csv(ransacfile, delimiter='\t') raw.info['bads'] = list(chans.loc[chans['status'] == 'bad', 'name']) return raw
def robust_avg_ref(epochs, ransac_parameters, apply=True): """ Create a robust average reference by first interpolating the bad channels to exclude outliers. The reference is applied as a projection. Return epochs with reference projection applied if apply=True """ ransac = Ransac(**ransac_parameters, verbose="tqdm") epochs_tmp = epochs.copy() epochs_tmp = ransac.fit_transform(epochs) set_eeg_reference(epochs_tmp, ref_channels="average", projection=True) robust_avg_proj = epochs_tmp.info["projs"][0] del epochs_tmp epochs.info["projs"].append(robust_avg_proj) if apply: epochs.apply_proj() return epochs
def interpolate_bads(epochs, ransac_parameters): ransac = Ransac(**ransac_parameters, verbose="tqdm") evoked = epochs.average() # for plotting epochs = ransac.fit_transform(epochs) evoked.info["bads"] = ransac.bad_chs_ # plot evoked response with and without interpolated bads: fig, ax = plt.subplots(2) evoked.plot(exclude=[], axes=ax[0], show=False) ax[0].set_title('Before RANSAC') evoked = epochs.average() # for plotting evoked.info["bads"] = ransac.bad_chs_ evoked.plot(exclude=[], axes=ax[1], show=False) ax[1].set_title('After RANSAC') fig.tight_layout() fig.savefig(_out_folder / Path("interpolate_bad_channels.pdf"), dpi=800) plt.close() return epochs
def test_ransac(): """Some basic tests for ransac.""" event_id = {'Visual/Left': 3} tmin, tmax = -0.2, 0.5 events = mne.find_events(raw) epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), decim=8, reject=None, preload=True) # normal case picks = mne.pick_types(epochs.info, meg='mag', eeg=False, stim=False, eog=False, exclude=[]) ransac = Ransac(picks=picks) epochs_clean = ransac.fit_transform(epochs) assert len(epochs_clean) == len(epochs) # Pass numpy instead of epochs X = epochs.get_data() pytest.raises(AttributeError, ransac.fit, X) # # should not contain both channel types picks = mne.pick_types(epochs.info, meg=True, eeg=False, stim=False, eog=False, exclude=[]) ransac = Ransac(picks=picks) pytest.raises(ValueError, ransac.fit, epochs) # # should not contain other channel types. picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=True, eog=False, exclude=[]) ransac = Ransac(picks=picks) pytest.raises(ValueError, ransac.fit, epochs)
def applyRansac(self): ''' Implements RAndom SAmple Consensus (RANSAC) method to detect bad channels. Returns - - - - self.info['bads']: list with all bad channels detected by the RANSAC algorithm ''' # select channels to display picks = mne.pick_types(self.info, eeg=True, exclude='bads') # use Ransac, interpolating bads and append bad channels to self.info['bads'] ransac = Ransac(verbose=False, picks=picks, n_jobs=1) epochs_clean = ransac.fit_transform(self) print('The following electrodes are selected as bad by Ransac:') print('\n'.join(ransac.bad_chs_)) self.info['bads'] = ransac.bad_chs_
detrend=0, preload=True) picks = mne.pick_types(epochs.info, meg='grad', eeg=False, stim=False, eog=False, include=[], exclude=[]) ############################################################################### # We import ``Ransac`` and run the familiar ``fit_transform`` method. from autoreject import Ransac # noqa from autoreject.utils import interpolate_bads # noqa ransac = Ransac(verbose='progressbar', picks=picks, n_jobs=1) epochs_clean = ransac.fit_transform(epochs) ############################################################################### # We can also get the list of bad channels computed by ``Ransac``. print('\n'.join(ransac.bad_chs_)) ############################################################################### # Then we compute the ``evoked`` before and after interpolation. evoked = epochs.average() evoked_clean = epochs_clean.average() ############################################################################### # We will manually mark the bad channels just for plotting.
# Step 2: Estimate ICA # Apply HPF to all channels and a 60Hz Notch filter to eogs raw.filter(preprocess_options['ica_highpass'], None, skip_by_annotation=['boundary']) raw.filter(None, 40, picks=['eog']) raw.notch_filter([60, 120], picks=['eog']) # Make ICA Epochs epochs = mne.Epochs(raw, events, event_id=event_id, tmin=preprocess_options['tmin'], tmax=preprocess_options['tmax'], baseline=(None, None), reject=None, preload=True) epochs.set_montage(bv_montage) # Autodetect bad channels ransac = Ransac(verbose=False, n_jobs=4, min_corr=.60, unbroken_time=.6) ransac.fit(epochs.copy().filter(None, 40)) if len(ransac.bad_chs_): for chan in ransac.bad_chs_: epochs.info['bads'].append(chan) print(f'RANSAC Bad Channels: {ransac.bad_chs_}') # Save RANSAC ransac_file = deriv_path / f'{sub}_task-{task}_ref_FCz_ransac.pkl' with open(ransac_file, 'wb') as f: pickle.dump(ransac, f) # Make RANSAC json json_info = { 'Description': 'RANSAC object computed from epoched data', 'parameteres': {
def preprocessing(file): print("setting default objects") # avoid MNE being too verbose mne.set_log_level('ERROR') use_autoreject = True electrodes_snr = [] ssveps = [] snrs_faces = [] snrs_ob = [] Report_pp = [] fooof_object = fooof.FOOOF(background_mode='fixed', ) #Start file by file analysis print("reading raw EDF") raw = mne.io.read_raw_edf(file, montage=mne.channels.read_montage('biosemi64'), eog=[f'EXG{n}' for n in range(1, 9)]) raw.info['subject_info'] = { 'pid': file.name[7:11], 'group': file.name[3:6], 'filename': file.name } # make raw plots file_id = raw.info['subject_info']['filename'] pid = raw.info['subject_info']['pid'] path = ('FOOOF/%s' % (pid)) if not os.path.exists(path): os.makedirs(path) plot = raw.plot_psd() plot.savefig('%s/rawplot_psd_%s.png' % (path, file_id)) plot.clf() plot = raw.plot(duration=30., start=60., block=True) plot.savefig('%s/rawplot_%s.png' % (path, file_id)) plot.clf() #Define Events events = mne.find_events(raw) events = events[events[:, 2] < 255, :] raw.info['events'] = mne.find_events(raw, stim_channel='STI 014') epochs = (mne.Epochs( raw, events, tmin=-1, tmax=15, preload=True, ).set_eeg_reference().load_data().resample(256).apply_proj()) #Evoked Plots evoked = epochs.average() fig = mne.viz.plot_evoked(evoked, spatial_colors=True, selectable=True) fig.savefig('%s/evoked_plot_%s.png' % (path, file_id)) fig.clf() if use_autoreject: #assess sensors print("Running ransac autoreject") picks = mne.pick_types( epochs.info, eeg=True, ) ransac = Ransac(verbose=False, picks=picks, n_jobs=1) epochs = ransac.fit_transform(epochs) ssvep = ssvepy.Ssvep( epochs, [1.2, 6], compute_tfr=False, fmin=0.5, fmax=45, noisebandwidth=3, ) ssvep.psd = (ssvep.psd.groupby( ssvep.psd.coords["epoch"] < 200).mean("epoch").rename( {"epoch": "faces"})) ssvep.snr = ssvep._get_snr(ssvep.psd.coords["frequency"].data) ssvep.original_psd = ssvep.psd.copy() import scipy.linalg fooof_data = xr.Dataset({ key: xr.full_like(ssvep.psd.interp(frequency=[1.2, 6]), np.nan) for key in ['peak_amp', 'r_squared', 'offset', 'slope'] }) for faces_present, cond_data in ssvep.psd.groupby("faces"): for channel, data in cond_data.groupby("channel"): try: print("Running fooof on channel: {}".format(channel)) fooof_object.fit(data.coords["frequency"].data, data.data.squeeze()) except scipy.linalg.LinAlgError: print(" Fooof_object failed for %s" % (file_id)) continue ssvep.psd.loc[faces_present, channel, :] -= (10**fooof_object._bg_fit) fooof_data.peak_amp.loc[faces_present, channel, 1.2] = match_peaks(1.2, fooof_object, max_harmonics=3) fooof_data.peak_amp.loc[faces_present, channel, 6] = match_peaks(6, fooof_object, max_harmonics=3) fooof_data.r_squared.loc[faces_present, channel, :] = fooof_object.r_squared_ fooof_data.offset.loc[ faces_present, channel, :] = fooof_object.background_params_[0] fooof_data.slope.loc[ faces_present, channel, :] = fooof_object.background_params_[1] # average the SNRs: snr_data = xr.concat( (ssvep.snr.interp( frequency=[1.2, 2.4, 3.6]).mean("frequency").expand_dims(dim={ "frequency": [1.2] }).transpose("faces", "channel", "frequency"), ssvep.snr.interp(frequency=[6, 12, 18]).mean("frequency").expand_dims( dim={ "frequency": [6] }).transpose("faces", "channel", "frequency")), "frequency") # convert to df participant_df = (pd.merge(snr_data.to_dataframe(name='snr').reset_index(), fooof_data.to_dataframe().reset_index(), on=['frequency', 'faces', 'channel']).assign( participant=raw.info['subject_info']['pid'], group=raw.info['subject_info']['group'], )) participant_df.to_csv('%s/output_%s.csv' % (path, file_id), header=True)
detrend=0, preload=True) picks = mne.pick_types(epochs.info, meg='grad', eeg=False, stim=False, eog=False, include=[], exclude=[]) ############################################################################### # We import ``Ransac`` and run the familiar ``fit_transform`` method. from autoreject import Ransac # noqa from autoreject.utils import interpolate_bads # noqa ransac = Ransac(verbose=True, picks=picks, n_jobs=1) epochs_clean = ransac.fit_transform(epochs) ############################################################################### # We can also get the list of bad channels computed by ``Ransac``. print('\n'.join(ransac.bad_chs_)) ############################################################################### # Then we compute the ``evoked`` before and after interpolation. evoked = epochs.average() evoked_clean = epochs_clean.average() ############################################################################### # We will manually mark the bad channels just for plotting.
raw.info['projs'] = list() # remove proj, don't proj while interpolating epochs = Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), reject=None, verbose=False, detrend=0, preload=True) picks = mne.pick_types(epochs.info, meg='grad', eeg=False, stim=False, eog=False, include=[], exclude=[]) ############################################################################### # We import ``Ransac`` and run the familiar ``fit_transform`` method. from autoreject import Ransac # noqa from autoreject.utils import interpolate_bads # noqa ransac = Ransac(verbose='progressbar', picks=picks, n_jobs=1) epochs_clean = ransac.fit_transform(epochs) ############################################################################### # We can also get the list of bad channels computed by ``Ransac``. print('\n'.join(ransac.bad_chs_)) ############################################################################### # Then we compute the ``evoked`` before and after interpolation. evoked = epochs.average() evoked_clean = epochs_clean.average() ############################################################################### # We will manually mark the bad channels just for plotting.
def run_epochs(subject): """Run epochs. Transform raw data into epochs and match with experimental conditions. Reject bad epochs using a (high) threshold (preliminary rejection). Parameters ---------- *subject: string The participant reference Save the resulting *-epo.fif file in the '3_epochs' directory. """ # Load filtered data input_path = root + '/2_rawfilter/' + subject + '-raw.fif' raw = mne.io.read_raw_fif(input_path) # Load e-prime df eprime_df = data_path + subject + '/' + subject + fname['eprime'] eprime = pd.read_csv(eprime_df, skiprows=1, sep='\t') eprime = eprime[behav_var] # Revome training rows after pause for TNT eprime = eprime.drop(eprime.index[[97, 195, 293]]) eprime.reset_index(inplace=True) # Find stim presentation in raw data events = mne.find_events(raw, stim_channel='STI 014') # Compensate for delay (as measured manually with photodiod) events[:, 0] += int(.015 * raw.info['sfreq']) # Keep only Obj Pres triggers events = events[events[:, 2] == 7, :] # Match stim presentation with conditions from eprime df for i in range(len(events)): if eprime['Cond1'][i] == 'Think': if eprime['Cond2'][i] == 'Emotion': events[i, 2] = 1 else: events[i, 2] = 2 elif eprime['Cond1'][i] == 'No-Think': if eprime['Cond2'][i] == 'Emotion': events[i, 2] = 3 else: events[i, 2] = 4 else: events[i, 2] = 5 # Set event id id = {'Think/EMO': 1, 'Think/NEU': 2, 'No-Think/EMO': 3, 'No-Think/NEU': 4} # Epoch raw data tmin, tmax = -1.5, 4 epochs = mne.Epochs(raw, events, id, tmin, tmax, preload=True) # Save epochs epochs.save(root + '/3_epochs/' + subject + '-epo.fif') epochs.info['projs'] = list() # remove proj ransac = Ransac(verbose='progressbar', n_jobs=1) epochs_clean = ransac.fit_transform(epochs) evoked = epochs.average().crop(-0.2, 3.0)\ .apply_baseline(baseline=(None, 0)) evoked_clean = epochs_clean.average()\ .crop(-0.2, 3.0).apply_baseline(baseline=(None, 0)) evoked.info['bads'] = ransac.bad_chs_ evoked_clean.info['bads'] = ransac.bad_chs_ # Evoked differences fig, axes = plt.subplots(2, 1, figsize=(6, 6)) for ax in axes: ax.tick_params(axis='x', which='both', bottom='off', top='off') ax.tick_params(axis='y', which='both', left='off', right='off') ylim = dict(grad=(-200, 200)) evoked.plot(exclude=[], axes=axes[0], ylim=ylim, show=False) axes[0].set_title('Before RANSAC') evoked_clean.plot(exclude=[], axes=axes[1], ylim=ylim) axes[1].set_title('After RANSAC') fig.tight_layout() plt.savefig(root + '/3_epochs/' + subject + '-evoked.png') plt.clf() plt.close() # Heatmap ch_names = [epochs.ch_names[ii] for ii in ransac.picks][7::10] fig, ax = plt.subplots(1, 1, figsize=(12, 6)) ax.imshow(ransac.bad_log.T, cmap='Reds', interpolation='nearest') ax.grid(False) ax.set_xlabel('Trials', size=15) ax.set_ylabel('Sensors', size=15) plt.setp(ax, yticks=range(7, len(ransac.picks), 10), yticklabels=ch_names) ax.tick_params(axis=u'both', which=u'both', length=0) fig.tight_layout(rect=[None, None, None, 1.1]) ax.set_title('Bad sensors', size=25) plt.savefig(root + '/3_epochs/' + subject + '-heatmap.png') plt.clf() plt.close() # Save epochs epochs_clean.save(root + '/3_epochs/' + subject + 'clean-epo.fif')