Пример #1
0
def test_preprocess_raw_str(base_concat_ds):
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]
    preprocess(base_concat_ds, preprocessors)
    assert len(base_concat_ds.datasets[0].raw.times) == 2500
    assert all([ds.raw_preproc_kwargs == [
        ('crop', {'tmax': 10, 'include_tmax': False}),
    ] for ds in base_concat_ds.datasets])
Пример #2
0
def test_windows_fixed_length_cropped(lazy_loadable_dataset):
    """Test fixed length windowing on cropped data.

    Cropping raw data changes the `first_samp` attribute of the Raw object, and
    so it is important to test this is taken into account by the windowers.
    """
    tmin, tmax = 100, 120

    ds = copy.deepcopy(lazy_loadable_dataset)
    ds.datasets[0].raw.annotations.crop(tmin, tmax)

    crop_ds = copy.deepcopy(lazy_loadable_dataset)
    crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax)
    preprocess(crop_ds, [crop_transform])

    # Extract windows
    sfreq = ds.datasets[0].raw.info['sfreq']
    tmin_samples, tmax_samples = int(tmin * sfreq), int(tmax * sfreq)

    windows1 = create_fixed_length_windows(concat_ds=ds,
                                           start_offset_samples=tmin_samples,
                                           stop_offset_samples=tmax_samples,
                                           window_size_samples=100,
                                           window_stride_samples=100,
                                           drop_last_window=True)
    windows2 = create_fixed_length_windows(concat_ds=crop_ds,
                                           start_offset_samples=0,
                                           stop_offset_samples=None,
                                           window_size_samples=100,
                                           window_stride_samples=100,
                                           drop_last_window=True)
    assert (windows1[0][0] == windows2[0][0]).all()
Пример #3
0
def test_filterbank(base_concat_ds):
    base_concat_ds = base_concat_ds.split([[0]])['0']
    preprocessors = [
        Preprocessor('pick_channels', ch_names=sorted(['C4', 'Cz']),
                     ordered=True),
        Preprocessor(filterbank, frequency_bands=[(0, 4), (4, 8), (8, 13)],
                     drop_original_signals=False, apply_on_array=False)
    ]
    preprocess(base_concat_ds, preprocessors)
    for x, y in base_concat_ds:
        break
    assert x.shape[0] == 8
    freq_band_annots = [
        ch.split('_')[-1]
        for ch in base_concat_ds.datasets[0].raw.ch_names
        if '_' in ch]
    assert len(np.unique(freq_band_annots)) == 3
    np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [
        'C4', 'C4_0-4', 'C4_4-8', 'C4_8-13',
        'Cz', 'Cz_0-4', 'Cz_4-8', 'Cz_8-13',
    ])
    assert all([ds.raw_preproc_kwargs == [
        ('pick_channels', {'ch_names': ['C4', 'Cz'], 'ordered': True}),
        ('filterbank', {'frequency_bands': [(0, 4), (4, 8), (8, 13)],
                        'drop_original_signals': False}),
    ] for ds in base_concat_ds.datasets])
Пример #4
0
def test_preprocessors_with_misc_channels():
    rng = np.random.RandomState(42)
    signal_sfreq = 50
    info = mne.create_info(ch_names=['0', '1', 'target_0', 'target_1'],
                           sfreq=signal_sfreq,
                           ch_types=['eeg', 'eeg', 'misc', 'misc'])
    signal = rng.randn(2, 1000)
    targets = rng.randn(2, 1000)
    raw = mne.io.RawArray(np.concatenate([signal, targets]), info=info)
    desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48})
    base_dataset = BaseDataset(raw, desc, target_name=None)
    concat_ds = BaseConcatDataset([base_dataset])
    preprocessors = [
        Preprocessor('pick_types', eeg=True, misc=True),
        Preprocessor(lambda x: x / 1e6),
    ]

    preprocess(concat_ds, preprocessors)

    # Check whether preprocessing has not affected the targets
    # This is only valid for preprocessors that use mne functions which do not modify
    # `misc` channels.
    np.testing.assert_array_equal(
        concat_ds.datasets[0].raw.get_data()[-2:, :],
        targets
    )
Пример #5
0
def test_no_raw_or_epochs():
    class EmptyDataset(object):
        def __init__(self):
            self.datasets = [1, 2, 3]

    ds = EmptyDataset()
    with pytest.raises(AssertionError):
        preprocess(ds, ["dummy", "dummy"])
