コード例 #1
0
ファイル: reference.py プロジェクト: adam2392/pyprep
    def perform_reference(self):
        """Estimate the true signal mean and interpolate bad channels.

        This function implements the functionality of the `performReference` function
        as part of the PREP pipeline on mne raw object.

        Notes
        -----
            This function calls robust_reference first
            Currently this function only implements the functionality of default
            settings, i.e., doRobustPost

        """
        # Phase 1: Estimate the true signal mean with robust referencing
        self.robust_reference()
        if self.noisy_channels["bad_all"]:
            self.raw.info["bads"] = self.noisy_channels["bad_all"]
            self.raw.interpolate_bads()
        self.reference_signal = (np.nanmean(
            self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6)
        rereferenced_index = [
            self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
        ]
        self.EEG = self.remove_reference(self.EEG, self.reference_signal,
                                         rereferenced_index)

        # Phase 2: Find the bad channels and interpolate
        self.raw._data = self.EEG * 1e-6
        noisy_detector = NoisyChannels(self.raw)
        noisy_detector.find_all_bads(ransac=self.ransac)

        # Record Noisy channels and EEG before interpolation
        self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
        self.EEG_before_interpolation = self.EEG.copy()

        bad_channels = _union(self.bad_before_interpolation,
                              self.unusable_channels)
        self.raw.info["bads"] = bad_channels
        self.raw.interpolate_bads()
        reference_correct = (np.nanmean(
            self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6)
        self.EEG = self.raw.get_data() * 1e6
        self.EEG = self.remove_reference(self.EEG, reference_correct,
                                         rereferenced_index)
        # reference signal after interpolation
        self.reference_signal_new = self.reference_signal + reference_correct
        # MNE Raw object after interpolation
        self.raw._data = self.EEG * 1e-6

        # Still noisy channels after interpolation
        self.interpolated_channels = bad_channels
        noisy_detector = NoisyChannels(self.raw)
        noisy_detector.find_all_bads(ransac=self.ransac)
        self.still_noisy_channels = noisy_detector.get_bads()
        self.raw.info["bads"] = self.still_noisy_channels
        return self
コード例 #2
0
def pyprep_noisy(matprep_artifacts):
    """Get original NoisyChannels results for comparison with MATLAB PREP.

    This fixture uses an artifact from MATLAB PREP of the CleanLined and
    detrended EEG signal right before MATLAB PREP runs its first iteration of
    NoisyChannels during re-referencing. As such, any differences in test results
    will be due to actual differences in the noisy channel detection code rather
    than differences at an earlier stage of the pipeline.

    """
    # Import pre-reference MATLAB PREP data
    preref_path = matprep_artifacts["4_matprep_pre_reference"]
    matprep_preref = mne.io.read_raw_eeglab(preref_path, preload=True)

    # Run NoisyChannels on MATLAB data and extract internal noisy info
    matprep_seed = 435656
    pyprep_noisy = NoisyChannels(matprep_preref,
                                 do_detrend=False,
                                 random_state=matprep_seed,
                                 matlab_strict=True)
    pyprep_noisy.find_all_bads()

    return pyprep_noisy
コード例 #3
0
    def robust_reference(self):
        """Detect bad channels and estimate the robust reference signal.

        This function implements the functionality of the `robustReference` function
        as part of the PREP pipeline on mne raw object.

        Parameters
        ----------
            ransac : boolean
                Whether or not to use ransac

        Returns
        -------
            noisy_channels: dictionary
                A dictionary of names of noisy channels detected from all methods
                after referencing
            reference_signal: 1D Array
                Estimation of the 'true' signal mean

        """
        raw = self.raw.copy()
        raw._data = removeTrend(raw.get_data(), sample_rate=self.sfreq)

        # Determine unusable channels and remove them from the reference channels
        noisy_detector = NoisyChannels(raw,
                                       do_detrend=False,
                                       random_state=self.random_state)
        noisy_detector.find_all_bads(ransac=self.ransac)
        self.noisy_channels_original = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": noisy_detector.bad_by_deviation,
            "bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
            "bad_by_correlation": noisy_detector.bad_by_correlation,
            "bad_by_ransac": noisy_detector.bad_by_ransac,
            "bad_all": noisy_detector.get_bads(),
        }
        self.noisy_channels = self.noisy_channels_original.copy()
        logger.info("Bad channels: {}".format(self.noisy_channels))

        self.unusable_channels = _union(noisy_detector.bad_by_nan,
                                        noisy_detector.bad_by_flat)

        # According to the Matlab Implementation (see robustReference.m)
        # self.unusable_channels = _union(self.unusable_channels,
        # noisy_detector.bad_by_SNR)
        # but maybe this makes no difference...

        self.reference_channels = _set_diff(self.reference_channels,
                                            self.unusable_channels)

        # Get initial estimate of the reference by the specified method
        signal = raw.get_data() * 1e6
        self.reference_signal = (
            np.nanmedian(raw.get_data(picks=self.reference_channels), axis=0) *
            1e6)
        reference_index = [
            self.ch_names_eeg.index(ch) for ch in self.reference_channels
        ]
        signal_tmp = self.remove_reference(signal, self.reference_signal,
                                           reference_index)

        # Remove reference from signal, iteratively interpolating bad channels
        raw_tmp = raw.copy()
        iterations = 0
        noisy_channels_old = []
        max_iteration_num = 4

        while True:
            raw_tmp._data = signal_tmp * 1e-6
            noisy_detector = NoisyChannels(raw_tmp,
                                           do_detrend=False,
                                           random_state=self.random_state)
            # Detrend applied at the beginning of the function.
            noisy_detector.find_all_bads(ransac=self.ransac)
            self.noisy_channels["bad_by_nan"] = _union(
                self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan)
            self.noisy_channels["bad_by_flat"] = _union(
                self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat)
            self.noisy_channels["bad_by_deviation"] = _union(
                self.noisy_channels["bad_by_deviation"],
                noisy_detector.bad_by_deviation)
            self.noisy_channels["bad_by_hf_noise"] = _union(
                self.noisy_channels["bad_by_hf_noise"],
                noisy_detector.bad_by_hf_noise)
            self.noisy_channels["bad_by_correlation"] = _union(
                self.noisy_channels["bad_by_correlation"],
                noisy_detector.bad_by_correlation,
            )
            self.noisy_channels["bad_by_ransac"] = _union(
                self.noisy_channels["bad_by_ransac"],
                noisy_detector.bad_by_ransac)
            self.noisy_channels["bad_all"] = _union(
                self.noisy_channels["bad_all"], noisy_detector.get_bads())
            logger.info("Bad channels: {}".format(self.noisy_channels))

            if (iterations > 1 and (not self.noisy_channels["bad_all"] or set(
                    self.noisy_channels["bad_all"]) == set(noisy_channels_old))
                    or iterations > max_iteration_num):
                break
            noisy_channels_old = self.noisy_channels["bad_all"].copy()

            if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2:
                raise ValueError(
                    "RobustReference:TooManyBad "
                    "Could not perform a robust reference -- not enough good channels"
                )

            if self.noisy_channels["bad_all"]:
                raw_tmp._data = signal * 1e-6
                raw_tmp.info["bads"] = self.noisy_channels["bad_all"]
                raw_tmp.interpolate_bads()
                signal_tmp = raw_tmp.get_data() * 1e6
            else:
                signal_tmp = signal
            self.reference_signal = (np.nanmean(
                raw_tmp.get_data(picks=self.reference_channels), axis=0) * 1e6)

            signal_tmp = self.remove_reference(signal, self.reference_signal,
                                               reference_index)
            iterations = iterations + 1
            logger.info("Iterations: {}".format(iterations))

        logger.info("Robust reference done")
        return self.noisy_channels, self.reference_signal
コード例 #4
0
def test_findnoisychannels(raw, montage):
    raw.set_montage(montage)
    nd = NoisyChannels(raw)
    nd.find_all_bads(ransac=True)
    bads = nd.get_bads()
    iterations = (
        10  # remove any noisy channels by interpolating the bads for 10 iterations
    )
    for iter in range(0, iterations):
        raw.info["bads"] = bads
        raw.interpolate_bads()
        nd = NoisyChannels(raw)
        nd.find_all_bads(ransac=True)
        bads = nd.get_bads()

    # make sure no bad channels exist in the data
    raw.drop_channels(ch_names=bads)

    # Test for NaN and flat channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Insert a nan value for a random channel
    rand_chn_idx1 = int(np.random.randint(0, m, 1))
    rand_chn_idx2 = int(np.random.randint(0, m, 1))
    rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1]
    rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2]
    raw_tmp._data[rand_chn_idx1, n - 1] = np.nan
    raw_tmp._data[rand_chn_idx2, :] = np.ones(n)
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_nan == [rand_chn_lab1]
    assert nd.bad_by_flat == [rand_chn_lab2]

    # Test for high and low deviations in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Now insert one random channel with very low deviations
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation
    # Inserting one random channel with a high deviation
    raw_tmp = raw.copy()
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    arbitrary_scaling = 5
    raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation

    # Test for correlation between EEG channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use cosine instead of sine to create a signal
    low = 10
    high = 30
    n_freq = 5
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = np.random.randint(low, high, n)
        signal[0, :] += np.cos(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_correlation()
    assert rand_chn_lab in nd.bad_by_correlation

    # Test for high freq noise detection
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use freqs between 90 and 100 Hz to insert hf noise
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = np.random.randint(90, 100, n)
        signal[0, :] += np.sin(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_hfnoise()
    assert rand_chn_lab in nd.bad_by_hf_noise

    # Test for signal to noise ratio in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # inserting an uncorrelated high frequency (90 Hz) signal in one channel
    raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_SNR()
    assert rand_chn_lab in nd.bad_by_SNR

    # Test for finding bad channels by RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp)
    np.random.seed(30)
    nd.find_bad_by_ransac()
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]
コード例 #5
0
def test_findnoisychannels(raw, montage):
    """Test find noisy channels."""
    # Set a random state for the test
    rng = np.random.RandomState(30)

    raw.set_montage(montage)
    nd = NoisyChannels(raw, random_state=rng)
    nd.find_all_bads(ransac=True)
    bads = nd.get_bads()
    iterations = (
        10  # remove any noisy channels by interpolating the bads for 10 iterations
    )
    for iter in range(0, iterations):
        if len(bads) == 0:
            continue
        raw.info["bads"] = bads
        raw.interpolate_bads()
        nd = NoisyChannels(raw, random_state=rng)
        nd.find_all_bads(ransac=True)
        bads = nd.get_bads()

    # make sure no bad channels exist in the data
    raw.drop_channels(ch_names=bads)

    # Test for NaN and flat channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Insert a nan value for a random channel and make another random channel
    # completely flat (ones)
    idxs = rng.choice(np.arange(m), size=2, replace=False)
    rand_chn_idx1 = idxs[0]
    rand_chn_idx2 = idxs[1]
    rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1]
    rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2]
    raw_tmp._data[rand_chn_idx1, n - 1] = np.nan
    raw_tmp._data[rand_chn_idx2, :] = np.ones(n)
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_nan == [rand_chn_lab1]
    assert nd.bad_by_flat == [rand_chn_lab2]

    # Test for high and low deviations in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Now insert one random channel with very low deviations
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation
    # Inserting one random channel with a high deviation
    raw_tmp = raw.copy()
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    arbitrary_scaling = 5
    raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation

    # Test for correlation between EEG channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use cosine instead of sine to create a signal
    low = 10
    high = 30
    n_freq = 5
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = rng.randint(low, high, n)
        signal[0, :] += np.cos(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_correlation()
    assert rand_chn_lab in nd.bad_by_correlation

    # Test for high freq noise detection
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use freqs between 90 and 100 Hz to insert hf noise
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = rng.randint(90, 100, n)
        signal[0, :] += np.sin(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_hfnoise()
    assert rand_chn_lab in nd.bad_by_hf_noise

    # Test for signal to noise ratio in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # inserting an uncorrelated high frequency (90 Hz) signal in one channel
    raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_SNR()
    assert rand_chn_lab in nd.bad_by_SNR

    # Test for finding bad channels by RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_ransac()
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]

    # Test for finding bad channels by channel-wise RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_ransac(channel_wise=True)
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]

    # Test not-enough-memory and n_samples type exceptions
    raw_tmp = raw.copy()
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)

    # Set n_samples very very high to trigger a memory error
    n_samples = int(1e100)
    with pytest.raises(MemoryError):
        nd.find_bad_by_ransac(n_samples=n_samples)

    # Set n_samples to a float to trigger a type error
    n_samples = 35.5
    with pytest.raises(TypeError):
        nd.find_bad_by_ransac(n_samples=n_samples)

    # Test IOError when not enough channels for ransac predictions
    raw_tmp = raw.copy()
    # Make flat all channels except 2
    num_bad_channels = raw._data.shape[0] - 2
    raw_tmp._data[0:num_bad_channels, :] = np.zeros_like(
        raw_tmp._data[0:num_bad_channels, :]
    )
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_all_bads(ransac=False)
    with pytest.raises(IOError):
        nd.find_bad_by_ransac()
