コード例 #1
0
    def fit(self):
        """Run the whole PREP pipeline."""
        noisy_detector = NoisyChannels(self.raw_eeg, random_state=self.random_state)
        noisy_detector.find_bad_by_nan_flat()
        # unusable_channels = _union(
        #     noisy_detector.bad_by_nan, noisy_detector.bad_by_flat
        # )
        # reference_channels = _set_diff(self.prep_params["ref_chs"], unusable_channels)
        # Step 1: 1Hz high pass filtering
        if len(self.prep_params["line_freqs"]) != 0:
            self.EEG_new = removeTrend(self.EEG_raw, sample_rate=self.sfreq)

            # Step 2: Removing line noise
            linenoise = self.prep_params["line_freqs"]
            if self.filter_kwargs is None:
                self.EEG_clean = mne.filter.notch_filter(
                    self.EEG_new,
                    Fs=self.sfreq,
                    freqs=linenoise,
                    method="spectrum_fit",
                    mt_bandwidth=2,
                    p_value=0.01,
                    filter_length="10s",
                )
            else:
                self.EEG_clean = mne.filter.notch_filter(
                    self.EEG_new,
                    Fs=self.sfreq,
                    freqs=linenoise,
                    **self.filter_kwargs,
                )

            # Add Trend back
            self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean
            self.raw_eeg._data = self.EEG * 1e-6

        # Step 3: Referencing
        reference = Reference(
            self.raw_eeg,
            self.prep_params,
            ransac=self.ransac,
            random_state=self.random_state,
        )
        reference.perform_reference()
        self.raw_eeg = reference.raw
        self.noisy_channels_original = reference.noisy_channels_original
        self.noisy_channels_before_interpolation = (
            reference.noisy_channels_before_interpolation
        )
        self.noisy_channels_after_interpolation = (
            reference.noisy_channels_after_interpolation
        )
        self.bad_before_interpolation = reference.bad_before_interpolation
        self.EEG_before_interpolation = reference.EEG_before_interpolation
        self.reference_before_interpolation = reference.reference_signal
        self.reference_after_interpolation = reference.reference_signal_new
        self.interpolated_channels = reference.interpolated_channels
        self.still_noisy_channels = reference.still_noisy_channels

        return self
コード例 #2
0
def test_bad_by_nan(raw_tmp):
    """Test the detection of channels containing any NaN values."""
    # Insert a NaN value into a random channel
    n_chans = raw_tmp._data.shape[0]
    nan_idx = int(RNG.randint(0, n_chans, 1))
    raw_tmp._data[nan_idx, 3] = np.nan

    # Test automatic detection of NaN channels on NoisyChannels init
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    assert nd.bad_by_nan == [raw_tmp.ch_names[nan_idx]]

    # Test manual re-running of NaN channel detection
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_nan == [raw_tmp.ch_names[nan_idx]]
コード例 #3
0
def test_bad_by_flat(raw_tmp):
    """Test the detection of channels with flat or very weak signals."""
    # Make the signal for a random channel extremely weak
    n_chans = raw_tmp._data.shape[0]
    flat_idx = int(RNG.randint(0, n_chans, 1))
    raw_tmp._data[flat_idx, :] = raw_tmp._data[flat_idx, :] * 1e-12

    # Test automatic detection of flat channels on NoisyChannels init
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    assert nd.bad_by_flat == [raw_tmp.ch_names[flat_idx]]

    # Test manual re-running of flat channel detection
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_flat == [raw_tmp.ch_names[flat_idx]]

    # Test detection when channel is completely flat
    raw_tmp._data[flat_idx, :] = 0
    nd = NoisyChannels(raw_tmp, do_detrend=False)
    assert nd.bad_by_flat == [raw_tmp.ch_names[flat_idx]]