Пример #6
0
def test_preprocess_windows_str(windows_concat_ds):
    preprocessors = [
        Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)]
    preprocess(windows_concat_ds, preprocessors)
    assert windows_concat_ds[0][0].shape[1] == 25
    assert all([ds.window_preproc_kwargs == [
        ('crop', {'tmin': 0, 'tmax': 0.1, 'include_tmax': False}),
    ] for ds in windows_concat_ds.datasets])
Пример #7
0
def test_preprocess_windows_callable_on_object(windows_concat_ds):
    factor = 10
    preprocessors = [Preprocessor(modify_windows_object, apply_on_array=False,
                                  factor=factor)]
    raw_window = windows_concat_ds[0][0]
    preprocess(windows_concat_ds, preprocessors)
    np.testing.assert_allclose(windows_concat_ds[0][0], raw_window * factor,
                               rtol=1e-4, atol=1e-4)
Пример #8
0
def test_scale_continuous(base_concat_ds):
    factor = 1e6
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(scale, factor=factor)
    ]
    raw_timepoint = base_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(base_concat_ds, preprocessors)
    np.testing.assert_allclose(base_concat_ds[0][0],
                               raw_timepoint * factor,
                               rtol=1e-4,
                               atol=1e-4)
Пример #9
0
def test_scale_windows(windows_concat_ds):
    factor = 1e6
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(scale, factor=factor)
    ]
    raw_window = windows_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(windows_concat_ds, preprocessors)
    np.testing.assert_allclose(windows_concat_ds[0][0],
                               raw_window * factor,
                               rtol=1e-4,
                               atol=1e-4)
Пример #10
0
def test_scale_windows(windows_concat_ds):
    factor = 1e6
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(deprecated_scale, factor=factor)
    ]
    raw_window = windows_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(windows_concat_ds, preprocessors)
    np.testing.assert_allclose(windows_concat_ds[0][0], raw_window * factor,
                               rtol=1e-4, atol=1e-4)
    assert all([ds.window_preproc_kwargs == [
        ('pick_types', {'eeg': True, 'meg': False, 'stim': False}),
        ('scale', {'factor': 1e6}),
    ] for ds in windows_concat_ds.datasets])
Пример #11
0
def test_deprecated_preprocs(base_concat_ds):
    msg1 = 'Class MNEPreproc is deprecated; will be removed in 0.7.0. Use ' \
           'Preprocessor with `apply_on_array=False` instead.'
    msg2 = 'NumpyPreproc is deprecated; will be removed in 0.7.0. Use ' \
           'Preprocessor with `apply_on_array=True` instead.'
    with pytest.warns(FutureWarning, match=msg1):
        mne_preproc = MNEPreproc('pick_types', eeg=True, meg=False, stim=False)
    factor = 1e6
    with pytest.warns(FutureWarning, match=msg2):
        np_preproc = NumpyPreproc(deprecated_scale, factor=factor)

    raw_timepoint = base_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(base_concat_ds, [mne_preproc, np_preproc])
    np.testing.assert_allclose(base_concat_ds[0][0], raw_timepoint * factor,
                               rtol=1e-4, atol=1e-4)
Пример #12
0
def test_preprocess_save_dir(base_concat_ds, windows_concat_ds, tmp_path,
                             kind, save, overwrite, n_jobs):
    preproc_kwargs = [
        ('crop', {'tmin': 0, 'tmax': 0.1, 'include_tmax': False})]
    preprocessors = [
        Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)]

    save_dir = str(tmp_path) if save else None
    if kind == 'raw':
        concat_ds = base_concat_ds
        preproc_kwargs_name = 'raw_preproc_kwargs'
    elif kind == 'windows':
        concat_ds = windows_concat_ds
        preproc_kwargs_name = 'window_preproc_kwargs'

    concat_ds = preprocess(
        concat_ds, preprocessors, save_dir, overwrite=overwrite, n_jobs=n_jobs)

    assert all([hasattr(ds, preproc_kwargs_name) for ds in concat_ds.datasets])
    assert all([getattr(ds, preproc_kwargs_name) == preproc_kwargs
                for ds in concat_ds.datasets])
    assert all([len(getattr(ds, kind).times) == 25
                for ds in concat_ds.datasets])
    if kind == 'raw':
        assert all([hasattr(ds, 'target_name') for ds in concat_ds.datasets])

    if save_dir is None:
        assert all([getattr(ds, kind).preload
                    for ds in concat_ds.datasets])
    else:
        assert all([not getattr(ds, kind).preload
                    for ds in concat_ds.datasets])
        save_dirs = [os.path.join(save_dir, str(i))
                     for i in range(len(concat_ds.datasets))]
        assert set(glob(save_dir + '/*')) == set(save_dirs)
