def perform_reference(self): """Estimate the true signal mean and interpolate bad channels. This function implements the functionality of the `performReference` function as part of the PREP pipeline on mne raw object. Notes ----- This function calls robust_reference first Currently this function only implements the functionality of default settings, i.e., doRobustPost """ # Phase 1: Estimate the true signal mean with robust referencing self.robust_reference() if self.noisy_channels["bad_all"]: self.raw.info["bads"] = self.noisy_channels["bad_all"] self.raw.interpolate_bads() self.reference_signal = (np.nanmean( self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6) rereferenced_index = [ self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels ] self.EEG = self.remove_reference(self.EEG, self.reference_signal, rereferenced_index) # Phase 2: Find the bad channels and interpolate self.raw._data = self.EEG * 1e-6 noisy_detector = NoisyChannels(self.raw) noisy_detector.find_all_bads(ransac=self.ransac) # Record Noisy channels and EEG before interpolation self.bad_before_interpolation = noisy_detector.get_bads(verbose=True) self.EEG_before_interpolation = self.EEG.copy() bad_channels = _union(self.bad_before_interpolation, self.unusable_channels) self.raw.info["bads"] = bad_channels self.raw.interpolate_bads() reference_correct = (np.nanmean( self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6) self.EEG = self.raw.get_data() * 1e6 self.EEG = self.remove_reference(self.EEG, reference_correct, rereferenced_index) # reference signal after interpolation self.reference_signal_new = self.reference_signal + reference_correct # MNE Raw object after interpolation self.raw._data = self.EEG * 1e-6 # Still noisy channels after interpolation self.interpolated_channels = bad_channels noisy_detector = NoisyChannels(self.raw) noisy_detector.find_all_bads(ransac=self.ransac) self.still_noisy_channels = noisy_detector.get_bads() self.raw.info["bads"] = self.still_noisy_channels return self
def fit(self): """Run the whole PREP pipeline.""" noisy_detector = NoisyChannels(self.raw, random_state=self.random_state) noisy_detector.find_bad_by_nan_flat() unusable_channels = _union(noisy_detector.bad_by_nan, noisy_detector.bad_by_flat) reference_channels = _set_diff(self.prep_params["ref_chs"], unusable_channels) # Step 1: 1Hz high pass filtering self.EEG_new = removeTrend(self.EEG_raw, sample_rate=self.sfreq) # Step 2: Removing line noise linenoise = self.prep_params["line_freqs"] self.EEG_clean = mne.filter.notch_filter( self.EEG_new, Fs=self.sfreq, freqs=linenoise, method="spectrum_fit", mt_bandwidth=2, p_value=0.01, ) # Add Trend back self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean self.raw._data = self.EEG * 1e-6 # Step 3: Referencing reference = Reference(self.raw, self.prep_params, ransac=self.ransac) reference.perform_reference() self.raw = reference.raw self.noisy_channels_original = reference.noisy_channels_original self.bad_before_interpolation = reference.bad_before_interpolation self.EEG_before_interpolation = reference.EEG_before_interpolation self.reference_before_interpolation = reference.reference_signal self.reference_after_interpolation = reference.reference_signal_new self.interpolated_channels = reference.interpolated_channels self.still_noisy_channels = reference.still_noisy_channels return self
def robust_reference(self): """Detect bad channels and estimate the robust reference signal. This function implements the functionality of the `robustReference` function as part of the PREP pipeline on mne raw object. Parameters ---------- ransac : boolean Whether or not to use ransac Returns ------- noisy_channels: dictionary A dictionary of names of noisy channels detected from all methods after referencing reference_signal: 1D Array Estimation of the 'true' signal mean """ raw = self.raw.copy() raw._data = removeTrend(raw.get_data(), sample_rate=self.sfreq) # Determine unusable channels and remove them from the reference channels noisy_detector = NoisyChannels(raw, do_detrend=False) noisy_detector.find_all_bads(ransac=self.ransac) self.noisy_channels_original = { "bad_by_nan": noisy_detector.bad_by_nan, "bad_by_flat": noisy_detector.bad_by_flat, "bad_by_deviation": noisy_detector.bad_by_deviation, "bad_by_hf_noise": noisy_detector.bad_by_hf_noise, "bad_by_correlation": noisy_detector.bad_by_correlation, "bad_by_ransac": noisy_detector.bad_by_ransac, "bad_all": noisy_detector.get_bads(), } self.noisy_channels = self.noisy_channels_original.copy() logger.info("Bad channels: {}".format(self.noisy_channels)) self.unusable_channels = _union(noisy_detector.bad_by_nan, noisy_detector.bad_by_flat) # unusable_channels = _union(unusable_channels, noisy_detector.bad_by_SNR) self.reference_channels = _set_diff(self.reference_channels, self.unusable_channels) # Get initial estimate of the reference by the specified method signal = raw.get_data() * 1e6 self.reference_signal = ( np.nanmedian(raw.get_data(picks=self.reference_channels), axis=0) * 1e6) reference_index = [ self.ch_names_eeg.index(ch) for ch in self.reference_channels ] signal_tmp = self.remove_reference(signal, self.reference_signal, reference_index) # Remove reference from signal, iteratively interpolating bad channels raw_tmp = raw.copy() iterations = 0 noisy_channels_old = [] max_iteration_num = 4 while True: raw_tmp._data = signal_tmp * 1e-6 noisy_detector = NoisyChannels(raw_tmp) noisy_detector.find_all_bads(ransac=self.ransac) self.noisy_channels["bad_by_nan"] = _union( self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan) self.noisy_channels["bad_by_flat"] = _union( self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat) self.noisy_channels["bad_by_deviation"] = _union( self.noisy_channels["bad_by_deviation"], noisy_detector.bad_by_deviation) self.noisy_channels["bad_by_hf_noise"] = _union( self.noisy_channels["bad_by_hf_noise"], noisy_detector.bad_by_hf_noise) self.noisy_channels["bad_by_correlation"] = _union( self.noisy_channels["bad_by_correlation"], noisy_detector.bad_by_correlation, ) self.noisy_channels["bad_by_ransac"] = _union( self.noisy_channels["bad_by_ransac"], noisy_detector.bad_by_ransac) self.noisy_channels["bad_all"] = _union( self.noisy_channels["bad_all"], noisy_detector.get_bads()) logger.info("Bad channels: {}".format(self.noisy_channels)) if (iterations > 1 and (not self.noisy_channels["bad_all"] or set( self.noisy_channels["bad_all"]) == set(noisy_channels_old)) or iterations > max_iteration_num): break noisy_channels_old = self.noisy_channels["bad_all"].copy() if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2: raise ValueError( "RobustReference:TooManyBad " "Could not perform a robust reference -- not enough good channels" ) if self.noisy_channels["bad_all"]: raw_tmp._data = signal * 1e-6 raw_tmp.info["bads"] = self.noisy_channels["bad_all"] raw_tmp.interpolate_bads() signal_tmp = raw_tmp.get_data() * 1e6 else: signal_tmp = signal self.reference_signal = (np.nanmean( raw_tmp.get_data(picks=self.reference_channels), axis=0) * 1e6) signal_tmp = self.remove_reference(signal, self.reference_signal, reference_index) iterations = iterations + 1 logger.info("Iterations: {}".format(iterations)) logger.info("Robust reference done") return self.noisy_channels, self.reference_signal