Beispiel #1
0
    def fit(self):
        """Run the whole PREP pipeline."""
        noisy_detector = NoisyChannels(self.raw_eeg, random_state=self.random_state)
        noisy_detector.find_bad_by_nan_flat()
        # unusable_channels = _union(
        #     noisy_detector.bad_by_nan, noisy_detector.bad_by_flat
        # )
        # reference_channels = _set_diff(self.prep_params["ref_chs"], unusable_channels)
        # Step 1: 1Hz high pass filtering
        if len(self.prep_params["line_freqs"]) != 0:
            self.EEG_new = removeTrend(self.EEG_raw, sample_rate=self.sfreq)

            # Step 2: Removing line noise
            linenoise = self.prep_params["line_freqs"]
            if self.filter_kwargs is None:
                self.EEG_clean = mne.filter.notch_filter(
                    self.EEG_new,
                    Fs=self.sfreq,
                    freqs=linenoise,
                    method="spectrum_fit",
                    mt_bandwidth=2,
                    p_value=0.01,
                    filter_length="10s",
                )
            else:
                self.EEG_clean = mne.filter.notch_filter(
                    self.EEG_new,
                    Fs=self.sfreq,
                    freqs=linenoise,
                    **self.filter_kwargs,
                )

            # Add Trend back
            self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean
            self.raw_eeg._data = self.EEG * 1e-6

        # Step 3: Referencing
        reference = Reference(
            self.raw_eeg,
            self.prep_params,
            ransac=self.ransac,
            random_state=self.random_state,
        )
        reference.perform_reference()
        self.raw_eeg = reference.raw
        self.noisy_channels_original = reference.noisy_channels_original
        self.noisy_channels_before_interpolation = (
            reference.noisy_channels_before_interpolation
        )
        self.noisy_channels_after_interpolation = (
            reference.noisy_channels_after_interpolation
        )
        self.bad_before_interpolation = reference.bad_before_interpolation
        self.EEG_before_interpolation = reference.EEG_before_interpolation
        self.reference_before_interpolation = reference.reference_signal
        self.reference_after_interpolation = reference.reference_signal_new
        self.interpolated_channels = reference.interpolated_channels
        self.still_noisy_channels = reference.still_noisy_channels

        return self