コード例 #4
0
def test_findnoisychannels(raw, montage):
    raw.set_montage(montage)
    nd = NoisyChannels(raw)
    nd.find_all_bads(ransac=True)
    bads = nd.get_bads()
    iterations = (
        10  # remove any noisy channels by interpolating the bads for 10 iterations
    )
    for iter in range(0, iterations):
        raw.info["bads"] = bads
        raw.interpolate_bads()
        nd = NoisyChannels(raw)
        nd.find_all_bads(ransac=True)
        bads = nd.get_bads()

    # make sure no bad channels exist in the data
    raw.drop_channels(ch_names=bads)

    # Test for NaN and flat channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Insert a nan value for a random channel
    rand_chn_idx1 = int(np.random.randint(0, m, 1))
    rand_chn_idx2 = int(np.random.randint(0, m, 1))
    rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1]
    rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2]
    raw_tmp._data[rand_chn_idx1, n - 1] = np.nan
    raw_tmp._data[rand_chn_idx2, :] = np.ones(n)
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_nan == [rand_chn_lab1]
    assert nd.bad_by_flat == [rand_chn_lab2]

    # Test for high and low deviations in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Now insert one random channel with very low deviations
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation
    # Inserting one random channel with a high deviation
    raw_tmp = raw.copy()
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    arbitrary_scaling = 5
    raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation

    # Test for correlation between EEG channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use cosine instead of sine to create a signal
    low = 10
    high = 30
    n_freq = 5
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = np.random.randint(low, high, n)
        signal[0, :] += np.cos(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_correlation()
    assert rand_chn_lab in nd.bad_by_correlation

    # Test for high freq noise detection
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use freqs between 90 and 100 Hz to insert hf noise
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = np.random.randint(90, 100, n)
        signal[0, :] += np.sin(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_hfnoise()
    assert rand_chn_lab in nd.bad_by_hf_noise

    # Test for signal to noise ratio in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(np.random.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # inserting an uncorrelated high frequency (90 Hz) signal in one channel
    raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6
    nd = NoisyChannels(raw_tmp)
    nd.find_bad_by_SNR()
    assert rand_chn_lab in nd.bad_by_SNR

    # Test for finding bad channels by RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp)
    np.random.seed(30)
    nd.find_bad_by_ransac()
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]
コード例 #5
0
def test_findnoisychannels(raw, montage):
    """Test find noisy channels."""
    # Set a random state for the test
    rng = np.random.RandomState(30)

    raw.set_montage(montage)
    nd = NoisyChannels(raw, random_state=rng)
    nd.find_all_bads(ransac=True)
    bads = nd.get_bads()
    iterations = (
        10  # remove any noisy channels by interpolating the bads for 10 iterations
    )
    for iter in range(0, iterations):
        if len(bads) == 0:
            continue
        raw.info["bads"] = bads
        raw.interpolate_bads()
        nd = NoisyChannels(raw, random_state=rng)
        nd.find_all_bads(ransac=True)
        bads = nd.get_bads()

    # make sure no bad channels exist in the data
    raw.drop_channels(ch_names=bads)

    # Test for NaN and flat channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Insert a nan value for a random channel and make another random channel
    # completely flat (ones)
    idxs = rng.choice(np.arange(m), size=2, replace=False)
    rand_chn_idx1 = idxs[0]
    rand_chn_idx2 = idxs[1]
    rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1]
    rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2]
    raw_tmp._data[rand_chn_idx1, n - 1] = np.nan
    raw_tmp._data[rand_chn_idx2, :] = np.ones(n)
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_nan_flat()
    assert nd.bad_by_nan == [rand_chn_lab1]
    assert nd.bad_by_flat == [rand_chn_lab2]

    # Test for high and low deviations in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    # Now insert one random channel with very low deviations
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation
    # Inserting one random channel with a high deviation
    raw_tmp = raw.copy()
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    arbitrary_scaling = 5
    raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_deviation()
    assert rand_chn_lab in nd.bad_by_deviation

    # Test for correlation between EEG channels
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use cosine instead of sine to create a signal
    low = 10
    high = 30
    n_freq = 5
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = rng.randint(low, high, n)
        signal[0, :] += np.cos(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_correlation()
    assert rand_chn_lab in nd.bad_by_correlation

    # Test for high freq noise detection
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # Use freqs between 90 and 100 Hz to insert hf noise
    signal = np.zeros((1, n))
    for freq_i in range(n_freq):
        freq = rng.randint(90, 100, n)
        signal[0, :] += np.sin(2 * np.pi * raw.times * freq)
    raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_hfnoise()
    assert rand_chn_lab in nd.bad_by_hf_noise

    # Test for signal to noise ratio in EEG data
    raw_tmp = raw.copy()
    m, n = raw_tmp._data.shape
    rand_chn_idx = int(rng.randint(0, m, 1))
    rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
    # inserting an uncorrelated high frequency (90 Hz) signal in one channel
    raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_SNR()
    assert rand_chn_lab in nd.bad_by_SNR

    # Test for finding bad channels by RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_ransac()
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]

    # Test for finding bad channels by channel-wise RANSAC
    raw_tmp = raw.copy()
    # Ransac identifies channels that go bad together and are highly correlated.
    # Inserting highly correlated signal in channels 0 through 3 at 30 Hz
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_bad_by_ransac(channel_wise=True)
    bads = nd.bad_by_ransac
    assert bads == raw_tmp.ch_names[0:6]

    # Test not-enough-memory and n_samples type exceptions
    raw_tmp = raw.copy()
    raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
    nd = NoisyChannels(raw_tmp, random_state=rng)

    # Set n_samples very very high to trigger a memory error
    n_samples = int(1e100)
    with pytest.raises(MemoryError):
        nd.find_bad_by_ransac(n_samples=n_samples)

    # Set n_samples to a float to trigger a type error
    n_samples = 35.5
    with pytest.raises(TypeError):
        nd.find_bad_by_ransac(n_samples=n_samples)

    # Test IOError when not enough channels for ransac predictions
    raw_tmp = raw.copy()
    # Make flat all channels except 2
    num_bad_channels = raw._data.shape[0] - 2
    raw_tmp._data[0:num_bad_channels, :] = np.zeros_like(
        raw_tmp._data[0:num_bad_channels, :]
    )
    nd = NoisyChannels(raw_tmp, random_state=rng)
    nd.find_all_bads(ransac=False)
    with pytest.raises(IOError):
        nd.find_bad_by_ransac()
