Exemple #1
0
def test_filterbank(base_concat_ds):
    base_concat_ds = base_concat_ds.split([[0]])['0']
    preprocessors = [
        Preprocessor('pick_channels',
                     ch_names=sorted(['C4', 'Cz']),
                     ordered=True),
        Preprocessor(filterbank,
                     frequency_bands=[(0, 4), (4, 8), (8, 13)],
                     drop_original_signals=False,
                     apply_on_array=False)
    ]
    preprocess(base_concat_ds, preprocessors)
    for x, y in base_concat_ds:
        break
    assert x.shape[0] == 8
    freq_band_annots = [
        ch.split('_')[-1] for ch in base_concat_ds.datasets[0].raw.ch_names
        if '_' in ch
    ]
    assert len(np.unique(freq_band_annots)) == 3
    np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [
        'C4',
        'C4_0-4',
        'C4_4-8',
        'C4_8-13',
        'Cz',
        'Cz_0-4',
        'Cz_4-8',
        'Cz_8-13',
    ])
Exemple #2
0
def bcic4_2a(subject, low_hz=None, high_hz=None, paradigm=None, phase=False):
    X = []
    y = []

    if isinstance(subject, int):
        subject = [subject]

    for subject_id in subject:
        # Load data
        print_off()
        dataset = MOABBDataset(dataset_name="BNCI2014001",
                               subject_ids=[subject_id])

        # Preprocess
        factor_new = 1e-3
        init_block_size = 1000
        preprocessors = [
            # keep only EEG sensors
            MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False),
            # convert from volt to microvolt
            NumpyPreproc(fn=lambda x: x * 1e+06),
            # bandpass filter
            MNEPreproc(fn='filter', l_freq=low_hz, h_freq=high_hz),
            # exponential moving standardization
            NumpyPreproc(fn=exponential_moving_standardize,
                         factor_new=factor_new,
                         init_block_size=init_block_size)
        ]
        preprocess(dataset, preprocessors)

        # Divide data by trial
        # - Check sampling frequency
        sfreq = dataset.datasets[0].raw.info['sfreq']
        assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])

        trial_start_offset_seconds = -0.5
        trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

        windows_dataset = create_windows_from_events(
            dataset,
            trial_start_offset_samples=trial_start_offset_samples,
            trial_stop_offset_samples=0,
            preload=True
            # verbose=True
        )

        # If you need split data, try this.
        if paradigm == 'session':
            if phase == "train":
                windows_dataset = windows_dataset.split('session')['session_T']
            else:
                windows_dataset = windows_dataset.split('session')['session_E']

        # Merge subject
        for trial in windows_dataset:
            X.append(trial[0])
            y.append(trial[1])

    print_on()
    return np.array(X), np.array(y)
Exemple #3
0
def test_filterbank(base_concat_ds):
    base_concat_ds = base_concat_ds.split([[0]])["0"]
    preprocessors = [
        MNEPreproc('pick_channels',
                   ch_names=sorted(["C4", "Cz"]),
                   ordered=True),
        MNEPreproc(filterbank,
                   frequency_bands=[(0, 4), (4, 8), (8, 13)],
                   drop_original_signals=False),
    ]
    preprocess(base_concat_ds, preprocessors)
    for x, y in base_concat_ds:
        break
    assert x.shape[0] == 8
    freq_band_annots = [
        ch.split("_")[-1] for ch in base_concat_ds.datasets[0].raw.ch_names
        if '_' in ch
    ]
    assert len(np.unique(freq_band_annots)) == 3
    np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [
        "C4",
        "C4_0-4",
        "C4_4-8",
        "C4_8-13",
        "Cz",
        "Cz_0-4",
        "Cz_4-8",
        "Cz_8-13",
    ])