Beispiel #2
0
    def perform_reference(self):
        """Estimate the true signal mean and interpolate bad channels.

        This function implements the functionality of the `performReference` function
        as part of the PREP pipeline on mne raw object.

        Notes
        -----
            This function calls robust_reference first
            Currently this function only implements the functionality of default
            settings, i.e., doRobustPost

        """
        # Phase 1: Estimate the true signal mean with robust referencing
        self.robust_reference()
        if self.noisy_channels["bad_all"]:
            self.raw.info["bads"] = self.noisy_channels["bad_all"]
            self.raw.interpolate_bads()
        self.reference_signal = (np.nanmean(
            self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6)
        rereferenced_index = [
            self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
        ]
        self.EEG = self.remove_reference(self.EEG, self.reference_signal,
                                         rereferenced_index)

        # Phase 2: Find the bad channels and interpolate
        self.raw._data = self.EEG * 1e-6
        noisy_detector = NoisyChannels(self.raw)
        noisy_detector.find_all_bads(ransac=self.ransac)

        # Record Noisy channels and EEG before interpolation
        self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
        self.EEG_before_interpolation = self.EEG.copy()

        bad_channels = _union(self.bad_before_interpolation,
                              self.unusable_channels)
        self.raw.info["bads"] = bad_channels
        self.raw.interpolate_bads()
        reference_correct = (np.nanmean(
            self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6)
        self.EEG = self.raw.get_data() * 1e6
        self.EEG = self.remove_reference(self.EEG, reference_correct,
                                         rereferenced_index)
        # reference signal after interpolation
        self.reference_signal_new = self.reference_signal + reference_correct
        # MNE Raw object after interpolation
        self.raw._data = self.EEG * 1e-6

        # Still noisy channels after interpolation
        self.interpolated_channels = bad_channels
        noisy_detector = NoisyChannels(self.raw)
        noisy_detector.find_all_bads(ransac=self.ransac)
        self.still_noisy_channels = noisy_detector.get_bads()
        self.raw.info["bads"] = self.still_noisy_channels
        return self
Beispiel #3
0
def test_bad_by_dropout(raw_tmp):
    """Test detection of channels with excessive portions of flat signal."""
    # Add large dropout portions to the signal of a random channel
    n_chans, n_samples = raw_tmp._data.shape
    dropout_idx = int(RNG.randint(0, n_chans, 1))
    x1, x2 = (int(n_samples / 10), int(2 * n_samples / 10))
    raw_tmp._data[dropout_idx, x1:x2] = 0  # flatten 10% of signal

    # Test detection of channels that have excessive dropout regions
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    nd.find_bad_by_correlation()
    assert nd.bad_by_dropout == [raw_tmp.ch_names[dropout_idx]]
Beispiel #4
0
def test_bad_by_nan(raw_tmp):
    """Test the detection of channels containing any NaN values."""
    # Insert a NaN value into a random channel
    n_chans = raw_tmp._data.shape[0]
    nan_idx = int(RNG.randint(0, n_chans, 1))
    raw_tmp._data[nan_idx, 3] = np.nan

    # Test automatic detection of NaN channels on NoisyChannels init
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    assert nd.bad_by_nan == [raw_tmp.ch_names[nan_idx]]

    # Test manual re-running of NaN channel detection
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_nan == [raw_tmp.ch_names[nan_idx]]
Beispiel #5
0
def test_bad_by_deviation(raw_tmp):
    """Test detection of channels with relatively high or low amplitudes."""
    # Set scaling factors for high and low deviation test channels
    low_dev_factor = 0.1
    high_dev_factor = 4.0

    # Make the signal for a random channel have a very high amplitude
    n_chans = raw_tmp._data.shape[0]
    high_dev_idx = int(RNG.randint(0, n_chans, 1))
    raw_tmp._data[high_dev_idx, :] *= high_dev_factor

    # Test detection of abnormally high-amplitude channels
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    nd.find_bad_by_deviation()
    assert nd.bad_by_deviation == [raw_tmp.ch_names[high_dev_idx]]

    # Make the signal for a different channel have a very low amplitude
    low_dev_idx = (high_dev_idx - 1) if high_dev_idx > 0 else 1
    raw_tmp._data[low_dev_idx, :] *= low_dev_factor

    # Test detection of abnormally low-amplitude channels
    # NOTE: The default z-score threshold (5.0) is too strict to allow detection
    # of abnormally low-amplitude channels in some datasets. Using a relaxed Z
    # threshold of 3.29 (p < 0.001, two-tailed) until a better solution is found.
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    nd.find_bad_by_deviation(deviation_threshold=3.29)
    bad_by_dev_idx = [low_dev_idx, high_dev_idx]
    assert nd.bad_by_deviation == [raw_tmp.ch_names[i] for i in bad_by_dev_idx]
Beispiel #6
0
def test_bad_by_SNR(raw_tmp):
    """Test detection of channels that have low signal-to-noise ratios."""
    # Replace a random channel's signal with uncorrelated values
    n_chans = raw_tmp._data.shape[0]
    low_snr_idx = int(RNG.randint(0, n_chans, 1))
    raw_tmp._data[low_snr_idx, :] = _generate_signal(10, 30, raw_tmp.times, 5)

    # Add some high-frequency noise to the uncorrelated channel
    hf_noise = _generate_signal(70, 80, raw_tmp.times, 5) * 10
    raw_tmp._data[low_snr_idx, :] += hf_noise

    # Test detection of channels with a low signal-to-noise ratio
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    nd.find_bad_by_SNR()
    assert nd.bad_by_SNR == [raw_tmp.ch_names[low_snr_idx]]
Beispiel #7
0
def test_find_bad_by_ransac(raw_tmp):
    """Test the RANSAC component of NoisyChannels."""
    # Set a consistent random seed for all RANSAC runs
    RANSAC_RNG = 435656

    # RANSAC identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 6 at 30 Hz
    raw_tmp._data[0:6, :] = _generate_signal(30, 30, raw_tmp.times)

    # Run different variations of RANSAC on the same data
    test_matrix = {
        # List items represent [matlab_strict, channel_wise, max_chunk_size]
        "by_window": [False, False, None],
        "by_channel": [False, True, None],
        "by_channel_maxchunk": [False, True, 2],
        "by_window_strict": [True, False, None],
        "by_channel_strict": [True, True, None],
    }
    bads = {}
    corr = {}
    for name, args in test_matrix.items():
        nd = NoisyChannels(raw_tmp,
                           do_detrend=False,
                           random_state=RANSAC_RNG,
                           matlab_strict=args[0])
        nd.find_bad_by_ransac(channel_wise=args[1], max_chunk_size=args[2])
        # Save bad channels and RANSAC correlation matrix for later comparison
        bads[name] = nd.bad_by_ransac
        corr[name] = nd._extra_info["bad_by_ransac"]["ransac_correlations"]

    # Test whether all methods detected bad channels properly
    assert bads["by_window"] == raw_tmp.ch_names[0:6]
    assert bads["by_channel"] == raw_tmp.ch_names[0:6]
    assert bads["by_channel_maxchunk"] == raw_tmp.ch_names[0:6]
    assert bads["by_window_strict"] == raw_tmp.ch_names[0:6]
    assert bads["by_channel_strict"] == raw_tmp.ch_names[0:6]

    # Make sure non-strict correlation matrices all match
    assert np.allclose(corr["by_window"], corr["by_channel"])
    assert np.allclose(corr["by_window"], corr["by_channel_maxchunk"])

    # Make sure MATLAB-strict correlation matrices match
    assert np.allclose(corr["by_window_strict"], corr["by_channel_strict"])

    # Make sure strict and non-strict matrices differ
    assert not np.allclose(corr["by_window"], corr["by_window_strict"])

    # Ensure that RANSAC doesn't change random state if in MATLAB-strict mode
    rng = RandomState(RANSAC_RNG)
    init_state = rng.get_state()[2]
    nd = NoisyChannels(raw_tmp,
                       do_detrend=False,
                       random_state=rng,
                       matlab_strict=True)
    nd.find_bad_by_ransac()
    assert rng.get_state()[2] == init_state
def pyprep_noisy(matprep_artifacts):
    """Get original NoisyChannels results for comparison with MATLAB PREP.

    This fixture uses an artifact from MATLAB PREP of the CleanLined and
    detrended EEG signal right before MATLAB PREP runs its first iteration of
    NoisyChannels during re-referencing. As such, any differences in test results
    will be due to actual differences in the noisy channel detection code rather
    than differences at an earlier stage of the pipeline.

    """
    # Import pre-reference MATLAB PREP data
    preref_path = matprep_artifacts["4_matprep_pre_reference"]
    matprep_preref = mne.io.read_raw_eeglab(preref_path, preload=True)

    # Run NoisyChannels on MATLAB data and extract internal noisy info
    matprep_seed = 435656
    pyprep_noisy = NoisyChannels(matprep_preref,
                                 do_detrend=False,
                                 random_state=matprep_seed,
                                 matlab_strict=True)
    pyprep_noisy.find_all_bads()

    return pyprep_noisy
Beispiel #9
0
def test_bad_by_hf_noise(raw_tmp):
    """Test detection of channels with high-frequency noise."""
    # Add some noise between 70 & 80 Hz to the signal of a random channel
    n_chans = raw_tmp._data.shape[0]
    hf_noise_idx = int(RNG.randint(0, n_chans, 1))
    hf_noise = _generate_signal(70, 80, raw_tmp.times, 5) * 10
    raw_tmp._data[hf_noise_idx, :] += hf_noise

    # Test detection of channels with high-frequency noise
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    nd.find_bad_by_hfnoise()
    assert nd.bad_by_hf_noise == [raw_tmp.ch_names[hf_noise_idx]]

    # Test lack of high-frequency noise detection when sample rate < 100 Hz
    raw_tmp.resample(80)  # downsample from 160 Hz to 80 Hz
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    nd.find_bad_by_hfnoise()
    assert len(nd.bad_by_hf_noise) == 0
    assert nd._extra_info["bad_by_hf_noise"]["median_channel_noisiness"] == 0
    assert nd._extra_info["bad_by_hf_noise"]["channel_noisiness_sd"] == 1
Beispiel #10
0
def test_bad_by_correlation(raw_tmp):
    """Test detection of channels that correlate poorly with others."""
    # Replace a random channel's signal with uncorrelated values
    n_chans, n_samples = raw_tmp._data.shape
    low_corr_idx = int(RNG.randint(0, n_chans, 1))
    raw_tmp._data[low_corr_idx, :] = _generate_signal(10, 30, raw_tmp.times, 5)

    # Test detection of channels that correlate poorly with others
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    nd.find_bad_by_correlation()
    assert nd.bad_by_correlation == [raw_tmp.ch_names[low_corr_idx]]

    # Add a channel with dropouts to see if correlation detection still works
    dropout_idx = (low_corr_idx - 1) if low_corr_idx > 0 else 1
    x1, x2 = (int(n_samples / 10), int(2 * n_samples / 10))
    raw_tmp._data[dropout_idx, x1:x2] = 0  # flatten 10% of signal

    # Re-test detection of channels that correlate poorly with others
    # (only new bad-by-correlation channel should be dropout)
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    nd.find_bad_by_correlation()
    assert raw_tmp.ch_names[low_corr_idx] in nd.bad_by_correlation
    assert len(nd.bad_by_correlation) <= 2
Beispiel #11
0
def test_bad_by_flat(raw_tmp):
    """Test the detection of channels with flat or very weak signals."""
    # Make the signal for a random channel extremely weak
    n_chans = raw_tmp._data.shape[0]
    flat_idx = int(RNG.randint(0, n_chans, 1))
    raw_tmp._data[flat_idx, :] = raw_tmp._data[flat_idx, :] * 1e-12

    # Test automatic detection of flat channels on NoisyChannels init
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    assert nd.bad_by_flat == [raw_tmp.ch_names[flat_idx]]

    # Test manual re-running of flat channel detection
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_flat == [raw_tmp.ch_names[flat_idx]]

    # Test detection when channel is completely flat
    raw_tmp._data[flat_idx, :] = 0
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    assert nd.bad_by_flat == [raw_tmp.ch_names[flat_idx]]
Beispiel #12
0
                    time) if i in bad_channels else np.sin(2 * np.pi *
                                                           freq_good * time)
    for i in range(n_chans)
]
# Scale the signal amplitude and add noise.
X = 2e-5 * np.array(X) + 1e-5 * rng.random((n_chans, time.shape[0]))

