示例#1
0
def test_preprocessors_with_misc_channels():
    rng = np.random.RandomState(42)
    signal_sfreq = 50
    info = mne.create_info(ch_names=['0', '1', 'target_0', 'target_1'],
                           sfreq=signal_sfreq,
                           ch_types=['eeg', 'eeg', 'misc', 'misc'])
    signal = rng.randn(2, 1000)
    targets = rng.randn(2, 1000)
    raw = mne.io.RawArray(np.concatenate([signal, targets]), info=info)
    desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48})
    base_dataset = BaseDataset(raw, desc, target_name=None)
    concat_ds = BaseConcatDataset([base_dataset])
    preprocessors = [
        Preprocessor('pick_types', eeg=True, misc=True),
        Preprocessor(lambda x: x / 1e6),
    ]

    preprocess(concat_ds, preprocessors)

    # Check whether preprocessing has not affected the targets
    # This is only valid for preprocessors that use mne functions which do not modify
    # `misc` channels.
    np.testing.assert_array_equal(
        concat_ds.datasets[0].raw.get_data()[-2:, :],
        targets
    )
示例#2
0
def test_filterbank(base_concat_ds):
    base_concat_ds = base_concat_ds.split([[0]])['0']
    preprocessors = [
        Preprocessor('pick_channels', ch_names=sorted(['C4', 'Cz']),
                     ordered=True),
        Preprocessor(filterbank, frequency_bands=[(0, 4), (4, 8), (8, 13)],
                     drop_original_signals=False, apply_on_array=False)
    ]
    preprocess(base_concat_ds, preprocessors)
    for x, y in base_concat_ds:
        break
    assert x.shape[0] == 8
    freq_band_annots = [
        ch.split('_')[-1]
        for ch in base_concat_ds.datasets[0].raw.ch_names
        if '_' in ch]
    assert len(np.unique(freq_band_annots)) == 3
    np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [
        'C4', 'C4_0-4', 'C4_4-8', 'C4_8-13',
        'Cz', 'Cz_0-4', 'Cz_4-8', 'Cz_8-13',
    ])
    assert all([ds.raw_preproc_kwargs == [
        ('pick_channels', {'ch_names': ['C4', 'Cz'], 'ordered': True}),
        ('filterbank', {'frequency_bands': [(0, 4), (4, 8), (8, 13)],
                        'drop_original_signals': False}),
    ] for ds in base_concat_ds.datasets])
示例#3
0
def test_scale_continuous(base_concat_ds):
    factor = 1e6
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(scale, factor=factor)
    ]
    raw_timepoint = base_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(base_concat_ds, preprocessors)
    np.testing.assert_allclose(base_concat_ds[0][0],
                               raw_timepoint * factor,
                               rtol=1e-4,
                               atol=1e-4)
示例#4
0
def test_scale_windows(windows_concat_ds):
    factor = 1e6
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(scale, factor=factor)
    ]
    raw_window = windows_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(windows_concat_ds, preprocessors)
    np.testing.assert_allclose(windows_concat_ds[0][0],
                               raw_window * factor,
                               rtol=1e-4,
                               atol=1e-4)
示例#5
0
def test_scale_windows(windows_concat_ds):
    factor = 1e6
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(deprecated_scale, factor=factor)
    ]
    raw_window = windows_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(windows_concat_ds, preprocessors)
    np.testing.assert_allclose(windows_concat_ds[0][0], raw_window * factor,
                               rtol=1e-4, atol=1e-4)
    assert all([ds.window_preproc_kwargs == [
        ('pick_types', {'eeg': True, 'meg': False, 'stim': False}),
        ('scale', {'factor': 1e6}),
    ] for ds in windows_concat_ds.datasets])