Exemple #4
0
def test_windows_fixed_length_cropped(lazy_loadable_dataset):
    """Test fixed length windowing on cropped data.

    Cropping raw data changes the `first_samp` attribute of the Raw object, and
    so it is important to test this is taken into account by the windowers.
    """
    tmin, tmax = 100, 120

    ds = copy.deepcopy(lazy_loadable_dataset)
    ds.datasets[0].raw.annotations.crop(tmin, tmax)

    crop_ds = copy.deepcopy(lazy_loadable_dataset)
    crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax)
    preprocess(crop_ds, [crop_transform])

    # Extract windows
    sfreq = ds.datasets[0].raw.info['sfreq']
    tmin_samples, tmax_samples = int(tmin * sfreq), int(tmax * sfreq)

    windows1 = create_fixed_length_windows(
        concat_ds=ds, start_offset_samples=tmin_samples,
        stop_offset_samples=tmax_samples, window_size_samples=100,
        window_stride_samples=100, drop_last_window=True)
    windows2 = create_fixed_length_windows(
        concat_ds=crop_ds, start_offset_samples=0,
        stop_offset_samples=None, window_size_samples=100,
        window_stride_samples=100, drop_last_window=True)
    assert (windows1[0][0] == windows2[0][0]).all()
Exemple #5
0
def test_no_raw_or_epochs():
    class EmptyDataset(object):
        def __init__(self):
            self.datasets = [1, 2, 3]

    ds = EmptyDataset()
    with pytest.raises(AssertionError):
        preprocess(ds, ["dummy", "dummy"])
def bandpass_window_BaseConcat(dataset,
                               bandpass_range=(4, 38),
                               window_start_offset=1.0,
                               window_end_offset=-0.5):
    '''
    For bandpass filtering and windowing to return in BaseConcatDataset form.
    :param dataset:
    :param bandpass_range:
    :param window_start_offset:
    :param window_end_offset:
    :return:
    '''
    low_cut_hz = bandpass_range[0]
    high_cut_hz = bandpass_range[1]

    factor_new = 1e-3
    init_block_size = 1000

    preprocessors = [
        # keep only EEG sensors
        MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False),
        # convert from volt to microvolt, directly modifying the numpy array
        NumpyPreproc(fn=lambda x: x * 1e6),
        # bandpass filter
        MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
        # exponential moving standardization
        NumpyPreproc(fn=exponential_moving_standardize,
                     factor_new=factor_new,
                     init_block_size=init_block_size)
    ]

    # Transform the data
    ds_copy = copy.deepcopy(dataset)
    preprocess(ds_copy, preprocessors)
    trial_start_offset_seconds = window_start_offset
    trial_stop_offset_seconds = window_end_offset

    # Extract sampling frequency, check that they are same in all datasets
    sfreq = ds_copy.datasets[0].raw.info['sfreq']
    assert all(
        [ds_subj.raw.info['sfreq'] == sfreq for ds_subj in ds_copy.datasets])
    # Calculate the trial start offset in samples.
    trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)
    trial_stop_offset_samples = int(trial_stop_offset_seconds * sfreq)

    # Create windows using braindecode function for this. It needs parameters to define how
    # trials should be used.

    windows_dataset = create_windows_from_events(
        ds_copy,
        trial_start_offset_samples=trial_start_offset_samples,
        trial_stop_offset_samples=trial_stop_offset_samples,
        preload=True,
    )

    return windows_dataset
def test_scale_windows(windows_concat_ds):
    factor = 1e6
    preprocessors = [
        MNEPreproc('pick_types', eeg=True, meg=False, stim=False),
        MNEPreproc(scale, factor=factor)
    ]
    raw_window = windows_concat_ds[0][0]
    preprocess(windows_concat_ds, preprocessors)
    expected = np.ones_like(raw_window) * factor
    np.testing.assert_allclose(windows_concat_ds[0][0] / raw_window, expected,
                               rtol=1e-4, atol=1e-4)
