Exemplo n.º 1
0
def test_highpass():
    """Test for checking high pass filters."""
    srate = 100
    t = np.arange(0, 30, 1 / srate)
    lowfreq_signal = np.sin(2 * np.pi * 0.1 * t)
    highfreq_signal = np.sin(2 * np.pi * 8 * t)
    signal = lowfreq_signal + highfreq_signal
    lowpass_filt1 = removeTrend.removeTrend(signal,
                                            detrendType="High pass sinc",
                                            sample_rate=srate,
                                            detrendCutoff=1)
    lowpass_filt2 = removeTrend.removeTrend(signal,
                                            detrendType="High pass",
                                            sample_rate=srate,
                                            detrendCutoff=1)
    lowpass_filt3 = removeTrend.removeTrend(
        signal,
        detrendType="High pass",
        sample_rate=srate,
        detrendCutoff=1,
        matlab_strict=True,
    )
    error1 = lowpass_filt1 - highfreq_signal
    error2 = lowpass_filt2 - highfreq_signal
    error3 = lowpass_filt3 - highfreq_signal
    assert np.sqrt(np.mean(error1**2)) < 0.1
    assert np.sqrt(np.mean(error2**2)) < 0.1
    assert np.sqrt(np.mean(error3**2)) < 0.1
Exemplo n.º 2
0
def test_compare_removeTrend(matprep_artifacts):
    """Test the numeric equivalence of removeTrend to MATLAB PREP."""
    # Get paths of MATLAB .set files
    raw_path = matprep_artifacts["1_matprep_raw"]
    detrend_path = matprep_artifacts["2_matprep_removetrend"]

    # Load relevant MATLAB data
    matprep_raw = mne.io.read_raw_eeglab(raw_path, preload=True)
    matprep_detrended = mne.io.read_raw_eeglab(detrend_path, preload=True)
    sample_rate = matprep_raw.info["sfreq"]

    # Apply removeTrend to raw artifact to get expected and actual signals
    expected = matprep_detrended._data
    actual = removeTrend(matprep_raw._data,
                         sample_rate,
                         detrendType="high pass",
                         matlab_strict=True)

    # Check MATLAB equivalence at start of recording
    win_size = 500  # window of samples to check
    assert np.allclose(actual[:, :win_size],
                       expected[:, :win_size],
                       equal_nan=True)

    # Check MATLAB equivalence in middle of recording
    win_start = int(actual.shape[1] / 2)
    win_end = win_start + win_size
    assert np.allclose(actual[:, win_start:win_end],
                       expected[:, win_start:win_end],
                       equal_nan=True)

    # Check MATLAB equivalence at end of recording
    assert np.allclose(actual[:, -win_size:],
                       expected[:, -win_size:],
                       equal_nan=True)
Exemplo n.º 3
0
    def fit(self):
        """Run the whole PREP pipeline."""
        noisy_detector = NoisyChannels(self.raw_eeg, random_state=self.random_state)
        noisy_detector.find_bad_by_nan_flat()
        # unusable_channels = _union(
        #     noisy_detector.bad_by_nan, noisy_detector.bad_by_flat
        # )
        # reference_channels = _set_diff(self.prep_params["ref_chs"], unusable_channels)
        # Step 1: 1Hz high pass filtering
        if len(self.prep_params["line_freqs"]) != 0:
            self.EEG_new = removeTrend(self.EEG_raw, sample_rate=self.sfreq)

            # Step 2: Removing line noise
            linenoise = self.prep_params["line_freqs"]
            if self.filter_kwargs is None:
                self.EEG_clean = mne.filter.notch_filter(
                    self.EEG_new,
                    Fs=self.sfreq,
                    freqs=linenoise,
                    method="spectrum_fit",
                    mt_bandwidth=2,
                    p_value=0.01,
                    filter_length="10s",
                )
            else:
                self.EEG_clean = mne.filter.notch_filter(
                    self.EEG_new,
                    Fs=self.sfreq,
                    freqs=linenoise,
                    **self.filter_kwargs,
                )

            # Add Trend back
            self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean
            self.raw_eeg._data = self.EEG * 1e-6

        # Step 3: Referencing
        reference = Reference(
            self.raw_eeg,
            self.prep_params,
            ransac=self.ransac,
            random_state=self.random_state,
        )
        reference.perform_reference()
        self.raw_eeg = reference.raw
        self.noisy_channels_original = reference.noisy_channels_original
        self.noisy_channels_before_interpolation = (
            reference.noisy_channels_before_interpolation
        )
        self.noisy_channels_after_interpolation = (
            reference.noisy_channels_after_interpolation
        )
        self.bad_before_interpolation = reference.bad_before_interpolation
        self.EEG_before_interpolation = reference.EEG_before_interpolation
        self.reference_before_interpolation = reference.reference_signal
        self.reference_after_interpolation = reference.reference_signal_new
        self.interpolated_channels = reference.interpolated_channels
        self.still_noisy_channels = reference.still_noisy_channels

        return self
