def perform_reference(self): """Estimate the true signal mean and interpolate bad channels. This function implements the functionality of the `performReference` function as part of the PREP pipeline on mne raw object. Notes ----- This function calls robust_reference first Currently this function only implements the functionality of default settings, i.e., doRobustPost """ # Phase 1: Estimate the true signal mean with robust referencing self.robust_reference() if self.noisy_channels["bad_all"]: self.raw.info["bads"] = self.noisy_channels["bad_all"] self.raw.interpolate_bads() self.reference_signal = (np.nanmean( self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6) rereferenced_index = [ self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels ] self.EEG = self.remove_reference(self.EEG, self.reference_signal, rereferenced_index) # Phase 2: Find the bad channels and interpolate self.raw._data = self.EEG * 1e-6 noisy_detector = NoisyChannels(self.raw) noisy_detector.find_all_bads(ransac=self.ransac) # Record Noisy channels and EEG before interpolation self.bad_before_interpolation = noisy_detector.get_bads(verbose=True) self.EEG_before_interpolation = self.EEG.copy() bad_channels = _union(self.bad_before_interpolation, self.unusable_channels) self.raw.info["bads"] = bad_channels self.raw.interpolate_bads() reference_correct = (np.nanmean( self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6) self.EEG = self.raw.get_data() * 1e6 self.EEG = self.remove_reference(self.EEG, reference_correct, rereferenced_index) # reference signal after interpolation self.reference_signal_new = self.reference_signal + reference_correct # MNE Raw object after interpolation self.raw._data = self.EEG * 1e-6 # Still noisy channels after interpolation self.interpolated_channels = bad_channels noisy_detector = NoisyChannels(self.raw) noisy_detector.find_all_bads(ransac=self.ransac) self.still_noisy_channels = noisy_detector.get_bads() self.raw.info["bads"] = self.still_noisy_channels return self
def robust_reference(self): """Detect bad channels and estimate the robust reference signal. This function implements the functionality of the `robustReference` function as part of the PREP pipeline on mne raw object. Parameters ---------- ransac : boolean Whether or not to use ransac Returns ------- noisy_channels: dictionary A dictionary of names of noisy channels detected from all methods after referencing reference_signal: 1D Array Estimation of the 'true' signal mean """ raw = self.raw.copy() raw._data = removeTrend(raw.get_data(), sample_rate=self.sfreq) # Determine unusable channels and remove them from the reference channels noisy_detector = NoisyChannels(raw, do_detrend=False, random_state=self.random_state) noisy_detector.find_all_bads(ransac=self.ransac) self.noisy_channels_original = { "bad_by_nan": noisy_detector.bad_by_nan, "bad_by_flat": noisy_detector.bad_by_flat, "bad_by_deviation": noisy_detector.bad_by_deviation, "bad_by_hf_noise": noisy_detector.bad_by_hf_noise, "bad_by_correlation": noisy_detector.bad_by_correlation, "bad_by_ransac": noisy_detector.bad_by_ransac, "bad_all": noisy_detector.get_bads(), } self.noisy_channels = self.noisy_channels_original.copy() logger.info("Bad channels: {}".format(self.noisy_channels)) self.unusable_channels = _union(noisy_detector.bad_by_nan, noisy_detector.bad_by_flat) # According to the Matlab Implementation (see robustReference.m) # self.unusable_channels = _union(self.unusable_channels, # noisy_detector.bad_by_SNR) # but maybe this makes no difference... self.reference_channels = _set_diff(self.reference_channels, self.unusable_channels) # Get initial estimate of the reference by the specified method signal = raw.get_data() * 1e6 self.reference_signal = ( np.nanmedian(raw.get_data(picks=self.reference_channels), axis=0) * 1e6) reference_index = [ self.ch_names_eeg.index(ch) for ch in self.reference_channels ] signal_tmp = self.remove_reference(signal, self.reference_signal, reference_index) # Remove reference from signal, iteratively interpolating bad channels raw_tmp = raw.copy() iterations = 0 noisy_channels_old = [] max_iteration_num = 4 while True: raw_tmp._data = signal_tmp * 1e-6 noisy_detector = NoisyChannels(raw_tmp, do_detrend=False, random_state=self.random_state) # Detrend applied at the beginning of the function. noisy_detector.find_all_bads(ransac=self.ransac) self.noisy_channels["bad_by_nan"] = _union( self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan) self.noisy_channels["bad_by_flat"] = _union( self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat) self.noisy_channels["bad_by_deviation"] = _union( self.noisy_channels["bad_by_deviation"], noisy_detector.bad_by_deviation) self.noisy_channels["bad_by_hf_noise"] = _union( self.noisy_channels["bad_by_hf_noise"], noisy_detector.bad_by_hf_noise) self.noisy_channels["bad_by_correlation"] = _union( self.noisy_channels["bad_by_correlation"], noisy_detector.bad_by_correlation, ) self.noisy_channels["bad_by_ransac"] = _union( self.noisy_channels["bad_by_ransac"], noisy_detector.bad_by_ransac) self.noisy_channels["bad_all"] = _union( self.noisy_channels["bad_all"], noisy_detector.get_bads()) logger.info("Bad channels: {}".format(self.noisy_channels)) if (iterations > 1 and (not self.noisy_channels["bad_all"] or set( self.noisy_channels["bad_all"]) == set(noisy_channels_old)) or iterations > max_iteration_num): break noisy_channels_old = self.noisy_channels["bad_all"].copy() if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2: raise ValueError( "RobustReference:TooManyBad " "Could not perform a robust reference -- not enough good channels" ) if self.noisy_channels["bad_all"]: raw_tmp._data = signal * 1e-6 raw_tmp.info["bads"] = self.noisy_channels["bad_all"] raw_tmp.interpolate_bads() signal_tmp = raw_tmp.get_data() * 1e6 else: signal_tmp = signal self.reference_signal = (np.nanmean( raw_tmp.get_data(picks=self.reference_channels), axis=0) * 1e6) signal_tmp = self.remove_reference(signal, self.reference_signal, reference_index) iterations = iterations + 1 logger.info("Iterations: {}".format(iterations)) logger.info("Robust reference done") return self.noisy_channels, self.reference_signal
def test_findnoisychannels(raw, montage): raw.set_montage(montage) nd = NoisyChannels(raw) nd.find_all_bads(ransac=True) bads = nd.get_bads() iterations = ( 10 # remove any noisy channels by interpolating the bads for 10 iterations ) for iter in range(0, iterations): raw.info["bads"] = bads raw.interpolate_bads() nd = NoisyChannels(raw) nd.find_all_bads(ransac=True) bads = nd.get_bads() # make sure no bad channels exist in the data raw.drop_channels(ch_names=bads) # Test for NaN and flat channels raw_tmp = raw.copy() m, n = raw_tmp._data.shape # Insert a nan value for a random channel rand_chn_idx1 = int(np.random.randint(0, m, 1)) rand_chn_idx2 = int(np.random.randint(0, m, 1)) rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1] rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2] raw_tmp._data[rand_chn_idx1, n - 1] = np.nan raw_tmp._data[rand_chn_idx2, :] = np.ones(n) nd = NoisyChannels(raw_tmp) nd.find_bad_by_nan_flat() assert nd.bad_by_nan == [rand_chn_lab1] assert nd.bad_by_flat == [rand_chn_lab2] # Test for high and low deviations in EEG data raw_tmp = raw.copy() m, n = raw_tmp._data.shape # Now insert one random channel with very low deviations rand_chn_idx = int(np.random.randint(0, m, 1)) rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10 nd = NoisyChannels(raw_tmp) nd.find_bad_by_deviation() assert rand_chn_lab in nd.bad_by_deviation # Inserting one random channel with a high deviation raw_tmp = raw.copy() rand_chn_idx = int(np.random.randint(0, m, 1)) rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] arbitrary_scaling = 5 raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling nd = NoisyChannels(raw_tmp) nd.find_bad_by_deviation() assert rand_chn_lab in nd.bad_by_deviation # Test for correlation between EEG channels raw_tmp = raw.copy() m, n = raw_tmp._data.shape rand_chn_idx = int(np.random.randint(0, m, 1)) rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] # Use cosine instead of sine to create a signal low = 10 high = 30 n_freq = 5 signal = np.zeros((1, n)) for freq_i in range(n_freq): freq = np.random.randint(low, high, n) signal[0, :] += np.cos(2 * np.pi * raw.times * freq) raw_tmp._data[rand_chn_idx, :] = signal * 1e-6 nd = NoisyChannels(raw_tmp) nd.find_bad_by_correlation() assert rand_chn_lab in nd.bad_by_correlation # Test for high freq noise detection raw_tmp = raw.copy() m, n = raw_tmp._data.shape rand_chn_idx = int(np.random.randint(0, m, 1)) rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] # Use freqs between 90 and 100 Hz to insert hf noise signal = np.zeros((1, n)) for freq_i in range(n_freq): freq = np.random.randint(90, 100, n) signal[0, :] += np.sin(2 * np.pi * raw.times * freq) raw_tmp._data[rand_chn_idx, :] = signal * 1e-6 nd = NoisyChannels(raw_tmp) nd.find_bad_by_hfnoise() assert rand_chn_lab in nd.bad_by_hf_noise # Test for signal to noise ratio in EEG data raw_tmp = raw.copy() m, n = raw_tmp._data.shape rand_chn_idx = int(np.random.randint(0, m, 1)) rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] # inserting an uncorrelated high frequency (90 Hz) signal in one channel raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6 nd = NoisyChannels(raw_tmp) nd.find_bad_by_SNR() assert rand_chn_lab in nd.bad_by_SNR # Test for finding bad channels by RANSAC raw_tmp = raw.copy() # Ransac identifies channels that go bad together and are highly correlated. # Inserting highly correlated signal in channels 0 through 3 at 30 Hz raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 nd = NoisyChannels(raw_tmp) np.random.seed(30) nd.find_bad_by_ransac() bads = nd.bad_by_ransac assert bads == raw_tmp.ch_names[0:6]
def test_findnoisychannels(raw, montage): """Test find noisy channels.""" # Set a random state for the test rng = np.random.RandomState(30) raw.set_montage(montage) nd = NoisyChannels(raw, random_state=rng) nd.find_all_bads(ransac=True) bads = nd.get_bads() iterations = ( 10 # remove any noisy channels by interpolating the bads for 10 iterations ) for iter in range(0, iterations): if len(bads) == 0: continue raw.info["bads"] = bads raw.interpolate_bads() nd = NoisyChannels(raw, random_state=rng) nd.find_all_bads(ransac=True) bads = nd.get_bads() # make sure no bad channels exist in the data raw.drop_channels(ch_names=bads) # Test for NaN and flat channels raw_tmp = raw.copy() m, n = raw_tmp._data.shape # Insert a nan value for a random channel and make another random channel # completely flat (ones) idxs = rng.choice(np.arange(m), size=2, replace=False) rand_chn_idx1 = idxs[0] rand_chn_idx2 = idxs[1] rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1] rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2] raw_tmp._data[rand_chn_idx1, n - 1] = np.nan raw_tmp._data[rand_chn_idx2, :] = np.ones(n) nd = NoisyChannels(raw_tmp, random_state=rng) nd.find_bad_by_nan_flat() assert nd.bad_by_nan == [rand_chn_lab1] assert nd.bad_by_flat == [rand_chn_lab2] # Test for high and low deviations in EEG data raw_tmp = raw.copy() m, n = raw_tmp._data.shape # Now insert one random channel with very low deviations rand_chn_idx = int(rng.randint(0, m, 1)) rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10 nd = NoisyChannels(raw_tmp, random_state=rng) nd.find_bad_by_deviation() assert rand_chn_lab in nd.bad_by_deviation # Inserting one random channel with a high deviation raw_tmp = raw.copy() rand_chn_idx = int(rng.randint(0, m, 1)) rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] arbitrary_scaling = 5 raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling nd = NoisyChannels(raw_tmp, random_state=rng) nd.find_bad_by_deviation() assert rand_chn_lab in nd.bad_by_deviation # Test for correlation between EEG channels raw_tmp = raw.copy() m, n = raw_tmp._data.shape rand_chn_idx = int(rng.randint(0, m, 1)) rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] # Use cosine instead of sine to create a signal low = 10 high = 30 n_freq = 5 signal = np.zeros((1, n)) for freq_i in range(n_freq): freq = rng.randint(low, high, n) signal[0, :] += np.cos(2 * np.pi * raw.times * freq) raw_tmp._data[rand_chn_idx, :] = signal * 1e-6 nd = NoisyChannels(raw_tmp, random_state=rng) nd.find_bad_by_correlation() assert rand_chn_lab in nd.bad_by_correlation # Test for high freq noise detection raw_tmp = raw.copy() m, n = raw_tmp._data.shape rand_chn_idx = int(rng.randint(0, m, 1)) rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] # Use freqs between 90 and 100 Hz to insert hf noise signal = np.zeros((1, n)) for freq_i in range(n_freq): freq = rng.randint(90, 100, n) signal[0, :] += np.sin(2 * np.pi * raw.times * freq) raw_tmp._data[rand_chn_idx, :] = signal * 1e-6 nd = NoisyChannels(raw_tmp, random_state=rng) nd.find_bad_by_hfnoise() assert rand_chn_lab in nd.bad_by_hf_noise # Test for signal to noise ratio in EEG data raw_tmp = raw.copy() m, n = raw_tmp._data.shape rand_chn_idx = int(rng.randint(0, m, 1)) rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] # inserting an uncorrelated high frequency (90 Hz) signal in one channel raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6 nd = NoisyChannels(raw_tmp, random_state=rng) nd.find_bad_by_SNR() assert rand_chn_lab in nd.bad_by_SNR # Test for finding bad channels by RANSAC raw_tmp = raw.copy() # Ransac identifies channels that go bad together and are highly correlated. # Inserting highly correlated signal in channels 0 through 3 at 30 Hz raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 nd = NoisyChannels(raw_tmp, random_state=rng) nd.find_bad_by_ransac() bads = nd.bad_by_ransac assert bads == raw_tmp.ch_names[0:6] # Test for finding bad channels by channel-wise RANSAC raw_tmp = raw.copy() # Ransac identifies channels that go bad together and are highly correlated. # Inserting highly correlated signal in channels 0 through 3 at 30 Hz raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 nd = NoisyChannels(raw_tmp, random_state=rng) nd.find_bad_by_ransac(channel_wise=True) bads = nd.bad_by_ransac assert bads == raw_tmp.ch_names[0:6] # Test not-enough-memory and n_samples type exceptions raw_tmp = raw.copy() raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 nd = NoisyChannels(raw_tmp, random_state=rng) # Set n_samples very very high to trigger a memory error n_samples = int(1e100) with pytest.raises(MemoryError): nd.find_bad_by_ransac(n_samples=n_samples) # Set n_samples to a float to trigger a type error n_samples = 35.5 with pytest.raises(TypeError): nd.find_bad_by_ransac(n_samples=n_samples) # Test IOError when not enough channels for ransac predictions raw_tmp = raw.copy() # Make flat all channels except 2 num_bad_channels = raw._data.shape[0] - 2 raw_tmp._data[0:num_bad_channels, :] = np.zeros_like( raw_tmp._data[0:num_bad_channels, :] ) nd = NoisyChannels(raw_tmp, random_state=rng) nd.find_all_bads(ransac=False) with pytest.raises(IOError): nd.find_bad_by_ransac()
def perform_reference(self, max_iterations=4): """Estimate the true signal mean and interpolate bad channels. Parameters ---------- max_iterations : int, optional The maximum number of iterations of noisy channel removal to perform during robust referencing. Defaults to ``4``. This function implements the functionality of the `performReference` function as part of the PREP pipeline on mne raw object. Notes ----- This function calls ``robust_reference`` first. Currently this function only implements the functionality of default settings, i.e., ``doRobustPost``. """ # Phase 1: Estimate the true signal mean with robust referencing self.robust_reference(max_iterations) # If we interpolate the raw here we would be interpolating # more than what we later actually account for (in interpolated channels). dummy = self.raw.copy() dummy.info["bads"] = self.noisy_channels["bad_all"] if self.matlab_strict: _eeglab_interpolate_bads(dummy) else: dummy.interpolate_bads() self.reference_signal = np.nanmean( dummy.get_data(picks=self.reference_channels), axis=0) del dummy rereferenced_index = [ self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels ] self.EEG = self.remove_reference(self.EEG, self.reference_signal, rereferenced_index) # Phase 2: Find the bad channels and interpolate self.raw._data = self.EEG noisy_detector = NoisyChannels(self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict) noisy_detector.find_all_bads(**self.ransac_settings) # Record Noisy channels and EEG before interpolation self.bad_before_interpolation = noisy_detector.get_bads(verbose=True) self.EEG_before_interpolation = self.EEG.copy() self.noisy_channels_before_interpolation = noisy_detector.get_bads( as_dict=True) self._extra_info["interpolated"] = noisy_detector._extra_info bad_channels = _union(self.bad_before_interpolation, self.unusable_channels) self.raw.info["bads"] = bad_channels if self.matlab_strict: _eeglab_interpolate_bads(self.raw) else: self.raw.interpolate_bads() reference_correct = np.nanmean( self.raw.get_data(picks=self.reference_channels), axis=0) self.EEG = self.raw.get_data() self.EEG = self.remove_reference(self.EEG, reference_correct, rereferenced_index) # reference signal after interpolation self.reference_signal_new = self.reference_signal + reference_correct # MNE Raw object after interpolation self.raw._data = self.EEG # Still noisy channels after interpolation self.interpolated_channels = bad_channels noisy_detector = NoisyChannels(self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict) noisy_detector.find_all_bads(**self.ransac_settings) self.still_noisy_channels = noisy_detector.get_bads() self.raw.info["bads"] = self.still_noisy_channels self.noisy_channels_after_interpolation = noisy_detector.get_bads( as_dict=True) self._extra_info["remaining_bad"] = noisy_detector._extra_info return self
def robust_reference(self, max_iterations=4): """Detect bad channels and estimate the robust reference signal. This function implements the functionality of the `robustReference` function as part of the PREP pipeline on mne raw object. Parameters ---------- max_iterations : int, optional The maximum number of iterations of noisy channel removal to perform during robust referencing. Defaults to ``4``. Returns ------- noisy_channels: dict A dictionary of names of noisy channels detected from all methods after referencing. reference_signal: np.ndarray, shape(n, ) Estimation of the 'true' signal mean """ raw = self.raw.copy() raw._data = removeTrend(raw.get_data(), self.sfreq, matlab_strict=self.matlab_strict) # Determine unusable channels and remove them from the reference channels noisy_detector = NoisyChannels( raw, do_detrend=False, random_state=self.random_state, matlab_strict=self.matlab_strict, ) noisy_detector.find_all_bads(**self.ransac_settings) self.noisy_channels_original = noisy_detector.get_bads(as_dict=True) self._extra_info["initial_bad"] = noisy_detector._extra_info logger.info("Bad channels: {}".format(self.noisy_channels_original)) # Determine channels to use/exclude from initial reference estimation self.unusable_channels = _union( noisy_detector.bad_by_nan + noisy_detector.bad_by_flat, noisy_detector.bad_by_SNR, ) reference_channels = _set_diff(self.reference_channels, self.unusable_channels) # Initialize channels to permanently flag as bad during referencing noisy = { "bad_by_nan": noisy_detector.bad_by_nan, "bad_by_flat": noisy_detector.bad_by_flat, "bad_by_deviation": [], "bad_by_hf_noise": [], "bad_by_correlation": [], "bad_by_SNR": [], "bad_by_dropout": [], "bad_by_ransac": [], "bad_all": [], } # Get initial estimate of the reference by the specified method signal = raw.get_data() self.reference_signal = np.nanmedian( raw.get_data(picks=reference_channels), axis=0) reference_index = [ self.ch_names_eeg.index(ch) for ch in reference_channels ] signal_tmp = self.remove_reference(signal, self.reference_signal, reference_index) # Remove reference from signal, iteratively interpolating bad channels raw_tmp = raw.copy() iterations = 0 previous_bads = set() while True: raw_tmp._data = signal_tmp noisy_detector = NoisyChannels( raw_tmp, do_detrend=False, random_state=self.random_state, matlab_strict=self.matlab_strict, ) # Detrend applied at the beginning of the function. # Detect all currently bad channels noisy_detector.find_all_bads(**self.ransac_settings) noisy_new = noisy_detector.get_bads(as_dict=True) # Specify bad channel types to ignore when updating noisy channels # NOTE: MATLAB PREP ignores dropout channels, possibly by mistake? # see: https://github.com/VisLab/EEG-Clean-Tools/issues/28 ignore = ["bad_by_SNR", "bad_all"] if self.matlab_strict: ignore += ["bad_by_dropout"] # Update set of all noisy channels detected so far with any new ones bad_chans = set() for bad_type in noisy_new.keys(): noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type]) if bad_type not in ignore: bad_chans.update(noisy[bad_type]) noisy["bad_all"] = list(bad_chans) logger.info("Bad channels: {}".format(noisy)) if (iterations > 1 and (len(bad_chans) == 0 or bad_chans == previous_bads) or iterations > max_iterations): logger.info("Robust reference done") self.noisy_channels = noisy break previous_bads = bad_chans.copy() if raw_tmp.info["nchan"] - len(bad_chans) < 2: raise ValueError( "RobustReference:TooManyBad " "Could not perform a robust reference -- not enough good channels" ) if len(bad_chans) > 0: raw_tmp._data = signal.copy() raw_tmp.info["bads"] = list(bad_chans) if self.matlab_strict: _eeglab_interpolate_bads(raw_tmp) else: raw_tmp.interpolate_bads() self.reference_signal = np.nanmean( raw_tmp.get_data(picks=reference_channels), axis=0) signal_tmp = self.remove_reference(signal, self.reference_signal, reference_index) iterations = iterations + 1 logger.info("Iterations: {}".format(iterations)) return self.noisy_channels, self.reference_signal
def preproc(self, event_dict, baseline_start, stim_dur, montage, out_dir='preproc', subjects='all', tasks='all', hp_cutoff=1, lp_cutoff=50, line_noise=60, seed=42, eog_chan='none'): '''Preprocesses a single EEG file. Assigns a list of epoched data to Dataset instance, where each entry in the list is a subject with concatenated task data. Here is the basic structure of the preprocessing workflow: - Set the montage - Band-pass filter (high-pass filter by default) - Automatically detect bad channels - Notch filter out line-noise - Reference data to average of all EEG channels - Automated removal of eye-related artifacts using ICA - Spherical interpolation of detected bad channels - Extract events and epoch the data accordingly - Align the events based on type (still need to implement this!) - Create a list of epoched data, with subject as the element concatenated across tasks Parameters ---------- event_dict: dict Maps integers to semantic labels for events within the experiment baseline_start: int or float Specify start of the baseline period (in seconds) stim_dur: int or float Stimulus duration (in seconds) Note: may need to make more general to allow various durations montage: mne.channels.montage.DigMontage Maps sensor locations to coordinates subjects: list or 'all' Specify which subjects to iterate through tasks: list or 'all' Specify which tasks to iterate through hp_cutoff: int or float The low frequency bound for the highpass filter in Hz line_noise: int or float The frequency of electrical noise to filter out in Hz seed: int Set the seed for replicable results eog_chan: str If there are no EOG channels present, select an EEG channel near the eyes for eye-related artifact detection ''' missing = [] # initialize missing file list subj_iter = self.gen_iter(subjects, self.n_subj) # get subject iterator task_iter = self.gen_iter(tasks, self.n_task) # get task iterator # Iterate through subjects (initialize subject epoch list) epochs_subj = [] for subj in subj_iter: # Iterate through tasks (initialize within-subject task epoch list) epochs_task = [] for task in task_iter: # Specify the file format self.get_file_format(subj, task) try: # Handles missing files raw = self.wget_raw_edf() # read except: print(f'---\nThis file does not exist: {self.file_path}\n---') # Need to write the missing file list out missing.append(self.file_path) break # Standardize montage mne.datasets.eegbci.standardize(raw) # Set montage and strip channel names of "." characters raw.set_montage(montage) raw.rename_channels(lambda x: x.strip('.')) # Apply high-pass filter np.random.seed(seed) raw.filter(l_freq=hp_cutoff, h_freq=lp_cutoff, picks=['eeg', 'eog']) # Instantiate NoisyChannels object noise_chans = NoisyChannels(raw, do_detrend=False) # Detect bad channels through multiple methods noise_chans.find_bad_by_nan_flat() noise_chans.find_bad_by_deviation() noise_chans.find_bad_by_SNR() # Set the bad channels in the raw object raw.info['bads'] = noise_chans.get_bads() print(f'Bad channels detected: {noise_chans.get_bads()}') # Define the frequencies for the notch filter (60Hz and its harmonics) #notch_filt = np.arange(line_noise, raw.info['sfreq'] // 2, line_noise) # Apply notch filter #print(f'Apply notch filter at {line_noise} Hz and its harmonics') #raw.notch_filter(notch_filt, picks=['eeg', 'eog'], verbose=False) # Reference to the average of all the good channels # Automatically excludes raw.info['bads'] raw.set_eeg_reference(ref_channels='average') # Instantiate ICA object ica = mne.preprocessing.ICA(max_iter=1000) # Run ICA ica.fit(raw) # Find which ICs match the EOG pattern if eog_chan == 'none': eog_indices, eog_scores = ica.find_bads_eog(raw, verbose=False) else: eog_indices, eog_scores = ica.find_bads_eog(raw, eog_chan, verbose=False) # Apply the IC blink removals (if any) ica.apply(raw, exclude=eog_indices) print(f'Removed IC index {eog_indices}') # Interpolate bad channels raw.interpolate_bads() # Specify pre-processing directory preproc_dir = Path(f'{out_dir}/ {subj}') # If directory does not exist, one will be created. if not os.path.isdir(preproc_dir): os.makedirs(preproc_dir) # raw.save(Path(preproc_dir, f'subj{subj}_task{task}_raw.fif'), # overwrite=True) # Find events events = mne.events_from_annotations(raw)[0] # Epoch the data preproc_epoch = mne.Epochs(raw, events, tmin=baseline_start, tmax=stim_dur, event_id=event_dict, event_repeated='error', on_missing='ignore', preload=True) # Equalize event counts preproc_epoch.equalize_event_counts(event_dict.keys()) # Rearrange and align the epochs align = [preproc_epoch[i] for i in event_dict.keys()] align_epoch = mne.concatenate_epochs(align) # Add to epoch list epochs_task.append(align_epoch) # Assuming some data exists for a subject # Concatenate epochs within subject concat_epoch = mne.concatenate_epochs(epochs_task) epochs_subj.append(concat_epoch) # Attaches a list with each entry corresponding to epochs for a subject self.epoch_list = epochs_subj
def perform_reference(self): """Estimate the true signal mean and interpolate bad channels. This function implements the functionality of the `performReference` function as part of the PREP pipeline on mne raw object. Notes ----- This function calls ``robust_reference`` first. Currently this function only implements the functionality of default settings, i.e., ``doRobustPost``. """ # Phase 1: Estimate the true signal mean with robust referencing self.robust_reference() # If we interpolate the raw here we would be interpolating # more than what we later actually account for (in interpolated channels). dummy = self.raw.copy() dummy.info["bads"] = self.noisy_channels["bad_all"] dummy.interpolate_bads() self.reference_signal = ( np.nanmean(dummy.get_data(picks=self.reference_channels), axis=0) * 1e6 ) del dummy rereferenced_index = [ self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels ] self.EEG = self.remove_reference( self.EEG, self.reference_signal, rereferenced_index ) # Phase 2: Find the bad channels and interpolate self.raw._data = self.EEG * 1e-6 noisy_detector = NoisyChannels(self.raw, random_state=self.random_state) noisy_detector.find_all_bads(ransac=self.ransac) # Record Noisy channels and EEG before interpolation self.bad_before_interpolation = noisy_detector.get_bads(verbose=True) self.EEG_before_interpolation = self.EEG.copy() self.noisy_channels_before_interpolation = { "bad_by_nan": noisy_detector.bad_by_nan, "bad_by_flat": noisy_detector.bad_by_flat, "bad_by_deviation": noisy_detector.bad_by_deviation, "bad_by_hf_noise": noisy_detector.bad_by_hf_noise, "bad_by_correlation": noisy_detector.bad_by_correlation, "bad_by_SNR": noisy_detector.bad_by_SNR, "bad_by_dropout": noisy_detector.bad_by_dropout, "bad_by_ransac": noisy_detector.bad_by_ransac, "bad_all": noisy_detector.get_bads(), } bad_channels = _union(self.bad_before_interpolation, self.unusable_channels) self.raw.info["bads"] = bad_channels self.raw.interpolate_bads() reference_correct = ( np.nanmean(self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6 ) self.EEG = self.raw.get_data() * 1e6 self.EEG = self.remove_reference( self.EEG, reference_correct, rereferenced_index ) # reference signal after interpolation self.reference_signal_new = self.reference_signal + reference_correct # MNE Raw object after interpolation self.raw._data = self.EEG * 1e-6 # Still noisy channels after interpolation self.interpolated_channels = bad_channels noisy_detector = NoisyChannels(self.raw, random_state=self.random_state) noisy_detector.find_all_bads(ransac=self.ransac) self.still_noisy_channels = noisy_detector.get_bads() self.raw.info["bads"] = self.still_noisy_channels self.noisy_channels_after_interpolation = { "bad_by_nan": noisy_detector.bad_by_nan, "bad_by_flat": noisy_detector.bad_by_flat, "bad_by_deviation": noisy_detector.bad_by_deviation, "bad_by_hf_noise": noisy_detector.bad_by_hf_noise, "bad_by_correlation": noisy_detector.bad_by_correlation, "bad_by_SNR": noisy_detector.bad_by_SNR, "bad_by_dropout": noisy_detector.bad_by_dropout, "bad_by_ransac": noisy_detector.bad_by_ransac, "bad_all": noisy_detector.get_bads(), } return self