コード例 #6
0
ファイル: rsrm_prep.py プロジェクト: VIXSoh/SRM
    def preproc(self, event_dict, baseline_start, stim_dur, montage, out_dir='preproc', subjects='all', tasks='all', 
                hp_cutoff=1, lp_cutoff=50, line_noise=60, seed=42, eog_chan='none'):
        '''Preprocesses a single EEG file. Assigns a list of epoched data to Dataset instance,
        where each entry in the list is a subject with concatenated task data. Here is the basic 
        structure of the preprocessing workflow:
        
            - Set the montage
            - Band-pass filter (high-pass filter by default)
            - Automatically detect bad channels
            - Notch filter out line-noise
            - Reference data to average of all EEG channels
            - Automated removal of eye-related artifacts using ICA
            - Spherical interpolation of detected bad channels
            - Extract events and epoch the data accordingly
            - Align the events based on type (still need to implement this!)
            - Create a list of epoched data, with subject as the element concatenated across tasks
        
        Parameters
        ----------
        event_dict: dict
            Maps integers to semantic labels for events within the experiment
            
        baseline_start: int or float
            Specify start of the baseline period (in seconds)
            
        stim_dur: int or float
            Stimulus duration (in seconds)
                Note: may need to make more general to allow various durations
                
        montage: mne.channels.montage.DigMontage
            Maps sensor locations to coordinates
            
        subjects: list or 'all'
            Specify which subjects to iterate through
            
        tasks: list or 'all'
            Specify which tasks to iterate through
            
        hp_cutoff: int or float
            The low frequency bound for the highpass filter in Hz
            
        line_noise: int or float
            The frequency of electrical noise to filter out in Hz
            
        seed: int
            Set the seed for replicable results
            
        eog_chan: str
            If there are no EOG channels present, select an EEG channel
            near the eyes for eye-related artifact detection
        '''

        missing = [] # initialize missing file list
        subj_iter = self.gen_iter(subjects, self.n_subj) # get subject iterator
        task_iter = self.gen_iter(tasks, self.n_task) # get task iterator

        # Iterate through subjects (initialize subject epoch list)
        epochs_subj = []
        for subj in subj_iter:

            # Iterate through tasks (initialize within-subject task epoch list)
            epochs_task = []
            for task in task_iter:
                # Specify the file format
                self.get_file_format(subj, task)

                try: # Handles missing files
                    raw = self.wget_raw_edf() # read
                except:
                    print(f'---\nThis file does not exist: {self.file_path}\n---')
                    # Need to write the missing file list out
                    missing.append(self.file_path)
                    break
                    
                # Standardize montage
                mne.datasets.eegbci.standardize(raw)
                # Set montage and strip channel names of "." characters
                raw.set_montage(montage)
                raw.rename_channels(lambda x: x.strip('.'))

                # Apply high-pass filter
                np.random.seed(seed)
                raw.filter(l_freq=hp_cutoff, h_freq=lp_cutoff, picks=['eeg', 'eog'])

                # Instantiate NoisyChannels object
                noise_chans = NoisyChannels(raw, do_detrend=False)

                # Detect bad channels through multiple methods
                noise_chans.find_bad_by_nan_flat()
                noise_chans.find_bad_by_deviation()
                noise_chans.find_bad_by_SNR()

                # Set the bad channels in the raw object
                raw.info['bads'] = noise_chans.get_bads()
                print(f'Bad channels detected: {noise_chans.get_bads()}')

                # Define the frequencies for the notch filter (60Hz and its harmonics)
                #notch_filt = np.arange(line_noise, raw.info['sfreq'] // 2, line_noise)

                # Apply notch filter
                #print(f'Apply notch filter at {line_noise} Hz and its harmonics')
                #raw.notch_filter(notch_filt, picks=['eeg', 'eog'], verbose=False)

                # Reference to the average of all the good channels 
                # Automatically excludes raw.info['bads']
                raw.set_eeg_reference(ref_channels='average')

                # Instantiate ICA object
                ica = mne.preprocessing.ICA(max_iter=1000)
                # Run ICA
                ica.fit(raw)

                # Find which ICs match the EOG pattern
                if eog_chan == 'none':
                    eog_indices, eog_scores = ica.find_bads_eog(raw, verbose=False)
                else:
                    eog_indices, eog_scores = ica.find_bads_eog(raw, eog_chan, verbose=False)

                # Apply the IC blink removals (if any)
                ica.apply(raw, exclude=eog_indices)
                print(f'Removed IC index {eog_indices}')

                # Interpolate bad channels
                raw.interpolate_bads()
                
                # Specify pre-processing directory
                preproc_dir = Path(f'{out_dir}/ {subj}')

                # If directory does not exist, one will be created.
                if not os.path.isdir(preproc_dir):
                    os.makedirs(preproc_dir)

               # raw.save(Path(preproc_dir, f'subj{subj}_task{task}_raw.fif'), 
                #         overwrite=True)

                # Find events
                events = mne.events_from_annotations(raw)[0]

                # Epoch the data
                preproc_epoch = mne.Epochs(raw, events, tmin=baseline_start, tmax=stim_dur, 
                                   event_id=event_dict, event_repeated='error', 
                                   on_missing='ignore', preload=True)
                
                # Equalize event counts
                preproc_epoch.equalize_event_counts(event_dict.keys())
                
                # Rearrange and align the epochs
                align = [preproc_epoch[i] for i in event_dict.keys()]
                align_epoch = mne.concatenate_epochs(align)
                
                # Add to epoch list
                epochs_task.append(align_epoch)

            # Assuming some data exists for a subject
            # Concatenate epochs within subject
            concat_epoch = mne.concatenate_epochs(epochs_task)
            epochs_subj.append(concat_epoch)
        # Attaches a list with each entry corresponding to epochs for a subject
        self.epoch_list = epochs_subj