示例#6
0
def test_windows_fixed_length_cropped(lazy_loadable_dataset):
    """Test fixed length windowing on cropped data.

    Cropping raw data changes the `first_samp` attribute of the Raw object, and
    so it is important to test this is taken into account by the windowers.
    """
    tmin, tmax = 100, 120

    ds = copy.deepcopy(lazy_loadable_dataset)
    ds.datasets[0].raw.annotations.crop(tmin, tmax)

    crop_ds = copy.deepcopy(lazy_loadable_dataset)
    crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax)
    preprocess(crop_ds, [crop_transform])

    # Extract windows
    sfreq = ds.datasets[0].raw.info['sfreq']
    tmin_samples, tmax_samples = int(tmin * sfreq), int(tmax * sfreq)

    windows1 = create_fixed_length_windows(concat_ds=ds,
                                           start_offset_samples=tmin_samples,
                                           stop_offset_samples=tmax_samples,
                                           window_size_samples=100,
                                           window_stride_samples=100,
                                           drop_last_window=True)
    windows2 = create_fixed_length_windows(concat_ds=crop_ds,
                                           start_offset_samples=0,
                                           stop_offset_samples=None,
                                           window_size_samples=100,
                                           window_stride_samples=100,
                                           drop_last_window=True)
    assert (windows1[0][0] == windows2[0][0]).all()
示例#7
0
def test_preprocess_save_dir(base_concat_ds, windows_concat_ds, tmp_path,
                             kind, save, overwrite, n_jobs):
    preproc_kwargs = [
        ('crop', {'tmin': 0, 'tmax': 0.1, 'include_tmax': False})]
    preprocessors = [
        Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)]

    save_dir = str(tmp_path) if save else None
    if kind == 'raw':
        concat_ds = base_concat_ds
        preproc_kwargs_name = 'raw_preproc_kwargs'
    elif kind == 'windows':
        concat_ds = windows_concat_ds
        preproc_kwargs_name = 'window_preproc_kwargs'

    concat_ds = preprocess(
        concat_ds, preprocessors, save_dir, overwrite=overwrite, n_jobs=n_jobs)

    assert all([hasattr(ds, preproc_kwargs_name) for ds in concat_ds.datasets])
    assert all([getattr(ds, preproc_kwargs_name) == preproc_kwargs
                for ds in concat_ds.datasets])
    assert all([len(getattr(ds, kind).times) == 25
                for ds in concat_ds.datasets])
    if kind == 'raw':
        assert all([hasattr(ds, 'target_name') for ds in concat_ds.datasets])

    if save_dir is None:
        assert all([getattr(ds, kind).preload
                    for ds in concat_ds.datasets])
    else:
        assert all([not getattr(ds, kind).preload
                    for ds in concat_ds.datasets])
        save_dirs = [os.path.join(save_dir, str(i))
                     for i in range(len(concat_ds.datasets))]
        assert set(glob(save_dir + '/*')) == set(save_dirs)
示例#8
0
def test_preprocess_raw_str(base_concat_ds):
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]
    preprocess(base_concat_ds, preprocessors)
    assert len(base_concat_ds.datasets[0].raw.times) == 2500
    assert all([ds.raw_preproc_kwargs == [
        ('crop', {'tmax': 10, 'include_tmax': False}),
    ] for ds in base_concat_ds.datasets])
示例#9
0
def test_set_raw_preproc_kwargs(base_concat_ds):
    raw_preproc_kwargs = [('crop', {'tmax': 10, 'include_tmax': False})]
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]
    ds = base_concat_ds.datasets[0]
    _set_preproc_kwargs(ds, preprocessors)

    assert hasattr(ds, 'raw_preproc_kwargs')
    assert ds.raw_preproc_kwargs == raw_preproc_kwargs
示例#10
0
def test_set_window_preproc_kwargs(windows_concat_ds):
    window_preproc_kwargs = [('crop', {'tmax': 10, 'include_tmax': False})]
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]
    ds = windows_concat_ds.datasets[0]
    _set_preproc_kwargs(ds, preprocessors)

    assert hasattr(ds, 'window_preproc_kwargs')
    assert ds.window_preproc_kwargs == window_preproc_kwargs
