def test_preprocess_raw_str(base_concat_ds): preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)] preprocess(base_concat_ds, preprocessors) assert len(base_concat_ds.datasets[0].raw.times) == 2500 assert all([ds.raw_preproc_kwargs == [ ('crop', {'tmax': 10, 'include_tmax': False}), ] for ds in base_concat_ds.datasets])
def test_windows_fixed_length_cropped(lazy_loadable_dataset): """Test fixed length windowing on cropped data. Cropping raw data changes the `first_samp` attribute of the Raw object, and so it is important to test this is taken into account by the windowers. """ tmin, tmax = 100, 120 ds = copy.deepcopy(lazy_loadable_dataset) ds.datasets[0].raw.annotations.crop(tmin, tmax) crop_ds = copy.deepcopy(lazy_loadable_dataset) crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax) preprocess(crop_ds, [crop_transform]) # Extract windows sfreq = ds.datasets[0].raw.info['sfreq'] tmin_samples, tmax_samples = int(tmin * sfreq), int(tmax * sfreq) windows1 = create_fixed_length_windows(concat_ds=ds, start_offset_samples=tmin_samples, stop_offset_samples=tmax_samples, window_size_samples=100, window_stride_samples=100, drop_last_window=True) windows2 = create_fixed_length_windows(concat_ds=crop_ds, start_offset_samples=0, stop_offset_samples=None, window_size_samples=100, window_stride_samples=100, drop_last_window=True) assert (windows1[0][0] == windows2[0][0]).all()
def test_filterbank(base_concat_ds): base_concat_ds = base_concat_ds.split([[0]])['0'] preprocessors = [ Preprocessor('pick_channels', ch_names=sorted(['C4', 'Cz']), ordered=True), Preprocessor(filterbank, frequency_bands=[(0, 4), (4, 8), (8, 13)], drop_original_signals=False, apply_on_array=False) ] preprocess(base_concat_ds, preprocessors) for x, y in base_concat_ds: break assert x.shape[0] == 8 freq_band_annots = [ ch.split('_')[-1] for ch in base_concat_ds.datasets[0].raw.ch_names if '_' in ch] assert len(np.unique(freq_band_annots)) == 3 np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [ 'C4', 'C4_0-4', 'C4_4-8', 'C4_8-13', 'Cz', 'Cz_0-4', 'Cz_4-8', 'Cz_8-13', ]) assert all([ds.raw_preproc_kwargs == [ ('pick_channels', {'ch_names': ['C4', 'Cz'], 'ordered': True}), ('filterbank', {'frequency_bands': [(0, 4), (4, 8), (8, 13)], 'drop_original_signals': False}), ] for ds in base_concat_ds.datasets])
def test_preprocessors_with_misc_channels(): rng = np.random.RandomState(42) signal_sfreq = 50 info = mne.create_info(ch_names=['0', '1', 'target_0', 'target_1'], sfreq=signal_sfreq, ch_types=['eeg', 'eeg', 'misc', 'misc']) signal = rng.randn(2, 1000) targets = rng.randn(2, 1000) raw = mne.io.RawArray(np.concatenate([signal, targets]), info=info) desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48}) base_dataset = BaseDataset(raw, desc, target_name=None) concat_ds = BaseConcatDataset([base_dataset]) preprocessors = [ Preprocessor('pick_types', eeg=True, misc=True), Preprocessor(lambda x: x / 1e6), ] preprocess(concat_ds, preprocessors) # Check whether preprocessing has not affected the targets # This is only valid for preprocessors that use mne functions which do not modify # `misc` channels. np.testing.assert_array_equal( concat_ds.datasets[0].raw.get_data()[-2:, :], targets )
def test_no_raw_or_epochs(): class EmptyDataset(object): def __init__(self): self.datasets = [1, 2, 3] ds = EmptyDataset() with pytest.raises(AssertionError): preprocess(ds, ["dummy", "dummy"])
def test_preprocess_windows_str(windows_concat_ds): preprocessors = [ Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)] preprocess(windows_concat_ds, preprocessors) assert windows_concat_ds[0][0].shape[1] == 25 assert all([ds.window_preproc_kwargs == [ ('crop', {'tmin': 0, 'tmax': 0.1, 'include_tmax': False}), ] for ds in windows_concat_ds.datasets])
def test_preprocess_windows_callable_on_object(windows_concat_ds): factor = 10 preprocessors = [Preprocessor(modify_windows_object, apply_on_array=False, factor=factor)] raw_window = windows_concat_ds[0][0] preprocess(windows_concat_ds, preprocessors) np.testing.assert_allclose(windows_concat_ds[0][0], raw_window * factor, rtol=1e-4, atol=1e-4)
def test_scale_continuous(base_concat_ds): factor = 1e6 preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), Preprocessor(scale, factor=factor) ] raw_timepoint = base_concat_ds[0][0][:22] # only keep EEG channels preprocess(base_concat_ds, preprocessors) np.testing.assert_allclose(base_concat_ds[0][0], raw_timepoint * factor, rtol=1e-4, atol=1e-4)
def test_scale_windows(windows_concat_ds): factor = 1e6 preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), Preprocessor(scale, factor=factor) ] raw_window = windows_concat_ds[0][0][:22] # only keep EEG channels preprocess(windows_concat_ds, preprocessors) np.testing.assert_allclose(windows_concat_ds[0][0], raw_window * factor, rtol=1e-4, atol=1e-4)
def test_scale_windows(windows_concat_ds): factor = 1e6 preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), Preprocessor(deprecated_scale, factor=factor) ] raw_window = windows_concat_ds[0][0][:22] # only keep EEG channels preprocess(windows_concat_ds, preprocessors) np.testing.assert_allclose(windows_concat_ds[0][0], raw_window * factor, rtol=1e-4, atol=1e-4) assert all([ds.window_preproc_kwargs == [ ('pick_types', {'eeg': True, 'meg': False, 'stim': False}), ('scale', {'factor': 1e6}), ] for ds in windows_concat_ds.datasets])
def test_deprecated_preprocs(base_concat_ds): msg1 = 'Class MNEPreproc is deprecated; will be removed in 0.7.0. Use ' \ 'Preprocessor with `apply_on_array=False` instead.' msg2 = 'NumpyPreproc is deprecated; will be removed in 0.7.0. Use ' \ 'Preprocessor with `apply_on_array=True` instead.' with pytest.warns(FutureWarning, match=msg1): mne_preproc = MNEPreproc('pick_types', eeg=True, meg=False, stim=False) factor = 1e6 with pytest.warns(FutureWarning, match=msg2): np_preproc = NumpyPreproc(deprecated_scale, factor=factor) raw_timepoint = base_concat_ds[0][0][:22] # only keep EEG channels preprocess(base_concat_ds, [mne_preproc, np_preproc]) np.testing.assert_allclose(base_concat_ds[0][0], raw_timepoint * factor, rtol=1e-4, atol=1e-4)
def test_preprocess_save_dir(base_concat_ds, windows_concat_ds, tmp_path, kind, save, overwrite, n_jobs): preproc_kwargs = [ ('crop', {'tmin': 0, 'tmax': 0.1, 'include_tmax': False})] preprocessors = [ Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)] save_dir = str(tmp_path) if save else None if kind == 'raw': concat_ds = base_concat_ds preproc_kwargs_name = 'raw_preproc_kwargs' elif kind == 'windows': concat_ds = windows_concat_ds preproc_kwargs_name = 'window_preproc_kwargs' concat_ds = preprocess( concat_ds, preprocessors, save_dir, overwrite=overwrite, n_jobs=n_jobs) assert all([hasattr(ds, preproc_kwargs_name) for ds in concat_ds.datasets]) assert all([getattr(ds, preproc_kwargs_name) == preproc_kwargs for ds in concat_ds.datasets]) assert all([len(getattr(ds, kind).times) == 25 for ds in concat_ds.datasets]) if kind == 'raw': assert all([hasattr(ds, 'target_name') for ds in concat_ds.datasets]) if save_dir is None: assert all([getattr(ds, kind).preload for ds in concat_ds.datasets]) else: assert all([not getattr(ds, kind).preload for ds in concat_ds.datasets]) save_dirs = [os.path.join(save_dir, str(i)) for i in range(len(concat_ds.datasets))] assert set(glob(save_dir + '/*')) == set(save_dirs)
def test_windows_from_events_cropped(lazy_loadable_dataset): """Test windowing from events on cropped data. Cropping raw data changes the `first_samp` attribute of the Raw object, and so it is important to test this is taken into account by the windowers. """ tmin, tmax = 100, 120 ds = copy.deepcopy(lazy_loadable_dataset) ds.datasets[0].raw.annotations.crop(tmin, tmax) crop_ds = copy.deepcopy(lazy_loadable_dataset) crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax) preprocess(crop_ds, [crop_transform]) # Extract windows windows1 = create_windows_from_events(concat_ds=ds, trial_start_offset_samples=0, trial_stop_offset_samples=0, window_size_samples=100, window_stride_samples=100, drop_last_window=False) windows2 = create_windows_from_events(concat_ds=crop_ds, trial_start_offset_samples=0, trial_stop_offset_samples=0, window_size_samples=100, window_stride_samples=100, drop_last_window=False) assert (windows1[0][0] == windows2[0][0]).all() # Make sure events that fall outside of recording will trigger an error with pytest.raises(ValueError, match='"trial_stop_offset_samples" too large'): create_windows_from_events(concat_ds=ds, trial_start_offset_samples=0, trial_stop_offset_samples=10000, window_size_samples=100, window_stride_samples=100, drop_last_window=False) with pytest.raises(ValueError, match='"trial_stop_offset_samples" too large'): create_windows_from_events(concat_ds=crop_ds, trial_start_offset_samples=0, trial_stop_offset_samples=2001, window_size_samples=100, window_stride_samples=100, drop_last_window=False)
def test_filterbank_order_channels_by_freq(base_concat_ds): base_concat_ds = base_concat_ds.split([[0]])['0'] preprocessors = [ Preprocessor('pick_channels', ch_names=sorted(['C4', 'Cz']), ordered=True), Preprocessor(filterbank, frequency_bands=[(0, 4), (4, 8), (8, 13)], drop_original_signals=False, order_by_frequency_band=True, apply_on_array=False) ] preprocess(base_concat_ds, preprocessors) np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [ 'C4', 'Cz', 'C4_0-4', 'Cz_0-4', 'C4_4-8', 'Cz_4-8', 'C4_8-13', 'Cz_8-13' ])
def test_zscore_windows(windows_concat_ds): preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), Preprocessor(zscore) ] preprocess(windows_concat_ds, preprocessors) for ds in windows_concat_ds.datasets: windowed_data = ds.windows.get_data() shape = windowed_data.shape # zero mean expected = np.zeros(shape[:-1]) np.testing.assert_allclose( windowed_data.mean(axis=-1), expected, rtol=1e-4, atol=1e-4) # unit variance expected = np.ones(shape[:-1]) np.testing.assert_allclose( windowed_data.std(axis=-1), expected, rtol=1e-4, atol=1e-4) assert all([ds.window_preproc_kwargs == [ ('pick_types', {'eeg': True, 'meg': False, 'stim': False}), ('zscore', {}), ] for ds in windows_concat_ds.datasets])
def test_zscore_continuous(base_concat_ds): preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), Preprocessor(zscore, channel_wise=True) ] preprocess(base_concat_ds, preprocessors) for ds in base_concat_ds.datasets: raw_data = ds.raw.get_data() shape = raw_data.shape # zero mean expected = np.zeros(shape[:-1]) np.testing.assert_allclose(raw_data.mean(axis=-1), expected, rtol=1e-4, atol=1e-4) # unit variance expected = np.ones(shape[:-1]) np.testing.assert_allclose(raw_data.std(axis=-1), expected, rtol=1e-4, atol=1e-4)
def test_preprocess_overwrite(base_concat_ds, tmp_path, overwrite): preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)] # Create temporary directory with preexisting files save_dir = str(tmp_path) for i, ds in enumerate(base_concat_ds.datasets): concat_ds = BaseConcatDataset([ds]) save_subdir = os.path.join(save_dir, str(i)) os.makedirs(save_subdir) concat_ds.save(save_subdir, overwrite=True) if overwrite: preprocess(base_concat_ds, preprocessors, save_dir, overwrite=True) # Make sure the serialized data is preprocessed preproc_concat_ds = load_concat_dataset(save_dir, True) assert all([len(ds.raw.times) == 2500 for ds in preproc_concat_ds.datasets]) else: with pytest.raises(FileExistsError): preprocess(base_concat_ds, preprocessors, save_dir, overwrite=False)
# Afterwards, we split the continuous signals into compute windows. We store # each recording to a unique subdirectory that is named corresponding to the # rec id. To save memory, after windowing and storing, we delete the raw # dataset and the windows dataset, respectively. window_size_samples = 1000 window_stride_samples = 1000 create_compute_windows = True out_i = 0 errors = [] OUT_PATH = './tuh_sample/' tuh_splits = tuh.split([[i] for i in range(len(tuh.datasets))]) for rec_i, tuh_subset in tuh_splits.items(): # implement preprocess for BaseDatasets? Would remove necessity # to split above preprocess(tuh_subset, preprocessors) # update description of the recording(s) tuh_subset.description.sfreq = len(tuh_subset.datasets) * [sfreq] tuh_subset.description.reference = len(tuh_subset.datasets) * ['ar'] tuh_subset.description.n_samples = [len(d) for d in tuh_subset.datasets] if create_compute_windows: # generate compute windows here and store them to disk tuh_windows = create_fixed_length_windows( tuh_subset, start_offset_samples=0, stop_offset_samples=None, window_size_samples=window_size_samples, window_stride_samples=window_stride_samples, drop_last_window=False)
def test_preprocess_raw_str(base_concat_ds): preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)] preprocess(base_concat_ds, preprocessors) assert len(base_concat_ds.datasets[0].raw.times) == 2500
def test_preprocess_windows_str(windows_concat_ds): preprocessors = [ Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False) ] preprocess(windows_concat_ds, preprocessors) assert windows_concat_ds[0][0].shape[1] == 25
def test_method_not_available(base_concat_ds): preprocessors = [Preprocessor('this_method_is_not_real', )] with pytest.raises(AttributeError): preprocess(base_concat_ds, preprocessors)
# # Next, we preprocess the raw data. We convert the data to microvolts and apply # a lowpass filter. Since the Sleep Physionet data is already sampled at 100 Hz # we don't need to apply resampling. from braindecode.preprocessing.preprocess import preprocess, Preprocessor, scale high_cut_hz = 30 preprocessors = [ Preprocessor(scale, factor=1e6, apply_on_array=True), Preprocessor('filter', l_freq=None, h_freq=high_cut_hz, n_jobs=n_jobs) ] # Transform the data preprocess(dataset, preprocessors) ###################################################################### # Extracting windows # ~~~~~~~~~~~~~~~~~~ # # We extract 30-s windows to be used in both the pretext and downstream tasks. # As RP (and SSL in general) don't require labelled data, the pretext task # could be performed using unlabelled windows extracted with # :func:`braindecode.datautil.windower.create_fixed_length_window`. # Here however, purely for convenience, we directly extract labelled windows so # that we can reuse them in the sleep staging downstream task later. from braindecode.preprocessing.windowers import create_windows_from_events window_size_s = 30
# We can iterate through ds which yields one time point of a continuous signal x, # and a target y (which can be None if targets are not defined for the entire # continuous signal). for x, y in ds: print(x.shape, y) break ############################################################################## # We can apply preprocessing transforms that are defined in mne and work # in-place, such as resampling, bandpass filtering, or electrode selection. preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=True), Preprocessor('resample', sfreq=100) ] print(ds.datasets[0].raw.info["sfreq"]) preprocess(ds, preprocessors) print(ds.datasets[0].raw.info["sfreq"]) ############################################################################### # We can easily split ds based on a criteria applied to the description # DataFrame: subsets = ds.split("session") print({subset_name: len(subset) for subset_name, subset in subsets.items()}) ############################################################################### # Next, we use a windower to extract events from the dataset based on events: windows_ds = create_windows_from_events( ds, trial_start_offset_samples=0, trial_stop_offset_samples=100, window_size_samples=400, window_stride_samples=100, drop_last_window=False)
def test_not_list(): with pytest.raises(AssertionError): preprocess(None, {'test': 1})
# We can iterate through ds which yields one time point of a continuous signal x, # and a target y (which can be None if targets are not defined for the entire # continuous signal). for x, y in dataset: print(x.shape, y) break ############################################################################## # We can apply preprocessing transforms that are defined in mne and work # in-place, such as resampling, bandpass filtering, or electrode selection. preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=True), Preprocessor('resample', sfreq=100) ] print(dataset.datasets[0].raw.info["sfreq"]) preprocess(dataset, preprocessors) print(dataset.datasets[0].raw.info["sfreq"]) ############################################################################### # We can easily split ds based on a criteria applied to the description # DataFrame: subsets = dataset.split("session") print({subset_name: len(subset) for subset_name, subset in subsets.items()}) ############################################################################### # Next, we use a windower to extract events from the dataset based on events: windows_dataset = create_windows_from_events(dataset, trial_start_offset_samples=0, trial_stop_offset_samples=100, window_size_samples=400, window_stride_samples=100,
low_cut_hz = 4. # low cut frequency for filtering high_cut_hz = 38. # high cut frequency for filtering # Parameters for exponential moving standardization factor_new = 1e-3 init_block_size = 1000 preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), # Keep EEG sensors Preprocessor(lambda x: x * 1e6), # Convert from V to uV Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz), # Bandpass filter Preprocessor(exponential_moving_standardize, # Exponential moving standardization factor_new=factor_new, init_block_size=init_block_size) ] # Transform the data preprocess(dataset, preprocessors) ###################################################################### # Create model and compute windowing parameters # --------------------------------------------- # ###################################################################### # In contrast to trialwise decoding, we first have to create the model # before we can cut the dataset into windows. This is because we need to # know the receptive field of the network to know how large the window # stride should be. #
# Next, we preprocess the raw data. We convert the data to microvolts and apply # a lowpass filter. We omit the downsampling step of [1]_ as the Sleep # Physionet data is already sampled at a lower 100 Hz. # from braindecode.preprocessing.preprocess import preprocess, Preprocessor high_cut_hz = 30 preprocessors = [ Preprocessor(lambda x: x * 1e6), Preprocessor('filter', l_freq=None, h_freq=high_cut_hz) ] # Transform the data preprocess(dataset, preprocessors) ###################################################################### # Extract windows # ~~~~~~~~~~~~~~~ # ###################################################################### # We extract 30-s windows to be used in the classification task. from braindecode.preprocessing import create_windows_from_events mapping = { # We merge stages 3 and 4 following AASM standards.
from braindecode.datasets.moabb import MOABBDataset from braindecode.preprocessing.preprocess import preprocess, MNEPreproc from braindecode.datautil.serialization import load_concat_dataset from braindecode.preprocessing.windowers import create_windows_from_events ############################################################################### # First, we load some dataset using MOABB. ds = MOABBDataset( dataset_name='BNCI2014001', subject_ids=[1], ) ############################################################################### # We can apply preprocessing steps to the dataset. It is also possible to skip # this step and not apply any preprocessing. preprocess(concat_ds=ds, preprocessors=[MNEPreproc(fn='resample', sfreq=10)]) ############################################################################### # We save the dataset to a an existing directory. It will create a '.fif' file # for every dataset in the concat dataset. Additionally it will create two # JSON files, the first holding the description of the dataset, the second # holding the name of the target. If you want to store to the same directory # several times, for example due to trying different preprocessing, you can # choose to overwrite the existing files. ds.save( path='./', overwrite=False, ) ############################################################################## # We load the saved dataset from a directory. Signals can be preloaded in