Example #1
0
def test_windows_fixed_length_cropped(lazy_loadable_dataset):
    """Test fixed length windowing on cropped data.

    Cropping raw data changes the `first_samp` attribute of the Raw object, and
    so it is important to test this is taken into account by the windowers.
    """
    tmin, tmax = 100, 120

    ds = copy.deepcopy(lazy_loadable_dataset)
    ds.datasets[0].raw.annotations.crop(tmin, tmax)

    crop_ds = copy.deepcopy(lazy_loadable_dataset)
    crop_transform = Preprocessor('crop', tmin=tmin, tmax=tmax)
    preprocess(crop_ds, [crop_transform])

    # Extract windows
    sfreq = ds.datasets[0].raw.info['sfreq']
    tmin_samples, tmax_samples = int(tmin * sfreq), int(tmax * sfreq)

    windows1 = create_fixed_length_windows(
        concat_ds=ds, start_offset_samples=tmin_samples,
        stop_offset_samples=tmax_samples, window_size_samples=100,
        window_stride_samples=100, drop_last_window=True)
    windows2 = create_fixed_length_windows(
        concat_ds=crop_ds, start_offset_samples=0,
        stop_offset_samples=None, window_size_samples=100,
        window_stride_samples=100, drop_last_window=True)
    assert (windows1[0][0] == windows2[0][0]).all()
Example #2
0
def test_epochs_kwargs(lazy_loadable_dataset):
    picks = ['ch0']
    on_missing = 'warning'
    flat = {'eeg': 3e-6}
    reject = {'eeg': 43e-6}

    windows = create_windows_from_events(
        concat_ds=lazy_loadable_dataset, trial_start_offset_samples=0,
        trial_stop_offset_samples=0, window_size_samples=100,
        window_stride_samples=100, drop_last_window=False, picks=picks,
        on_missing=on_missing, flat=flat, reject=reject)

    epochs = windows.datasets[0].windows
    assert epochs.ch_names == picks
    assert epochs.reject == reject
    assert epochs.flat == flat

    windows = create_fixed_length_windows(
        concat_ds=lazy_loadable_dataset, start_offset_samples=0,
        stop_offset_samples=None, window_size_samples=100,
        window_stride_samples=100, drop_last_window=False, picks=picks,
        on_missing=on_missing, flat=flat, reject=reject)

    epochs = windows.datasets[0].windows
    assert epochs.ch_names == picks
    assert epochs.reject == reject
    assert epochs.flat == flat
Example #3
0
def test_fixed_length_windows_preload_false(lazy_loadable_dataset):
    windows = create_fixed_length_windows(
        concat_ds=lazy_loadable_dataset, start_offset_samples=0,
        stop_offset_samples=100, window_size_samples=100,
        window_stride_samples=100, drop_last_window=False, preload=False)

    assert all([not ds.windows.preload for ds in windows.datasets])
Example #4
0
def test_drop_bad_windows(concat_ds_targets, drop_bad_windows, preload):
    concat_ds, _ = concat_ds_targets
    windows_from_events = create_windows_from_events(
        concat_ds=concat_ds,
        trial_start_offset_samples=0,
        trial_stop_offset_samples=0,
        window_size_samples=100,
        window_stride_samples=100,
        drop_last_window=False,
        preload=preload,
        drop_bad_windows=drop_bad_windows)

    windows_fixed_length = create_fixed_length_windows(
        concat_ds=concat_ds,
        start_offset_samples=0,
        stop_offset_samples=1000,
        window_size_samples=1000,
        window_stride_samples=1000,
        drop_last_window=False,
        preload=preload,
        drop_bad_windows=drop_bad_windows)

    assert (windows_from_events.datasets[0].windows._bad_dropped ==
            drop_bad_windows)
    assert (windows_fixed_length.datasets[0].windows._bad_dropped ==
            drop_bad_windows)
Example #5
0
def test_fixed_length_windower_n_jobs(lazy_loadable_dataset):
    longer_dataset = BaseConcatDataset([lazy_loadable_dataset.datasets[0]] * 8)
    windows = [create_fixed_length_windows(
        concat_ds=longer_dataset, start_offset_samples=0,
        stop_offset_samples=None, window_size_samples=100,
        window_stride_samples=100, drop_last_window=True, preload=True,
        n_jobs=n_jobs) for n_jobs in [1, 2]]

    assert windows[0].description.equals(windows[1].description)
    for ds1, ds2 in zip(windows[0].datasets, windows[1].datasets):
        # assert ds1.windows == ds2.windows  # Runs locally, fails in CI
        assert np.allclose(ds1.windows.get_data(), ds2.windows.get_data())
        assert pd.Series(ds1.windows.info).to_json() == \
               pd.Series(ds2.windows.info).to_json()
        assert ds1.description.equals(ds2.description)
        assert np.array_equal(ds1.y, ds2.y)
        assert np.array_equal(ds1.crop_inds, ds2.crop_inds)
