def test_preprocessors_with_misc_channels(): rng = np.random.RandomState(42) signal_sfreq = 50 info = mne.create_info(ch_names=['0', '1', 'target_0', 'target_1'], sfreq=signal_sfreq, ch_types=['eeg', 'eeg', 'misc', 'misc']) signal = rng.randn(2, 1000) targets = rng.randn(2, 1000) raw = mne.io.RawArray(np.concatenate([signal, targets]), info=info) desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48}) base_dataset = BaseDataset(raw, desc, target_name=None) concat_ds = BaseConcatDataset([base_dataset]) preprocessors = [ Preprocessor('pick_types', eeg=True, misc=True), Preprocessor(lambda x: x / 1e6), ] preprocess(concat_ds, preprocessors) # Check whether preprocessing has not affected the targets # This is only valid for preprocessors that use mne functions which do not modify # `misc` channels. np.testing.assert_array_equal( concat_ds.datasets[0].raw.get_data()[-2:, :], targets )
def test_filterbank(base_concat_ds): base_concat_ds = base_concat_ds.split([[0]])['0'] preprocessors = [ Preprocessor('pick_channels', ch_names=sorted(['C4', 'Cz']), ordered=True), Preprocessor(filterbank, frequency_bands=[(0, 4), (4, 8), (8, 13)], drop_original_signals=False, apply_on_array=False) ] preprocess(base_concat_ds, preprocessors) for x, y in base_concat_ds: break assert x.shape[0] == 8 freq_band_annots = [ ch.split('_')[-1] for ch in base_concat_ds.datasets[0].raw.ch_names if '_' in ch] assert len(np.unique(freq_band_annots)) == 3 np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [ 'C4', 'C4_0-4', 'C4_4-8', 'C4_8-13', 'Cz', 'Cz_0-4', 'Cz_4-8', 'Cz_8-13', ]) assert all([ds.raw_preproc_kwargs == [ ('pick_channels', {'ch_names': ['C4', 'Cz'], 'ordered': True}), ('filterbank', {'frequency_bands': [(0, 4), (4, 8), (8, 13)], 'drop_original_signals': False}), ] for ds in base_concat_ds.datasets])
def test_scale_continuous(base_concat_ds): factor = 1e6 preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), Preprocessor(scale, factor=factor) ] raw_timepoint = base_concat_ds[0][0][:22] # only keep EEG channels preprocess(base_concat_ds, preprocessors) np.testing.assert_allclose(base_concat_ds[0][0], raw_timepoint * factor, rtol=1e-4, atol=1e-4)
def test_scale_windows(windows_concat_ds): factor = 1e6 preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), Preprocessor(scale, factor=factor) ] raw_window = windows_concat_ds[0][0][:22] # only keep EEG channels preprocess(windows_concat_ds, preprocessors) np.testing.assert_allclose(windows_concat_ds[0][0], raw_window * factor, rtol=1e-4, atol=1e-4)
def test_scale_windows(windows_concat_ds): factor = 1e6 preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), Preprocessor(deprecated_scale, factor=factor) ] raw_window = windows_concat_ds[0][0][:22] # only keep EEG channels preprocess(windows_concat_ds, preprocessors) np.testing.assert_allclose(windows_concat_ds[0][0], raw_window * factor, rtol=1e-4, atol=1e-4) assert all([ds.window_preproc_kwargs == [ ('pick_types', {'eeg': True, 'meg': False, 'stim': False}), ('scale', {'factor': 1e6}), ] for ds in windows_concat_ds.datasets])
def test_windows_fixed_length_cropped(lazy_loadable_dataset): """Test fixed length windowing on cropped data. Cropping raw data changes the `first_samp` attribute of the Raw object, and so it is important to test this is taken into account by the windowers. """ tmin, tmax = 100, 120 ds = copy.deepcopy(lazy_loadable_dataset) ds.datasets[0].raw.annotations.crop(tmin, tmax) crop_ds = copy.deepcopy(lazy_loadable_dataset) crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax) preprocess(crop_ds, [crop_transform]) # Extract windows sfreq = ds.datasets[0].raw.info['sfreq'] tmin_samples, tmax_samples = int(tmin * sfreq), int(tmax * sfreq) windows1 = create_fixed_length_windows(concat_ds=ds, start_offset_samples=tmin_samples, stop_offset_samples=tmax_samples, window_size_samples=100, window_stride_samples=100, drop_last_window=True) windows2 = create_fixed_length_windows(concat_ds=crop_ds, start_offset_samples=0, stop_offset_samples=None, window_size_samples=100, window_stride_samples=100, drop_last_window=True) assert (windows1[0][0] == windows2[0][0]).all()
def test_preprocess_save_dir(base_concat_ds, windows_concat_ds, tmp_path, kind, save, overwrite, n_jobs): preproc_kwargs = [ ('crop', {'tmin': 0, 'tmax': 0.1, 'include_tmax': False})] preprocessors = [ Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)] save_dir = str(tmp_path) if save else None if kind == 'raw': concat_ds = base_concat_ds preproc_kwargs_name = 'raw_preproc_kwargs' elif kind == 'windows': concat_ds = windows_concat_ds preproc_kwargs_name = 'window_preproc_kwargs' concat_ds = preprocess( concat_ds, preprocessors, save_dir, overwrite=overwrite, n_jobs=n_jobs) assert all([hasattr(ds, preproc_kwargs_name) for ds in concat_ds.datasets]) assert all([getattr(ds, preproc_kwargs_name) == preproc_kwargs for ds in concat_ds.datasets]) assert all([len(getattr(ds, kind).times) == 25 for ds in concat_ds.datasets]) if kind == 'raw': assert all([hasattr(ds, 'target_name') for ds in concat_ds.datasets]) if save_dir is None: assert all([getattr(ds, kind).preload for ds in concat_ds.datasets]) else: assert all([not getattr(ds, kind).preload for ds in concat_ds.datasets]) save_dirs = [os.path.join(save_dir, str(i)) for i in range(len(concat_ds.datasets))] assert set(glob(save_dir + '/*')) == set(save_dirs)
def test_preprocess_raw_str(base_concat_ds): preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)] preprocess(base_concat_ds, preprocessors) assert len(base_concat_ds.datasets[0].raw.times) == 2500 assert all([ds.raw_preproc_kwargs == [ ('crop', {'tmax': 10, 'include_tmax': False}), ] for ds in base_concat_ds.datasets])
def test_set_raw_preproc_kwargs(base_concat_ds): raw_preproc_kwargs = [('crop', {'tmax': 10, 'include_tmax': False})] preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)] ds = base_concat_ds.datasets[0] _set_preproc_kwargs(ds, preprocessors) assert hasattr(ds, 'raw_preproc_kwargs') assert ds.raw_preproc_kwargs == raw_preproc_kwargs
def test_set_window_preproc_kwargs(windows_concat_ds): window_preproc_kwargs = [('crop', {'tmax': 10, 'include_tmax': False})] preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)] ds = windows_concat_ds.datasets[0] _set_preproc_kwargs(ds, preprocessors) assert hasattr(ds, 'window_preproc_kwargs') assert ds.window_preproc_kwargs == window_preproc_kwargs
def test_preprocess_windows_str(windows_concat_ds): preprocessors = [ Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)] preprocess(windows_concat_ds, preprocessors) assert windows_concat_ds[0][0].shape[1] == 25 assert all([ds.window_preproc_kwargs == [ ('crop', {'tmin': 0, 'tmax': 0.1, 'include_tmax': False}), ] for ds in windows_concat_ds.datasets])
def test_filterbank_order_channels_by_freq(base_concat_ds): base_concat_ds = base_concat_ds.split([[0]])['0'] preprocessors = [ Preprocessor('pick_channels', ch_names=sorted(['C4', 'Cz']), ordered=True), Preprocessor(filterbank, frequency_bands=[(0, 4), (4, 8), (8, 13)], drop_original_signals=False, order_by_frequency_band=True, apply_on_array=False) ] preprocess(base_concat_ds, preprocessors) np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [ 'C4', 'Cz', 'C4_0-4', 'Cz_0-4', 'C4_4-8', 'Cz_4-8', 'C4_8-13', 'Cz_8-13' ])
def test_preprocess_windows_callable_on_object(windows_concat_ds): factor = 10 preprocessors = [Preprocessor(modify_windows_object, apply_on_array=False, factor=factor)] raw_window = windows_concat_ds[0][0] preprocess(windows_concat_ds, preprocessors) np.testing.assert_allclose(windows_concat_ds[0][0], raw_window * factor, rtol=1e-4, atol=1e-4)
def test_zscore_windows(windows_concat_ds): preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), Preprocessor(zscore) ] preprocess(windows_concat_ds, preprocessors) for ds in windows_concat_ds.datasets: windowed_data = ds.windows.get_data() shape = windowed_data.shape # zero mean expected = np.zeros(shape[:-1]) np.testing.assert_allclose( windowed_data.mean(axis=-1), expected, rtol=1e-4, atol=1e-4) # unit variance expected = np.ones(shape[:-1]) np.testing.assert_allclose( windowed_data.std(axis=-1), expected, rtol=1e-4, atol=1e-4) assert all([ds.window_preproc_kwargs == [ ('pick_types', {'eeg': True, 'meg': False, 'stim': False}), ('zscore', {}), ] for ds in windows_concat_ds.datasets])
def test_zscore_continuous(base_concat_ds): preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), Preprocessor(zscore, channel_wise=True) ] preprocess(base_concat_ds, preprocessors) for ds in base_concat_ds.datasets: raw_data = ds.raw.get_data() shape = raw_data.shape # zero mean expected = np.zeros(shape[:-1]) np.testing.assert_allclose(raw_data.mean(axis=-1), expected, rtol=1e-4, atol=1e-4) # unit variance expected = np.ones(shape[:-1]) np.testing.assert_allclose(raw_data.std(axis=-1), expected, rtol=1e-4, atol=1e-4)
def test_windows_from_events_cropped(lazy_loadable_dataset): """Test windowing from events on cropped data. Cropping raw data changes the `first_samp` attribute of the Raw object, and so it is important to test this is taken into account by the windowers. """ tmin, tmax = 100, 120 ds = copy.deepcopy(lazy_loadable_dataset) ds.datasets[0].raw.annotations.crop(tmin, tmax) crop_ds = copy.deepcopy(lazy_loadable_dataset) crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax) preprocess(crop_ds, [crop_transform]) # Extract windows windows1 = create_windows_from_events(concat_ds=ds, trial_start_offset_samples=0, trial_stop_offset_samples=0, window_size_samples=100, window_stride_samples=100, drop_last_window=False) windows2 = create_windows_from_events(concat_ds=crop_ds, trial_start_offset_samples=0, trial_stop_offset_samples=0, window_size_samples=100, window_stride_samples=100, drop_last_window=False) assert (windows1[0][0] == windows2[0][0]).all() # Make sure events that fall outside of recording will trigger an error with pytest.raises(ValueError, match='"trial_stop_offset_samples" too large'): create_windows_from_events(concat_ds=ds, trial_start_offset_samples=0, trial_stop_offset_samples=10000, window_size_samples=100, window_stride_samples=100, drop_last_window=False) with pytest.raises(ValueError, match='"trial_stop_offset_samples" too large'): create_windows_from_events(concat_ds=crop_ds, trial_start_offset_samples=0, trial_stop_offset_samples=2001, window_size_samples=100, window_stride_samples=100, drop_last_window=False)
def test_preprocess_overwrite(base_concat_ds, tmp_path, overwrite): preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)] # Create temporary directory with preexisting files save_dir = str(tmp_path) for i, ds in enumerate(base_concat_ds.datasets): concat_ds = BaseConcatDataset([ds]) save_subdir = os.path.join(save_dir, str(i)) os.makedirs(save_subdir) concat_ds.save(save_subdir, overwrite=True) if overwrite: preprocess(base_concat_ds, preprocessors, save_dir, overwrite=True) # Make sure the serialized data is preprocessed preproc_concat_ds = load_concat_dataset(save_dir, True) assert all([len(ds.raw.times) == 2500 for ds in preproc_concat_ds.datasets]) else: with pytest.raises(FileExistsError): preprocess(base_concat_ds, preprocessors, save_dir, overwrite=False)
def custom_crop(raw, tmin=0.0, tmax=None, include_tmax=True): # crop recordings to tmin – tmax. can be incomplete if recording # has lower duration than tmax # by default mne fails if tmax is bigger than duration tmax = min((raw.n_times - 1) / raw.info['sfreq'], tmax) raw.crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax) tmin = 1 * 60 tmax = 6 * 60 sfreq = 100 preprocessors = [ Preprocessor(custom_crop, tmin=tmin, tmax=tmax, include_tmax=False, apply_on_array=False), Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'), Preprocessor(custom_rename_channels, mapping=ch_mapping, apply_on_array=False), Preprocessor('pick_channels', ch_names=short_ch_names, ordered=True), Preprocessor(lambda x: x * 1e6), Preprocessor('resample', sfreq=sfreq), ] ############################################################################### # The preprocessing loop works as follows. For every recording, we apply the # preprocessors as defined above. Then, we update the description of the rec, # since we have altered the duration, the reference, and the sampling frequency.
def test_preprocess_raw_str(base_concat_ds): preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)] preprocess(base_concat_ds, preprocessors) assert len(base_concat_ds.datasets[0].raw.times) == 2500
def test_preprocess_windows_str(windows_concat_ds): preprocessors = [ Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False) ] preprocess(windows_concat_ds, preprocessors) assert windows_concat_ds[0][0].shape[1] == 25
def test_method_not_available(base_concat_ds): preprocessors = [Preprocessor('this_method_is_not_real', )] with pytest.raises(AttributeError): preprocess(base_concat_ds, preprocessors)
crop_wake_mins=30) ###################################################################### # Preprocessing # ~~~~~~~~~~~~~ # # Next, we preprocess the raw data. We convert the data to microvolts and apply # a lowpass filter. Since the Sleep Physionet data is already sampled at 100 Hz # we don't need to apply resampling. from braindecode.preprocessing.preprocess import preprocess, Preprocessor, scale high_cut_hz = 30 preprocessors = [ Preprocessor(scale, factor=1e6, apply_on_array=True), Preprocessor('filter', l_freq=None, h_freq=high_cut_hz, n_jobs=n_jobs) ] # Transform the data preprocess(dataset, preprocessors) ###################################################################### # Extracting windows # ~~~~~~~~~~~~~~~~~~ # # We extract 30-s windows to be used in both the pretext and downstream tasks. # As RP (and SSL in general) don't require labelled data, the pretext task # could be performed using unlabelled windows extracted with # :func:`braindecode.datautil.windower.create_fixed_length_window`. # Here however, purely for convenience, we directly extract labelled windows so
# ds has a pandas DataFrame with additional description of its internal datasets dataset.description ############################################################################## # We can iterate through ds which yields one time point of a continuous signal x, # and a target y (which can be None if targets are not defined for the entire # continuous signal). for x, y in dataset: print(x.shape, y) break ############################################################################## # We can apply preprocessing transforms that are defined in mne and work # in-place, such as resampling, bandpass filtering, or electrode selection. preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=True), Preprocessor('resample', sfreq=100) ] print(dataset.datasets[0].raw.info["sfreq"]) preprocess(dataset, preprocessors) print(dataset.datasets[0].raw.info["sfreq"]) ############################################################################### # We can easily split ds based on a criteria applied to the description # DataFrame: subsets = dataset.split("session") print({subset_name: len(subset) for subset_name, subset in subsets.items()}) ############################################################################### # Next, we use a windower to extract events from the dataset based on events: windows_dataset = create_windows_from_events(dataset,
def test_set_preproc_kwargs_wrong_type(base_concat_ds): preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)] with pytest.raises(TypeError): _set_preproc_kwargs(base_concat_ds, preprocessors)
# from braindecode.preprocessing.preprocess import ( exponential_moving_standardize, preprocess, Preprocessor) low_cut_hz = 8. # low cut frequency for filtering high_cut_hz = 48. # high cut frequency for filtering # Parameters for exponential moving standardization factor_new = 1e-3 init_block_size = 1000 channels_list = dataset.datasets[0].raw.info['ch_names'][:62] # channels_list.remove('T8') # channels_list.remove('T9') preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), # Keep EEG sensors Preprocessor('pick_channels', ch_names=channels_list), # select 62 of 64 channels Preprocessor(lambda x: x * 1e6), # Convert from V to uV Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz), # Bandpass filter Preprocessor(exponential_moving_standardize, # Exponential moving standardization factor_new=factor_new, init_block_size=init_block_size) ] # Transform the data preprocess(dataset, preprocessors) ###################################################################### # Cut Compute Windows # ~~~~~~~~~~~~~~~~~~~ #
from braindecode.datasets.moabb import MOABBDataset subject_id = 3 dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id]) from braindecode.preprocessing.preprocess import ( exponential_moving_standardize, preprocess, Preprocessor) low_cut_hz = 4. # low cut frequency for filtering high_cut_hz = 38. # high cut frequency for filtering # Parameters for exponential moving standardization factor_new = 1e-3 init_block_size = 1000 preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), # Keep EEG sensors Preprocessor(lambda x: x * 1e6), # Convert from V to uV Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz), # Bandpass filter Preprocessor(exponential_moving_standardize, # Exponential moving standardization factor_new=factor_new, init_block_size=init_block_size) ] # Transform the data preprocess(dataset, preprocessors) ###################################################################### # Create model and compute windowing parameters # --------------------------------------------- #
# ~~~~~~~~~~~~~ # ###################################################################### # Next, we preprocess the raw data. We convert the data to microvolts and apply # a lowpass filter. We omit the downsampling step of [1]_ as the Sleep # Physionet data is already sampled at a lower 100 Hz. # from braindecode.preprocessing.preprocess import preprocess, Preprocessor high_cut_hz = 30 preprocessors = [ Preprocessor(lambda x: x * 1e6), Preprocessor('filter', l_freq=None, h_freq=high_cut_hz) ] # Transform the data preprocess(dataset, preprocessors) ###################################################################### # Extract windows # ~~~~~~~~~~~~~~~ # ###################################################################### # We extract 30-s windows to be used in the classification task.