Beispiel #1
0
    def perform_reference(self):
        """Estimate the true signal mean and interpolate bad channels.

        This function implements the functionality of the `performReference` function
        as part of the PREP pipeline on mne raw object.

        Notes
        -----
            This function calls robust_reference first
            Currently this function only implements the functionality of default
            settings, i.e., doRobustPost

        """
        # Phase 1: Estimate the true signal mean with robust referencing
        self.robust_reference()
        if self.noisy_channels["bad_all"]:
            self.raw.info["bads"] = self.noisy_channels["bad_all"]
            self.raw.interpolate_bads()
        self.reference_signal = (np.nanmean(
            self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6)
        rereferenced_index = [
            self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
        ]
        self.EEG = self.remove_reference(self.EEG, self.reference_signal,
                                         rereferenced_index)

        # Phase 2: Find the bad channels and interpolate
        self.raw._data = self.EEG * 1e-6
        noisy_detector = NoisyChannels(self.raw)
        noisy_detector.find_all_bads(ransac=self.ransac)

        # Record Noisy channels and EEG before interpolation
        self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
        self.EEG_before_interpolation = self.EEG.copy()

        bad_channels = _union(self.bad_before_interpolation,
                              self.unusable_channels)
        self.raw.info["bads"] = bad_channels
        self.raw.interpolate_bads()
        reference_correct = (np.nanmean(
            self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6)
        self.EEG = self.raw.get_data() * 1e6
        self.EEG = self.remove_reference(self.EEG, reference_correct,
                                         rereferenced_index)
        # reference signal after interpolation
        self.reference_signal_new = self.reference_signal + reference_correct
        # MNE Raw object after interpolation
        self.raw._data = self.EEG * 1e-6

        # Still noisy channels after interpolation
        self.interpolated_channels = bad_channels
        noisy_detector = NoisyChannels(self.raw)
        noisy_detector.find_all_bads(ransac=self.ransac)
        self.still_noisy_channels = noisy_detector.get_bads()
        self.raw.info["bads"] = self.still_noisy_channels
        return self
Beispiel #2
0
    def robust_reference(self):
        """Detect bad channels and estimate the robust reference signal.

        This function implements the functionality of the `robustReference` function
        as part of the PREP pipeline on mne raw object.

        Parameters
        ----------
            ransac : boolean
                Whether or not to use ransac

        Returns
        -------
            noisy_channels: dictionary
                A dictionary of names of noisy channels detected from all methods
                after referencing
            reference_signal: 1D Array
                Estimation of the 'true' signal mean

        """
        raw = self.raw.copy()
        raw._data = removeTrend(raw.get_data(), sample_rate=self.sfreq)

        # Determine unusable channels and remove them from the reference channels
        noisy_detector = NoisyChannels(raw,
                                       do_detrend=False,
                                       random_state=self.random_state)
        noisy_detector.find_all_bads(ransac=self.ransac)
        self.noisy_channels_original = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": noisy_detector.bad_by_deviation,
            "bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
            "bad_by_correlation": noisy_detector.bad_by_correlation,
            "bad_by_ransac": noisy_detector.bad_by_ransac,
            "bad_all": noisy_detector.get_bads(),
        }
        self.noisy_channels = self.noisy_channels_original.copy()
        logger.info("Bad channels: {}".format(self.noisy_channels))

        self.unusable_channels = _union(noisy_detector.bad_by_nan,
                                        noisy_detector.bad_by_flat)

        # According to the Matlab Implementation (see robustReference.m)
        # self.unusable_channels = _union(self.unusable_channels,
        # noisy_detector.bad_by_SNR)
        # but maybe this makes no difference...

        self.reference_channels = _set_diff(self.reference_channels,
                                            self.unusable_channels)

        # Get initial estimate of the reference by the specified method
        signal = raw.get_data() * 1e6
        self.reference_signal = (
            np.nanmedian(raw.get_data(picks=self.reference_channels), axis=0) *
            1e6)
        reference_index = [
            self.ch_names_eeg.index(ch) for ch in self.reference_channels
        ]
        signal_tmp = self.remove_reference(signal, self.reference_signal,
                                           reference_index)

        # Remove reference from signal, iteratively interpolating bad channels
        raw_tmp = raw.copy()
        iterations = 0
        noisy_channels_old = []
        max_iteration_num = 4

        while True:
            raw_tmp._data = signal_tmp * 1e-6
            noisy_detector = NoisyChannels(raw_tmp,
                                           do_detrend=False,
                                           random_state=self.random_state)
            # Detrend applied at the beginning of the function.
            noisy_detector.find_all_bads(ransac=self.ransac)
            self.noisy_channels["bad_by_nan"] = _union(
                self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan)
            self.noisy_channels["bad_by_flat"] = _union(
                self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat)
            self.noisy_channels["bad_by_deviation"] = _union(
                self.noisy_channels["bad_by_deviation"],
                noisy_detector.bad_by_deviation)
            self.noisy_channels["bad_by_hf_noise"] = _union(
                self.noisy_channels["bad_by_hf_noise"],
                noisy_detector.bad_by_hf_noise)
            self.noisy_channels["bad_by_correlation"] = _union(
                self.noisy_channels["bad_by_correlation"],
                noisy_detector.bad_by_correlation,
            )
            self.noisy_channels["bad_by_ransac"] = _union(
                self.noisy_channels["bad_by_ransac"],
                noisy_detector.bad_by_ransac)
            self.noisy_channels["bad_all"] = _union(
                self.noisy_channels["bad_all"], noisy_detector.get_bads())
            logger.info("Bad channels: {}".format(self.noisy_channels))

            if (iterations > 1 and (not self.noisy_channels["bad_all"] or set(
                    self.noisy_channels["bad_all"]) == set(noisy_channels_old))
                    or iterations > max_iteration_num):
                break
            noisy_channels_old = self.noisy_channels["bad_all"].copy()

            if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2:
                raise ValueError(
                    "RobustReference:TooManyBad "
                    "Could not perform a robust reference -- not enough good channels"
                )

            if self.noisy_channels["bad_all"]:
                raw_tmp._data = signal * 1e-6
                raw_tmp.info["bads"] = self.noisy_channels["bad_all"]
                raw_tmp.interpolate_bads()
                signal_tmp = raw_tmp.get_data() * 1e6
            else:
                signal_tmp = signal
            self.reference_signal = (np.nanmean(
                raw_tmp.get_data(picks=self.reference_channels), axis=0) * 1e6)

            signal_tmp = self.remove_reference(signal, self.reference_signal,
                                               reference_index)
            iterations = iterations + 1
            logger.info("Iterations: {}".format(iterations))

        logger.info("Robust reference done")
        return self.noisy_channels, self.reference_signal
def test_findnoisychannels(raw, montage):
    raw.set_montage(montage)
    nd = NoisyChannels(raw)
    nd.find_all_bads(ransac=True)
    bads = nd.get_bads()
    iterations = (
        10  # remove any noisy channels by interpolating the bads for 10 iterations
    )
    for iter in range(0, iterations):
        raw.info["bads"] = bads
        raw.interpolate_bads()
        nd = NoisyChannels(raw)
        nd.find_all_bads(ransac=True)
        bads = nd.get_bads()

    # make sure no bad channels exist in the data
    raw.drop_channels(ch_names=bads)

    # Test for NaN and flat channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Insert a nan value for a random channel
    rand_chn_idx1 = int(np.random.randint(0, m, 1))
    rand_chn_idx2 = int(np.random.randint(0, m, 1))
    rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1]
    rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2]
    raw_tmp._data[rand_chn_idx1, n - 1] = np.nan
    raw_tmp._data[rand_chn_idx2, :] = np.ones(n)
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_nan == [rand_chn_lab1]
    assert nd.bad_by_flat == [rand_chn_lab2]

    # Test for high and low deviations in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Now insert one random channel with very low deviations
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation
    # Inserting one random channel with a high deviation
    raw_tmp = raw.copy()
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    arbitrary_scaling = 5
    raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation

    # Test for correlation between EEG channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use cosine instead of sine to create a signal
    low = 10
    high = 30
    n_freq = 5
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = np.random.randint(low, high, n)
        signal[0, :] += np.cos(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_correlation()
    assert rand_chn_lab in nd.bad_by_correlation

    # Test for high freq noise detection
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use freqs between 90 and 100 Hz to insert hf noise
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = np.random.randint(90, 100, n)
        signal[0, :] += np.sin(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_hfnoise()
    assert rand_chn_lab in nd.bad_by_hf_noise

    # Test for signal to noise ratio in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # inserting an uncorrelated high frequency (90 Hz) signal in one channel
    raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_SNR()
    assert rand_chn_lab in nd.bad_by_SNR

    # Test for finding bad channels by RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp)
    np.random.seed(30)
    nd.find_bad_by_ransac()
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]
Beispiel #4
0
def test_findnoisychannels(raw, montage):
    """Test find noisy channels."""
    # Set a random state for the test
    rng = np.random.RandomState(30)

    raw.set_montage(montage)
    nd = NoisyChannels(raw, random_state=rng)
    nd.find_all_bads(ransac=True)
    bads = nd.get_bads()
    iterations = (
        10  # remove any noisy channels by interpolating the bads for 10 iterations
    )
    for iter in range(0, iterations):
        if len(bads) == 0:
            continue
        raw.info["bads"] = bads
        raw.interpolate_bads()
        nd = NoisyChannels(raw, random_state=rng)
        nd.find_all_bads(ransac=True)
        bads = nd.get_bads()

    # make sure no bad channels exist in the data
    raw.drop_channels(ch_names=bads)

    # Test for NaN and flat channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Insert a nan value for a random channel and make another random channel
    # completely flat (ones)
    idxs = rng.choice(np.arange(m), size=2, replace=False)
    rand_chn_idx1 = idxs[0]
    rand_chn_idx2 = idxs[1]
    rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1]
    rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2]
    raw_tmp._data[rand_chn_idx1, n - 1] = np.nan
    raw_tmp._data[rand_chn_idx2, :] = np.ones(n)
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_nan == [rand_chn_lab1]
    assert nd.bad_by_flat == [rand_chn_lab2]

    # Test for high and low deviations in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Now insert one random channel with very low deviations
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation
    # Inserting one random channel with a high deviation
    raw_tmp = raw.copy()
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    arbitrary_scaling = 5
    raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation

    # Test for correlation between EEG channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use cosine instead of sine to create a signal
    low = 10
    high = 30
    n_freq = 5
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = rng.randint(low, high, n)
        signal[0, :] += np.cos(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_correlation()
    assert rand_chn_lab in nd.bad_by_correlation

    # Test for high freq noise detection
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use freqs between 90 and 100 Hz to insert hf noise
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = rng.randint(90, 100, n)
        signal[0, :] += np.sin(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_hfnoise()
    assert rand_chn_lab in nd.bad_by_hf_noise

    # Test for signal to noise ratio in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # inserting an uncorrelated high frequency (90 Hz) signal in one channel
    raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_SNR()
    assert rand_chn_lab in nd.bad_by_SNR

    # Test for finding bad channels by RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_ransac()
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]

    # Test for finding bad channels by channel-wise RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_ransac(channel_wise=True)
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]

    # Test not-enough-memory and n_samples type exceptions
    raw_tmp = raw.copy()
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)

    # Set n_samples very very high to trigger a memory error
    n_samples = int(1e100)
    with pytest.raises(MemoryError):
        nd.find_bad_by_ransac(n_samples=n_samples)

    # Set n_samples to a float to trigger a type error
    n_samples = 35.5
    with pytest.raises(TypeError):
        nd.find_bad_by_ransac(n_samples=n_samples)

    # Test IOError when not enough channels for ransac predictions
    raw_tmp = raw.copy()
    # Make flat all channels except 2
    num_bad_channels = raw._data.shape[0] - 2
    raw_tmp._data[0:num_bad_channels, :] = np.zeros_like(
        raw_tmp._data[0:num_bad_channels, :]
    )
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_all_bads(ransac=False)
    with pytest.raises(IOError):
        nd.find_bad_by_ransac()