Exemple #8
0
def load_train_test_hgd(subject_id):
    hgd_names = [
        'Fp2', 'Fp1', 'F4', 'F3', 'C4', 'C3', 'P4', 'P3', 'O2', 'O1', 'F8',
        'F7', 'T8', 'T7', 'P8', 'P7', 'M2', 'M1', 'Fz', 'Cz', 'Pz'
    ]
    log.info("Loading dataset..")
    # using the moabb dataset to load our data
    dataset = MOABBDataset(dataset_name="Schirrmeister2017",
                           subject_ids=[subject_id])
    sfreq = 32
    train_whole_set = dataset.split('run')['train']

    log.info("Preprocessing dataset..")
    # Define preprocessing steps
    preprocessors = [
        # convert from volt to microvolt, directly modifying the numpy array
        MNEPreproc(
            fn='set_eeg_reference',
            ref_channels='average',
        ),
        MNEPreproc(fn='pick_channels', ch_names=hgd_names, ordered=True),
        NumpyPreproc(fn=lambda x: x * 1e6),
        NumpyPreproc(fn=lambda x: np.clip(x, -800, 800)),
        NumpyPreproc(fn=lambda x: x / 10),
        MNEPreproc(fn='resample', sfreq=sfreq),
        NumpyPreproc(fn=lambda x: np.clip(x, -80, 80)),
        NumpyPreproc(fn=lambda x: x / 3),
        NumpyPreproc(fn=exponential_moving_demean,
                     init_block_size=int(sfreq * 10),
                     factor_new=1 / (sfreq * 5)),
        # keep only EEG sensors
        # NumpyPreproc(fn=exponential_moving_demean, init_block_size=sfreq*10, factor_new=1/(sfreq*5)),
    ]

    # Preprocess the data
    preprocess(train_whole_set, preprocessors)
    # Next, extract the 4-second trials from the dataset.
    # Create windows using braindecode function for this. It needs parameters to define how
    # trials should be used.
    class_names = ['Right Hand', 'Rest']  # for later plotting
    class_mapping = {'right_hand': 0, 'rest': 1}

    windows_dataset = create_windows_from_events(
        train_whole_set,
        trial_start_offset_samples=0,
        trial_stop_offset_samples=0,
        preload=True,
        mapping=class_mapping,
    )
    from torch.utils.data import Subset
    n_split = int(np.round(0.75 * len(windows_dataset)))
    valid_set = Subset(windows_dataset, range(n_split, len(windows_dataset)))
    train_set = Subset(windows_dataset, range(0, n_split))
    return train_set, valid_set
def test_scale_continuous(base_concat_ds):
    factor = 1e6
    preprocessors = [
        MNEPreproc('pick_types', eeg=True, meg=False, stim=False),
        NumpyPreproc(scale, factor=factor)
    ]
    raw_timepoint = base_concat_ds[0][0]
    preprocess(base_concat_ds, preprocessors)
    expected = np.ones_like(raw_timepoint) * factor
    np.testing.assert_allclose(base_concat_ds[0][0] / raw_timepoint, expected,
                               rtol=1e-4, atol=1e-4)
Exemple #10
0
def test_scale_continuous(base_concat_ds):
    factor = 1e6
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(scale, factor=factor)
    ]
    raw_timepoint = base_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(base_concat_ds, preprocessors)
    np.testing.assert_allclose(base_concat_ds[0][0],
                               raw_timepoint * factor,
                               rtol=1e-4,
                               atol=1e-4)
Exemple #11
0
def test_scale_windows(windows_concat_ds):
    factor = 1e6
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(scale, factor=factor)
    ]
    raw_window = windows_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(windows_concat_ds, preprocessors)
    np.testing.assert_allclose(windows_concat_ds[0][0],
                               raw_window * factor,
                               rtol=1e-4,
                               atol=1e-4)
Exemple #12
0
def test_preprocess_windows_callable_on_object(windows_concat_ds):
    factor = 10
    preprocessors = [
        Preprocessor(modify_windows_object,
                     apply_on_array=False,
                     factor=factor)
    ]
    raw_window = windows_concat_ds[0][0]
    preprocess(windows_concat_ds, preprocessors)
    np.testing.assert_allclose(windows_concat_ds[0][0],
                               raw_window * factor,
                               rtol=1e-4,
                               atol=1e-4)
