Example #1
0
    def fit(self):
        """Run the whole PREP pipeline."""
        noisy_detector = NoisyChannels(self.raw_eeg, random_state=self.random_state)
        noisy_detector.find_bad_by_nan_flat()
        # unusable_channels = _union(
        #     noisy_detector.bad_by_nan, noisy_detector.bad_by_flat
        # )
        # reference_channels = _set_diff(self.prep_params["ref_chs"], unusable_channels)
        # Step 1: 1Hz high pass filtering
        if len(self.prep_params["line_freqs"]) != 0:
            self.EEG_new = removeTrend(self.EEG_raw, sample_rate=self.sfreq)

            # Step 2: Removing line noise
            linenoise = self.prep_params["line_freqs"]
            if self.filter_kwargs is None:
                self.EEG_clean = mne.filter.notch_filter(
                    self.EEG_new,
                    Fs=self.sfreq,
                    freqs=linenoise,
                    method="spectrum_fit",
                    mt_bandwidth=2,
                    p_value=0.01,
                    filter_length="10s",
                )
            else:
                self.EEG_clean = mne.filter.notch_filter(
                    self.EEG_new,
                    Fs=self.sfreq,
                    freqs=linenoise,
                    **self.filter_kwargs,
                )

            # Add Trend back
            self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean
            self.raw_eeg._data = self.EEG * 1e-6

        # Step 3: Referencing
        reference = Reference(
            self.raw_eeg,
            self.prep_params,
            ransac=self.ransac,
            random_state=self.random_state,
        )
        reference.perform_reference()
        self.raw_eeg = reference.raw
        self.noisy_channels_original = reference.noisy_channels_original
        self.noisy_channels_before_interpolation = (
            reference.noisy_channels_before_interpolation
        )
        self.noisy_channels_after_interpolation = (
            reference.noisy_channels_after_interpolation
        )
        self.bad_before_interpolation = reference.bad_before_interpolation
        self.EEG_before_interpolation = reference.EEG_before_interpolation
        self.reference_before_interpolation = reference.reference_signal
        self.reference_after_interpolation = reference.reference_signal_new
        self.interpolated_channels = reference.interpolated_channels
        self.still_noisy_channels = reference.still_noisy_channels

        return self
Example #2
0
def test_clean_input(raw_clean):
    """Test robust referencing with a clean input signal."""
    ch_names = raw_clean.info["ch_names"]
    params = {"ref_chs": ch_names, "reref_chs": ch_names}

    # Here we monkey-patch Reference to skip bad channel detection, ensuring
    # a run with all clean channels is tested
    with mock.patch("pyprep.NoisyChannels.find_all_bads", return_value=True):
        reference = Reference(raw_clean, params, ransac=False)
        reference.robust_reference()

    assert len(reference.unusable_channels) == 0
    assert len(reference.noisy_channels_original["bad_all"]) == 0
    assert len(reference.noisy_channels["bad_all"]) == 0
Example #3
0
def test_basic_input(raw):
    """Test Reference output data type."""
    ch_names = raw.info["ch_names"]

    raw_tmp = raw.copy()
    params = {"ref_chs": ch_names, "reref_chs": ch_names}
    reference = Reference(raw_tmp, params, ransac=False)
    reference.perform_reference()
    assert type(reference.noisy_channels) == dict
    assert type(reference.noisy_channels_original) == dict
    assert type(reference.bad_before_interpolation) == list
    assert type(reference.reference_signal) == np.ndarray
    assert type(reference.interpolated_channels) == list
    assert type(reference.still_noisy_channels) == list
    assert type(reference.raw) == mne.io.edf.edf.RawEDF
Example #4
0
def test_all_bad_input(raw_clean):
    """Test robust reference when all reference channels are bad."""
    ch_names = raw_clean.info["ch_names"]
    params = {"ref_chs": ch_names, "reref_chs": ch_names}

    # Define a mock function to make all channels bad by deviation
    def _bad_by_dev(self):
        self.bad_by_deviation = self.ch_names_original.tolist()

    # Here we monkey-patch Reference to make all channels bad by deviation, allowing
    # us to test the 'too-few-good-channels' exception
    with mock.patch("pyprep.NoisyChannels.find_bad_by_deviation",
                    new=_bad_by_dev):
        reference = Reference(raw_clean, params, ransac=False)
        with pytest.raises(ValueError):
            reference.robust_reference()
Example #5
0
def test_basic_input(raw, montage):
    """Test Reference output data type."""
    ch_names = raw.info["ch_names"]

    raw_tmp = raw.copy()
    raw_tmp.set_montage(montage)
    params = {"ref_chs": ch_names, "reref_chs": ch_names}
    reference = Reference(raw_tmp, params, ransac=False)
    reference.perform_reference()
    assert type(reference.noisy_channels) == dict
    assert type(reference.noisy_channels_original) == dict
    assert type(reference.bad_before_interpolation) == list
    assert type(reference.reference_signal) == np.ndarray
    assert type(reference.interpolated_channels) == list
    assert type(reference.still_noisy_channels) == list
    assert type(reference.raw) == mne.io.edf.edf.RawEDF

    # Make sure the set of reference channels weren't modified by re-referencing
    assert params["ref_chs"] == reference.reference_channels