Пример #13
0
def test_windows_from_events_cropped(lazy_loadable_dataset):
    """Test windowing from events on cropped data.

    Cropping raw data changes the `first_samp` attribute of the Raw object, and
    so it is important to test this is taken into account by the windowers.
    """
    tmin, tmax = 100, 120

    ds = copy.deepcopy(lazy_loadable_dataset)
    ds.datasets[0].raw.annotations.crop(tmin, tmax)

    crop_ds = copy.deepcopy(lazy_loadable_dataset)
    crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax)
    preprocess(crop_ds, [crop_transform])

    # Extract windows
    windows1 = create_windows_from_events(concat_ds=ds,
                                          trial_start_offset_samples=0,
                                          trial_stop_offset_samples=0,
                                          window_size_samples=100,
                                          window_stride_samples=100,
                                          drop_last_window=False)
    windows2 = create_windows_from_events(concat_ds=crop_ds,
                                          trial_start_offset_samples=0,
                                          trial_stop_offset_samples=0,
                                          window_size_samples=100,
                                          window_stride_samples=100,
                                          drop_last_window=False)
    assert (windows1[0][0] == windows2[0][0]).all()

    # Make sure events that fall outside of recording will trigger an error
    with pytest.raises(ValueError,
                       match='"trial_stop_offset_samples" too large'):
        create_windows_from_events(concat_ds=ds,
                                   trial_start_offset_samples=0,
                                   trial_stop_offset_samples=10000,
                                   window_size_samples=100,
                                   window_stride_samples=100,
                                   drop_last_window=False)
    with pytest.raises(ValueError,
                       match='"trial_stop_offset_samples" too large'):
        create_windows_from_events(concat_ds=crop_ds,
                                   trial_start_offset_samples=0,
                                   trial_stop_offset_samples=2001,
                                   window_size_samples=100,
                                   window_stride_samples=100,
                                   drop_last_window=False)
Пример #14
0
def test_filterbank_order_channels_by_freq(base_concat_ds):
    base_concat_ds = base_concat_ds.split([[0]])['0']
    preprocessors = [
        Preprocessor('pick_channels',
                     ch_names=sorted(['C4', 'Cz']),
                     ordered=True),
        Preprocessor(filterbank,
                     frequency_bands=[(0, 4), (4, 8), (8, 13)],
                     drop_original_signals=False,
                     order_by_frequency_band=True,
                     apply_on_array=False)
    ]
    preprocess(base_concat_ds, preprocessors)
    np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [
        'C4', 'Cz', 'C4_0-4', 'Cz_0-4', 'C4_4-8', 'Cz_4-8', 'C4_8-13',
        'Cz_8-13'
    ])
Пример #15
0
def test_zscore_windows(windows_concat_ds):
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(zscore)
    ]
    preprocess(windows_concat_ds, preprocessors)
    for ds in windows_concat_ds.datasets:
        windowed_data = ds.windows.get_data()
        shape = windowed_data.shape
        # zero mean
        expected = np.zeros(shape[:-1])
        np.testing.assert_allclose(
            windowed_data.mean(axis=-1), expected, rtol=1e-4, atol=1e-4)
        # unit variance
        expected = np.ones(shape[:-1])
        np.testing.assert_allclose(
            windowed_data.std(axis=-1), expected, rtol=1e-4, atol=1e-4)
    assert all([ds.window_preproc_kwargs == [
        ('pick_types', {'eeg': True, 'meg': False, 'stim': False}),
        ('zscore', {}),
    ] for ds in windows_concat_ds.datasets])
Пример #16
0
def test_zscore_continuous(base_concat_ds):
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(zscore, channel_wise=True)
    ]
    preprocess(base_concat_ds, preprocessors)
    for ds in base_concat_ds.datasets:
        raw_data = ds.raw.get_data()
        shape = raw_data.shape
        # zero mean
        expected = np.zeros(shape[:-1])
        np.testing.assert_allclose(raw_data.mean(axis=-1),
                                   expected,
                                   rtol=1e-4,
                                   atol=1e-4)
        # unit variance
        expected = np.ones(shape[:-1])
        np.testing.assert_allclose(raw_data.std(axis=-1),
                                   expected,
                                   rtol=1e-4,
                                   atol=1e-4)