Exemple #13
0
def load_train_valid_tuh(n_subjects, n_seconds, ids_to_load):
    path = '/home/schirrmr/data/preproced-tuh/all-sensors-32-hz/'
    log.info("Load concat dataset...")
    dataset = load_concat_dataset(path, preload=False, ids_to_load=ids_to_load)
    whole_train_set = dataset.split('session')['train']
    n_max_minutes = int(np.ceil(n_seconds / 60) + 2)
    sfreq = whole_train_set.datasets[0].raw.info['sfreq']
    log.info("Preprocess concat dataset...")
    preprocess(whole_train_set, [
        MNEPreproc('crop', tmin=0, tmax=n_max_minutes * 60, include_tmax=True),
        NumpyPreproc(fn=lambda x: np.clip(x, -80, 80)),
        NumpyPreproc(fn=lambda x: x / 3),
        NumpyPreproc(fn=exponential_moving_demean,
                     init_block_size=int(sfreq * 10),
                     factor_new=1 / (sfreq * 5)),
    ])
    subject_datasets = whole_train_set.split('subject')

    n_split = int(np.round(n_subjects * 0.75))
    keys = list(subject_datasets.keys())
    train_sets = [
        d for i in range(n_split) for d in subject_datasets[keys[i]].datasets
    ]
    train_set = BaseConcatDataset(train_sets)
    valid_sets = [
        d for i in range(n_split, n_subjects)
        for d in subject_datasets[keys[i]].datasets
    ]
    valid_set = BaseConcatDataset(valid_sets)

    train_set = create_fixed_length_windows(
        train_set,
        start_offset_samples=60 * 32,
        stop_offset_samples=60 * 32 + 32 * n_seconds,
        preload=True,
        window_size_samples=128,
        window_stride_samples=64,
        drop_last_window=True,
    )

    valid_set = create_fixed_length_windows(
        valid_set,
        start_offset_samples=60 * 32,
        stop_offset_samples=60 * 32 + 32 * n_seconds,
        preload=True,
        window_size_samples=128,
        window_stride_samples=64,
        drop_last_window=True,
    )
    return train_set, valid_set
Exemple #14
0
def test_filterbank_order_channels_by_freq(base_concat_ds):
    base_concat_ds = base_concat_ds.split([[0]])["0"]
    preprocessors = [
        MNEPreproc('pick_channels',
                   ch_names=sorted(["C4", "Cz"]),
                   ordered=True),
        MNEPreproc(filterbank,
                   frequency_bands=[(0, 4), (4, 8), (8, 13)],
                   drop_original_signals=False,
                   order_by_frequency_band=True),
    ]
    preprocess(base_concat_ds, preprocessors)
    np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [
        "C4", "Cz", "C4_0-4", "Cz_0-4", "C4_4-8", "Cz_4-8", "C4_8-13",
        "Cz_8-13"
    ])
Exemple #15
0
def test_windows_from_events_cropped(lazy_loadable_dataset):
    """Test windowing from events on cropped data.

    Cropping raw data changes the `first_samp` attribute of the Raw object, and
    so it is important to test this is taken into account by the windowers.
    """
    tmin, tmax = 100, 120

    ds = copy.deepcopy(lazy_loadable_dataset)
    ds.datasets[0].raw.annotations.crop(tmin, tmax)

    crop_ds = copy.deepcopy(lazy_loadable_dataset)
    crop_transform = MNEPreproc('crop', tmin=tmin, tmax=tmax)
    preprocess(crop_ds, [crop_transform])

    # Extract windows
    windows1 = create_windows_from_events(concat_ds=ds,
                                          trial_start_offset_samples=0,
                                          trial_stop_offset_samples=0,
                                          window_size_samples=100,
                                          window_stride_samples=100,
                                          drop_last_window=False)
    windows2 = create_windows_from_events(concat_ds=crop_ds,
                                          trial_start_offset_samples=0,
                                          trial_stop_offset_samples=0,
                                          window_size_samples=100,
                                          window_stride_samples=100,
                                          drop_last_window=False)
    assert (windows1[0][0] == windows2[0][0]).all()

    # Make sure events that fall outside of recording will trigger an error
    with pytest.raises(ValueError,
                       match='"trial_stop_offset_samples" too large'):
        create_windows_from_events(concat_ds=ds,
                                   trial_start_offset_samples=0,
                                   trial_stop_offset_samples=10000,
                                   window_size_samples=100,
                                   window_stride_samples=100,
                                   drop_last_window=False)
    with pytest.raises(ValueError,
                       match='"trial_stop_offset_samples" too large'):
        create_windows_from_events(concat_ds=crop_ds,
                                   trial_start_offset_samples=0,
                                   trial_stop_offset_samples=2001,
                                   window_size_samples=100,
                                   window_stride_samples=100,
                                   drop_last_window=False)