raw = mne.io.RawArray(X, info)

raw.set_montage(montage, verbose=False)

###############################################################################
# Assign the mne object to the :class:`NoisyChannels` class. The resulting object
# will be the place where all following methods are performed.

nd = NoisyChannels(raw, random_state=1337)
nd2 = NoisyChannels(raw, random_state=1337)

###############################################################################
# Find all bad channels using channel-wise RANSAC and print a summary
start_time = perf_counter()
nd.find_bad_by_ransac(channel_wise=True)
print("--- %s seconds ---" % (perf_counter() - start_time))

# Repeat channel-wise RANSAC using a single channel at a time. This is slower
# but needs less memory.
start_time = perf_counter()
nd2.find_bad_by_ransac(channel_wise=True, max_chunk_size=1)
print("--- %s seconds ---" % (perf_counter() - start_time))

###############################################################################
Beispiel #13
0
    def robust_reference(self, max_iterations=4):
        """Detect bad channels and estimate the robust reference signal.

        This function implements the functionality of the `robustReference` function
        as part of the PREP pipeline on mne raw object.

        Parameters
        ----------
        max_iterations : int, optional
            The maximum number of iterations of noisy channel removal to perform
            during robust referencing. Defaults to ``4``.

        Returns
        -------
        noisy_channels: dict
            A dictionary of names of noisy channels detected from all methods
            after referencing.
        reference_signal: np.ndarray, shape(n, )
            Estimation of the 'true' signal mean

        """
        raw = self.raw.copy()
        raw._data = removeTrend(raw.get_data(),
                                self.sfreq,
                                matlab_strict=self.matlab_strict)

        # Determine unusable channels and remove them from the reference channels
        noisy_detector = NoisyChannels(
            raw,
            do_detrend=False,
            random_state=self.random_state,
            matlab_strict=self.matlab_strict,
        )
        noisy_detector.find_all_bads(**self.ransac_settings)
        self.noisy_channels_original = noisy_detector.get_bads(as_dict=True)
        self._extra_info["initial_bad"] = noisy_detector._extra_info
        logger.info("Bad channels: {}".format(self.noisy_channels_original))

        # Determine channels to use/exclude from initial reference estimation
        self.unusable_channels = _union(
            noisy_detector.bad_by_nan + noisy_detector.bad_by_flat,
            noisy_detector.bad_by_SNR,
        )
        reference_channels = _set_diff(self.reference_channels,
                                       self.unusable_channels)

        # Initialize channels to permanently flag as bad during referencing
        noisy = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": [],
            "bad_by_hf_noise": [],
            "bad_by_correlation": [],
            "bad_by_SNR": [],
            "bad_by_dropout": [],
            "bad_by_ransac": [],
            "bad_all": [],
        }

        # Get initial estimate of the reference by the specified method
        signal = raw.get_data()
        self.reference_signal = np.nanmedian(
            raw.get_data(picks=reference_channels), axis=0)
        reference_index = [
            self.ch_names_eeg.index(ch) for ch in reference_channels
        ]
        signal_tmp = self.remove_reference(signal, self.reference_signal,
                                           reference_index)

        # Remove reference from signal, iteratively interpolating bad channels
        raw_tmp = raw.copy()
        iterations = 0
        previous_bads = set()

        while True:
            raw_tmp._data = signal_tmp
            noisy_detector = NoisyChannels(
                raw_tmp,
                do_detrend=False,
                random_state=self.random_state,
                matlab_strict=self.matlab_strict,
            )
            # Detrend applied at the beginning of the function.

            # Detect all currently bad channels
            noisy_detector.find_all_bads(**self.ransac_settings)
            noisy_new = noisy_detector.get_bads(as_dict=True)

            # Specify bad channel types to ignore when updating noisy channels
            # NOTE: MATLAB PREP ignores dropout channels, possibly by mistake?
            # see: https://github.com/VisLab/EEG-Clean-Tools/issues/28
            ignore = ["bad_by_SNR", "bad_all"]
            if self.matlab_strict:
                ignore += ["bad_by_dropout"]

            # Update set of all noisy channels detected so far with any new ones
            bad_chans = set()
            for bad_type in noisy_new.keys():
                noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
                if bad_type not in ignore:
                    bad_chans.update(noisy[bad_type])
            noisy["bad_all"] = list(bad_chans)
            logger.info("Bad channels: {}".format(noisy))

            if (iterations > 1 and
                (len(bad_chans) == 0 or bad_chans == previous_bads)
                    or iterations > max_iterations):
                logger.info("Robust reference done")
                self.noisy_channels = noisy
                break
            previous_bads = bad_chans.copy()

            if raw_tmp.info["nchan"] - len(bad_chans) < 2:
                raise ValueError(
                    "RobustReference:TooManyBad "
                    "Could not perform a robust reference -- not enough good channels"
                )

            if len(bad_chans) > 0:
                raw_tmp._data = signal.copy()
                raw_tmp.info["bads"] = list(bad_chans)
                if self.matlab_strict:
                    _eeglab_interpolate_bads(raw_tmp)
                else:
                    raw_tmp.interpolate_bads()

            self.reference_signal = np.nanmean(
                raw_tmp.get_data(picks=reference_channels), axis=0)

            signal_tmp = self.remove_reference(signal, self.reference_signal,
                                               reference_index)
            iterations = iterations + 1
            logger.info("Iterations: {}".format(iterations))

        return self.noisy_channels, self.reference_signal
