Exemplo n.º 1
0
def test_filterbank(base_concat_ds):
    base_concat_ds = base_concat_ds.split([[0]])['0']
    preprocessors = [
        Preprocessor('pick_channels',
                     ch_names=sorted(['C4', 'Cz']),
                     ordered=True),
        Preprocessor(filterbank,
                     frequency_bands=[(0, 4), (4, 8), (8, 13)],
                     drop_original_signals=False,
                     apply_on_array=False)
    ]
    preprocess(base_concat_ds, preprocessors)
    for x, y in base_concat_ds:
        break
    assert x.shape[0] == 8
    freq_band_annots = [
        ch.split('_')[-1] for ch in base_concat_ds.datasets[0].raw.ch_names
        if '_' in ch
    ]
    assert len(np.unique(freq_band_annots)) == 3
    np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [
        'C4',
        'C4_0-4',
        'C4_4-8',
        'C4_8-13',
        'Cz',
        'Cz_0-4',
        'Cz_4-8',
        'Cz_8-13',
    ])
Exemplo n.º 2
0
def test_scale_continuous(base_concat_ds):
    factor = 1e6
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(scale, factor=factor)
    ]
    raw_timepoint = base_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(base_concat_ds, preprocessors)
    np.testing.assert_allclose(base_concat_ds[0][0],
                               raw_timepoint * factor,
                               rtol=1e-4,
                               atol=1e-4)
Exemplo n.º 3
0
def test_scale_windows(windows_concat_ds):
    factor = 1e6
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(scale, factor=factor)
    ]
    raw_window = windows_concat_ds[0][0][:22]  # only keep EEG channels
    preprocess(windows_concat_ds, preprocessors)
    np.testing.assert_allclose(windows_concat_ds[0][0],
                               raw_window * factor,
                               rtol=1e-4,
                               atol=1e-4)
Exemplo n.º 4
0
def test_windows_fixed_length_cropped(lazy_loadable_dataset):
    """Test fixed length windowing on cropped data.

    Cropping raw data changes the `first_samp` attribute of the Raw object, and
    so it is important to test this is taken into account by the windowers.
    """
    tmin, tmax = 100, 120

    ds = copy.deepcopy(lazy_loadable_dataset)
    ds.datasets[0].raw.annotations.crop(tmin, tmax)

    crop_ds = copy.deepcopy(lazy_loadable_dataset)
    crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax)
    preprocess(crop_ds, [crop_transform])

    # Extract windows
    sfreq = ds.datasets[0].raw.info['sfreq']
    tmin_samples, tmax_samples = int(tmin * sfreq), int(tmax * sfreq)

    windows1 = create_fixed_length_windows(
        concat_ds=ds, start_offset_samples=tmin_samples,
        stop_offset_samples=tmax_samples, window_size_samples=100,
        window_stride_samples=100, drop_last_window=True)
    windows2 = create_fixed_length_windows(
        concat_ds=crop_ds, start_offset_samples=0,
        stop_offset_samples=None, window_size_samples=100,
        window_stride_samples=100, drop_last_window=True)
    assert (windows1[0][0] == windows2[0][0]).all()
Exemplo n.º 5
0
def test_filterbank_order_channels_by_freq(base_concat_ds):
    base_concat_ds = base_concat_ds.split([[0]])['0']
    preprocessors = [
        Preprocessor('pick_channels',
                     ch_names=sorted(['C4', 'Cz']),
                     ordered=True),
        Preprocessor(filterbank,
                     frequency_bands=[(0, 4), (4, 8), (8, 13)],
                     drop_original_signals=False,
                     order_by_frequency_band=True,
                     apply_on_array=False)
    ]
    preprocess(base_concat_ds, preprocessors)
    np.testing.assert_array_equal(base_concat_ds.datasets[0].raw.ch_names, [
        'C4', 'Cz', 'C4_0-4', 'Cz_0-4', 'C4_4-8', 'Cz_4-8', 'C4_8-13',
        'Cz_8-13'
    ])