Beispiel #5
0
    def perform_reference(self, max_iterations=4):
        """Estimate the true signal mean and interpolate bad channels.

        Parameters
        ----------
        max_iterations : int, optional
            The maximum number of iterations of noisy channel removal to perform
            during robust referencing. Defaults to ``4``.

        This function implements the functionality of the `performReference` function
        as part of the PREP pipeline on mne raw object.

        Notes
        -----
        This function calls ``robust_reference`` first.
        Currently this function only implements the functionality of default
        settings, i.e., ``doRobustPost``.

        """
        # Phase 1: Estimate the true signal mean with robust referencing
        self.robust_reference(max_iterations)
        # If we interpolate the raw here we would be interpolating
        # more than what we later actually account for (in interpolated channels).
        dummy = self.raw.copy()
        dummy.info["bads"] = self.noisy_channels["bad_all"]
        if self.matlab_strict:
            _eeglab_interpolate_bads(dummy)
        else:
            dummy.interpolate_bads()
        self.reference_signal = np.nanmean(
            dummy.get_data(picks=self.reference_channels), axis=0)
        del dummy
        rereferenced_index = [
            self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
        ]
        self.EEG = self.remove_reference(self.EEG, self.reference_signal,
                                         rereferenced_index)

        # Phase 2: Find the bad channels and interpolate
        self.raw._data = self.EEG
        noisy_detector = NoisyChannels(self.raw,
                                       random_state=self.random_state,
                                       matlab_strict=self.matlab_strict)
        noisy_detector.find_all_bads(**self.ransac_settings)

        # Record Noisy channels and EEG before interpolation
        self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
        self.EEG_before_interpolation = self.EEG.copy()
        self.noisy_channels_before_interpolation = noisy_detector.get_bads(
            as_dict=True)
        self._extra_info["interpolated"] = noisy_detector._extra_info

        bad_channels = _union(self.bad_before_interpolation,
                              self.unusable_channels)
        self.raw.info["bads"] = bad_channels
        if self.matlab_strict:
            _eeglab_interpolate_bads(self.raw)
        else:
            self.raw.interpolate_bads()
        reference_correct = np.nanmean(
            self.raw.get_data(picks=self.reference_channels), axis=0)
        self.EEG = self.raw.get_data()
        self.EEG = self.remove_reference(self.EEG, reference_correct,
                                         rereferenced_index)
        # reference signal after interpolation
        self.reference_signal_new = self.reference_signal + reference_correct
        # MNE Raw object after interpolation
        self.raw._data = self.EEG

        # Still noisy channels after interpolation
        self.interpolated_channels = bad_channels
        noisy_detector = NoisyChannels(self.raw,
                                       random_state=self.random_state,
                                       matlab_strict=self.matlab_strict)
        noisy_detector.find_all_bads(**self.ransac_settings)
        self.still_noisy_channels = noisy_detector.get_bads()
        self.raw.info["bads"] = self.still_noisy_channels
        self.noisy_channels_after_interpolation = noisy_detector.get_bads(
            as_dict=True)
        self._extra_info["remaining_bad"] = noisy_detector._extra_info

        return self