示例#11
0
def test_preprocess_windows_str(windows_concat_ds):
    preprocessors = [
        Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)]
    preprocess(windows_concat_ds, preprocessors)
    assert windows_concat_ds[0][0].shape[1] == 25
    assert all([ds.window_preproc_kwargs == [
        ('crop', {'tmin': 0, 'tmax': 0.1, 'include_tmax': False}),
    ] for ds in windows_concat_ds.datasets])
示例#12
0
def test_filterbank_order_channels_by_freq(base_concat_ds):
    base_concat_ds = base_concat_ds.split([[0]])['0']
    preprocessors = [
        Preprocessor('pick_channels',
                     ch_names=sorted(['C4', 'Cz']),
                     ordered=True),
        Preprocessor(filterbank,
                     frequency_bands=[(0, 4), (4, 8), (8, 13)],
                     drop_original_signals=False,
                     order_by_frequency_band=True,
                     apply_on_array=False)
    ]
    preprocess(base_concat_ds, preprocessors)
    np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [
        'C4', 'Cz', 'C4_0-4', 'Cz_0-4', 'C4_4-8', 'Cz_4-8', 'C4_8-13',
        'Cz_8-13'
    ])
示例#13
0
def test_preprocess_windows_callable_on_object(windows_concat_ds):
    factor = 10
    preprocessors = [Preprocessor(modify_windows_object, apply_on_array=False,
                                  factor=factor)]
    raw_window = windows_concat_ds[0][0]
    preprocess(windows_concat_ds, preprocessors)
    np.testing.assert_allclose(windows_concat_ds[0][0], raw_window * factor,
                               rtol=1e-4, atol=1e-4)
示例#14
0
def test_zscore_windows(windows_concat_ds):
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(zscore)
    ]
    preprocess(windows_concat_ds, preprocessors)
    for ds in windows_concat_ds.datasets:
        windowed_data = ds.windows.get_data()
        shape = windowed_data.shape
        # zero mean
        expected = np.zeros(shape[:-1])
        np.testing.assert_allclose(
            windowed_data.mean(axis=-1), expected, rtol=1e-4, atol=1e-4)
        # unit variance
        expected = np.ones(shape[:-1])
        np.testing.assert_allclose(
            windowed_data.std(axis=-1), expected, rtol=1e-4, atol=1e-4)
    assert all([ds.window_preproc_kwargs == [
        ('pick_types', {'eeg': True, 'meg': False, 'stim': False}),
        ('zscore', {}),
    ] for ds in windows_concat_ds.datasets])
示例#15
0
def test_zscore_continuous(base_concat_ds):
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(zscore, channel_wise=True)
    ]
    preprocess(base_concat_ds, preprocessors)
    for ds in base_concat_ds.datasets:
        raw_data = ds.raw.get_data()
        shape = raw_data.shape
        # zero mean
        expected = np.zeros(shape[:-1])
        np.testing.assert_allclose(raw_data.mean(axis=-1),
                                   expected,
                                   rtol=1e-4,
                                   atol=1e-4)
        # unit variance
        expected = np.ones(shape[:-1])
        np.testing.assert_allclose(raw_data.std(axis=-1),
                                   expected,
                                   rtol=1e-4,
                                   atol=1e-4)
示例#16
0
def test_windows_from_events_cropped(lazy_loadable_dataset):
    """Test windowing from events on cropped data.

    Cropping raw data changes the `first_samp` attribute of the Raw object, and
    so it is important to test this is taken into account by the windowers.
    """
    tmin, tmax = 100, 120

    ds = copy.deepcopy(lazy_loadable_dataset)
    ds.datasets[0].raw.annotations.crop(tmin, tmax)

    crop_ds = copy.deepcopy(lazy_loadable_dataset)
    crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax)
    preprocess(crop_ds, [crop_transform])

    # Extract windows
    windows1 = create_windows_from_events(concat_ds=ds,
                                          trial_start_offset_samples=0,
                                          trial_stop_offset_samples=0,
                                          window_size_samples=100,
                                          window_stride_samples=100,
                                          drop_last_window=False)
    windows2 = create_windows_from_events(concat_ds=crop_ds,
                                          trial_start_offset_samples=0,
                                          trial_stop_offset_samples=0,
                                          window_size_samples=100,
                                          window_stride_samples=100,
                                          drop_last_window=False)
    assert (windows1[0][0] == windows2[0][0]).all()

    # Make sure events that fall outside of recording will trigger an error
    with pytest.raises(ValueError,
                       match='"trial_stop_offset_samples" too large'):
        create_windows_from_events(concat_ds=ds,
                                   trial_start_offset_samples=0,
                                   trial_stop_offset_samples=10000,
                                   window_size_samples=100,
                                   window_stride_samples=100,
                                   drop_last_window=False)
    with pytest.raises(ValueError,
                       match='"trial_stop_offset_samples" too large'):
        create_windows_from_events(concat_ds=crop_ds,
                                   trial_start_offset_samples=0,
                                   trial_stop_offset_samples=2001,
                                   window_size_samples=100,
                                   window_stride_samples=100,
                                   drop_last_window=False)