コード例 #6
0
ファイル: reference.py プロジェクト: sappelhoff/pyprep
    def perform_reference(self, max_iterations=4):
        """Estimate the true signal mean and interpolate bad channels.

        Parameters
        ----------
        max_iterations : int, optional
            The maximum number of iterations of noisy channel removal to perform
            during robust referencing. Defaults to ``4``.

        This function implements the functionality of the `performReference` function
        as part of the PREP pipeline on mne raw object.

        Notes
        -----
        This function calls ``robust_reference`` first.
        Currently this function only implements the functionality of default
        settings, i.e., ``doRobustPost``.

        """
        # Phase 1: Estimate the true signal mean with robust referencing
        self.robust_reference(max_iterations)
        # If we interpolate the raw here we would be interpolating
        # more than what we later actually account for (in interpolated channels).
        dummy = self.raw.copy()
        dummy.info["bads"] = self.noisy_channels["bad_all"]
        if self.matlab_strict:
            _eeglab_interpolate_bads(dummy)
        else:
            dummy.interpolate_bads()
        self.reference_signal = np.nanmean(
            dummy.get_data(picks=self.reference_channels), axis=0)
        del dummy
        rereferenced_index = [
            self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
        ]
        self.EEG = self.remove_reference(self.EEG, self.reference_signal,
                                         rereferenced_index)

        # Phase 2: Find the bad channels and interpolate
        self.raw._data = self.EEG
        noisy_detector = NoisyChannels(self.raw,
                                       random_state=self.random_state,
                                       matlab_strict=self.matlab_strict)
        noisy_detector.find_all_bads(**self.ransac_settings)

        # Record Noisy channels and EEG before interpolation
        self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
        self.EEG_before_interpolation = self.EEG.copy()
        self.noisy_channels_before_interpolation = noisy_detector.get_bads(
            as_dict=True)
        self._extra_info["interpolated"] = noisy_detector._extra_info

        bad_channels = _union(self.bad_before_interpolation,
                              self.unusable_channels)
        self.raw.info["bads"] = bad_channels
        if self.matlab_strict:
            _eeglab_interpolate_bads(self.raw)
        else:
            self.raw.interpolate_bads()
        reference_correct = np.nanmean(
            self.raw.get_data(picks=self.reference_channels), axis=0)
        self.EEG = self.raw.get_data()
        self.EEG = self.remove_reference(self.EEG, reference_correct,
                                         rereferenced_index)
        # reference signal after interpolation
        self.reference_signal_new = self.reference_signal + reference_correct
        # MNE Raw object after interpolation
        self.raw._data = self.EEG

        # Still noisy channels after interpolation
        self.interpolated_channels = bad_channels
        noisy_detector = NoisyChannels(self.raw,
                                       random_state=self.random_state,
                                       matlab_strict=self.matlab_strict)
        noisy_detector.find_all_bads(**self.ransac_settings)
        self.still_noisy_channels = noisy_detector.get_bads()
        self.raw.info["bads"] = self.still_noisy_channels
        self.noisy_channels_after_interpolation = noisy_detector.get_bads(
            as_dict=True)
        self._extra_info["remaining_bad"] = noisy_detector._extra_info

        return self