Exemple #16
0
def test_filterbank_order_channels_by_freq(base_concat_ds):
    base_concat_ds = base_concat_ds.split([[0]])['0']
    preprocessors = [
        Preprocessor('pick_channels',
                     ch_names=sorted(['C4', 'Cz']),
                     ordered=True),
        Preprocessor(filterbank,
                     frequency_bands=[(0, 4), (4, 8), (8, 13)],
                     drop_original_signals=False,
                     order_by_frequency_band=True,
                     apply_on_array=False)
    ]
    preprocess(base_concat_ds, preprocessors)
    np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [
        'C4', 'Cz', 'C4_0-4', 'Cz_0-4', 'C4_4-8', 'Cz_4-8', 'C4_8-13',
        'Cz_8-13'
    ])
Exemple #17
0
def test_deprecated_preprocs(base_concat_ds):
    msg1 = 'MNEPreproc is deprecated. Use Preprocessor with ' \
           '`apply_on_array=False` instead.'
    msg2 = 'NumpyPreproc is deprecated. Use Preprocessor with ' \
           '`apply_on_array=True` instead.'
    with pytest.warns(UserWarning, match=msg1):
        mne_preproc = MNEPreproc('pick_types', eeg=True, meg=False, stim=False)
    factor = 1e6
    with pytest.warns(UserWarning, match=msg2):
        np_preproc = NumpyPreproc(scale, factor=factor)

    raw_timepoint = base_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(base_concat_ds, [mne_preproc, np_preproc])
    np.testing.assert_allclose(base_concat_ds[0][0],
                               raw_timepoint * factor,
                               rtol=1e-4,
                               atol=1e-4)
def test_zscore_windows(windows_concat_ds):
    preprocessors = [
        MNEPreproc('pick_types', eeg=True, meg=False, stim=False),
        MNEPreproc(zscore, )
    ]
    preprocess(windows_concat_ds, preprocessors)
    for ds in windows_concat_ds.datasets:
        windowed_data = ds.windows.get_data()
        shape = windowed_data.shape
        # zero mean
        expected = np.zeros(shape[:-1])
        np.testing.assert_allclose(
            windowed_data.mean(axis=-1), expected, rtol=1e-4, atol=1e-4)
        # unit variance
        expected = np.ones(shape[:-1])
        np.testing.assert_allclose(
            windowed_data.std(axis=-1), expected, rtol=1e-4, atol=1e-4)
def test_zscore_continuous(base_concat_ds):
    preprocessors = [
        MNEPreproc('pick_types', eeg=True, meg=False, stim=False),
        MNEPreproc('apply_function', fun=zscore, channel_wise=True)
    ]
    preprocess(base_concat_ds, preprocessors)
    for ds in base_concat_ds.datasets:
        raw_data = ds.raw.get_data()
        shape = raw_data.shape
        # zero mean
        expected = np.zeros(shape[:-1])
        np.testing.assert_allclose(
            raw_data.mean(axis=-1), expected, rtol=1e-4, atol=1e-4)
        # unit variance
        expected = np.ones(shape[:-1])
        np.testing.assert_allclose(
            raw_data.std(axis=-1), expected, rtol=1e-4, atol=1e-4)
Exemple #20
0
def get_sleep_physionet(subject_ids=range(83), recording_ids=[1, 2]):
    bug_subjects = [13, 36, 39, 48, 52, 59, 65, 68, 69, 73, 75, 76, 78, 79]
    subject_ids = [id_ for id_ in subject_ids if id_ not in bug_subjects]

    dataset = SleepPhysionet(subject_ids=subject_ids,
                             recording_ids=recording_ids,
                             crop_wake_mins=30)
    high_cut_hz = 30

    preprocessors = [
        # convert from volt to microvolt, directly modifying the numpy array
        NumpyPreproc(fn=lambda x: x * 1e6),
        # bandpass filter
        MNEPreproc(fn='filter', l_freq=None, h_freq=high_cut_hz),
    ]
    # Transform the data
    preprocess(dataset, preprocessors)

    mapping = {  # We merge stages 3 and 4 following AASM standards.
        'Sleep stage W': 0,
        'Sleep stage 1': 1,
        'Sleep stage 2': 2,
        'Sleep stage 3': 3,
        'Sleep stage 4': 3,
        'Sleep stage R': 4
    }

    window_size_s = 30
    sfreq = 100
    window_size_samples = window_size_s * sfreq

    windows_dataset = create_windows_from_events(
        dataset,
        trial_start_offset_samples=0,
        trial_stop_offset_samples=0,
        window_size_samples=window_size_samples,
        window_stride_samples=window_size_samples,
        preload=True,
        mapping=mapping)

    preprocess(windows_dataset, [MNEPreproc(fn=zscore)])

    info = pd.read_excel('SC-subjects.xls')

    return windows_dataset, info