Beispiel #14
0
    def perform_reference(self, max_iterations=4):
        """Estimate the true signal mean and interpolate bad channels.

        Parameters
        ----------
        max_iterations : int, optional
            The maximum number of iterations of noisy channel removal to perform
            during robust referencing. Defaults to ``4``.

        This function implements the functionality of the `performReference` function
        as part of the PREP pipeline on mne raw object.

        Notes
        -----
        This function calls ``robust_reference`` first.
        Currently this function only implements the functionality of default
        settings, i.e., ``doRobustPost``.

        """
        # Phase 1: Estimate the true signal mean with robust referencing
        self.robust_reference(max_iterations)
        # If we interpolate the raw here we would be interpolating
        # more than what we later actually account for (in interpolated channels).
        dummy = self.raw.copy()
        dummy.info["bads"] = self.noisy_channels["bad_all"]
        if self.matlab_strict:
            _eeglab_interpolate_bads(dummy)
        else:
            dummy.interpolate_bads()
        self.reference_signal = np.nanmean(
            dummy.get_data(picks=self.reference_channels), axis=0)
        del dummy
        rereferenced_index = [
            self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
        ]
        self.EEG = self.remove_reference(self.EEG, self.reference_signal,
                                         rereferenced_index)

        # Phase 2: Find the bad channels and interpolate
        self.raw._data = self.EEG
        noisy_detector = NoisyChannels(self.raw,
                                       random_state=self.random_state,
                                       matlab_strict=self.matlab_strict)
        noisy_detector.find_all_bads(**self.ransac_settings)

        # Record Noisy channels and EEG before interpolation
        self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
        self.EEG_before_interpolation = self.EEG.copy()
        self.noisy_channels_before_interpolation = noisy_detector.get_bads(
            as_dict=True)
        self._extra_info["interpolated"] = noisy_detector._extra_info

        bad_channels = _union(self.bad_before_interpolation,
                              self.unusable_channels)
        self.raw.info["bads"] = bad_channels
        if self.matlab_strict:
            _eeglab_interpolate_bads(self.raw)
        else:
            self.raw.interpolate_bads()
        reference_correct = np.nanmean(
            self.raw.get_data(picks=self.reference_channels), axis=0)
        self.EEG = self.raw.get_data()
        self.EEG = self.remove_reference(self.EEG, reference_correct,
                                         rereferenced_index)
        # reference signal after interpolation
        self.reference_signal_new = self.reference_signal + reference_correct
        # MNE Raw object after interpolation
        self.raw._data = self.EEG

        # Still noisy channels after interpolation
        self.interpolated_channels = bad_channels
        noisy_detector = NoisyChannels(self.raw,
                                       random_state=self.random_state,
                                       matlab_strict=self.matlab_strict)
        noisy_detector.find_all_bads(**self.ransac_settings)
        self.still_noisy_channels = noisy_detector.get_bads()
        self.raw.info["bads"] = self.still_noisy_channels
        self.noisy_channels_after_interpolation = noisy_detector.get_bads(
            as_dict=True)
        self._extra_info["remaining_bad"] = noisy_detector._extra_info

        return self