コード例 #7
0
ファイル: reference.py プロジェクト: sappelhoff/pyprep
    def robust_reference(self, max_iterations=4):
        """Detect bad channels and estimate the robust reference signal.

        This function implements the functionality of the `robustReference` function
        as part of the PREP pipeline on mne raw object.

        Parameters
        ----------
        max_iterations : int, optional
            The maximum number of iterations of noisy channel removal to perform
            during robust referencing. Defaults to ``4``.

        Returns
        -------
        noisy_channels: dict
            A dictionary of names of noisy channels detected from all methods
            after referencing.
        reference_signal: np.ndarray, shape(n, )
            Estimation of the 'true' signal mean

        """
        raw = self.raw.copy()
        raw._data = removeTrend(raw.get_data(),
                                self.sfreq,
                                matlab_strict=self.matlab_strict)

        # Determine unusable channels and remove them from the reference channels
        noisy_detector = NoisyChannels(
            raw,
            do_detrend=False,
            random_state=self.random_state,
            matlab_strict=self.matlab_strict,
        )
        noisy_detector.find_all_bads(**self.ransac_settings)
        self.noisy_channels_original = noisy_detector.get_bads(as_dict=True)
        self._extra_info["initial_bad"] = noisy_detector._extra_info
        logger.info("Bad channels: {}".format(self.noisy_channels_original))

        # Determine channels to use/exclude from initial reference estimation
        self.unusable_channels = _union(
            noisy_detector.bad_by_nan + noisy_detector.bad_by_flat,
            noisy_detector.bad_by_SNR,
        )
        reference_channels = _set_diff(self.reference_channels,
                                       self.unusable_channels)

        # Initialize channels to permanently flag as bad during referencing
        noisy = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": [],
            "bad_by_hf_noise": [],
            "bad_by_correlation": [],
            "bad_by_SNR": [],
            "bad_by_dropout": [],
            "bad_by_ransac": [],
            "bad_all": [],
        }

        # Get initial estimate of the reference by the specified method
        signal = raw.get_data()
        self.reference_signal = np.nanmedian(
            raw.get_data(picks=reference_channels), axis=0)
        reference_index = [
            self.ch_names_eeg.index(ch) for ch in reference_channels
        ]
        signal_tmp = self.remove_reference(signal, self.reference_signal,
                                           reference_index)

        # Remove reference from signal, iteratively interpolating bad channels
        raw_tmp = raw.copy()
        iterations = 0
        previous_bads = set()

        while True:
            raw_tmp._data = signal_tmp
            noisy_detector = NoisyChannels(
                raw_tmp,
                do_detrend=False,
                random_state=self.random_state,
                matlab_strict=self.matlab_strict,
            )
            # Detrend applied at the beginning of the function.

            # Detect all currently bad channels
            noisy_detector.find_all_bads(**self.ransac_settings)
            noisy_new = noisy_detector.get_bads(as_dict=True)

            # Specify bad channel types to ignore when updating noisy channels
            # NOTE: MATLAB PREP ignores dropout channels, possibly by mistake?
            # see: https://github.com/VisLab/EEG-Clean-Tools/issues/28
            ignore = ["bad_by_SNR", "bad_all"]
            if self.matlab_strict:
                ignore += ["bad_by_dropout"]

            # Update set of all noisy channels detected so far with any new ones
            bad_chans = set()
            for bad_type in noisy_new.keys():
                noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
                if bad_type not in ignore:
                    bad_chans.update(noisy[bad_type])
            noisy["bad_all"] = list(bad_chans)
            logger.info("Bad channels: {}".format(noisy))

            if (iterations > 1 and
                (len(bad_chans) == 0 or bad_chans == previous_bads)
                    or iterations > max_iterations):
                logger.info("Robust reference done")
                self.noisy_channels = noisy
                break
            previous_bads = bad_chans.copy()

            if raw_tmp.info["nchan"] - len(bad_chans) < 2:
                raise ValueError(
                    "RobustReference:TooManyBad "
                    "Could not perform a robust reference -- not enough good channels"
                )

            if len(bad_chans) > 0:
                raw_tmp._data = signal.copy()
                raw_tmp.info["bads"] = list(bad_chans)
                if self.matlab_strict:
                    _eeglab_interpolate_bads(raw_tmp)
                else:
                    raw_tmp.interpolate_bads()

            self.reference_signal = np.nanmean(
                raw_tmp.get_data(picks=reference_channels), axis=0)

            signal_tmp = self.remove_reference(signal, self.reference_signal,
                                               reference_index)
            iterations = iterations + 1
            logger.info("Iterations: {}".format(iterations))

        return self.noisy_channels, self.reference_signal