Exemple #21
0
def build_epoch(subjects,
                recording,
                crop_wake_mins,
                preprocessing,
                train=True):
    dataset = SleepPhysionet(subject_ids=subjects,
                             recording_ids=recording,
                             crop_wake_mins=crop_wake_mins)

    if preprocessing:
        preprocessors = []
        if "microvolt_scaling" in preprocessing:
            preprocessors.append(NumpyPreproc(fn=lambda x: x * 1e6))
        if "filtering" in preprocessing:
            high_cut_hz = 30
            preprocessors.append(
                MNEPreproc(fn='filter', l_freq=None, h_freq=high_cut_hz))

        # Transform the data
        preprocess(dataset, preprocessors)
    mapping = {  # We merge stages 3 and 4 following AASM standards.
        'Sleep stage W': 0,
        'Sleep stage 1': 1,
        'Sleep stage 2': 2,
        'Sleep stage 3': 3,
        'Sleep stage 4': 3,
        'Sleep stage R': 4
    }

    window_size_s = 30
    sfreq = 100
    window_size_samples = window_size_s * sfreq

    windows_dataset = create_windows_from_events(
        dataset,
        trial_start_offset_samples=0,
        trial_stop_offset_samples=0,
        window_size_samples=window_size_samples,
        window_stride_samples=window_size_samples,
        preload=True,
        mapping=mapping)

    return windows_dataset
Exemple #22
0
def preprocess_data(dataset,
                    low_cut_hz=4,
                    high_cut_hz=38,
                    factor_new=1e-3,
                    init_block_size=1000):
    preprocessors = [
        # keep only EEG sensors
        MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False),
        # convert from volt to microvolt, directly modifying the numpy array
        NumpyPreproc(fn=lambda x: x * 1e6),
        # bandpass filter
        MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
        # exponential moving standardization
        NumpyPreproc(fn=exponential_moving_standardize,
                     factor_new=factor_new,
                     init_block_size=init_block_size)
    ]
    # Transform the data
    preprocess(dataset, preprocessors)

    return dataset
Exemple #23
0
def our_preprocess(dataset):
  low_cut_hz = 4.  # low cut frequency for filtering
  high_cut_hz = 38.  # high cut frequency for filtering
  # Parameters for exponential moving standardization
  factor_new = 1e-3
  init_block_size = 1000

  preprocessors = [
      # keep only EEG sensors
      MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False),
      # convert from volt to microvolt, directly modifying the numpy array
      NumpyPreproc(fn=lambda x: x * 1e6),
      # bandpass filter
      MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
      # exponential moving standardization
      NumpyPreproc(fn=exponential_moving_standardize, factor_new=factor_new,
          init_block_size=init_block_size)
  ]

  # Transform the data
  preprocess(dataset, preprocessors)
from braindecode.datautil.serialization import load_concat_dataset
from braindecode.datautil.windowers import create_windows_from_events


###############################################################################
# First, we load some dataset using MOABB.
ds = MOABBDataset(
    dataset_name='BNCI2014001',
    subject_ids=[1],
)

###############################################################################
# We can apply preprocessing steps to the dataset. It is also possible to skip
# this step and not apply any preprocessing.
preprocess(
    concat_ds=ds,
    preprocessors=[MNEPreproc(fn='resample', sfreq=10)]
)

###############################################################################
# We save the dataset to a an existing directory. It will create a '.fif' file 
# for every dataset in the concat dataset. Additionally it will create two 
# JSON files, the first holding the description of the dataset, the second 
# holding the name of the target. If you want to store to the same directory 
# several times, for example due to trying different preprocessing, you can 
# choose to overwrite the existing files.
ds.save(
    path='./',
    overwrite=False,
)