Example #6
0
def test_remove_reference():
    """Test removing the reference."""
    signal = np.array([[1, 2, 3, 4], [0, 1, 2, 3], [3, 4, 5, 6]])
    reference = np.array([1, 1, 2, 2])
    with pytest.raises(ValueError):
        Reference.remove_reference(reference, reference)
    with pytest.raises(ValueError):
        Reference.remove_reference(signal, signal)
    with pytest.raises(ValueError):
        Reference.remove_reference(signal, reference[0:3])
    with pytest.raises(TypeError):
        Reference.remove_reference(signal, reference, np.array([1, 2]))
    assert np.array_equal(
        Reference.remove_reference(signal, reference, [1, 2]),
        np.array([[1, 2, 3, 4], [-1, 0, 0, 1], [2, 3, 3, 4]]),
    )
Example #7
0
def test_all_bad_input(raw):
    """Test robust reference when all reference channels are bad."""
    ch_names = raw.info["ch_names"]

    raw_tmp = raw.copy()
    m, n = raw_tmp.get_data().shape

    # Randomly set some channels as bad
    [nan_chn_idx, flat_chn_idx] = random.sample(set(np.arange(0, m)), 2)

    # Insert a nan value for a random channel
    # nan_chn_lab = raw_tmp.ch_names[nan_chn_idx]
    raw_tmp._data[nan_chn_idx, n - 1] = np.nan

    # Insert one random flat channel
    # flat_chn_lab = raw_tmp.ch_names[flat_chn_idx]
    raw_tmp._data[flat_chn_idx, :] = np.ones_like(raw_tmp._data[1, :]) * 1e-6

    reference_channels = [ch_names[nan_chn_idx], ch_names[flat_chn_idx]]
    params = {"ref_chs": reference_channels, "reref_chs": reference_channels}
    reference = Reference(raw_tmp, params, ransac=False)
    with pytest.raises(ValueError):
        reference.robust_reference()
def pyprep_reference(matprep_artifacts):
    """Get the robust re-referenced signal for comparison with MATLAB PREP.

    This fixture uses an artifact from MATLAB PREP of the CleanLined EEG signal
    right before MATLAB PREP calls ``performReference``. As such, the results
    of these tests will not be affected by any differences in the CleanLine
    implementations of MATLAB PREP and PyPREP.

    """
    # Import post-CleanLine MATLAB PREP data
    setfile_path = matprep_artifacts["3_matprep_cleanline"]
    matprep_set = mne.io.read_raw_eeglab(setfile_path, preload=True)
    ch_names = matprep_set.info["ch_names"]

    # Run robust referencing on MATLAB data and extract internal noisy info
    matprep_seed = 435656
    params = {"ref_chs": ch_names, "reref_chs": ch_names}
    pyprep_reref = Reference(matprep_set,
                             params,
                             random_state=matprep_seed,
                             matlab_strict=True)
    pyprep_reref.perform_reference()

    return pyprep_reref