コード例 #8
0
ファイル: reference.py プロジェクト: yjmantilla/pyprep
    def perform_reference(self):
        """Estimate the true signal mean and interpolate bad channels.

        This function implements the functionality of the `performReference` function
        as part of the PREP pipeline on mne raw object.

        Notes
        -----
        This function calls ``robust_reference`` first.
        Currently this function only implements the functionality of default
        settings, i.e., ``doRobustPost``.

        """
        # Phase 1: Estimate the true signal mean with robust referencing
        self.robust_reference()
        # If we interpolate the raw here we would be interpolating
        # more than what we later actually account for (in interpolated channels).
        dummy = self.raw.copy()
        dummy.info["bads"] = self.noisy_channels["bad_all"]
        dummy.interpolate_bads()
        self.reference_signal = (
            np.nanmean(dummy.get_data(picks=self.reference_channels), axis=0) * 1e6
        )
        del dummy
        rereferenced_index = [
            self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
        ]
        self.EEG = self.remove_reference(
            self.EEG, self.reference_signal, rereferenced_index
        )

        # Phase 2: Find the bad channels and interpolate
        self.raw._data = self.EEG * 1e-6
        noisy_detector = NoisyChannels(self.raw, random_state=self.random_state)
        noisy_detector.find_all_bads(ransac=self.ransac)

        # Record Noisy channels and EEG before interpolation
        self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
        self.EEG_before_interpolation = self.EEG.copy()
        self.noisy_channels_before_interpolation = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": noisy_detector.bad_by_deviation,
            "bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
            "bad_by_correlation": noisy_detector.bad_by_correlation,
            "bad_by_SNR": noisy_detector.bad_by_SNR,
            "bad_by_dropout": noisy_detector.bad_by_dropout,
            "bad_by_ransac": noisy_detector.bad_by_ransac,
            "bad_all": noisy_detector.get_bads(),
        }

        bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
        self.raw.info["bads"] = bad_channels
        self.raw.interpolate_bads()
        reference_correct = (
            np.nanmean(self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6
        )
        self.EEG = self.raw.get_data() * 1e6
        self.EEG = self.remove_reference(
            self.EEG, reference_correct, rereferenced_index
        )
        # reference signal after interpolation
        self.reference_signal_new = self.reference_signal + reference_correct
        # MNE Raw object after interpolation
        self.raw._data = self.EEG * 1e-6

        # Still noisy channels after interpolation
        self.interpolated_channels = bad_channels
        noisy_detector = NoisyChannels(self.raw, random_state=self.random_state)
        noisy_detector.find_all_bads(ransac=self.ransac)
        self.still_noisy_channels = noisy_detector.get_bads()
        self.raw.info["bads"] = self.still_noisy_channels
        self.noisy_channels_after_interpolation = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": noisy_detector.bad_by_deviation,
            "bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
            "bad_by_correlation": noisy_detector.bad_by_correlation,
            "bad_by_SNR": noisy_detector.bad_by_SNR,
            "bad_by_dropout": noisy_detector.bad_by_dropout,
            "bad_by_ransac": noisy_detector.bad_by_ransac,
            "bad_all": noisy_detector.get_bads(),
        }

        return self