示例#17
0
def test_preprocess_overwrite(base_concat_ds, tmp_path, overwrite):
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]

    # Create temporary directory with preexisting files
    save_dir = str(tmp_path)
    for i, ds in enumerate(base_concat_ds.datasets):
        concat_ds = BaseConcatDataset([ds])
        save_subdir = os.path.join(save_dir, str(i))
        os.makedirs(save_subdir)
        concat_ds.save(save_subdir, overwrite=True)

    if overwrite:
        preprocess(base_concat_ds, preprocessors, save_dir, overwrite=True)
        # Make sure the serialized data is preprocessed
        preproc_concat_ds = load_concat_dataset(save_dir, True)
        assert all([len(ds.raw.times) == 2500
                    for ds in preproc_concat_ds.datasets])
    else:
        with pytest.raises(FileExistsError):
            preprocess(base_concat_ds, preprocessors, save_dir,
                       overwrite=False)
示例#18
0
def custom_crop(raw, tmin=0.0, tmax=None, include_tmax=True):
    # crop recordings to tmin – tmax. can be incomplete if recording
    # has lower duration than tmax
    # by default mne fails if tmax is bigger than duration
    tmax = min((raw.n_times - 1) / raw.info['sfreq'], tmax)
    raw.crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax)


tmin = 1 * 60
tmax = 6 * 60
sfreq = 100

preprocessors = [
    Preprocessor(custom_crop,
                 tmin=tmin,
                 tmax=tmax,
                 include_tmax=False,
                 apply_on_array=False),
    Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'),
    Preprocessor(custom_rename_channels,
                 mapping=ch_mapping,
                 apply_on_array=False),
    Preprocessor('pick_channels', ch_names=short_ch_names, ordered=True),
    Preprocessor(lambda x: x * 1e6),
    Preprocessor('resample', sfreq=sfreq),
]

###############################################################################
# The preprocessing loop works as follows. For every recording, we apply the
# preprocessors as defined above. Then, we update the description of the rec,
# since we have altered the duration, the reference, and the sampling frequency.
示例#19
0
def test_preprocess_raw_str(base_concat_ds):
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]
    preprocess(base_concat_ds, preprocessors)
    assert len(base_concat_ds.datasets[0].raw.times) == 2500
示例#20
0
def test_preprocess_windows_str(windows_concat_ds):
    preprocessors = [
        Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)
    ]
    preprocess(windows_concat_ds, preprocessors)
    assert windows_concat_ds[0][0].shape[1] == 25
示例#21
0
def test_method_not_available(base_concat_ds):
    preprocessors = [Preprocessor('this_method_is_not_real', )]
    with pytest.raises(AttributeError):
        preprocess(base_concat_ds, preprocessors)
示例#22
0
                         crop_wake_mins=30)

######################################################################
# Preprocessing
# ~~~~~~~~~~~~~
#
# Next, we preprocess the raw data. We convert the data to microvolts and apply
# a lowpass filter. Since the Sleep Physionet data is already sampled at 100 Hz
# we don't need to apply resampling.

from braindecode.preprocessing.preprocess import preprocess, Preprocessor, scale

high_cut_hz = 30