Exemplo n.º 6
0
def test_zscore_windows(windows_concat_ds):
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(zscore)
    ]
    preprocess(windows_concat_ds, preprocessors)
    for ds in windows_concat_ds.datasets:
        windowed_data = ds.windows.get_data()
        shape = windowed_data.shape
        # zero mean
        expected = np.zeros(shape[:-1])
        np.testing.assert_allclose(windowed_data.mean(axis=-1),
                                   expected,
                                   rtol=1e-4,
                                   atol=1e-4)
        # unit variance
        expected = np.ones(shape[:-1])
        np.testing.assert_allclose(windowed_data.std(axis=-1),
                                   expected,
                                   rtol=1e-4,
                                   atol=1e-4)
Exemplo n.º 7
0
def test_zscore_continuous(base_concat_ds):
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(zscore, channel_wise=True)
    ]
    preprocess(base_concat_ds, preprocessors)
    for ds in base_concat_ds.datasets:
        raw_data = ds.raw.get_data()
        shape = raw_data.shape
        # zero mean
        expected = np.zeros(shape[:-1])
        np.testing.assert_allclose(raw_data.mean(axis=-1),
                                   expected,
                                   rtol=1e-4,
                                   atol=1e-4)
        # unit variance
        expected = np.ones(shape[:-1])
        np.testing.assert_allclose(raw_data.std(axis=-1),
                                   expected,
                                   rtol=1e-4,
                                   atol=1e-4)
Exemplo n.º 8
0
def test_preprocess_windows_callable_on_object(windows_concat_ds):
    factor = 10
    preprocessors = [
        Preprocessor(modify_windows_object,
                     apply_on_array=False,
                     factor=factor)
    ]
    raw_window = windows_concat_ds[0][0]
    preprocess(windows_concat_ds, preprocessors)
    np.testing.assert_allclose(windows_concat_ds[0][0],
                               raw_window * factor,
                               rtol=1e-4,
                               atol=1e-4)
Exemplo n.º 9
0
def test_windows_from_events_cropped(lazy_loadable_dataset):
    """Test windowing from events on cropped data.

    Cropping raw data changes the `first_samp` attribute of the Raw object, and
    so it is important to test this is taken into account by the windowers.
    """
    tmin, tmax = 100, 120

    ds = copy.deepcopy(lazy_loadable_dataset)
    ds.datasets[0].raw.annotations.crop(tmin, tmax)

    crop_ds = copy.deepcopy(lazy_loadable_dataset)
    crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax)
    preprocess(crop_ds, [crop_transform])

    # Extract windows
    windows1 = create_windows_from_events(
        concat_ds=ds, trial_start_offset_samples=0, trial_stop_offset_samples=0,
        window_size_samples=100, window_stride_samples=100,
        drop_last_window=False)
    windows2 = create_windows_from_events(
        concat_ds=crop_ds, trial_start_offset_samples=0,
        trial_stop_offset_samples=0, window_size_samples=100,
        window_stride_samples=100, drop_last_window=False)
    assert (windows1[0][0] == windows2[0][0]).all()

    # Make sure events that fall outside of recording will trigger an error
    with pytest.raises(
            ValueError, match='"trial_stop_offset_samples" too large'):
        create_windows_from_events(
            concat_ds=ds, trial_start_offset_samples=0,
            trial_stop_offset_samples=10000, window_size_samples=100,
            window_stride_samples=100, drop_last_window=False)
    with pytest.raises(
            ValueError, match='"trial_stop_offset_samples" too large'):
        create_windows_from_events(
            concat_ds=crop_ds, trial_start_offset_samples=0,
            trial_stop_offset_samples=2001, window_size_samples=100,
            window_stride_samples=100, drop_last_window=False)
Exemplo n.º 10
0
# ds has a pandas DataFrame with additional description of its internal datasets
display(ds.description)