Exemplo n.º 4
0
    def __init__(self, raw):
        """Initialize the class."""
        # Make sure that we got an MNE object
        assert isinstance(raw, mne.io.BaseRaw)

        self.raw_mne = raw.copy()
        self.sample_rate = raw.info["sfreq"]
        self.EEGData = self.raw_mne.get_data(picks="eeg")
        self.EEGData = removeTrend(self.EEGData, sample_rate=self.sample_rate)
        self.EEGData_beforeFilt = self.EEGData
        self.ch_names_original = np.asarray(raw.info["ch_names"])
        self.n_chans_original = len(self.ch_names_original)
        self.n_chans_new = self.n_chans_original
        self.signal_len = len(self.raw_mne.times)
        self.original_dimensions = np.shape(self.EEGData)
        self.new_dimensions = self.original_dimensions
        self.original_channels = np.arange(self.original_dimensions[0])
        self.new_channels = self.original_channels
        self.ch_names_new = self.ch_names_original
        self.channels_interpolate = self.original_channels

        # The identified bad channels
        self.bad_by_nan = []
        self.bad_by_flat = []
        self.bad_by_deviation = []
        self.bad_by_hf_noise = []
        self.bad_by_correlation = []
        self.bad_by_SNR = []
        self.bad_by_dropout = []
        self.bad_by_ransac = []
Exemplo n.º 5
0
    def __init__(self,
                 raw,
                 do_detrend=True,
                 random_state=None,
                 matlab_strict=False):
        # Make sure that we got an MNE object
        assert isinstance(raw, mne.io.BaseRaw)

        raw.load_data()
        self.raw_mne = raw.copy()
        self.raw_mne.pick_types(eeg=True)
        self.sample_rate = raw.info["sfreq"]
        if do_detrend:
            self.raw_mne._data = removeTrend(self.raw_mne.get_data(),
                                             self.sample_rate,
                                             matlab_strict=matlab_strict)
        self.matlab_strict = matlab_strict

        # Extra data for debugging
        self._extra_info = {
            "bad_by_deviation": {},
            "bad_by_hf_noise": {},
            "bad_by_correlation": {},
            "bad_by_dropout": {},
            "bad_by_ransac": {},
        }

        # random_state
        self.random_state = check_random_state(random_state)

        # The identified bad channels
        self.bad_by_nan = []
        self.bad_by_flat = []
        self.bad_by_deviation = []
        self.bad_by_hf_noise = []
        self.bad_by_correlation = []
        self.bad_by_SNR = []
        self.bad_by_dropout = []
        self.bad_by_ransac = []

        # Get original EEG channel names, channel count & samples
        ch_names = np.asarray(self.raw_mne.info["ch_names"])
        self.ch_names_original = ch_names
        self.n_chans_original = len(ch_names)
        self.n_samples = raw._data.shape[1]

        # Before anything else, flag bad-by-NaNs and bad-by-flats
        self.find_bad_by_nan_flat()
        bads_by_nan_flat = self.bad_by_nan + self.bad_by_flat

        # Make a subset of the data containing only usable EEG channels
        self.usable_idx = np.isin(ch_names, bads_by_nan_flat, invert=True)
        self.EEGData = self.raw_mne.get_data(picks=ch_names[self.usable_idx])
        self.EEGFiltered = None

        # Get usable EEG channel names & channel counts
        self.ch_names_new = np.asarray(ch_names[self.usable_idx])
        self.n_chans_new = len(self.ch_names_new)