Beispiel #15
0
def test_find_bad_by_ransac_err(raw_tmp):
    """Test error handling in the `find_bad_by_ransac` method."""
    # Set n_samples very very high to trigger a memory error
    n_samples = int(1e100)
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    with pytest.raises(MemoryError):
        nd.find_bad_by_ransac(n_samples=n_samples)

    # Set n_samples to a float to trigger a type error
    n_samples = 35.5
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    with pytest.raises(TypeError):
        nd.find_bad_by_ransac(n_samples=n_samples)

    # Test IOError when too few good channels for RANSAC sample size
    n_chans = raw_tmp._data.shape[0]
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    nd.bad_by_deviation = raw_tmp.info["ch_names"][0:int(n_chans * 0.8)]
    with pytest.raises(IOError):
        nd.find_bad_by_ransac()

    # Test IOError when not enough channels for RANSAC predictions
    raw_tmp._data[0:(n_chans - 2), :] = 0  # make all channels flat except 2
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    with pytest.raises(IOError):
        nd.find_bad_by_ransac()
Beispiel #16
0
                    time) if i in bad_channels else np.sin(2 * np.pi *
                                                           freq_good * time)
    for i in range(n_chans)
]
# Scale the signal amplitude and add noise.
X = 2e-5 * np.array(X) + 1e-5 * np.random.random((n_chans, time.shape[0]))

raw = mne.io.RawArray(X, info)

raw.set_montage(montage, verbose=False)

###############################################################################
# Assign the mne object to the :class:`NoisyChannels` class. The resulting object
# will be the place where all following methods are performed.

nd = NoisyChannels(raw)

###############################################################################
# Find all bad channels and print a summary
start_time = perf_counter()
nd.find_bad_by_ransac()
print("--- %s seconds ---" % (perf_counter() - start_time))

###############################################################################
# Now the bad channels are saved in `bads` and we can continue processing our
# `raw` object. For more information, we can access attributes of the ``nd``
# instance:

# Check channels that go bad together by correlation (RANSAC)
print(nd.bad_by_ransac)
assert set(bad_ch_names) == set(nd.bad_by_ransac)
Beispiel #17
0
                    time) if i in bad_channels else np.sin(2 * np.pi *
                                                           freq_good * time)
    for i in range(n_chans)
]
# Scale the signal amplitude and add noise.
X = 2e-5 * np.array(X) + 1e-5 * np.random.random((n_chans, time.shape[0]))

raw = mne.io.RawArray(X, info)

raw.set_montage(montage, verbose=False)

###############################################################################
# Assign the mne object to the :class:`NoisyChannels` class. The resulting object
# will be the place where all following methods are performed.

nd = NoisyChannels(raw)
nd2 = NoisyChannels(raw)

###############################################################################
# Find all bad channels and print a summary
start_time = perf_counter()
nd.find_bad_by_ransac()
print("--- %s seconds ---" % (perf_counter() - start_time))

# Repeat RANSAC in a channel wise manner. This is slower but needs less memory.
start_time = perf_counter()
nd2.find_bad_by_ransac(channel_wise=True)
print("--- %s seconds ---" % (perf_counter() - start_time))

###############################################################################
# Now the bad channels are saved in `bads` and we can continue processing our
Beispiel #18
0
                      session=settings.session,
                      datatype='eeg',
                      suffix='eeg')
 raw = read_raw_bids(bids_path=bids_path, extra_params=dict(preload=True))
 raw.info['bads'] = [
     'FT7', 'FT8', 'T7', 'T8', 'TP7', 'TP8', 'AF7', 'AF8', 'M1', 'M2',
     'BIP1', 'BIP2', 'BIP3', 'BIP4', 'BIP5', 'BIP6', 'BIP7', 'BIP8', 'BIP9',
     'BIP10', 'BIP11', 'BIP12', 'BIP13', 'BIP14', 'BIP15', 'BIP16', 'BIP17',
     'BIP18', 'BIP19', 'BIP20', 'BIP21', 'BIP22', 'BIP23', 'BIP24'
 ]
 raw.drop_channels(ch_names=raw.info['bads'])
 montage = mne.channels.make_standard_montage('standard_1020')
 raw.set_montage(montage)
 filt_raw = raw.copy()
 filt_raw.drop_channels('EOG')
 nd = NoisyChannels(filt_raw)
 nd.find_bad_by_ransac(n_samples=50, channel_wise=True)
 raw.info['bads'] = nd.bad_by_ransac
 raw.notch_filter(np.arange(50, 250, 50))
 filt_raw.notch_filter(np.arange(50, 250, 50))
 raw.filter(l_freq=settings.high_filter,
            h_freq=settings.low_filter,
            fir_design='firwin')
 filt_raw.filter(l_freq=1, h_freq=settings.low_filter)
 filt_raw.info['bads'] = raw.info['bads']
 raw.interpolate_bads(reset_bads=True, mode='accurate')
 filt_raw.interpolate_bads(reset_bads=True, mode='accurate')
 raw.resample(200, npad='auto')
 filt_raw.resample(200, npad='auto')
 raw.set_eeg_reference('average')
 filt_raw.set_eeg_reference('average')