##############################################################################
# We can iterate through ds which yields one time point of a continuous signal x,
# and a target y (which can be None if targets are not defined for the entire
# continuous signal).
for x, y in ds:
    print(x.shape, y)
    break

##############################################################################
# We can apply preprocessing transforms that are defined in mne and work
# in-place, such as resampling, bandpass filtering, or electrode selection.
preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=True),
    Preprocessor('resample', sfreq=100)
]
print(ds.datasets[0].raw.info["sfreq"])
preprocess(ds, preprocessors)
print(ds.datasets[0].raw.info["sfreq"])

###############################################################################
# We can easily split ds based on a criteria applied to the description
# DataFrame:
subsets = ds.split("session")
print({subset_name: len(subset) for subset_name, subset in subsets.items()})

###############################################################################
# Next, we use a windower to extract events from the dataset based on events:
windows_ds = create_windows_from_events(
Exemplo n.º 11
0
def train(subject_id):

    print('\n--------------------------------------------------\n')
    print(
        'Training on BCI_IV_2a dataset | Cross-subject | ID: {:02d}\n'.format(
            subject_id))

    ##### subject_range = [subject_id]
    subject_range = [x for x in range(1, 10)]

    dataset = MOABBDataset(dataset_name="BNCI2014001",
                           subject_ids=subject_range)

    ######################################################################
    # Preprocessing

    low_cut_hz = 4.  # low cut frequency for filtering
    high_cut_hz = 38.  # high cut frequency for filtering
    # Parameters for exponential moving standardization
    factor_new = 1e-3
    init_block_size = 1000

    preprocessors = [
        Preprocessor('pick_types', eeg=True, eog=False, meg=False,
                     stim=False),  # Keep EEG sensors
        Preprocessor(lambda x: x * 1e6),  # Convert from V to uV
        Preprocessor('filter', l_freq=low_cut_hz,
                     h_freq=high_cut_hz),  # Bandpass filter
        #Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'),
        Preprocessor('resample', sfreq=125),
        Preprocessor(covariance_align),

        ## Preprocessor(exponential_moving_standardize,  # Exponential moving standardization
        ## factor_new=factor_new, init_block_size=init_block_size)
        ## Preprocessor('pick_channels', ch_names=short_ch_names, ordered=True),
    ]

    # Transform the data
    print('Preprocessing dataset\n')
    preprocess(dataset, preprocessors)

    ######################################################################
    # Cut Compute Windows
    # ~~~~~~~~~~~~~~~~~~~

    trial_start_offset_seconds = -0.5
    trial_stop_offset_seconds = 0.0
    # Extract sampling frequency, check that they are same in all datasets
    sfreq = dataset.datasets[0].raw.info['sfreq']
    assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
    # Calculate the trial start offset in samples.
    trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)
    trial_stop_offset_samples = int(trial_stop_offset_seconds * sfreq)

    # Create windows using braindecode function for this. It needs parameters to define how
    # trials should be used.
    print('Windowing dataset\n')
    windows_dataset = create_windows_from_events(
        dataset,
        # picks=["Fz", "FC3", "FC1", "FCz", "FC2", "FC4", "C5", "C3", "C1", "Cz", "C2", "C4", "C6", "CP3", "CP1", "CPz", "CP2", "CP4", "P1", "Pz", "P2", "POz"],
        trial_start_offset_samples=trial_start_offset_samples,
        trial_stop_offset_samples=trial_stop_offset_samples,
        preload=True,
    )

    print('Computing covariances of each WindowsDataset')
    windows_dataset.compute_covariances_concat()

    # print(windows_dataset.datasets[0].windows)

    ######################################################################
    # Merge multiple datasets into a single WindowDataset
    # metadata_all = [ds.windows.metadata for ds in windows_dataset.datasets]
    # metadata_full = pd.concat(metadata_all)
    """
	epochs_all = [ds.windows for ds in windows_dataset.datasets]
	epochs_full = mne.concatenate_epochs(epochs_all)
	full_dataset = WindowsDataset(windows=epochs_full, description=None, transform=None)
	windows_dataset = full_dataset
	"""
    ######################################################################
    # Split dataset into train and valid

    # keep only session 1:
    # temp = windows_dataset.split( 'session' )
    # windows_dataset = temp['session_T']

    # print(windows_dataset.datasets[0].windows)
    # print(windows_dataset.datasets[0].windows.get_data().shape)
    # quit()

    subject_column = windows_dataset.description['subject'].values
    inds_train = list(np.where(subject_column != subject_id)[0])
    inds_valid = list(np.where(subject_column == subject_id)[0])
    splitted = windows_dataset.split([inds_train, inds_valid])
    train_set = splitted['0']
    valid_set = splitted['1']

    #######

    epochs_all = [ds.windows for ds in train_set.datasets]
    epochs_full = mne.concatenate_epochs(epochs_all)
    trialwise_weights_all = [ds.trialwise_weights for ds in train_set.datasets]
    trialwise_weights_full = np.hstack(trialwise_weights_all)
    full_dataset = WindowsDataset(windows=epochs_full,
                                  description=None,
                                  transform=None)
    full_dataset.trialwise_weights = trialwise_weights_full
    train_set = full_dataset
    # print(train_set.windows.metadata)
    ######################################################################
    # Create model

    cuda = torch.cuda.is_available(
    )  # check if GPU is available, if True chooses to use it
    device = 'cuda' if cuda else 'cpu'
    if cuda:
        torch.backends.cudnn.benchmark = True
    seed = 20200220  # random seed to make results reproducible
    # Set random seed to be able to reproduce results
    set_random_seeds(seed=seed, cuda=cuda)

    n_classes = 4
    # Extract number of chans and time steps from dataset
    n_chans = train_set[0][0].shape[0]
    input_window_samples = train_set[0][0].shape[1]
    """
	model = ShallowFBCSPNet(
		n_chans,
		n_classes,
		input_window_samples=input_window_samples,
		final_conv_length='auto')
	"""
    """
	model = EEGNetv1(
			n_chans,
			n_classes,
			input_window_samples=input_window_samples,
			final_conv_length="auto",
			pool_mode="mean",
			second_kernel_size=(2, 32),
			third_kernel_size=(8, 4),
			drop_prob=0.25)
	"""
    """
	model = HybridNet(n_chans, n_classes,
					input_window_samples=input_window_samples)
	"""
    """
	model = TCN(n_chans, n_classes,
				n_blocks=6,
				n_filters=32,
				kernel_size=9,
				drop_prob=0.0,
				add_log_softmax=True)
	"""

    model = EEGNetv4(
        n_chans,
        n_classes,
        input_window_samples=input_window_samples,
        final_conv_length="auto",
        pool_mode="mean",
        F1=8,
        D=2,
        F2=16,  # usually set to F1*D (?)
        kernel_length=64,
        third_kernel_size=(8, 4),
        drop_prob=0.2)

    if cuda:
        model.cuda()

    ######################################################################
    # Training

    # These values we found good for shallow network:
    lr = 0.01  # 0.0625 * 0.01
    weight_decay = 0.0005

    # For deep4 they should be:
    # lr = 1 * 0.01
    # weight_decay = 0.5 * 0.001

    batch_size = 64
    n_epochs = 100

    # clf = EEGClassifier(
    clf = EEGClassifier_weighted(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.SGD,  #AdamW,
        train_split=predefined_split(
            valid_set),  # using valid_set for validation
        optimizer__lr=lr,
        optimizer__momentum=0.9,
        optimizer__weight_decay=weight_decay,
        batch_size=batch_size,
        callbacks=[
            "accuracy",  #("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
        ],
        device=device,
    )
    # Model training for a specified number of epochs. `y` is None as it is already supplied
    # in the dataset.
    clf.fit(train_set, y=None, epochs=n_epochs)

    results_columns = [
        'train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy'
    ]
    df = pd.DataFrame(clf.history[:, results_columns],
                      columns=results_columns,
                      index=clf.history[:, 'epoch'])

    val_accs = df['valid_accuracy'].values
    max_val_acc = 100.0 * np.max(val_accs)

    return max_val_acc