##############################################################################
Exemple #25
0
def train(subject_id):

    print('\n--------------------------------------------------\n')
    print(
        'Training on BCI_IV_2a dataset | Cross-subject | ID: {:02d}\n'.format(
            subject_id))

    ##### subject_range = [subject_id]
    subject_range = [x for x in range(1, 10)]

    dataset = MOABBDataset(dataset_name="BNCI2014001",
                           subject_ids=subject_range)

    ######################################################################
    # Preprocessing

    low_cut_hz = 4.  # low cut frequency for filtering
    high_cut_hz = 38.  # high cut frequency for filtering
    # Parameters for exponential moving standardization
    factor_new = 1e-3
    init_block_size = 1000

    preprocessors = [
        Preprocessor('pick_types', eeg=True, eog=False, meg=False,
                     stim=False),  # Keep EEG sensors
        Preprocessor(lambda x: x * 1e6),  # Convert from V to uV
        Preprocessor('filter', l_freq=low_cut_hz,
                     h_freq=high_cut_hz),  # Bandpass filter
        #Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'),
        Preprocessor('resample', sfreq=125),
        Preprocessor(covariance_align),

        ## Preprocessor(exponential_moving_standardize,  # Exponential moving standardization
        ## factor_new=factor_new, init_block_size=init_block_size)
        ## Preprocessor('pick_channels', ch_names=short_ch_names, ordered=True),
    ]

    # Transform the data
    print('Preprocessing dataset\n')
    preprocess(dataset, preprocessors)

    ######################################################################
    # Cut Compute Windows
    # ~~~~~~~~~~~~~~~~~~~

    trial_start_offset_seconds = -0.5
    trial_stop_offset_seconds = 0.0
    # Extract sampling frequency, check that they are same in all datasets
    sfreq = dataset.datasets[0].raw.info['sfreq']
    assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
    # Calculate the trial start offset in samples.
    trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)
    trial_stop_offset_samples = int(trial_stop_offset_seconds * sfreq)

    # Create windows using braindecode function for this. It needs parameters to define how
    # trials should be used.
    print('Windowing dataset\n')
    windows_dataset = create_windows_from_events(
        dataset,
        # picks=["Fz", "FC3", "FC1", "FCz", "FC2", "FC4", "C5", "C3", "C1", "Cz", "C2", "C4", "C6", "CP3", "CP1", "CPz", "CP2", "CP4", "P1", "Pz", "P2", "POz"],
        trial_start_offset_samples=trial_start_offset_samples,
        trial_stop_offset_samples=trial_stop_offset_samples,
        preload=True,
    )

    print('Computing covariances of each WindowsDataset')
    windows_dataset.compute_covariances_concat()

    # print(windows_dataset.datasets[0].windows)

    ######################################################################
    # Merge multiple datasets into a single WindowDataset
    # metadata_all = [ds.windows.metadata for ds in windows_dataset.datasets]
    # metadata_full = pd.concat(metadata_all)
    """
	epochs_all = [ds.windows for ds in windows_dataset.datasets]
	epochs_full = mne.concatenate_epochs(epochs_all)
	full_dataset = WindowsDataset(windows=epochs_full, description=None, transform=None)
	windows_dataset = full_dataset
	"""
    ######################################################################
    # Split dataset into train and valid

    # keep only session 1:
    # temp = windows_dataset.split( 'session' )
    # windows_dataset = temp['session_T']

    # print(windows_dataset.datasets[0].windows)
    # print(windows_dataset.datasets[0].windows.get_data().shape)
    # quit()

    subject_column = windows_dataset.description['subject'].values
    inds_train = list(np.where(subject_column != subject_id)[0])
    inds_valid = list(np.where(subject_column == subject_id)[0])
    splitted = windows_dataset.split([inds_train, inds_valid])
    train_set = splitted['0']
    valid_set = splitted['1']

    #######

    epochs_all = [ds.windows for ds in train_set.datasets]
    epochs_full = mne.concatenate_epochs(epochs_all)
    trialwise_weights_all = [ds.trialwise_weights for ds in train_set.datasets]
    trialwise_weights_full = np.hstack(trialwise_weights_all)
    full_dataset = WindowsDataset(windows=epochs_full,
                                  description=None,
                                  transform=None)
    full_dataset.trialwise_weights = trialwise_weights_full
    train_set = full_dataset
    # print(train_set.windows.metadata)
    ######################################################################
    # Create model

    cuda = torch.cuda.is_available(
    )  # check if GPU is available, if True chooses to use it
    device = 'cuda' if cuda else 'cpu'
    if cuda:
        torch.backends.cudnn.benchmark = True
    seed = 20200220  # random seed to make results reproducible
    # Set random seed to be able to reproduce results
    set_random_seeds(seed=seed, cuda=cuda)

    n_classes = 4
    # Extract number of chans and time steps from dataset
    n_chans = train_set[0][0].shape[0]
    input_window_samples = train_set[0][0].shape[1]
    """
	model = ShallowFBCSPNet(
		n_chans,
		n_classes,
		input_window_samples=input_window_samples,
		final_conv_length='auto')
	"""
    """
	model = EEGNetv1(
			n_chans,
			n_classes,
			input_window_samples=input_window_samples,
			final_conv_length="auto",
			pool_mode="mean",
			second_kernel_size=(2, 32),
			third_kernel_size=(8, 4),
			drop_prob=0.25)
	"""
    """
	model = HybridNet(n_chans, n_classes,
					input_window_samples=input_window_samples)
	"""
    """
	model = TCN(n_chans, n_classes,
				n_blocks=6,
				n_filters=32,
				kernel_size=9,
				drop_prob=0.0,
				add_log_softmax=True)
	"""

    model = EEGNetv4(
        n_chans,
        n_classes,
        input_window_samples=input_window_samples,
        final_conv_length="auto",
        pool_mode="mean",
        F1=8,
        D=2,
        F2=16,  # usually set to F1*D (?)
        kernel_length=64,
        third_kernel_size=(8, 4),
        drop_prob=0.2)

    if cuda:
        model.cuda()

    ######################################################################
    # Training

    # These values we found good for shallow network:
    lr = 0.01  # 0.0625 * 0.01
    weight_decay = 0.0005

    # For deep4 they should be:
    # lr = 1 * 0.01
    # weight_decay = 0.5 * 0.001

    batch_size = 64
    n_epochs = 100

    # clf = EEGClassifier(
    clf = EEGClassifier_weighted(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.SGD,  #AdamW,
        train_split=predefined_split(
            valid_set),  # using valid_set for validation
        optimizer__lr=lr,
        optimizer__momentum=0.9,
        optimizer__weight_decay=weight_decay,
        batch_size=batch_size,
        callbacks=[
            "accuracy",  #("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
        ],
        device=device,
    )
    # Model training for a specified number of epochs. `y` is None as it is already supplied
    # in the dataset.
    clf.fit(train_set, y=None, epochs=n_epochs)

    results_columns = [
        'train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy'
    ]
    df = pd.DataFrame(clf.history[:, results_columns],
                      columns=results_columns,
                      index=clf.history[:, 'epoch'])

    val_accs = df['valid_accuracy'].values
    max_val_acc = 100.0 * np.max(val_accs)

    return max_val_acc