Beispiel #19
0
    def preproc(self, event_dict, baseline_start, stim_dur, montage, out_dir='preproc', subjects='all', tasks='all', 
                hp_cutoff=1, lp_cutoff=50, line_noise=60, seed=42, eog_chan='none'):
        '''Preprocesses a single EEG file. Assigns a list of epoched data to Dataset instance,
        where each entry in the list is a subject with concatenated task data. Here is the basic 
        structure of the preprocessing workflow:
        
            - Set the montage
            - Band-pass filter (high-pass filter by default)
            - Automatically detect bad channels
            - Notch filter out line-noise
            - Reference data to average of all EEG channels
            - Automated removal of eye-related artifacts using ICA
            - Spherical interpolation of detected bad channels
            - Extract events and epoch the data accordingly
            - Align the events based on type (still need to implement this!)
            - Create a list of epoched data, with subject as the element concatenated across tasks
        
        Parameters
        ----------
        event_dict: dict
            Maps integers to semantic labels for events within the experiment
            
        baseline_start: int or float
            Specify start of the baseline period (in seconds)
            
        stim_dur: int or float
            Stimulus duration (in seconds)
                Note: may need to make more general to allow various durations
                
        montage: mne.channels.montage.DigMontage
            Maps sensor locations to coordinates
            
        subjects: list or 'all'
            Specify which subjects to iterate through
            
        tasks: list or 'all'
            Specify which tasks to iterate through
            
        hp_cutoff: int or float
            The low frequency bound for the highpass filter in Hz
            
        line_noise: int or float
            The frequency of electrical noise to filter out in Hz
            
        seed: int
            Set the seed for replicable results
            
        eog_chan: str
            If there are no EOG channels present, select an EEG channel
            near the eyes for eye-related artifact detection
        '''

        missing = [] # initialize missing file list
        subj_iter = self.gen_iter(subjects, self.n_subj) # get subject iterator
        task_iter = self.gen_iter(tasks, self.n_task) # get task iterator

        # Iterate through subjects (initialize subject epoch list)
        epochs_subj = []
        for subj in subj_iter:

            # Iterate through tasks (initialize within-subject task epoch list)
            epochs_task = []
            for task in task_iter:
                # Specify the file format
                self.get_file_format(subj, task)

                try: # Handles missing files
                    raw = self.wget_raw_edf() # read
                except:
                    print(f'---\nThis file does not exist: {self.file_path}\n---')
                    # Need to write the missing file list out
                    missing.append(self.file_path)
                    break
                    
                # Standardize montage
                mne.datasets.eegbci.standardize(raw)
                # Set montage and strip channel names of "." characters
                raw.set_montage(montage)
                raw.rename_channels(lambda x: x.strip('.'))

                # Apply high-pass filter
                np.random.seed(seed)
                raw.filter(l_freq=hp_cutoff, h_freq=lp_cutoff, picks=['eeg', 'eog'])

                # Instantiate NoisyChannels object
                noise_chans = NoisyChannels(raw, do_detrend=False)

                # Detect bad channels through multiple methods
                noise_chans.find_bad_by_nan_flat()
                noise_chans.find_bad_by_deviation()
                noise_chans.find_bad_by_SNR()

                # Set the bad channels in the raw object
                raw.info['bads'] = noise_chans.get_bads()
                print(f'Bad channels detected: {noise_chans.get_bads()}')

                # Define the frequencies for the notch filter (60Hz and its harmonics)
                #notch_filt = np.arange(line_noise, raw.info['sfreq'] // 2, line_noise)

                # Apply notch filter
                #print(f'Apply notch filter at {line_noise} Hz and its harmonics')
                #raw.notch_filter(notch_filt, picks=['eeg', 'eog'], verbose=False)

                # Reference to the average of all the good channels 
                # Automatically excludes raw.info['bads']
                raw.set_eeg_reference(ref_channels='average')

                # Instantiate ICA object
                ica = mne.preprocessing.ICA(max_iter=1000)
                # Run ICA
                ica.fit(raw)

                # Find which ICs match the EOG pattern
                if eog_chan == 'none':
                    eog_indices, eog_scores = ica.find_bads_eog(raw, verbose=False)
                else:
                    eog_indices, eog_scores = ica.find_bads_eog(raw, eog_chan, verbose=False)

                # Apply the IC blink removals (if any)
                ica.apply(raw, exclude=eog_indices)
                print(f'Removed IC index {eog_indices}')

                # Interpolate bad channels
                raw.interpolate_bads()
                
                # Specify pre-processing directory
                preproc_dir = Path(f'{out_dir}/ {subj}')

                # If directory does not exist, one will be created.
                if not os.path.isdir(preproc_dir):
                    os.makedirs(preproc_dir)

               # raw.save(Path(preproc_dir, f'subj{subj}_task{task}_raw.fif'), 
                #         overwrite=True)

                # Find events
                events = mne.events_from_annotations(raw)[0]

                # Epoch the data
                preproc_epoch = mne.Epochs(raw, events, tmin=baseline_start, tmax=stim_dur, 
                                   event_id=event_dict, event_repeated='error', 
                                   on_missing='ignore', preload=True)
                
                # Equalize event counts
                preproc_epoch.equalize_event_counts(event_dict.keys())
                
                # Rearrange and align the epochs
                align = [preproc_epoch[i] for i in event_dict.keys()]
                align_epoch = mne.concatenate_epochs(align)
                
                # Add to epoch list
                epochs_task.append(align_epoch)

            # Assuming some data exists for a subject
            # Concatenate epochs within subject
            concat_epoch = mne.concatenate_epochs(epochs_task)
            epochs_subj.append(concat_epoch)
        # Attaches a list with each entry corresponding to epochs for a subject
        self.epoch_list = epochs_subj