Exemplo n.º 12
0
def test_preprocess_windows_str(windows_concat_ds):
    preprocessors = [
        Preprocessor('crop', tmin=0, tmax=0.1, include_tmax=False)
    ]
    preprocess(windows_concat_ds, preprocessors)
    assert windows_concat_ds[0][0].shape[1] == 25
Exemplo n.º 13
0
def test_preprocess_raw_str(base_concat_ds):
    preprocessors = [Preprocessor('crop', tmax=10, include_tmax=False)]
    preprocess(base_concat_ds, preprocessors)
    assert len(base_concat_ds.datasets[0].raw.times) == 2500
Exemplo n.º 14
0
def test_method_not_available(base_concat_ds):
    preprocessors = [Preprocessor('this_method_is_not_real', )]
    with pytest.raises(AttributeError):
        preprocess(base_concat_ds, preprocessors)
Exemplo n.º 15
0
#    These prepocessings are now directly applied to the loaded
#    data, and not on-the-fly applied as transformations in
#    PyTorch-libraries like
#    `torchvision <https://pytorch.org/docs/stable/torchvision/index.html>`__.
#

from braindecode.datautil.preprocess import (exponential_moving_standardize, preprocess, Preprocessor)

low_cut_hz = 4.  # low cut frequency for filtering
high_cut_hz = 38.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(lambda x: x * 1e6),  # Convert from V to uV
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(exponential_moving_standardize,  # Exponential moving standardization
                 factor_new=factor_new, init_block_size=init_block_size)
]