Exemplo n.º 6
0
    def __init__(self, raw, do_detrend=True, random_state=None):
        """Initialize the class.

        Parameters
        ----------
        raw : mne.io.Raw
            The MNE raw object.
        do_detrend : bool
            Whether or not to remove a trend from the data upon initializing the
            `NoisyChannels` object. Defaults to True.
        random_state : int | None | np.random.RandomState
            The random seed at which to initialize the class. If random_state
            is an int, it will be used as a seed for RandomState.
            If None, the seed will be obtained from the operating system
            (see RandomState for details). Default is None.

        """
        # Make sure that we got an MNE object
        assert isinstance(raw, mne.io.BaseRaw)

        self.raw_mne = raw.copy()
        self.sample_rate = raw.info["sfreq"]
        if do_detrend:
            self.raw_mne._data = removeTrend(
                self.raw_mne.get_data(), sample_rate=self.sample_rate
            )

        self.EEGData = self.raw_mne.get_data(picks="eeg")
        self.EEGData_beforeFilt = self.EEGData
        self.ch_names_original = np.asarray(raw.info["ch_names"])
        self.n_chans_original = len(self.ch_names_original)
        self.n_chans_new = self.n_chans_original
        self.signal_len = len(self.raw_mne.times)
        self.original_dimensions = np.shape(self.EEGData)
        self.new_dimensions = self.original_dimensions
        self.original_channels = np.arange(self.original_dimensions[0])
        self.new_channels = self.original_channels
        self.ch_names_new = self.ch_names_original
        self.channels_interpolate = self.original_channels

        # random_state
        self.random_state = check_random_state(random_state)

        # The identified bad channels
        self.bad_by_nan = []
        self.bad_by_flat = []
        self.bad_by_deviation = []
        self.bad_by_hf_noise = []
        self.bad_by_correlation = []
        self.bad_by_SNR = []
        self.bad_by_dropout = []
        self.bad_by_ransac = []
Exemplo n.º 7
0
def raw_clean_detrend(raw_clean):
    """Return a pre-detrended `mne.io.Raw` object with no bad channels.

    Based on the data from the `raw_clean` fixture, which uses the data for
    subject 1, run 1 from the Physionet BCI2000 dataset.

    This is only run once per session to save time.

    """
    raw_clean_detrended = raw_clean.copy()
    raw_clean_detrended._data = removeTrend(raw_clean.get_data(),
                                            raw_clean.info["sfreq"])
    return raw_clean_detrended
Exemplo n.º 8
0
def test_detrend():
    """Test for local regression to remove linear trend from EEG data."""
    # creating a new signal for checking detrending using local regression
    srate = 100
    t = np.arange(0, 30, 1 / srate)
    randgen = np.random.RandomState(9)
    npoints = len(t)
    signal = randgen.randn(npoints)
    signal_trend = 2 + 1.5 * np.linspace(0, 1, npoints) + signal
    signal_detrend = removeTrend.removeTrend(
        signal_trend, detrendType="Local detrend", sample_rate=100
    )
    error3 = signal_detrend - signal
    assert np.sqrt(np.mean(error3 ** 2)) < 0.1
