def test_highpass(): """Test for checking high pass filters.""" srate = 100 t = np.arange(0, 30, 1 / srate) lowfreq_signal = np.sin(2 * np.pi * 0.1 * t) highfreq_signal = np.sin(2 * np.pi * 8 * t) signal = lowfreq_signal + highfreq_signal lowpass_filt1 = removeTrend.removeTrend(signal, detrendType="High pass sinc", sample_rate=srate, detrendCutoff=1) lowpass_filt2 = removeTrend.removeTrend(signal, detrendType="High pass", sample_rate=srate, detrendCutoff=1) lowpass_filt3 = removeTrend.removeTrend( signal, detrendType="High pass", sample_rate=srate, detrendCutoff=1, matlab_strict=True, ) error1 = lowpass_filt1 - highfreq_signal error2 = lowpass_filt2 - highfreq_signal error3 = lowpass_filt3 - highfreq_signal assert np.sqrt(np.mean(error1**2)) < 0.1 assert np.sqrt(np.mean(error2**2)) < 0.1 assert np.sqrt(np.mean(error3**2)) < 0.1
def test_compare_removeTrend(matprep_artifacts): """Test the numeric equivalence of removeTrend to MATLAB PREP.""" # Get paths of MATLAB .set files raw_path = matprep_artifacts["1_matprep_raw"] detrend_path = matprep_artifacts["2_matprep_removetrend"] # Load relevant MATLAB data matprep_raw = mne.io.read_raw_eeglab(raw_path, preload=True) matprep_detrended = mne.io.read_raw_eeglab(detrend_path, preload=True) sample_rate = matprep_raw.info["sfreq"] # Apply removeTrend to raw artifact to get expected and actual signals expected = matprep_detrended._data actual = removeTrend(matprep_raw._data, sample_rate, detrendType="high pass", matlab_strict=True) # Check MATLAB equivalence at start of recording win_size = 500 # window of samples to check assert np.allclose(actual[:, :win_size], expected[:, :win_size], equal_nan=True) # Check MATLAB equivalence in middle of recording win_start = int(actual.shape[1] / 2) win_end = win_start + win_size assert np.allclose(actual[:, win_start:win_end], expected[:, win_start:win_end], equal_nan=True) # Check MATLAB equivalence at end of recording assert np.allclose(actual[:, -win_size:], expected[:, -win_size:], equal_nan=True)
def fit(self): """Run the whole PREP pipeline.""" noisy_detector = NoisyChannels(self.raw_eeg, random_state=self.random_state) noisy_detector.find_bad_by_nan_flat() # unusable_channels = _union( # noisy_detector.bad_by_nan, noisy_detector.bad_by_flat # ) # reference_channels = _set_diff(self.prep_params["ref_chs"], unusable_channels) # Step 1: 1Hz high pass filtering if len(self.prep_params["line_freqs"]) != 0: self.EEG_new = removeTrend(self.EEG_raw, sample_rate=self.sfreq) # Step 2: Removing line noise linenoise = self.prep_params["line_freqs"] if self.filter_kwargs is None: self.EEG_clean = mne.filter.notch_filter( self.EEG_new, Fs=self.sfreq, freqs=linenoise, method="spectrum_fit", mt_bandwidth=2, p_value=0.01, filter_length="10s", ) else: self.EEG_clean = mne.filter.notch_filter( self.EEG_new, Fs=self.sfreq, freqs=linenoise, **self.filter_kwargs, ) # Add Trend back self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean self.raw_eeg._data = self.EEG * 1e-6 # Step 3: Referencing reference = Reference( self.raw_eeg, self.prep_params, ransac=self.ransac, random_state=self.random_state, ) reference.perform_reference() self.raw_eeg = reference.raw self.noisy_channels_original = reference.noisy_channels_original self.noisy_channels_before_interpolation = ( reference.noisy_channels_before_interpolation ) self.noisy_channels_after_interpolation = ( reference.noisy_channels_after_interpolation ) self.bad_before_interpolation = reference.bad_before_interpolation self.EEG_before_interpolation = reference.EEG_before_interpolation self.reference_before_interpolation = reference.reference_signal self.reference_after_interpolation = reference.reference_signal_new self.interpolated_channels = reference.interpolated_channels self.still_noisy_channels = reference.still_noisy_channels return self
def __init__(self, raw): """Initialize the class.""" # Make sure that we got an MNE object assert isinstance(raw, mne.io.BaseRaw) self.raw_mne = raw.copy() self.sample_rate = raw.info["sfreq"] self.EEGData = self.raw_mne.get_data(picks="eeg") self.EEGData = removeTrend(self.EEGData, sample_rate=self.sample_rate) self.EEGData_beforeFilt = self.EEGData self.ch_names_original = np.asarray(raw.info["ch_names"]) self.n_chans_original = len(self.ch_names_original) self.n_chans_new = self.n_chans_original self.signal_len = len(self.raw_mne.times) self.original_dimensions = np.shape(self.EEGData) self.new_dimensions = self.original_dimensions self.original_channels = np.arange(self.original_dimensions[0]) self.new_channels = self.original_channels self.ch_names_new = self.ch_names_original self.channels_interpolate = self.original_channels # The identified bad channels self.bad_by_nan = [] self.bad_by_flat = [] self.bad_by_deviation = [] self.bad_by_hf_noise = [] self.bad_by_correlation = [] self.bad_by_SNR = [] self.bad_by_dropout = [] self.bad_by_ransac = []
def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False): # Make sure that we got an MNE object assert isinstance(raw, mne.io.BaseRaw) raw.load_data() self.raw_mne = raw.copy() self.raw_mne.pick_types(eeg=True) self.sample_rate = raw.info["sfreq"] if do_detrend: self.raw_mne._data = removeTrend(self.raw_mne.get_data(), self.sample_rate, matlab_strict=matlab_strict) self.matlab_strict = matlab_strict # Extra data for debugging self._extra_info = { "bad_by_deviation": {}, "bad_by_hf_noise": {}, "bad_by_correlation": {}, "bad_by_dropout": {}, "bad_by_ransac": {}, } # random_state self.random_state = check_random_state(random_state) # The identified bad channels self.bad_by_nan = [] self.bad_by_flat = [] self.bad_by_deviation = [] self.bad_by_hf_noise = [] self.bad_by_correlation = [] self.bad_by_SNR = [] self.bad_by_dropout = [] self.bad_by_ransac = [] # Get original EEG channel names, channel count & samples ch_names = np.asarray(self.raw_mne.info["ch_names"]) self.ch_names_original = ch_names self.n_chans_original = len(ch_names) self.n_samples = raw._data.shape[1] # Before anything else, flag bad-by-NaNs and bad-by-flats self.find_bad_by_nan_flat() bads_by_nan_flat = self.bad_by_nan + self.bad_by_flat # Make a subset of the data containing only usable EEG channels self.usable_idx = np.isin(ch_names, bads_by_nan_flat, invert=True) self.EEGData = self.raw_mne.get_data(picks=ch_names[self.usable_idx]) self.EEGFiltered = None # Get usable EEG channel names & channel counts self.ch_names_new = np.asarray(ch_names[self.usable_idx]) self.n_chans_new = len(self.ch_names_new)
def __init__(self, raw, do_detrend=True, random_state=None): """Initialize the class. Parameters ---------- raw : mne.io.Raw The MNE raw object. do_detrend : bool Whether or not to remove a trend from the data upon initializing the `NoisyChannels` object. Defaults to True. random_state : int | None | np.random.RandomState The random seed at which to initialize the class. If random_state is an int, it will be used as a seed for RandomState. If None, the seed will be obtained from the operating system (see RandomState for details). Default is None. """ # Make sure that we got an MNE object assert isinstance(raw, mne.io.BaseRaw) self.raw_mne = raw.copy() self.sample_rate = raw.info["sfreq"] if do_detrend: self.raw_mne._data = removeTrend( self.raw_mne.get_data(), sample_rate=self.sample_rate ) self.EEGData = self.raw_mne.get_data(picks="eeg") self.EEGData_beforeFilt = self.EEGData self.ch_names_original = np.asarray(raw.info["ch_names"]) self.n_chans_original = len(self.ch_names_original) self.n_chans_new = self.n_chans_original self.signal_len = len(self.raw_mne.times) self.original_dimensions = np.shape(self.EEGData) self.new_dimensions = self.original_dimensions self.original_channels = np.arange(self.original_dimensions[0]) self.new_channels = self.original_channels self.ch_names_new = self.ch_names_original self.channels_interpolate = self.original_channels # random_state self.random_state = check_random_state(random_state) # The identified bad channels self.bad_by_nan = [] self.bad_by_flat = [] self.bad_by_deviation = [] self.bad_by_hf_noise = [] self.bad_by_correlation = [] self.bad_by_SNR = [] self.bad_by_dropout = [] self.bad_by_ransac = []
def raw_clean_detrend(raw_clean): """Return a pre-detrended `mne.io.Raw` object with no bad channels. Based on the data from the `raw_clean` fixture, which uses the data for subject 1, run 1 from the Physionet BCI2000 dataset. This is only run once per session to save time. """ raw_clean_detrended = raw_clean.copy() raw_clean_detrended._data = removeTrend(raw_clean.get_data(), raw_clean.info["sfreq"]) return raw_clean_detrended
def test_detrend(): """Test for local regression to remove linear trend from EEG data.""" # creating a new signal for checking detrending using local regression srate = 100 t = np.arange(0, 30, 1 / srate) randgen = np.random.RandomState(9) npoints = len(t) signal = randgen.randn(npoints) signal_trend = 2 + 1.5 * np.linspace(0, 1, npoints) + signal signal_detrend = removeTrend.removeTrend( signal_trend, detrendType="Local detrend", sample_rate=100 ) error3 = signal_detrend - signal assert np.sqrt(np.mean(error3 ** 2)) < 0.1
def robust_reference(self): """Detect bad channels and estimate the robust reference signal. This function implements the functionality of the `robustReference` function as part of the PREP pipeline on mne raw object. Parameters ---------- ransac : boolean Whether or not to use ransac Returns ------- noisy_channels: dictionary A dictionary of names of noisy channels detected from all methods after referencing reference_signal: 1D Array Estimation of the 'true' signal mean """ raw = self.raw.copy() raw._data = removeTrend(raw.get_data(), sample_rate=self.sfreq) # Determine unusable channels and remove them from the reference channels noisy_detector = NoisyChannels(raw, do_detrend=False, random_state=self.random_state) noisy_detector.find_all_bads(ransac=self.ransac) self.noisy_channels_original = { "bad_by_nan": noisy_detector.bad_by_nan, "bad_by_flat": noisy_detector.bad_by_flat, "bad_by_deviation": noisy_detector.bad_by_deviation, "bad_by_hf_noise": noisy_detector.bad_by_hf_noise, "bad_by_correlation": noisy_detector.bad_by_correlation, "bad_by_ransac": noisy_detector.bad_by_ransac, "bad_all": noisy_detector.get_bads(), } self.noisy_channels = self.noisy_channels_original.copy() logger.info("Bad channels: {}".format(self.noisy_channels)) self.unusable_channels = _union(noisy_detector.bad_by_nan, noisy_detector.bad_by_flat) # According to the Matlab Implementation (see robustReference.m) # self.unusable_channels = _union(self.unusable_channels, # noisy_detector.bad_by_SNR) # but maybe this makes no difference... self.reference_channels = _set_diff(self.reference_channels, self.unusable_channels) # Get initial estimate of the reference by the specified method signal = raw.get_data() * 1e6 self.reference_signal = ( np.nanmedian(raw.get_data(picks=self.reference_channels), axis=0) * 1e6) reference_index = [ self.ch_names_eeg.index(ch) for ch in self.reference_channels ] signal_tmp = self.remove_reference(signal, self.reference_signal, reference_index) # Remove reference from signal, iteratively interpolating bad channels raw_tmp = raw.copy() iterations = 0 noisy_channels_old = [] max_iteration_num = 4 while True: raw_tmp._data = signal_tmp * 1e-6 noisy_detector = NoisyChannels(raw_tmp, do_detrend=False, random_state=self.random_state) # Detrend applied at the beginning of the function. noisy_detector.find_all_bads(ransac=self.ransac) self.noisy_channels["bad_by_nan"] = _union( self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan) self.noisy_channels["bad_by_flat"] = _union( self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat) self.noisy_channels["bad_by_deviation"] = _union( self.noisy_channels["bad_by_deviation"], noisy_detector.bad_by_deviation) self.noisy_channels["bad_by_hf_noise"] = _union( self.noisy_channels["bad_by_hf_noise"], noisy_detector.bad_by_hf_noise) self.noisy_channels["bad_by_correlation"] = _union( self.noisy_channels["bad_by_correlation"], noisy_detector.bad_by_correlation, ) self.noisy_channels["bad_by_ransac"] = _union( self.noisy_channels["bad_by_ransac"], noisy_detector.bad_by_ransac) self.noisy_channels["bad_all"] = _union( self.noisy_channels["bad_all"], noisy_detector.get_bads()) logger.info("Bad channels: {}".format(self.noisy_channels)) if (iterations > 1 and (not self.noisy_channels["bad_all"] or set( self.noisy_channels["bad_all"]) == set(noisy_channels_old)) or iterations > max_iteration_num): break noisy_channels_old = self.noisy_channels["bad_all"].copy() if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2: raise ValueError( "RobustReference:TooManyBad " "Could not perform a robust reference -- not enough good channels" ) if self.noisy_channels["bad_all"]: raw_tmp._data = signal * 1e-6 raw_tmp.info["bads"] = self.noisy_channels["bad_all"] raw_tmp.interpolate_bads() signal_tmp = raw_tmp.get_data() * 1e6 else: signal_tmp = signal self.reference_signal = (np.nanmean( raw_tmp.get_data(picks=self.reference_channels), axis=0) * 1e6) signal_tmp = self.remove_reference(signal, self.reference_signal, reference_index) iterations = iterations + 1 logger.info("Iterations: {}".format(iterations)) logger.info("Robust reference done") return self.noisy_channels, self.reference_signal
def preprocess_eeg(id_num, random_seed=None): # Set important variables bids_path = BIDSPath(id_num, task=task, datatype=datatype, root=bids_root) plot_path = os.path.join(plotdir, "sub_{0}".format(id_num)) if os.path.exists(plot_path): shutil.rmtree(plot_path) os.mkdir(plot_path) if not random_seed: random_seed = int(binascii.b2a_hex(os.urandom(4)), 16) random.seed(random_seed) id_info = {"id": id_num, "random_seed": random_seed} ### Load and prepare EEG data ############################################# header = "### Processing sub-{0} (seed: {1}) ###".format( id_num, random_seed) print("\n" + "#" * len(header)) print(header) print("#" * len(header) + "\n") # Load EEG data raw = read_raw_bids(bids_path, verbose=True) # Check if recording is complete complete = len(raw.annotations) >= 600 # Add a montage to the data montage_kind = "standard_1005" montage = mne.channels.make_standard_montage(montage_kind) mne.datasets.eegbci.standardize(raw) raw.set_montage(montage) # Extract some info eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False) ch_names = raw.info["ch_names"] ch_names_eeg = list(np.asarray(ch_names)[eeg_index]) sample_rate = raw.info["sfreq"] # Make a copy of the data raw_copy = raw.copy() raw_copy.load_data() # Trim duplicated data (only needed for sub-005) annot = raw_copy.annotations file_starts = [a for a in annot if a['description'] == "file start"] if len(file_starts): duplicate_start = file_starts[0]['onset'] raw_copy.crop(tmax=duplicate_start) # Make backup of EOG and EMG channels to re-append after PREP raw_other = raw_copy.copy() raw_other.pick_types(eog=True, emg=True, stim=False) # Prepare copy of raw data for PREP raw_copy.pick_types(eeg=True) # Plot data prior to any processing if complete: save_psd_plot(id_num, "psd_0_raw", plot_path, raw_copy) save_channel_plot(id_num, "ch_0_raw", plot_path, raw_copy) ### Clean up events ####################################################### print("\n\n=== Processing Event Annotations... ===\n") event_names = [ "stim_on", "red_on", "trace_start", "trace_end", "accuracy_submit", "vividness_submit" ] doubled = [] wrong_label = [] new_onsets = [] new_durations = [] new_descriptions = [] # Find and flag any duplicate triggers annot = raw_copy.annotations trigger_count = len(annot) for i in range(1, trigger_count - 1): a = annot[i] on_last = i + 1 == trigger_count prev_trigger = annot[i - 1]['description'] next_onset = annot[i + 1]['onset'] if not on_last else a['onset'] + 100 # Determine whether duplicates are doubles or mislabeled if a['description'] == prev_trigger: if (next_onset - a['onset']) < 0.002: doubled.append(a) else: wrong_label.append(a) # Rename annotations to have meaningful names & fix duplicates for a in raw_copy.annotations: if a in doubled or a['description'] not in event_names: continue if a in wrong_label: index = event_names.index(a['description']) a['description'] = event_names[index + 1] new_onsets.append(a['onset']) new_durations.append(a['duration']) new_descriptions.append(a['description']) # Replace old annotations with new fixed ones if len(annot): new_annot = mne.Annotations( new_onsets, new_durations, new_descriptions, orig_time=raw_copy.annotations[0]['orig_time']) raw_copy.set_annotations(new_annot) # Check annotations to verify we have equal numbers of each orig_counts = Counter(annot.description) counts = Counter(raw_copy.annotations.description) print("Updated Annotation Counts:") for a in event_names: out = " - '{0}': {1} -> {2}" print(out.format(a, orig_counts[a], counts[a])) # Get info id_info['annot_doubled'] = len(doubled) id_info['annot_wrong'] = len(wrong_label) count_vals = [ n for n in counts.values() if n != counts['vividness_submit'] ] id_info['equal_triggers'] = all(x == count_vals[0] for x in count_vals) id_info['stim_on'] = counts['stim_on'] id_info['red_on'] = counts['red_on'] id_info['trace_start'] = counts['trace_start'] id_info['trace_end'] = counts['trace_end'] id_info['acc_submit'] = counts['accuracy_submit'] id_info['vivid_submit'] = counts['vividness_submit'] if not complete: remaining_info = { 'initial_bad': "NA", 'num_initial_bad': "NA", 'interpolated': "NA", 'num_interpolated': "NA", 'remaining_bad': "NA", 'num_remaining_bad': "NA" } id_info.update(remaining_info) e = "\n\n### Incomplete recording for sub-{0}, skipping... ###\n\n" print(e.format(id_num)) return id_info ### Run components of PREP manually ####################################### print("\n\n=== Performing CleanLine... ===") # Try to remove line noise using CleanLine approach linenoise = np.arange(60, sample_rate / 2, 60) EEG_raw = raw_copy.get_data() * 1e6 EEG_new = removeTrend(EEG_raw, sample_rate=raw.info["sfreq"]) EEG_clean = mne.filter.notch_filter( EEG_new, Fs=raw.info["sfreq"], freqs=linenoise, filter_length="10s", method="spectrum_fit", mt_bandwidth=2, p_value=0.01, ) EEG_final = EEG_raw - EEG_new + EEG_clean raw_copy._data = EEG_final * 1e-6 del linenoise, EEG_raw, EEG_new, EEG_clean, EEG_final # Plot data following cleanline save_psd_plot(id_num, "psd_1_cleanline", plot_path, raw_copy) save_channel_plot(id_num, "ch_1_cleanline", plot_path, raw_copy) # Perform robust re-referencing prep_params = {"ref_chs": ch_names_eeg, "reref_chs": ch_names_eeg} reference = Reference(raw_copy, prep_params, ransac=True, random_state=random_seed) print("\n\n=== Performing Robust Re-referencing... ===\n") reference.perform_reference() # If not interpolating bad channels, use pre-interpolation channel data if not interpolate_bads: reference.raw._data = reference.EEG_before_interpolation * 1e-6 reference.interpolated_channels = [] reference.still_noisy_channels = reference.bad_before_interpolation reference.raw.info["bads"] = reference.bad_before_interpolation # Plot data following robust re-reference save_psd_plot(id_num, "psd_2_reref", plot_path, reference.raw) save_channel_plot(id_num, "ch_2_reref", plot_path, reference.raw) # Re-append removed EMG/EOG/trigger channels raw_prepped = reference.raw.add_channels([raw_other]) # Get info initial_bad = reference.noisy_channels_original["bad_all"] id_info['initial_bad'] = " ".join(initial_bad) id_info['num_initial_bad'] = len(initial_bad) interpolated = reference.interpolated_channels id_info['interpolated'] = " ".join(interpolated) id_info['num_interpolated'] = len(interpolated) remaining_bad = reference.still_noisy_channels id_info['remaining_bad'] = " ".join(remaining_bad) id_info['num_remaining_bad'] = len(remaining_bad) # Print re-referencing info print("\nRe-Referencing Info:") print(" - Bad channels original: {0}".format(initial_bad)) if interpolate_bads: print(" - Bad channels after re-referencing: {0}".format(interpolated)) print(" - Bad channels after interpolation: {0}".format(remaining_bad)) else: print( " - Bad channels after re-referencing: {0}".format(remaining_bad)) # Check if too many channels were interpolated for the participant prop_interpolated = len( reference.interpolated_channels) / len(ch_names_eeg) e = "### NOTE: Too many interpolated channels for sub-{0} ({1}) ###" if max_interpolated < prop_interpolated: print("\n") print(e.format(id_num, len(reference.interpolated_channels))) print("\n") ### Filter data and apply ICA to remove blinks ############################ # Apply highpass & lowpass filters print("\n\n=== Applying Highpass & Lowpass Filters... ===") raw_prepped.filter(1.0, 50.0, fir_design='firwin') # Plot data following frequency filters save_psd_plot(id_num, "psd_3_filtered", plot_path, raw_prepped) save_channel_plot(id_num, "ch_3_filtered", plot_path, raw_prepped) # Perform ICA using EOG data on eye blinks print("\n\n=== Removing Blinks Using ICA... ===\n") ica = ICA(n_components=20, random_state=random_seed, method='picard') ica.fit(raw_prepped, decim=5) eog_indices, eog_scores = ica.find_bads_eog(raw_prepped) ica.exclude = eog_indices if not len(ica.exclude): err = " - Encountered an ICA error for sub-{0}, skipping for now..." print("\n") print(err.format(id_num)) print("\n") save_bad_fif(raw_prepped, id_num, ica_err_dir) return id_info # Plot ICA info & diagnostics before removing from signal save_ica_plots(id_num, plot_path, raw_prepped, ica, eog_scores) # Remove eye blink independent components based on ICA ica.apply(raw_prepped) # Plot data following ICA save_psd_plot(id_num, "psd_4_ica", plot_path, raw_prepped) save_channel_plot(id_num, "ch_4_ica", plot_path, raw_prepped) ### Compute Current Source Density (CSD) estimates ######################## if perform_csd: print("\n") print("=== Computing Current Source Density (CSD) Estimates... ===\n") raw_prepped = mne.preprocessing.compute_current_source_density( raw_prepped.drop_channels(remaining_bad)) # Plot data following CSD save_psd_plot(id_num, "psd_5_csd", plot_path, raw_prepped) save_channel_plot(id_num, "ch_5_csd", plot_path, raw_prepped) ### Write preprocessed data to new EDF #################################### if max_interpolated < prop_interpolated: if not os.path.isdir(noisy_bad_dir): os.makedirs(noisy_bad_dir) outpath = os.path.join(noisy_bad_dir, outfile_fmt.format(id_num)) else: outpath = os.path.join(outdir, outfile_fmt.format(id_num)) write_mne_edf(outpath, raw_prepped) print("\n\n### sub-{0} complete! ###\n\n".format(id_num)) return id_info
def robust_reference(self, max_iterations=4): """Detect bad channels and estimate the robust reference signal. This function implements the functionality of the `robustReference` function as part of the PREP pipeline on mne raw object. Parameters ---------- max_iterations : int, optional The maximum number of iterations of noisy channel removal to perform during robust referencing. Defaults to ``4``. Returns ------- noisy_channels: dict A dictionary of names of noisy channels detected from all methods after referencing. reference_signal: np.ndarray, shape(n, ) Estimation of the 'true' signal mean """ raw = self.raw.copy() raw._data = removeTrend(raw.get_data(), self.sfreq, matlab_strict=self.matlab_strict) # Determine unusable channels and remove them from the reference channels noisy_detector = NoisyChannels( raw, do_detrend=False, random_state=self.random_state, matlab_strict=self.matlab_strict, ) noisy_detector.find_all_bads(**self.ransac_settings) self.noisy_channels_original = noisy_detector.get_bads(as_dict=True) self._extra_info["initial_bad"] = noisy_detector._extra_info logger.info("Bad channels: {}".format(self.noisy_channels_original)) # Determine channels to use/exclude from initial reference estimation self.unusable_channels = _union( noisy_detector.bad_by_nan + noisy_detector.bad_by_flat, noisy_detector.bad_by_SNR, ) reference_channels = _set_diff(self.reference_channels, self.unusable_channels) # Initialize channels to permanently flag as bad during referencing noisy = { "bad_by_nan": noisy_detector.bad_by_nan, "bad_by_flat": noisy_detector.bad_by_flat, "bad_by_deviation": [], "bad_by_hf_noise": [], "bad_by_correlation": [], "bad_by_SNR": [], "bad_by_dropout": [], "bad_by_ransac": [], "bad_all": [], } # Get initial estimate of the reference by the specified method signal = raw.get_data() self.reference_signal = np.nanmedian( raw.get_data(picks=reference_channels), axis=0) reference_index = [ self.ch_names_eeg.index(ch) for ch in reference_channels ] signal_tmp = self.remove_reference(signal, self.reference_signal, reference_index) # Remove reference from signal, iteratively interpolating bad channels raw_tmp = raw.copy() iterations = 0 previous_bads = set() while True: raw_tmp._data = signal_tmp noisy_detector = NoisyChannels( raw_tmp, do_detrend=False, random_state=self.random_state, matlab_strict=self.matlab_strict, ) # Detrend applied at the beginning of the function. # Detect all currently bad channels noisy_detector.find_all_bads(**self.ransac_settings) noisy_new = noisy_detector.get_bads(as_dict=True) # Specify bad channel types to ignore when updating noisy channels # NOTE: MATLAB PREP ignores dropout channels, possibly by mistake? # see: https://github.com/VisLab/EEG-Clean-Tools/issues/28 ignore = ["bad_by_SNR", "bad_all"] if self.matlab_strict: ignore += ["bad_by_dropout"] # Update set of all noisy channels detected so far with any new ones bad_chans = set() for bad_type in noisy_new.keys(): noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type]) if bad_type not in ignore: bad_chans.update(noisy[bad_type]) noisy["bad_all"] = list(bad_chans) logger.info("Bad channels: {}".format(noisy)) if (iterations > 1 and (len(bad_chans) == 0 or bad_chans == previous_bads) or iterations > max_iterations): logger.info("Robust reference done") self.noisy_channels = noisy break previous_bads = bad_chans.copy() if raw_tmp.info["nchan"] - len(bad_chans) < 2: raise ValueError( "RobustReference:TooManyBad " "Could not perform a robust reference -- not enough good channels" ) if len(bad_chans) > 0: raw_tmp._data = signal.copy() raw_tmp.info["bads"] = list(bad_chans) if self.matlab_strict: _eeglab_interpolate_bads(raw_tmp) else: raw_tmp.interpolate_bads() self.reference_signal = np.nanmean( raw_tmp.get_data(picks=reference_channels), axis=0) signal_tmp = self.remove_reference(signal, self.reference_signal, reference_index) iterations = iterations + 1 logger.info("Iterations: {}".format(iterations)) return self.noisy_channels, self.reference_signal