Beispiel #6
0
    def robust_reference(self, max_iterations=4):
        """Detect bad channels and estimate the robust reference signal.

        This function implements the functionality of the `robustReference` function
        as part of the PREP pipeline on mne raw object.

        Parameters
        ----------
        max_iterations : int, optional
            The maximum number of iterations of noisy channel removal to perform
            during robust referencing. Defaults to ``4``.

        Returns
        -------
        noisy_channels: dict
            A dictionary of names of noisy channels detected from all methods
            after referencing.
        reference_signal: np.ndarray, shape(n, )
            Estimation of the 'true' signal mean

        """
        raw = self.raw.copy()
        raw._data = removeTrend(raw.get_data(),
                                self.sfreq,
                                matlab_strict=self.matlab_strict)

        # Determine unusable channels and remove them from the reference channels
        noisy_detector = NoisyChannels(
            raw,
            do_detrend=False,
            random_state=self.random_state,
            matlab_strict=self.matlab_strict,
        )
        noisy_detector.find_all_bads(**self.ransac_settings)
        self.noisy_channels_original = noisy_detector.get_bads(as_dict=True)
        self._extra_info["initial_bad"] = noisy_detector._extra_info
        logger.info("Bad channels: {}".format(self.noisy_channels_original))

        # Determine channels to use/exclude from initial reference estimation
        self.unusable_channels = _union(
            noisy_detector.bad_by_nan + noisy_detector.bad_by_flat,
            noisy_detector.bad_by_SNR,
        )
        reference_channels = _set_diff(self.reference_channels,
                                       self.unusable_channels)

        # Initialize channels to permanently flag as bad during referencing
        noisy = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": [],
            "bad_by_hf_noise": [],
            "bad_by_correlation": [],
            "bad_by_SNR": [],
            "bad_by_dropout": [],
            "bad_by_ransac": [],
            "bad_all": [],
        }

        # Get initial estimate of the reference by the specified method
        signal = raw.get_data()
        self.reference_signal = np.nanmedian(
            raw.get_data(picks=reference_channels), axis=0)
        reference_index = [
            self.ch_names_eeg.index(ch) for ch in reference_channels
        ]
        signal_tmp = self.remove_reference(signal, self.reference_signal,
                                           reference_index)

        # Remove reference from signal, iteratively interpolating bad channels
        raw_tmp = raw.copy()
        iterations = 0
        previous_bads = set()

        while True:
            raw_tmp._data = signal_tmp
            noisy_detector = NoisyChannels(
                raw_tmp,
                do_detrend=False,
                random_state=self.random_state,
                matlab_strict=self.matlab_strict,
            )
            # Detrend applied at the beginning of the function.

            # Detect all currently bad channels
            noisy_detector.find_all_bads(**self.ransac_settings)
            noisy_new = noisy_detector.get_bads(as_dict=True)

            # Specify bad channel types to ignore when updating noisy channels
            # NOTE: MATLAB PREP ignores dropout channels, possibly by mistake?
            # see: https://github.com/VisLab/EEG-Clean-Tools/issues/28
            ignore = ["bad_by_SNR", "bad_all"]
            if self.matlab_strict:
                ignore += ["bad_by_dropout"]

            # Update set of all noisy channels detected so far with any new ones
            bad_chans = set()
            for bad_type in noisy_new.keys():
                noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
                if bad_type not in ignore:
                    bad_chans.update(noisy[bad_type])
            noisy["bad_all"] = list(bad_chans)
            logger.info("Bad channels: {}".format(noisy))

            if (iterations > 1 and
                (len(bad_chans) == 0 or bad_chans == previous_bads)
                    or iterations > max_iterations):
                logger.info("Robust reference done")
                self.noisy_channels = noisy
                break
            previous_bads = bad_chans.copy()

            if raw_tmp.info["nchan"] - len(bad_chans) < 2:
                raise ValueError(
                    "RobustReference:TooManyBad "
                    "Could not perform a robust reference -- not enough good channels"
                )

            if len(bad_chans) > 0:
                raw_tmp._data = signal.copy()
                raw_tmp.info["bads"] = list(bad_chans)
                if self.matlab_strict:
                    _eeglab_interpolate_bads(raw_tmp)
                else:
                    raw_tmp.interpolate_bads()

            self.reference_signal = np.nanmean(
                raw_tmp.get_data(picks=reference_channels), axis=0)

            signal_tmp = self.remove_reference(signal, self.reference_signal,
                                               reference_index)
            iterations = iterations + 1
            logger.info("Iterations: {}".format(iterations))

        return self.noisy_channels, self.reference_signal