Beispiel #20
0
def test_findnoisychannels(raw, montage):
    """Test find noisy channels."""
    # Set a random state for the test
    rng = np.random.RandomState(30)

    raw.set_montage(montage)
    nd = NoisyChannels(raw, random_state=rng)
    nd.find_all_bads(ransac=True)
    bads = nd.get_bads()
    iterations = (
        10  # remove any noisy channels by interpolating the bads for 10 iterations
    )
    for iter in range(0, iterations):
        if len(bads) == 0:
            continue
        raw.info["bads"] = bads
        raw.interpolate_bads()
        nd = NoisyChannels(raw, random_state=rng)
        nd.find_all_bads(ransac=True)
        bads = nd.get_bads()

    # make sure no bad channels exist in the data
    raw.drop_channels(ch_names=bads)

    # Test for NaN and flat channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Insert a nan value for a random channel and make another random channel
    # completely flat (ones)
    idxs = rng.choice(np.arange(m), size=2, replace=False)
    rand_chn_idx1 = idxs[0]
    rand_chn_idx2 = idxs[1]
    rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1]
    rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2]
    raw_tmp._data[rand_chn_idx1, n - 1] = np.nan
    raw_tmp._data[rand_chn_idx2, :] = np.ones(n)
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_nan == [rand_chn_lab1]
    assert nd.bad_by_flat == [rand_chn_lab2]

    # Test for high and low deviations in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Now insert one random channel with very low deviations
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation
    # Inserting one random channel with a high deviation
    raw_tmp = raw.copy()
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    arbitrary_scaling = 5
    raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation

    # Test for correlation between EEG channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use cosine instead of sine to create a signal
    low = 10
    high = 30
    n_freq = 5
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = rng.randint(low, high, n)
        signal[0, :] += np.cos(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_correlation()
    assert rand_chn_lab in nd.bad_by_correlation

    # Test for high freq noise detection
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use freqs between 90 and 100 Hz to insert hf noise
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = rng.randint(90, 100, n)
        signal[0, :] += np.sin(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_hfnoise()
    assert rand_chn_lab in nd.bad_by_hf_noise

    # Test for signal to noise ratio in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # inserting an uncorrelated high frequency (90 Hz) signal in one channel
    raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_SNR()
    assert rand_chn_lab in nd.bad_by_SNR

    # Test for finding bad channels by RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_ransac()
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]

    # Test for finding bad channels by channel-wise RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_ransac(channel_wise=True)
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]

    # Test not-enough-memory and n_samples type exceptions
    raw_tmp = raw.copy()
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)

    # Set n_samples very very high to trigger a memory error
    n_samples = int(1e100)
    with pytest.raises(MemoryError):
        nd.find_bad_by_ransac(n_samples=n_samples)

    # Set n_samples to a float to trigger a type error
    n_samples = 35.5
    with pytest.raises(TypeError):
        nd.find_bad_by_ransac(n_samples=n_samples)

    # Test IOError when not enough channels for ransac predictions
    raw_tmp = raw.copy()
    # Make flat all channels except 2
    num_bad_channels = raw._data.shape[0] - 2
    raw_tmp._data[0:num_bad_channels, :] = np.zeros_like(
        raw_tmp._data[0:num_bad_channels, :]
    )
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_all_bads(ransac=False)
    with pytest.raises(IOError):
        nd.find_bad_by_ransac()
def test_findnoisychannels(raw, montage):
    raw.set_montage(montage)
    nd = NoisyChannels(raw)
    nd.find_all_bads(ransac=True)
    bads = nd.get_bads()
    iterations = (
        10  # remove any noisy channels by interpolating the bads for 10 iterations
    )
    for iter in range(0, iterations):
        raw.info["bads"] = bads
        raw.interpolate_bads()
        nd = NoisyChannels(raw)
        nd.find_all_bads(ransac=True)
        bads = nd.get_bads()

    # make sure no bad channels exist in the data
    raw.drop_channels(ch_names=bads)

    # Test for NaN and flat channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Insert a nan value for a random channel
    rand_chn_idx1 = int(np.random.randint(0, m, 1))
    rand_chn_idx2 = int(np.random.randint(0, m, 1))
    rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1]
    rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2]
    raw_tmp._data[rand_chn_idx1, n - 1] = np.nan
    raw_tmp._data[rand_chn_idx2, :] = np.ones(n)
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_nan == [rand_chn_lab1]
    assert nd.bad_by_flat == [rand_chn_lab2]

    # Test for high and low deviations in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Now insert one random channel with very low deviations
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation
    # Inserting one random channel with a high deviation
    raw_tmp = raw.copy()
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    arbitrary_scaling = 5
    raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation

    # Test for correlation between EEG channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use cosine instead of sine to create a signal
    low = 10
    high = 30
    n_freq = 5
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = np.random.randint(low, high, n)
        signal[0, :] += np.cos(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_correlation()
    assert rand_chn_lab in nd.bad_by_correlation

    # Test for high freq noise detection
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use freqs between 90 and 100 Hz to insert hf noise
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = np.random.randint(90, 100, n)
        signal[0, :] += np.sin(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_hfnoise()
    assert rand_chn_lab in nd.bad_by_hf_noise

    # Test for signal to noise ratio in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # inserting an uncorrelated high frequency (90 Hz) signal in one channel
    raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_SNR()
    assert rand_chn_lab in nd.bad_by_SNR

    # Test for finding bad channels by RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp)
    np.random.seed(30)
    nd.find_bad_by_ransac()
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]
Beispiel #22
0
    def robust_reference(self):
        """Detect bad channels and estimate the robust reference signal.

        This function implements the functionality of the `robustReference` function
        as part of the PREP pipeline on mne raw object.

        Parameters
        ----------
            ransac : boolean
                Whether or not to use ransac

        Returns
        -------
            noisy_channels: dictionary
                A dictionary of names of noisy channels detected from all methods
                after referencing
            reference_signal: 1D Array
                Estimation of the 'true' signal mean

        """
        raw = self.raw.copy()
        raw._data = removeTrend(raw.get_data(), sample_rate=self.sfreq)

        # Determine unusable channels and remove them from the reference channels
        noisy_detector = NoisyChannels(raw,
                                       do_detrend=False,
                                       random_state=self.random_state)
        noisy_detector.find_all_bads(ransac=self.ransac)
        self.noisy_channels_original = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": noisy_detector.bad_by_deviation,
            "bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
            "bad_by_correlation": noisy_detector.bad_by_correlation,
            "bad_by_ransac": noisy_detector.bad_by_ransac,
            "bad_all": noisy_detector.get_bads(),
        }
        self.noisy_channels = self.noisy_channels_original.copy()
        logger.info("Bad channels: {}".format(self.noisy_channels))

        self.unusable_channels = _union(noisy_detector.bad_by_nan,
                                        noisy_detector.bad_by_flat)

        # According to the Matlab Implementation (see robustReference.m)
        # self.unusable_channels = _union(self.unusable_channels,
        # noisy_detector.bad_by_SNR)
        # but maybe this makes no difference...

        self.reference_channels = _set_diff(self.reference_channels,
                                            self.unusable_channels)

        # Get initial estimate of the reference by the specified method
        signal = raw.get_data() * 1e6
        self.reference_signal = (
            np.nanmedian(raw.get_data(picks=self.reference_channels), axis=0) *
            1e6)
        reference_index = [
            self.ch_names_eeg.index(ch) for ch in self.reference_channels
        ]
        signal_tmp = self.remove_reference(signal, self.reference_signal,
                                           reference_index)

        # Remove reference from signal, iteratively interpolating bad channels
        raw_tmp = raw.copy()
        iterations = 0
        noisy_channels_old = []
        max_iteration_num = 4

        while True:
            raw_tmp._data = signal_tmp * 1e-6
            noisy_detector = NoisyChannels(raw_tmp,
                                           do_detrend=False,
                                           random_state=self.random_state)
            # Detrend applied at the beginning of the function.
            noisy_detector.find_all_bads(ransac=self.ransac)
            self.noisy_channels["bad_by_nan"] = _union(
                self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan)
            self.noisy_channels["bad_by_flat"] = _union(
                self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat)
            self.noisy_channels["bad_by_deviation"] = _union(
                self.noisy_channels["bad_by_deviation"],
                noisy_detector.bad_by_deviation)
            self.noisy_channels["bad_by_hf_noise"] = _union(
                self.noisy_channels["bad_by_hf_noise"],
                noisy_detector.bad_by_hf_noise)
            self.noisy_channels["bad_by_correlation"] = _union(
                self.noisy_channels["bad_by_correlation"],
                noisy_detector.bad_by_correlation,
            )
            self.noisy_channels["bad_by_ransac"] = _union(
                self.noisy_channels["bad_by_ransac"],
                noisy_detector.bad_by_ransac)
            self.noisy_channels["bad_all"] = _union(
                self.noisy_channels["bad_all"], noisy_detector.get_bads())
            logger.info("Bad channels: {}".format(self.noisy_channels))

            if (iterations > 1 and (not self.noisy_channels["bad_all"] or set(
                    self.noisy_channels["bad_all"]) == set(noisy_channels_old))
                    or iterations > max_iteration_num):
                break
            noisy_channels_old = self.noisy_channels["bad_all"].copy()

            if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2:
                raise ValueError(
                    "RobustReference:TooManyBad "
                    "Could not perform a robust reference -- not enough good channels"
                )

            if self.noisy_channels["bad_all"]:
                raw_tmp._data = signal * 1e-6
                raw_tmp.info["bads"] = self.noisy_channels["bad_all"]
                raw_tmp.interpolate_bads()
                signal_tmp = raw_tmp.get_data() * 1e6
            else:
                signal_tmp = signal
            self.reference_signal = (np.nanmean(
                raw_tmp.get_data(picks=self.reference_channels), axis=0) * 1e6)

            signal_tmp = self.remove_reference(signal, self.reference_signal,
                                               reference_index)
            iterations = iterations + 1
            logger.info("Iterations: {}".format(iterations))

        logger.info("Robust reference done")
        return self.noisy_channels, self.reference_signal