preprocessors = [
    Preprocessor(scale, factor=1e6, apply_on_array=True),
    Preprocessor('filter', l_freq=None, h_freq=high_cut_hz, n_jobs=n_jobs)
]

# Transform the data
preprocess(dataset, preprocessors)

######################################################################
# Extracting windows
# ~~~~~~~~~~~~~~~~~~
#
# We extract 30-s windows to be used in both the pretext and downstream tasks.
# As RP (and SSL in general) don't require labelled data, the pretext task
# could be performed using unlabelled windows extracted with
# :func:`braindecode.datautil.windower.create_fixed_length_window`.
# Here however, purely for convenience, we directly extract labelled windows so
示例#23
0
# ds has a pandas DataFrame with additional description of its internal datasets
dataset.description

##############################################################################
# We can iterate through ds which yields one time point of a continuous signal x,
# and a target y (which can be None if targets are not defined for the entire
# continuous signal).
for x, y in dataset:
    print(x.shape, y)
    break

##############################################################################
# We can apply preprocessing transforms that are defined in mne and work
# in-place, such as resampling, bandpass filtering, or electrode selection.
preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=True),
    Preprocessor('resample', sfreq=100)
]
print(dataset.datasets[0].raw.info["sfreq"])
preprocess(dataset, preprocessors)
print(dataset.datasets[0].raw.info["sfreq"])

###############################################################################
# We can easily split ds based on a criteria applied to the description
# DataFrame:
subsets = dataset.split("session")
print({subset_name: len(subset) for subset_name, subset in subsets.items()})

###############################################################################
# Next, we use a windower to extract events from the dataset based on events:
windows_dataset = create_windows_from_events(dataset,
示例#24
0
def test_set_preproc_kwargs_wrong_type(base_concat_ds):
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]
    with pytest.raises(TypeError):
        _set_preproc_kwargs(base_concat_ds, preprocessors)
示例#25
0
#

from braindecode.preprocessing.preprocess import (
    exponential_moving_standardize, preprocess, Preprocessor)

low_cut_hz = 8.  # low cut frequency for filtering
high_cut_hz = 48.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
channels_list = dataset.datasets[0].raw.info['ch_names'][:62]
# channels_list.remove('T8')
# channels_list.remove('T9')

preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor('pick_channels', ch_names=channels_list),  # select 62 of 64 channels
    Preprocessor(lambda x: x * 1e6),  # Convert from V to uV
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(exponential_moving_standardize,  # Exponential moving standardization
                 factor_new=factor_new, init_block_size=init_block_size)
]

# Transform the data
preprocess(dataset, preprocessors)


######################################################################
# Cut Compute Windows
# ~~~~~~~~~~~~~~~~~~~
#
from braindecode.datasets.moabb import MOABBDataset

subject_id = 3
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])

from braindecode.preprocessing.preprocess import (
    exponential_moving_standardize, preprocess, Preprocessor)

low_cut_hz = 4.  # low cut frequency for filtering
high_cut_hz = 38.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(lambda x: x * 1e6),  # Convert from V to uV
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(exponential_moving_standardize,  # Exponential moving standardization
                 factor_new=factor_new, init_block_size=init_block_size)
]

# Transform the data
preprocess(dataset, preprocessors)


######################################################################
# Create model and compute windowing parameters
# ---------------------------------------------
#
# ~~~~~~~~~~~~~
#


######################################################################
# Next, we preprocess the raw data. We convert the data to microvolts and apply
# a lowpass filter. We omit the downsampling step of [1]_ as the Sleep
# Physionet data is already sampled at a lower 100 Hz.
#

from braindecode.preprocessing.preprocess import preprocess, Preprocessor

high_cut_hz = 30

preprocessors = [
    Preprocessor(lambda x: x * 1e6),
    Preprocessor('filter', l_freq=None, h_freq=high_cut_hz)
]

# Transform the data
preprocess(dataset, preprocessors)


######################################################################
# Extract windows
# ~~~~~~~~~~~~~~~
#


######################################################################
# We extract 30-s windows to be used in the classification task.