Exemplo n.º 9
0
    def robust_reference(self):
        """Detect bad channels and estimate the robust reference signal.

        This function implements the functionality of the `robustReference` function
        as part of the PREP pipeline on mne raw object.

        Parameters
        ----------
            ransac : boolean
                Whether or not to use ransac

        Returns
        -------
            noisy_channels: dictionary
                A dictionary of names of noisy channels detected from all methods
                after referencing
            reference_signal: 1D Array
                Estimation of the 'true' signal mean

        """
        raw = self.raw.copy()
        raw._data = removeTrend(raw.get_data(), sample_rate=self.sfreq)

        # Determine unusable channels and remove them from the reference channels
        noisy_detector = NoisyChannels(raw,
                                       do_detrend=False,
                                       random_state=self.random_state)
        noisy_detector.find_all_bads(ransac=self.ransac)
        self.noisy_channels_original = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": noisy_detector.bad_by_deviation,
            "bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
            "bad_by_correlation": noisy_detector.bad_by_correlation,
            "bad_by_ransac": noisy_detector.bad_by_ransac,
            "bad_all": noisy_detector.get_bads(),
        }
        self.noisy_channels = self.noisy_channels_original.copy()
        logger.info("Bad channels: {}".format(self.noisy_channels))

        self.unusable_channels = _union(noisy_detector.bad_by_nan,
                                        noisy_detector.bad_by_flat)

        # According to the Matlab Implementation (see robustReference.m)
        # self.unusable_channels = _union(self.unusable_channels,
        # noisy_detector.bad_by_SNR)
        # but maybe this makes no difference...

        self.reference_channels = _set_diff(self.reference_channels,
                                            self.unusable_channels)

        # Get initial estimate of the reference by the specified method
        signal = raw.get_data() * 1e6
        self.reference_signal = (
            np.nanmedian(raw.get_data(picks=self.reference_channels), axis=0) *
            1e6)
        reference_index = [
            self.ch_names_eeg.index(ch) for ch in self.reference_channels
        ]
        signal_tmp = self.remove_reference(signal, self.reference_signal,
                                           reference_index)

        # Remove reference from signal, iteratively interpolating bad channels
        raw_tmp = raw.copy()
        iterations = 0
        noisy_channels_old = []
        max_iteration_num = 4

        while True:
            raw_tmp._data = signal_tmp * 1e-6
            noisy_detector = NoisyChannels(raw_tmp,
                                           do_detrend=False,
                                           random_state=self.random_state)
            # Detrend applied at the beginning of the function.
            noisy_detector.find_all_bads(ransac=self.ransac)
            self.noisy_channels["bad_by_nan"] = _union(
                self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan)
            self.noisy_channels["bad_by_flat"] = _union(
                self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat)
            self.noisy_channels["bad_by_deviation"] = _union(
                self.noisy_channels["bad_by_deviation"],
                noisy_detector.bad_by_deviation)
            self.noisy_channels["bad_by_hf_noise"] = _union(
                self.noisy_channels["bad_by_hf_noise"],
                noisy_detector.bad_by_hf_noise)
            self.noisy_channels["bad_by_correlation"] = _union(
                self.noisy_channels["bad_by_correlation"],
                noisy_detector.bad_by_correlation,
            )
            self.noisy_channels["bad_by_ransac"] = _union(
                self.noisy_channels["bad_by_ransac"],
                noisy_detector.bad_by_ransac)
            self.noisy_channels["bad_all"] = _union(
                self.noisy_channels["bad_all"], noisy_detector.get_bads())
            logger.info("Bad channels: {}".format(self.noisy_channels))

            if (iterations > 1 and (not self.noisy_channels["bad_all"] or set(
                    self.noisy_channels["bad_all"]) == set(noisy_channels_old))
                    or iterations > max_iteration_num):
                break
            noisy_channels_old = self.noisy_channels["bad_all"].copy()

            if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2:
                raise ValueError(
                    "RobustReference:TooManyBad "
                    "Could not perform a robust reference -- not enough good channels"
                )

            if self.noisy_channels["bad_all"]:
                raw_tmp._data = signal * 1e-6
                raw_tmp.info["bads"] = self.noisy_channels["bad_all"]
                raw_tmp.interpolate_bads()
                signal_tmp = raw_tmp.get_data() * 1e6
            else:
                signal_tmp = signal
            self.reference_signal = (np.nanmean(
                raw_tmp.get_data(picks=self.reference_channels), axis=0) * 1e6)

            signal_tmp = self.remove_reference(signal, self.reference_signal,
                                               reference_index)
            iterations = iterations + 1
            logger.info("Iterations: {}".format(iterations))

        logger.info("Robust reference done")
        return self.noisy_channels, self.reference_signal