# Transform the data
preprocess(dataset, preprocessors) # in place modification


######################################################################
# Cut Compute Windows
# ~~~~~~~~~~~~~~~~~~~
#
# ~~~~~~~~~~~~~
#


######################################################################
# Next, we preprocess the raw data. We apply convert the data to microvolts and
# apply a lowpass filter. We omit the downsampling step of [1]_ as the Sleep
# Physionet data is already sampled at a lower 100 Hz.
#

from braindecode.datautil.preprocess import preprocess, Preprocessor

high_cut_hz = 30

preprocessors = [
    Preprocessor(lambda x: x * 1e6),
    Preprocessor('filter', l_freq=None, h_freq=high_cut_hz)
]

# Transform the data
preprocess(dataset, preprocessors)


######################################################################
# Extract windows
# ~~~~~~~~~~~~~~~
#


######################################################################
# We extract 30-s windows to be used in the classification task.

def custom_crop(raw, tmin=0.0, tmax=None, include_tmax=True):
    # crop recordings to tmin – tmax. can be incomplete if recording
    # has lower duration than tmax
    # by default mne fails if tmax is bigger than duration
    tmax = min((raw.n_times - 1) / raw.info['sfreq'], tmax)
    raw.crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax)


tmin = 1 * 60
tmax = 6 * 60
sfreq = 100

preprocessors = [
    Preprocessor(custom_crop, tmin=tmin, tmax=tmax, include_tmax=False,
                 apply_on_array=False),
    Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'),
    Preprocessor(custom_rename_channels, mapping=ch_mapping,
                 apply_on_array=False),
    Preprocessor('pick_channels', ch_names=short_ch_names, ordered=True),
    Preprocessor(lambda x: x * 1e6),
    Preprocessor('resample', sfreq=sfreq),
]


###############################################################################
# The preprocessing loop works as follows. For every recording, we apply the
# preprocessors as defined above. Then, we update the description of the rec,
# since we have altered the duration, the reference, and the sampling frequency.
# Afterwards, we split the continuous signals into compute windows. We store
# each recording to a unique subdirectory that is named corresponding to the