Beispiel #7
0
    def preproc(self, event_dict, baseline_start, stim_dur, montage, out_dir='preproc', subjects='all', tasks='all', 
                hp_cutoff=1, lp_cutoff=50, line_noise=60, seed=42, eog_chan='none'):
        '''Preprocesses a single EEG file. Assigns a list of epoched data to Dataset instance,
        where each entry in the list is a subject with concatenated task data. Here is the basic 
        structure of the preprocessing workflow:
        
            - Set the montage
            - Band-pass filter (high-pass filter by default)
            - Automatically detect bad channels
            - Notch filter out line-noise
            - Reference data to average of all EEG channels
            - Automated removal of eye-related artifacts using ICA
            - Spherical interpolation of detected bad channels
            - Extract events and epoch the data accordingly
            - Align the events based on type (still need to implement this!)
            - Create a list of epoched data, with subject as the element concatenated across tasks
        
        Parameters
        ----------
        event_dict: dict
            Maps integers to semantic labels for events within the experiment
            
        baseline_start: int or float
            Specify start of the baseline period (in seconds)
            
        stim_dur: int or float
            Stimulus duration (in seconds)
                Note: may need to make more general to allow various durations
                
        montage: mne.channels.montage.DigMontage
            Maps sensor locations to coordinates
            
        subjects: list or 'all'
            Specify which subjects to iterate through
            
        tasks: list or 'all'
            Specify which tasks to iterate through
            
        hp_cutoff: int or float
            The low frequency bound for the highpass filter in Hz
            
        line_noise: int or float
            The frequency of electrical noise to filter out in Hz
            
        seed: int
            Set the seed for replicable results
            
        eog_chan: str
            If there are no EOG channels present, select an EEG channel
            near the eyes for eye-related artifact detection
        '''

        missing = [] # initialize missing file list
        subj_iter = self.gen_iter(subjects, self.n_subj) # get subject iterator
        task_iter = self.gen_iter(tasks, self.n_task) # get task iterator

        # Iterate through subjects (initialize subject epoch list)
        epochs_subj = []
        for subj in subj_iter:

            # Iterate through tasks (initialize within-subject task epoch list)
            epochs_task = []
            for task in task_iter:
                # Specify the file format
                self.get_file_format(subj, task)

                try: # Handles missing files
                    raw = self.wget_raw_edf() # read
                except:
                    print(f'---\nThis file does not exist: {self.file_path}\n---')
                    # Need to write the missing file list out
                    missing.append(self.file_path)
                    break
                    
                # Standardize montage
                mne.datasets.eegbci.standardize(raw)
                # Set montage and strip channel names of "." characters
                raw.set_montage(montage)
                raw.rename_channels(lambda x: x.strip('.'))

                # Apply high-pass filter
                np.random.seed(seed)
                raw.filter(l_freq=hp_cutoff, h_freq=lp_cutoff, picks=['eeg', 'eog'])

                # Instantiate NoisyChannels object
                noise_chans = NoisyChannels(raw, do_detrend=False)

                # Detect bad channels through multiple methods
                noise_chans.find_bad_by_nan_flat()
                noise_chans.find_bad_by_deviation()
                noise_chans.find_bad_by_SNR()

                # Set the bad channels in the raw object
                raw.info['bads'] = noise_chans.get_bads()
                print(f'Bad channels detected: {noise_chans.get_bads()}')

                # Define the frequencies for the notch filter (60Hz and its harmonics)
                #notch_filt = np.arange(line_noise, raw.info['sfreq'] // 2, line_noise)

                # Apply notch filter
                #print(f'Apply notch filter at {line_noise} Hz and its harmonics')
                #raw.notch_filter(notch_filt, picks=['eeg', 'eog'], verbose=False)

                # Reference to the average of all the good channels 
                # Automatically excludes raw.info['bads']
                raw.set_eeg_reference(ref_channels='average')

                # Instantiate ICA object
                ica = mne.preprocessing.ICA(max_iter=1000)
                # Run ICA
                ica.fit(raw)

                # Find which ICs match the EOG pattern
                if eog_chan == 'none':
                    eog_indices, eog_scores = ica.find_bads_eog(raw, verbose=False)
                else:
                    eog_indices, eog_scores = ica.find_bads_eog(raw, eog_chan, verbose=False)

                # Apply the IC blink removals (if any)
                ica.apply(raw, exclude=eog_indices)
                print(f'Removed IC index {eog_indices}')

                # Interpolate bad channels
                raw.interpolate_bads()
                
                # Specify pre-processing directory
                preproc_dir = Path(f'{out_dir}/ {subj}')

                # If directory does not exist, one will be created.
                if not os.path.isdir(preproc_dir):
                    os.makedirs(preproc_dir)

               # raw.save(Path(preproc_dir, f'subj{subj}_task{task}_raw.fif'), 
                #         overwrite=True)

                # Find events
                events = mne.events_from_annotations(raw)[0]

                # Epoch the data
                preproc_epoch = mne.Epochs(raw, events, tmin=baseline_start, tmax=stim_dur, 
                                   event_id=event_dict, event_repeated='error', 
                                   on_missing='ignore', preload=True)
                
                # Equalize event counts
                preproc_epoch.equalize_event_counts(event_dict.keys())
                
                # Rearrange and align the epochs
                align = [preproc_epoch[i] for i in event_dict.keys()]
                align_epoch = mne.concatenate_epochs(align)
                
                # Add to epoch list
                epochs_task.append(align_epoch)

            # Assuming some data exists for a subject
            # Concatenate epochs within subject
            concat_epoch = mne.concatenate_epochs(epochs_task)
            epochs_subj.append(concat_epoch)
        # Attaches a list with each entry corresponding to epochs for a subject
        self.epoch_list = epochs_subj