Exemplo n.º 10
0
def preprocess_eeg(id_num, random_seed=None):

    # Set important variables
    bids_path = BIDSPath(id_num, task=task, datatype=datatype, root=bids_root)
    plot_path = os.path.join(plotdir, "sub_{0}".format(id_num))
    if os.path.exists(plot_path):
        shutil.rmtree(plot_path)
    os.mkdir(plot_path)
    if not random_seed:
        random_seed = int(binascii.b2a_hex(os.urandom(4)), 16)
    random.seed(random_seed)
    id_info = {"id": id_num, "random_seed": random_seed}

    ### Load and prepare EEG data #############################################

    header = "### Processing sub-{0} (seed: {1}) ###".format(
        id_num, random_seed)
    print("\n" + "#" * len(header))
    print(header)
    print("#" * len(header) + "\n")

    # Load EEG data
    raw = read_raw_bids(bids_path, verbose=True)

    # Check if recording is complete
    complete = len(raw.annotations) >= 600

    # Add a montage to the data
    montage_kind = "standard_1005"
    montage = mne.channels.make_standard_montage(montage_kind)
    mne.datasets.eegbci.standardize(raw)
    raw.set_montage(montage)

    # Extract some info
    eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False)
    ch_names = raw.info["ch_names"]
    ch_names_eeg = list(np.asarray(ch_names)[eeg_index])
    sample_rate = raw.info["sfreq"]

    # Make a copy of the data
    raw_copy = raw.copy()
    raw_copy.load_data()

    # Trim duplicated data (only needed for sub-005)
    annot = raw_copy.annotations
    file_starts = [a for a in annot if a['description'] == "file start"]
    if len(file_starts):
        duplicate_start = file_starts[0]['onset']
        raw_copy.crop(tmax=duplicate_start)

    # Make backup of EOG and EMG channels to re-append after PREP
    raw_other = raw_copy.copy()
    raw_other.pick_types(eog=True, emg=True, stim=False)

    # Prepare copy of raw data for PREP
    raw_copy.pick_types(eeg=True)

    # Plot data prior to any processing
    if complete:
        save_psd_plot(id_num, "psd_0_raw", plot_path, raw_copy)
        save_channel_plot(id_num, "ch_0_raw", plot_path, raw_copy)

    ### Clean up events #######################################################

    print("\n\n=== Processing Event Annotations... ===\n")

    event_names = [
        "stim_on", "red_on", "trace_start", "trace_end", "accuracy_submit",
        "vividness_submit"
    ]
    doubled = []
    wrong_label = []
    new_onsets = []
    new_durations = []
    new_descriptions = []

    # Find and flag any duplicate triggers
    annot = raw_copy.annotations
    trigger_count = len(annot)
    for i in range(1, trigger_count - 1):
        a = annot[i]
        on_last = i + 1 == trigger_count
        prev_trigger = annot[i - 1]['description']
        next_onset = annot[i + 1]['onset'] if not on_last else a['onset'] + 100
        # Determine whether duplicates are doubles or mislabeled
        if a['description'] == prev_trigger:
            if (next_onset - a['onset']) < 0.002:
                doubled.append(a)
            else:
                wrong_label.append(a)

    # Rename annotations to have meaningful names & fix duplicates
    for a in raw_copy.annotations:
        if a in doubled or a['description'] not in event_names:
            continue
        if a in wrong_label:
            index = event_names.index(a['description'])
            a['description'] = event_names[index + 1]
        new_onsets.append(a['onset'])
        new_durations.append(a['duration'])
        new_descriptions.append(a['description'])

    # Replace old annotations with new fixed ones
    if len(annot):
        new_annot = mne.Annotations(
            new_onsets,
            new_durations,
            new_descriptions,
            orig_time=raw_copy.annotations[0]['orig_time'])
        raw_copy.set_annotations(new_annot)

    # Check annotations to verify we have equal numbers of each
    orig_counts = Counter(annot.description)
    counts = Counter(raw_copy.annotations.description)
    print("Updated Annotation Counts:")
    for a in event_names:
        out = " - '{0}': {1} -> {2}"
        print(out.format(a, orig_counts[a], counts[a]))

    # Get info
    id_info['annot_doubled'] = len(doubled)
    id_info['annot_wrong'] = len(wrong_label)

    count_vals = [
        n for n in counts.values() if n != counts['vividness_submit']
    ]
    id_info['equal_triggers'] = all(x == count_vals[0] for x in count_vals)
    id_info['stim_on'] = counts['stim_on']
    id_info['red_on'] = counts['red_on']
    id_info['trace_start'] = counts['trace_start']
    id_info['trace_end'] = counts['trace_end']
    id_info['acc_submit'] = counts['accuracy_submit']
    id_info['vivid_submit'] = counts['vividness_submit']

    if not complete:
        remaining_info = {
            'initial_bad': "NA",
            'num_initial_bad': "NA",
            'interpolated': "NA",
            'num_interpolated': "NA",
            'remaining_bad': "NA",
            'num_remaining_bad': "NA"
        }
        id_info.update(remaining_info)
        e = "\n\n### Incomplete recording for sub-{0}, skipping... ###\n\n"
        print(e.format(id_num))
        return id_info

    ### Run components of PREP manually #######################################

    print("\n\n=== Performing CleanLine... ===")

    # Try to remove line noise using CleanLine approach
    linenoise = np.arange(60, sample_rate / 2, 60)
    EEG_raw = raw_copy.get_data() * 1e6
    EEG_new = removeTrend(EEG_raw, sample_rate=raw.info["sfreq"])
    EEG_clean = mne.filter.notch_filter(
        EEG_new,
        Fs=raw.info["sfreq"],
        freqs=linenoise,
        filter_length="10s",
        method="spectrum_fit",
        mt_bandwidth=2,
        p_value=0.01,
    )
    EEG_final = EEG_raw - EEG_new + EEG_clean
    raw_copy._data = EEG_final * 1e-6
    del linenoise, EEG_raw, EEG_new, EEG_clean, EEG_final

    # Plot data following cleanline
    save_psd_plot(id_num, "psd_1_cleanline", plot_path, raw_copy)
    save_channel_plot(id_num, "ch_1_cleanline", plot_path, raw_copy)

    # Perform robust re-referencing
    prep_params = {"ref_chs": ch_names_eeg, "reref_chs": ch_names_eeg}
    reference = Reference(raw_copy,
                          prep_params,
                          ransac=True,
                          random_state=random_seed)
    print("\n\n=== Performing Robust Re-referencing... ===\n")
    reference.perform_reference()

    # If not interpolating bad channels, use pre-interpolation channel data
    if not interpolate_bads:
        reference.raw._data = reference.EEG_before_interpolation * 1e-6
        reference.interpolated_channels = []
        reference.still_noisy_channels = reference.bad_before_interpolation
        reference.raw.info["bads"] = reference.bad_before_interpolation

    # Plot data following robust re-reference
    save_psd_plot(id_num, "psd_2_reref", plot_path, reference.raw)
    save_channel_plot(id_num, "ch_2_reref", plot_path, reference.raw)

    # Re-append removed EMG/EOG/trigger channels
    raw_prepped = reference.raw.add_channels([raw_other])

    # Get info
    initial_bad = reference.noisy_channels_original["bad_all"]
    id_info['initial_bad'] = " ".join(initial_bad)
    id_info['num_initial_bad'] = len(initial_bad)

    interpolated = reference.interpolated_channels
    id_info['interpolated'] = " ".join(interpolated)
    id_info['num_interpolated'] = len(interpolated)

    remaining_bad = reference.still_noisy_channels
    id_info['remaining_bad'] = " ".join(remaining_bad)
    id_info['num_remaining_bad'] = len(remaining_bad)

    # Print re-referencing info
    print("\nRe-Referencing Info:")
    print(" - Bad channels original: {0}".format(initial_bad))
    if interpolate_bads:
        print(" - Bad channels after re-referencing: {0}".format(interpolated))
        print(" - Bad channels after interpolation: {0}".format(remaining_bad))
    else:
        print(
            " - Bad channels after re-referencing: {0}".format(remaining_bad))

    # Check if too many channels were interpolated for the participant
    prop_interpolated = len(
        reference.interpolated_channels) / len(ch_names_eeg)
    e = "### NOTE: Too many interpolated channels for sub-{0} ({1}) ###"
    if max_interpolated < prop_interpolated:
        print("\n")
        print(e.format(id_num, len(reference.interpolated_channels)))
        print("\n")

    ### Filter data and apply ICA to remove blinks ############################

    # Apply highpass & lowpass filters
    print("\n\n=== Applying Highpass & Lowpass Filters... ===")
    raw_prepped.filter(1.0, 50.0, fir_design='firwin')

    # Plot data following frequency filters
    save_psd_plot(id_num, "psd_3_filtered", plot_path, raw_prepped)
    save_channel_plot(id_num, "ch_3_filtered", plot_path, raw_prepped)

    # Perform ICA using EOG data on eye blinks
    print("\n\n=== Removing Blinks Using ICA... ===\n")
    ica = ICA(n_components=20, random_state=random_seed, method='picard')
    ica.fit(raw_prepped, decim=5)
    eog_indices, eog_scores = ica.find_bads_eog(raw_prepped)
    ica.exclude = eog_indices

    if not len(ica.exclude):
        err = " - Encountered an ICA error for sub-{0}, skipping for now..."
        print("\n")
        print(err.format(id_num))
        print("\n")
        save_bad_fif(raw_prepped, id_num, ica_err_dir)
        return id_info

    # Plot ICA info & diagnostics before removing from signal
    save_ica_plots(id_num, plot_path, raw_prepped, ica, eog_scores)

    # Remove eye blink independent components based on ICA
    ica.apply(raw_prepped)

    # Plot data following ICA
    save_psd_plot(id_num, "psd_4_ica", plot_path, raw_prepped)
    save_channel_plot(id_num, "ch_4_ica", plot_path, raw_prepped)

    ### Compute Current Source Density (CSD) estimates ########################

    if perform_csd:
        print("\n")
        print("=== Computing Current Source Density (CSD) Estimates... ===\n")
        raw_prepped = mne.preprocessing.compute_current_source_density(
            raw_prepped.drop_channels(remaining_bad))

        # Plot data following CSD
        save_psd_plot(id_num, "psd_5_csd", plot_path, raw_prepped)
        save_channel_plot(id_num, "ch_5_csd", plot_path, raw_prepped)

    ### Write preprocessed data to new EDF ####################################

    if max_interpolated < prop_interpolated:
        if not os.path.isdir(noisy_bad_dir):
            os.makedirs(noisy_bad_dir)
        outpath = os.path.join(noisy_bad_dir, outfile_fmt.format(id_num))
    else:
        outpath = os.path.join(outdir, outfile_fmt.format(id_num))
    write_mne_edf(outpath, raw_prepped)

    print("\n\n### sub-{0} complete! ###\n\n".format(id_num))

    return id_info