Example #9
0
def preprocess_eeg(id_num, random_seed=None):

    # Set important variables
    bids_path = BIDSPath(id_num, task=task, datatype=datatype, root=bids_root)
    plot_path = os.path.join(plotdir, "sub_{0}".format(id_num))
    if os.path.exists(plot_path):
        shutil.rmtree(plot_path)
    os.mkdir(plot_path)
    if not random_seed:
        random_seed = int(binascii.b2a_hex(os.urandom(4)), 16)
    random.seed(random_seed)
    id_info = {"id": id_num, "random_seed": random_seed}

    ### Load and prepare EEG data #############################################

    header = "### Processing sub-{0} (seed: {1}) ###".format(
        id_num, random_seed)
    print("\n" + "#" * len(header))
    print(header)
    print("#" * len(header) + "\n")

    # Load EEG data
    raw = read_raw_bids(bids_path, verbose=True)

    # Check if recording is complete
    complete = len(raw.annotations) >= 600

    # Add a montage to the data
    montage_kind = "standard_1005"
    montage = mne.channels.make_standard_montage(montage_kind)
    mne.datasets.eegbci.standardize(raw)
    raw.set_montage(montage)

    # Extract some info
    eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False)
    ch_names = raw.info["ch_names"]
    ch_names_eeg = list(np.asarray(ch_names)[eeg_index])
    sample_rate = raw.info["sfreq"]

    # Make a copy of the data
    raw_copy = raw.copy()
    raw_copy.load_data()

    # Trim duplicated data (only needed for sub-005)
    annot = raw_copy.annotations
    file_starts = [a for a in annot if a['description'] == "file start"]
    if len(file_starts):
        duplicate_start = file_starts[0]['onset']
        raw_copy.crop(tmax=duplicate_start)

    # Make backup of EOG and EMG channels to re-append after PREP
    raw_other = raw_copy.copy()
    raw_other.pick_types(eog=True, emg=True, stim=False)

    # Prepare copy of raw data for PREP
    raw_copy.pick_types(eeg=True)

    # Plot data prior to any processing
    if complete:
        save_psd_plot(id_num, "psd_0_raw", plot_path, raw_copy)
        save_channel_plot(id_num, "ch_0_raw", plot_path, raw_copy)

    ### Clean up events #######################################################

    print("\n\n=== Processing Event Annotations... ===\n")

    event_names = [
        "stim_on", "red_on", "trace_start", "trace_end", "accuracy_submit",
        "vividness_submit"
    ]
    doubled = []
    wrong_label = []
    new_onsets = []
    new_durations = []
    new_descriptions = []

    # Find and flag any duplicate triggers
    annot = raw_copy.annotations
    trigger_count = len(annot)
    for i in range(1, trigger_count - 1):
        a = annot[i]
        on_last = i + 1 == trigger_count
        prev_trigger = annot[i - 1]['description']
        next_onset = annot[i + 1]['onset'] if not on_last else a['onset'] + 100
        # Determine whether duplicates are doubles or mislabeled
        if a['description'] == prev_trigger:
            if (next_onset - a['onset']) < 0.002:
                doubled.append(a)
            else:
                wrong_label.append(a)

    # Rename annotations to have meaningful names & fix duplicates
    for a in raw_copy.annotations:
        if a in doubled or a['description'] not in event_names:
            continue
        if a in wrong_label:
            index = event_names.index(a['description'])
            a['description'] = event_names[index + 1]
        new_onsets.append(a['onset'])
        new_durations.append(a['duration'])
        new_descriptions.append(a['description'])

    # Replace old annotations with new fixed ones
    if len(annot):
        new_annot = mne.Annotations(
            new_onsets,
            new_durations,
            new_descriptions,
            orig_time=raw_copy.annotations[0]['orig_time'])
        raw_copy.set_annotations(new_annot)

    # Check annotations to verify we have equal numbers of each
    orig_counts = Counter(annot.description)
    counts = Counter(raw_copy.annotations.description)
    print("Updated Annotation Counts:")
    for a in event_names:
        out = " - '{0}': {1} -> {2}"
        print(out.format(a, orig_counts[a], counts[a]))

    # Get info
    id_info['annot_doubled'] = len(doubled)
    id_info['annot_wrong'] = len(wrong_label)

    count_vals = [
        n for n in counts.values() if n != counts['vividness_submit']
    ]
    id_info['equal_triggers'] = all(x == count_vals[0] for x in count_vals)
    id_info['stim_on'] = counts['stim_on']
    id_info['red_on'] = counts['red_on']
    id_info['trace_start'] = counts['trace_start']
    id_info['trace_end'] = counts['trace_end']
    id_info['acc_submit'] = counts['accuracy_submit']
    id_info['vivid_submit'] = counts['vividness_submit']

    if not complete:
        remaining_info = {
            'initial_bad': "NA",
            'num_initial_bad': "NA",
            'interpolated': "NA",
            'num_interpolated': "NA",
            'remaining_bad': "NA",
            'num_remaining_bad': "NA"
        }
        id_info.update(remaining_info)
        e = "\n\n### Incomplete recording for sub-{0}, skipping... ###\n\n"
        print(e.format(id_num))
        return id_info

    ### Run components of PREP manually #######################################

    print("\n\n=== Performing CleanLine... ===")

    # Try to remove line noise using CleanLine approach
    linenoise = np.arange(60, sample_rate / 2, 60)
    EEG_raw = raw_copy.get_data() * 1e6
    EEG_new = removeTrend(EEG_raw, sample_rate=raw.info["sfreq"])
    EEG_clean = mne.filter.notch_filter(
        EEG_new,
        Fs=raw.info["sfreq"],
        freqs=linenoise,
        filter_length="10s",
        method="spectrum_fit",
        mt_bandwidth=2,
        p_value=0.01,
    )
    EEG_final = EEG_raw - EEG_new + EEG_clean
    raw_copy._data = EEG_final * 1e-6
    del linenoise, EEG_raw, EEG_new, EEG_clean, EEG_final

    # Plot data following cleanline
    save_psd_plot(id_num, "psd_1_cleanline", plot_path, raw_copy)
    save_channel_plot(id_num, "ch_1_cleanline", plot_path, raw_copy)

    # Perform robust re-referencing
    prep_params = {"ref_chs": ch_names_eeg, "reref_chs": ch_names_eeg}
    reference = Reference(raw_copy,
                          prep_params,
                          ransac=True,
                          random_state=random_seed)
    print("\n\n=== Performing Robust Re-referencing... ===\n")
    reference.perform_reference()

    # If not interpolating bad channels, use pre-interpolation channel data
    if not interpolate_bads:
        reference.raw._data = reference.EEG_before_interpolation * 1e-6
        reference.interpolated_channels = []
        reference.still_noisy_channels = reference.bad_before_interpolation
        reference.raw.info["bads"] = reference.bad_before_interpolation

    # Plot data following robust re-reference
    save_psd_plot(id_num, "psd_2_reref", plot_path, reference.raw)
    save_channel_plot(id_num, "ch_2_reref", plot_path, reference.raw)

    # Re-append removed EMG/EOG/trigger channels
    raw_prepped = reference.raw.add_channels([raw_other])

    # Get info
    initial_bad = reference.noisy_channels_original["bad_all"]
    id_info['initial_bad'] = " ".join(initial_bad)
    id_info['num_initial_bad'] = len(initial_bad)

    interpolated = reference.interpolated_channels
    id_info['interpolated'] = " ".join(interpolated)
    id_info['num_interpolated'] = len(interpolated)

    remaining_bad = reference.still_noisy_channels
    id_info['remaining_bad'] = " ".join(remaining_bad)
    id_info['num_remaining_bad'] = len(remaining_bad)

    # Print re-referencing info
    print("\nRe-Referencing Info:")
    print(" - Bad channels original: {0}".format(initial_bad))
    if interpolate_bads:
        print(" - Bad channels after re-referencing: {0}".format(interpolated))
        print(" - Bad channels after interpolation: {0}".format(remaining_bad))
    else:
        print(
            " - Bad channels after re-referencing: {0}".format(remaining_bad))

    # Check if too many channels were interpolated for the participant
    prop_interpolated = len(
        reference.interpolated_channels) / len(ch_names_eeg)
    e = "### NOTE: Too many interpolated channels for sub-{0} ({1}) ###"
    if max_interpolated < prop_interpolated:
        print("\n")
        print(e.format(id_num, len(reference.interpolated_channels)))
        print("\n")

    ### Filter data and apply ICA to remove blinks ############################

    # Apply highpass & lowpass filters
    print("\n\n=== Applying Highpass & Lowpass Filters... ===")
    raw_prepped.filter(1.0, 50.0, fir_design='firwin')

    # Plot data following frequency filters
    save_psd_plot(id_num, "psd_3_filtered", plot_path, raw_prepped)
    save_channel_plot(id_num, "ch_3_filtered", plot_path, raw_prepped)

    # Perform ICA using EOG data on eye blinks
    print("\n\n=== Removing Blinks Using ICA... ===\n")
    ica = ICA(n_components=20, random_state=random_seed, method='picard')
    ica.fit(raw_prepped, decim=5)
    eog_indices, eog_scores = ica.find_bads_eog(raw_prepped)
    ica.exclude = eog_indices

    if not len(ica.exclude):
        err = " - Encountered an ICA error for sub-{0}, skipping for now..."
        print("\n")
        print(err.format(id_num))
        print("\n")
        save_bad_fif(raw_prepped, id_num, ica_err_dir)
        return id_info

    # Plot ICA info & diagnostics before removing from signal
    save_ica_plots(id_num, plot_path, raw_prepped, ica, eog_scores)

    # Remove eye blink independent components based on ICA
    ica.apply(raw_prepped)

    # Plot data following ICA
    save_psd_plot(id_num, "psd_4_ica", plot_path, raw_prepped)
    save_channel_plot(id_num, "ch_4_ica", plot_path, raw_prepped)

    ### Compute Current Source Density (CSD) estimates ########################

    if perform_csd:
        print("\n")
        print("=== Computing Current Source Density (CSD) Estimates... ===\n")
        raw_prepped = mne.preprocessing.compute_current_source_density(
            raw_prepped.drop_channels(remaining_bad))

        # Plot data following CSD
        save_psd_plot(id_num, "psd_5_csd", plot_path, raw_prepped)
        save_channel_plot(id_num, "ch_5_csd", plot_path, raw_prepped)

    ### Write preprocessed data to new EDF ####################################

    if max_interpolated < prop_interpolated:
        if not os.path.isdir(noisy_bad_dir):
            os.makedirs(noisy_bad_dir)
        outpath = os.path.join(noisy_bad_dir, outfile_fmt.format(id_num))
    else:
        outpath = os.path.join(outdir, outfile_fmt.format(id_num))
    write_mne_edf(outpath, raw_prepped)

    print("\n\n### sub-{0} complete! ###\n\n".format(id_num))

    return id_info