Пример #17
0
def test_preprocess_overwrite(base_concat_ds, tmp_path, overwrite):
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]

    # Create temporary directory with preexisting files
    save_dir = str(tmp_path)
    for i, ds in enumerate(base_concat_ds.datasets):
        concat_ds = BaseConcatDataset([ds])
        save_subdir = os.path.join(save_dir, str(i))
        os.makedirs(save_subdir)
        concat_ds.save(save_subdir, overwrite=True)

    if overwrite:
        preprocess(base_concat_ds, preprocessors, save_dir, overwrite=True)
        # Make sure the serialized data is preprocessed
        preproc_concat_ds = load_concat_dataset(save_dir, True)
        assert all([len(ds.raw.times) == 2500
                    for ds in preproc_concat_ds.datasets])
    else:
        with pytest.raises(FileExistsError):
            preprocess(base_concat_ds, preprocessors, save_dir,
                       overwrite=False)
Пример #18
0
# Afterwards, we split the continuous signals into compute windows. We store
# each recording to a unique subdirectory that is named corresponding to the
# rec id. To save memory, after windowing and storing, we delete the raw
# dataset and the windows dataset, respectively.
window_size_samples = 1000
window_stride_samples = 1000
create_compute_windows = True

out_i = 0
errors = []
OUT_PATH = './tuh_sample/'
tuh_splits = tuh.split([[i] for i in range(len(tuh.datasets))])
for rec_i, tuh_subset in tuh_splits.items():
    # implement preprocess for BaseDatasets? Would remove necessity
    # to split above
    preprocess(tuh_subset, preprocessors)

    # update description of the recording(s)
    tuh_subset.description.sfreq = len(tuh_subset.datasets) * [sfreq]
    tuh_subset.description.reference = len(tuh_subset.datasets) * ['ar']
    tuh_subset.description.n_samples = [len(d) for d in tuh_subset.datasets]

    if create_compute_windows:
        # generate compute windows here and store them to disk
        tuh_windows = create_fixed_length_windows(
            tuh_subset,
            start_offset_samples=0,
            stop_offset_samples=None,
            window_size_samples=window_size_samples,
            window_stride_samples=window_stride_samples,
            drop_last_window=False)
Пример #19
0
def test_preprocess_raw_str(base_concat_ds):
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]
    preprocess(base_concat_ds, preprocessors)
    assert len(base_concat_ds.datasets[0].raw.times) == 2500
Пример #20
0
def test_preprocess_windows_str(windows_concat_ds):
    preprocessors = [
        Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)
    ]
    preprocess(windows_concat_ds, preprocessors)
    assert windows_concat_ds[0][0].shape[1] == 25
Пример #21
0
def test_method_not_available(base_concat_ds):
    preprocessors = [Preprocessor('this_method_is_not_real', )]
    with pytest.raises(AttributeError):
        preprocess(base_concat_ds, preprocessors)
Пример #22
0
#
# Next, we preprocess the raw data. We convert the data to microvolts and apply
# a lowpass filter. Since the Sleep Physionet data is already sampled at 100 Hz
# we don't need to apply resampling.

from braindecode.preprocessing.preprocess import preprocess, Preprocessor, scale

high_cut_hz = 30

preprocessors = [
    Preprocessor(scale, factor=1e6, apply_on_array=True),
    Preprocessor('filter', l_freq=None, h_freq=high_cut_hz, n_jobs=n_jobs)
]

# Transform the data
preprocess(dataset, preprocessors)

######################################################################
# Extracting windows
# ~~~~~~~~~~~~~~~~~~
#
# We extract 30-s windows to be used in both the pretext and downstream tasks.
# As RP (and SSL in general) don't require labelled data, the pretext task
# could be performed using unlabelled windows extracted with
# :func:`braindecode.datautil.windower.create_fixed_length_window`.
# Here however, purely for convenience, we directly extract labelled windows so
# that we can reuse them in the sleep staging downstream task later.

from braindecode.preprocessing.windowers import create_windows_from_events

window_size_s = 30
Пример #23
0
# We can iterate through ds which yields one time point of a continuous signal x,
# and a target y (which can be None if targets are not defined for the entire
# continuous signal).
for x, y in ds:
    print(x.shape, y)
    break