Exemplo n.º 11
0
    def robust_reference(self, max_iterations=4):
        """Detect bad channels and estimate the robust reference signal.

        This function implements the functionality of the `robustReference` function
        as part of the PREP pipeline on mne raw object.

        Parameters
        ----------
        max_iterations : int, optional
            The maximum number of iterations of noisy channel removal to perform
            during robust referencing. Defaults to ``4``.

        Returns
        -------
        noisy_channels: dict
            A dictionary of names of noisy channels detected from all methods
            after referencing.
        reference_signal: np.ndarray, shape(n, )
            Estimation of the 'true' signal mean

        """
        raw = self.raw.copy()
        raw._data = removeTrend(raw.get_data(),
                                self.sfreq,
                                matlab_strict=self.matlab_strict)

        # Determine unusable channels and remove them from the reference channels
        noisy_detector = NoisyChannels(
            raw,
            do_detrend=False,
            random_state=self.random_state,
            matlab_strict=self.matlab_strict,
        )
        noisy_detector.find_all_bads(**self.ransac_settings)
        self.noisy_channels_original = noisy_detector.get_bads(as_dict=True)
        self._extra_info["initial_bad"] = noisy_detector._extra_info
        logger.info("Bad channels: {}".format(self.noisy_channels_original))

        # Determine channels to use/exclude from initial reference estimation
        self.unusable_channels = _union(
            noisy_detector.bad_by_nan + noisy_detector.bad_by_flat,
            noisy_detector.bad_by_SNR,
        )
        reference_channels = _set_diff(self.reference_channels,
                                       self.unusable_channels)

        # Initialize channels to permanently flag as bad during referencing
        noisy = {
            "bad_by_nan": noisy_detector.bad_by_nan,
            "bad_by_flat": noisy_detector.bad_by_flat,
            "bad_by_deviation": [],
            "bad_by_hf_noise": [],
            "bad_by_correlation": [],
            "bad_by_SNR": [],
            "bad_by_dropout": [],
            "bad_by_ransac": [],
            "bad_all": [],
        }

        # Get initial estimate of the reference by the specified method
        signal = raw.get_data()
        self.reference_signal = np.nanmedian(
            raw.get_data(picks=reference_channels), axis=0)
        reference_index = [
            self.ch_names_eeg.index(ch) for ch in reference_channels
        ]
        signal_tmp = self.remove_reference(signal, self.reference_signal,
                                           reference_index)

        # Remove reference from signal, iteratively interpolating bad channels
        raw_tmp = raw.copy()
        iterations = 0
        previous_bads = set()

        while True:
            raw_tmp._data = signal_tmp
            noisy_detector = NoisyChannels(
                raw_tmp,
                do_detrend=False,
                random_state=self.random_state,
                matlab_strict=self.matlab_strict,
            )
            # Detrend applied at the beginning of the function.

            # Detect all currently bad channels
            noisy_detector.find_all_bads(**self.ransac_settings)
            noisy_new = noisy_detector.get_bads(as_dict=True)

            # Specify bad channel types to ignore when updating noisy channels
            # NOTE: MATLAB PREP ignores dropout channels, possibly by mistake?
            # see: https://github.com/VisLab/EEG-Clean-Tools/issues/28
            ignore = ["bad_by_SNR", "bad_all"]
            if self.matlab_strict:
                ignore += ["bad_by_dropout"]

            # Update set of all noisy channels detected so far with any new ones
            bad_chans = set()
            for bad_type in noisy_new.keys():
                noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
                if bad_type not in ignore:
                    bad_chans.update(noisy[bad_type])
            noisy["bad_all"] = list(bad_chans)
            logger.info("Bad channels: {}".format(noisy))

            if (iterations > 1 and
                (len(bad_chans) == 0 or bad_chans == previous_bads)
                    or iterations > max_iterations):
                logger.info("Robust reference done")
                self.noisy_channels = noisy
                break
            previous_bads = bad_chans.copy()

            if raw_tmp.info["nchan"] - len(bad_chans) < 2:
                raise ValueError(
                    "RobustReference:TooManyBad "
                    "Could not perform a robust reference -- not enough good channels"
                )

            if len(bad_chans) > 0:
                raw_tmp._data = signal.copy()
                raw_tmp.info["bads"] = list(bad_chans)
                if self.matlab_strict:
                    _eeglab_interpolate_bads(raw_tmp)
                else:
                    raw_tmp.interpolate_bads()

            self.reference_signal = np.nanmean(
                raw_tmp.get_data(picks=reference_channels), axis=0)

            signal_tmp = self.remove_reference(signal, self.reference_signal,
                                               reference_index)
            iterations = iterations + 1
            logger.info("Iterations: {}".format(iterations))

        return self.noisy_channels, self.reference_signal