Beispiel #23
0
    def perform_reference(self):
        """Estimate the true signal mean and interpolate bad channels.

        This function implements the functionality of the `performReference` function
        as part of the PREP pipeline on mne raw object.

        Notes
        -----
        This function calls ``robust_reference`` first.
        Currently this function only implements the functionality of default
        settings, i.e., ``doRobustPost``.

        """
        # Phase 1: Estimate the true signal mean with robust referencing
        self.robust_reference()
        # If we interpolate the raw here we would be interpolating
        # more than what we later actually account for (in interpolated channels).
        dummy = self.raw.copy()
        dummy.info["bads"] = self.noisy_channels["bad_all"]
        dummy.interpolate_bads()
        self.reference_signal = (
            np.nanmean(dummy.get_data(picks=self.reference_channels), axis=0) * 1e6
        )
        del dummy
        rereferenced_index = [
            self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
        ]
        self.EEG = self.remove_reference(
            self.EEG, self.reference_signal, rereferenced_index
        )

        # Phase 2: Find the bad channels and interpolate
        self.raw._data = self.EEG * 1e-6
        noisy_detector = NoisyChannels(self.raw, random_state=self.random_state)
        noisy_detector.find_all_bads(ransac=self.ransac)

        # Record Noisy channels and EEG before interpolation
        self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
        self.EEG_before_interpolation = self.EEG.copy()
        self.noisy_channels_before_interpolation = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": noisy_detector.bad_by_deviation,
            "bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
            "bad_by_correlation": noisy_detector.bad_by_correlation,
            "bad_by_SNR": noisy_detector.bad_by_SNR,
            "bad_by_dropout": noisy_detector.bad_by_dropout,
            "bad_by_ransac": noisy_detector.bad_by_ransac,
            "bad_all": noisy_detector.get_bads(),
        }

        bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
        self.raw.info["bads"] = bad_channels
        self.raw.interpolate_bads()
        reference_correct = (
            np.nanmean(self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6
        )
        self.EEG = self.raw.get_data() * 1e6
        self.EEG = self.remove_reference(
            self.EEG, reference_correct, rereferenced_index
        )
        # reference signal after interpolation
        self.reference_signal_new = self.reference_signal + reference_correct
        # MNE Raw object after interpolation
        self.raw._data = self.EEG * 1e-6

        # Still noisy channels after interpolation
        self.interpolated_channels = bad_channels
        noisy_detector = NoisyChannels(self.raw, random_state=self.random_state)
        noisy_detector.find_all_bads(ransac=self.ransac)
        self.still_noisy_channels = noisy_detector.get_bads()
        self.raw.info["bads"] = self.still_noisy_channels
        self.noisy_channels_after_interpolation = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": noisy_detector.bad_by_deviation,
            "bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
            "bad_by_correlation": noisy_detector.bad_by_correlation,
            "bad_by_SNR": noisy_detector.bad_by_SNR,
            "bad_by_dropout": noisy_detector.bad_by_dropout,
            "bad_by_ransac": noisy_detector.bad_by_ransac,
            "bad_all": noisy_detector.get_bads(),
        }

        return self