Exemple #26
0
#

from braindecode.datautil.preprocess import (MNEPreproc, NumpyPreproc,
                                             preprocess)

high_cut_hz = 30

preprocessors = [
    # convert from volt to microvolt, directly modifying the numpy array
    NumpyPreproc(fn=lambda x: x * 1e6),
    # bandpass filter
    MNEPreproc(fn='filter', l_freq=None, h_freq=high_cut_hz),
]

# Transform the data
preprocess(dataset, preprocessors)

######################################################################
# Extract windows
# ~~~~~~~~~~~~~~~
#

######################################################################
# We extract 30-s windows to be used in the classification task.

from braindecode.datautil.windowers import create_windows_from_events

mapping = {  # We merge stages 3 and 4 following AASM standards.
    'Sleep stage W': 0,
    'Sleep stage 1': 1,
    'Sleep stage 2': 2,
Exemple #27
0
def test_preprocess_windows_str(windows_concat_ds):
    preprocessors = [
        Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)
    ]
    preprocess(windows_concat_ds, preprocessors)
    assert windows_concat_ds[0][0].shape[1] == 25
Exemple #28
0
def test_preprocess_raw_str(base_concat_ds):
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]
    preprocess(base_concat_ds, preprocessors)
    assert len(base_concat_ds.datasets[0].raw.times) == 2500
Exemple #29
0
def test_method_not_available(base_concat_ds):
    preprocessors = [Preprocessor('this_method_is_not_real', )]
    with pytest.raises(AttributeError):
        preprocess(base_concat_ds, preprocessors)
Exemple #30
0
def test_not_list():
    with pytest.raises(AssertionError):
        preprocess(None, {'test': 1})