##############################################################################
# We can apply preprocessing transforms that are defined in mne and work
# in-place, such as resampling, bandpass filtering, or electrode selection.
preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=True),
    Preprocessor('resample', sfreq=100)
]
print(ds.datasets[0].raw.info["sfreq"])
preprocess(ds, preprocessors)
print(ds.datasets[0].raw.info["sfreq"])

###############################################################################
# We can easily split ds based on a criteria applied to the description
# DataFrame:
subsets = ds.split("session")
print({subset_name: len(subset) for subset_name, subset in subsets.items()})

###############################################################################
# Next, we use a windower to extract events from the dataset based on events:
windows_ds = create_windows_from_events(
    ds, trial_start_offset_samples=0, trial_stop_offset_samples=100,
    window_size_samples=400, window_stride_samples=100,
    drop_last_window=False)
Пример #24
0
def test_not_list():
    with pytest.raises(AssertionError):
        preprocess(None, {'test': 1})
Пример #25
0
# We can iterate through ds which yields one time point of a continuous signal x,
# and a target y (which can be None if targets are not defined for the entire
# continuous signal).
for x, y in dataset:
    print(x.shape, y)
    break

##############################################################################
# We can apply preprocessing transforms that are defined in mne and work
# in-place, such as resampling, bandpass filtering, or electrode selection.
preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=True),
    Preprocessor('resample', sfreq=100)
]
print(dataset.datasets[0].raw.info["sfreq"])
preprocess(dataset, preprocessors)
print(dataset.datasets[0].raw.info["sfreq"])

###############################################################################
# We can easily split ds based on a criteria applied to the description
# DataFrame:
subsets = dataset.split("session")
print({subset_name: len(subset) for subset_name, subset in subsets.items()})

###############################################################################
# Next, we use a windower to extract events from the dataset based on events:
windows_dataset = create_windows_from_events(dataset,
                                             trial_start_offset_samples=0,
                                             trial_stop_offset_samples=100,
                                             window_size_samples=400,
                                             window_stride_samples=100,
low_cut_hz = 4.  # low cut frequency for filtering
high_cut_hz = 38.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(lambda x: x * 1e6),  # Convert from V to uV
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(exponential_moving_standardize,  # Exponential moving standardization
                 factor_new=factor_new, init_block_size=init_block_size)
]

# Transform the data
preprocess(dataset, preprocessors)


######################################################################
# Create model and compute windowing parameters
# ---------------------------------------------
#


######################################################################
# In contrast to trialwise decoding, we first have to create the model
# before we can cut the dataset into windows. This is because we need to
# know the receptive field of the network to know how large the window
# stride should be.
#
# Next, we preprocess the raw data. We convert the data to microvolts and apply
# a lowpass filter. We omit the downsampling step of [1]_ as the Sleep
# Physionet data is already sampled at a lower 100 Hz.
#

from braindecode.preprocessing.preprocess import preprocess, Preprocessor

high_cut_hz = 30

preprocessors = [
    Preprocessor(lambda x: x * 1e6),
    Preprocessor('filter', l_freq=None, h_freq=high_cut_hz)
]

# Transform the data
preprocess(dataset, preprocessors)


######################################################################
# Extract windows
# ~~~~~~~~~~~~~~~
#


######################################################################
# We extract 30-s windows to be used in the classification task.

from braindecode.preprocessing import create_windows_from_events


mapping = {  # We merge stages 3 and 4 following AASM standards.
Пример #28
0
from braindecode.datasets.moabb import MOABBDataset
from braindecode.preprocessing.preprocess import preprocess, MNEPreproc
from braindecode.datautil.serialization import load_concat_dataset
from braindecode.preprocessing.windowers import create_windows_from_events

###############################################################################
# First, we load some dataset using MOABB.
ds = MOABBDataset(
    dataset_name='BNCI2014001',
    subject_ids=[1],
)

###############################################################################
# We can apply preprocessing steps to the dataset. It is also possible to skip
# this step and not apply any preprocessing.
preprocess(concat_ds=ds, preprocessors=[MNEPreproc(fn='resample', sfreq=10)])

###############################################################################
# We save the dataset to a an existing directory. It will create a '.fif' file
# for every dataset in the concat dataset. Additionally it will create two
# JSON files, the first holding the description of the dataset, the second
# holding the name of the target. If you want to store to the same directory
# several times, for example due to trying different preprocessing, you can
# choose to overwrite the existing files.
ds.save(
    path='./',
    overwrite=False,
)

##############################################################################
# We load the saved dataset from a directory. Signals can be preloaded in