Beispiel #8
0
    def perform_reference(self):
        """Estimate the true signal mean and interpolate bad channels.

        This function implements the functionality of the `performReference` function
        as part of the PREP pipeline on mne raw object.

        Notes
        -----
        This function calls ``robust_reference`` first.
        Currently this function only implements the functionality of default
        settings, i.e., ``doRobustPost``.

        """
        # Phase 1: Estimate the true signal mean with robust referencing
        self.robust_reference()
        # If we interpolate the raw here we would be interpolating
        # more than what we later actually account for (in interpolated channels).
        dummy = self.raw.copy()
        dummy.info["bads"] = self.noisy_channels["bad_all"]
        dummy.interpolate_bads()
        self.reference_signal = (
            np.nanmean(dummy.get_data(picks=self.reference_channels), axis=0) * 1e6
        )
        del dummy
        rereferenced_index = [
            self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
        ]
        self.EEG = self.remove_reference(
            self.EEG, self.reference_signal, rereferenced_index
        )

        # Phase 2: Find the bad channels and interpolate
        self.raw._data = self.EEG * 1e-6
        noisy_detector = NoisyChannels(self.raw, random_state=self.random_state)
        noisy_detector.find_all_bads(ransac=self.ransac)

        # Record Noisy channels and EEG before interpolation
        self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
        self.EEG_before_interpolation = self.EEG.copy()
        self.noisy_channels_before_interpolation = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": noisy_detector.bad_by_deviation,
            "bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
            "bad_by_correlation": noisy_detector.bad_by_correlation,
            "bad_by_SNR": noisy_detector.bad_by_SNR,
            "bad_by_dropout": noisy_detector.bad_by_dropout,
            "bad_by_ransac": noisy_detector.bad_by_ransac,
            "bad_all": noisy_detector.get_bads(),
        }

        bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
        self.raw.info["bads"] = bad_channels
        self.raw.interpolate_bads()
        reference_correct = (
            np.nanmean(self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6
        )
        self.EEG = self.raw.get_data() * 1e6
        self.EEG = self.remove_reference(
            self.EEG, reference_correct, rereferenced_index
        )
        # reference signal after interpolation
        self.reference_signal_new = self.reference_signal + reference_correct
        # MNE Raw object after interpolation
        self.raw._data = self.EEG * 1e-6

        # Still noisy channels after interpolation
        self.interpolated_channels = bad_channels
        noisy_detector = NoisyChannels(self.raw, random_state=self.random_state)
        noisy_detector.find_all_bads(ransac=self.ransac)
        self.still_noisy_channels = noisy_detector.get_bads()
        self.raw.info["bads"] = self.still_noisy_channels
        self.noisy_channels_after_interpolation = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": noisy_detector.bad_by_deviation,
            "bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
            "bad_by_correlation": noisy_detector.bad_by_correlation,
            "bad_by_SNR": noisy_detector.bad_by_SNR,
            "bad_by_dropout": noisy_detector.bad_by_dropout,
            "bad_by_ransac": noisy_detector.bad_by_ransac,
            "bad_all": noisy_detector.get_bads(),
        }

        return self