Example #6
0
def test_fixed_length_windower(start_offset_samples, window_size_samples,
                               window_stride_samples, drop_last_window,
                               mapping):
    rng = np.random.RandomState(42)
    info = mne.create_info(ch_names=['0', '1'], sfreq=50, ch_types='eeg')
    data = rng.randn(2, 1000)
    raw = mne.io.RawArray(data=data, info=info)
    desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48})
    base_ds = BaseDataset(raw, desc, target_name="age")
    concat_ds = BaseConcatDataset([base_ds])

    if window_size_samples is None:
        window_size_samples = base_ds.raw.n_times
    stop_offset_samples = data.shape[1] - start_offset_samples
    epochs_ds = create_fixed_length_windows(
        concat_ds,
        start_offset_samples=start_offset_samples,
        stop_offset_samples=stop_offset_samples,
        window_size_samples=window_size_samples,
        window_stride_samples=window_stride_samples,
        drop_last_window=drop_last_window,
        mapping=mapping)

    if mapping is not None:
        assert base_ds.target == 48
        assert all(epochs_ds.datasets[0].windows.metadata['target'] == 0)

    epochs_data = epochs_ds.datasets[0].windows.get_data()

    idxs = np.arange(start_offset_samples,
                     stop_offset_samples - window_size_samples + 1,
                     window_stride_samples)
    if not drop_last_window and idxs[
            -1] != stop_offset_samples - window_size_samples:
        idxs = np.append(idxs, stop_offset_samples - window_size_samples)

    assert len(idxs) == epochs_data.shape[0], (
        'Number of epochs different than expected')
    assert window_size_samples == epochs_data.shape[2], (
        'Window size different than expected')
    for j, idx in enumerate(idxs):
        np.testing.assert_allclose(base_ds.raw.get_data()[:, idx:idx +
                                                          window_size_samples],
                                   epochs_data[j, :],
                                   err_msg=f'Epochs different for epoch {j}')
Example #7
0
        fake_descrition = pd.Series(
            data=[target, train_or_eval],
            index=["target", "session"])
        base_ds = BaseDataset(raw, fake_descrition, target_name="target")
        datasets.append(base_ds)
    dataset = BaseConcatDataset(datasets)
    return dataset

dataset = fake_regression_dataset(
    n_fake_recs=5, n_fake_chs=21, fake_sfreq=100, fake_duration_s=60)

windows_dataset = create_fixed_length_windows(
    dataset,
    start_offset_samples=0,
    stop_offset_samples=0,
    window_size_samples=input_window_samples,
    window_stride_samples=n_preds_per_input,
    drop_last_window=False,
    drop_bad_windows=True,
)

splits = windows_dataset.split("session")
train_set = splits["train"]
valid_set = splits["eval"]

regressor = EEGRegressor(
    model,
    cropped=True,
    criterion=CroppedLoss,
    criterion__loss_function=torch.nn.functional.mse_loss,
    optimizer=torch.optim.AdamW,
Example #8
0
        base_ds = BaseDataset(raw, fake_descrition, target_name="target")
        datasets.append(base_ds)
    dataset = BaseConcatDataset(datasets)
    return dataset


dataset = fake_regression_dataset(n_fake_recs=5,
                                  n_fake_chs=21,
                                  fake_sfreq=100,
                                  fake_duration_s=60)

windows_dataset = create_fixed_length_windows(
    dataset,
    start_offset_samples=0,
    stop_offset_samples=None,
    supercrop_size_samples=input_time_length,
    supercrop_stride_samples=n_preds_per_input,
    drop_samples=False,
    drop_bad_windows=True,
)

splits = windows_dataset.split("session")
train_set = splits["train"]
valid_set = splits["eval"]

regressor = EEGRegressor(
    model,
    cropped=True,
    criterion=CroppedLoss,
    criterion__loss_function=torch.nn.functional.mse_loss,
    optimizer=torch.optim.AdamW,