Esempio n. 1
0
def test_windows_from_events_different_events(tmpdir_factory):
    description_expected = 5 * ['T0', 'T1'] + 4 * ['T2', 'T3'] + 2 * ['T1']
    raw = _get_raw(tmpdir_factory, description_expected[:10])
    base_ds = BaseDataset(raw, description=pd.Series({'file_id': 1}))

    raw_1 = _get_raw(tmpdir_factory, description_expected[10:])
    base_ds_1 = BaseDataset(raw_1, description=pd.Series({'file_id': 2}))
    concat_ds = BaseConcatDataset([base_ds, base_ds_1])

    windows = create_windows_from_events(
        concat_ds=concat_ds, trial_start_offset_samples=0,
        trial_stop_offset_samples=0, window_size_samples=100,
        window_stride_samples=100, drop_last_window=False)
    description = []
    events = []
    for ds in windows.datasets:
        description += ds.windows.metadata['target'].to_list()
        events += ds.windows.events[:, 0].tolist()

    assert len(description) == 20
    np.testing.assert_array_equal(description,
                                  5 * [0, 1] + 4 * [2, 3] + 2 * [1])
    np.testing.assert_array_equal(
        np.concatenate(
            [raw.time_as_index(raw.annotations.onset, use_rounding=True),
             raw_1.time_as_index(raw.annotations.onset, use_rounding=True)]),
        events)
Esempio n. 2
0
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
    """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
    files.

    Parameters
    ----------
    path: str
        Path to the directory of the .fif / -epo.fif and .json files.
    preload: bool
        Whether to preload the data.
    ids_to_load: None | list(int)
        Ids of specific files to load.
    target_name: None or str
        Load specific description column as target. If not given, take saved target name.

    Returns
    -------
    concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
    """
    # assume we have a single concat dataset to load
    concat_of_raws = os.path.isfile(os.path.join(path, '0-raw.fif'))
    assert not (not concat_of_raws and target_name is not None), (
        'Setting a new target is only supported for raws.')
    concat_of_epochs = os.path.isfile(os.path.join(path, '0-epo.fif'))
    paths = [path]
    # assume we have multiple concat datasets to load
    if not (concat_of_raws or concat_of_epochs):
        concat_of_raws = os.path.isfile(os.path.join(path, '0', '0-raw.fif'))
        concat_of_epochs = os.path.isfile(os.path.join(path, '0', '0-epo.fif'))
        path = os.path.join(path, '*', '')
        paths = glob(path)
        paths = sorted(paths, key=lambda p: int(p.split(os.sep)[-2]))
        if ids_to_load is not None:
            paths = [paths[i] for i in ids_to_load]
        ids_to_load = None
    # if we have neither a single nor multiple datasets, something went wrong
    assert concat_of_raws or concat_of_epochs, (
        f'Expect either raw or epo to exist in {path} or in '
        f'{os.path.join(path, "0")}')

    datasets = []
    for path in paths:
        if concat_of_raws and target_name is None:
            target_file_name = os.path.join(path, 'target_name.json')
            target_name = json.load(open(target_file_name, "r"))['target_name']

        all_signals, description = _load_signals_and_description(
            path=path, preload=preload, raws=concat_of_raws,
            ids_to_load=ids_to_load
        )
        for i_signal, signal in enumerate(all_signals):
            if concat_of_raws:
                datasets.append(
                    BaseDataset(signal, description.iloc[i_signal],
                                target_name=target_name))
            else:
                datasets.append(
                    WindowsDataset(signal, description.iloc[i_signal])
                )
    return BaseConcatDataset(datasets)
Esempio n. 3
0
def lazy_loadable_dataset(tmpdir_factory):
    """Make a dataset of fif files that can be loaded lazily.
    """
    raw = _get_raw(tmpdir_factory)
    base_ds = BaseDataset(raw, description=pd.Series({'file_id': 1}))
    concat_ds = BaseConcatDataset([base_ds])

    return concat_ds
Esempio n. 4
0
def concat_ds_targets():
    raws, description = fetch_data_with_moabb(
        dataset_name="BNCI2014001", subject_ids=4)
    events, _ = mne.events_from_annotations(raws[0])
    targets = events[:, -1] - 1
    ds = BaseDataset(raws[0], description.iloc[0])
    concat_ds = BaseConcatDataset([ds])
    return concat_ds, targets
Esempio n. 5
0
def test_overlapping_trial_offsets(raw_info_targets):
    raw, info, targets = raw_info_targets
    base_ds = BaseDataset(raw, info)
    windower = EventWindower(trial_start_offset_samples=-2000,
                             trial_stop_offset_samples=1000,
                             supercrop_size_samples=1000,
                             supercrop_stride_samples=1000)
    with pytest.raises(AssertionError, match='trials overlap not implemented'):
        windower(base_ds)
Esempio n. 6
0
def test_stride_has_no_effect(raw_info_targets):
    raw, info, targets = raw_info_targets
    base_ds = BaseDataset(raw, info)
    windower = EventWindower(trial_start_offset_samples=0,
                             trial_stop_offset_samples=1000,
                             supercrop_size_samples=1000,
                             supercrop_stride_samples=1000)
    windows = windower(base_ds)
    description = windows.events[:, -1]
    assert len(description) == len(targets)
    np.testing.assert_array_equal(description, targets)
Esempio n. 7
0
def test_dropping_last_incomplete_supercrop(raw_info_targets):
    raw, info, targets = raw_info_targets
    base_ds = BaseDataset(raw, info)
    windower = EventWindower(trial_start_offset_samples=-250,
                             trial_stop_offset_samples=250,
                             supercrop_size_samples=250,
                             supercrop_stride_samples=300,
                             drop_samples=True)
    windows = windower(base_ds)
    description = windows.events[:, -1]
    assert len(description) == len(targets)
    np.testing.assert_array_equal(description, targets)
Esempio n. 8
0
def test_shifting_last_supercrop_back_in(raw_info_targets):
    raw, info, targets = raw_info_targets
    base_ds = BaseDataset(raw, info)
    windower = EventWindower(trial_start_offset_samples=-250,
                             trial_stop_offset_samples=250,
                             supercrop_size_samples=250,
                             supercrop_stride_samples=300)
    windows = windower(base_ds)
    description = windows.events[:, -1]
    assert len(description) == len(targets) * 2
    np.testing.assert_array_equal(description[0::2], targets)
    np.testing.assert_array_equal(description[1::2], targets)
Esempio n. 9
0
def test_maximally_overlapping_supercrops(raw_info_targets):
    raw, info, targets = raw_info_targets
    base_ds = BaseDataset(raw, info)
    windower = EventWindower(trial_start_offset_samples=-2,
                             trial_stop_offset_samples=1000,
                             supercrop_size_samples=1000,
                             supercrop_stride_samples=1)
    windows = windower(base_ds)
    description = windows.events[:, -1]
    assert len(description) == len(targets) * 3
    np.testing.assert_array_equal(description[0::3], targets)
    np.testing.assert_array_equal(description[1::3], targets)
    np.testing.assert_array_equal(description[2::3], targets)
Esempio n. 10
0
def test_windows_from_events_mapping_filter(tmpdir_factory):
    raw = _get_raw(tmpdir_factory, 5 * ['T0', 'T1'])
    base_ds = BaseDataset(raw, description=pd.Series({'file_id': 1}))
    concat_ds = BaseConcatDataset([base_ds])

    windows = create_windows_from_events(
        concat_ds=concat_ds, trial_start_offset_samples=0,
        trial_stop_offset_samples=0, window_size_samples=100,
        window_stride_samples=100, drop_last_window=False, mapping={'T1': 0})
    description = windows.datasets[0].windows.metadata['target'].to_list()

    assert len(description) == 5
    np.testing.assert_array_equal(description, np.zeros(5))
    # dataset should contain only 'T1' events
    np.testing.assert_array_equal(
        (raw.time_as_index(raw.annotations.onset[1::2], use_rounding=True)),
        windows.datasets[0].windows.events[:, 0])
Esempio n. 11
0
def test_fixed_length_windower(start_offset_samples, window_size_samples,
                               window_stride_samples, drop_last_window,
                               mapping):
    rng = np.random.RandomState(42)
    info = mne.create_info(ch_names=['0', '1'], sfreq=50, ch_types='eeg')
    data = rng.randn(2, 1000)
    raw = mne.io.RawArray(data=data, info=info)
    desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48})
    base_ds = BaseDataset(raw, desc, target_name="age")
    concat_ds = BaseConcatDataset([base_ds])

    if window_size_samples is None:
        window_size_samples = base_ds.raw.n_times
    stop_offset_samples = data.shape[1] - start_offset_samples
    epochs_ds = create_fixed_length_windows(
        concat_ds,
        start_offset_samples=start_offset_samples,
        stop_offset_samples=stop_offset_samples,
        window_size_samples=window_size_samples,
        window_stride_samples=window_stride_samples,
        drop_last_window=drop_last_window,
        mapping=mapping)

    if mapping is not None:
        assert base_ds.target == 48
        assert all(epochs_ds.datasets[0].windows.metadata['target'] == 0)

    epochs_data = epochs_ds.datasets[0].windows.get_data()

    idxs = np.arange(start_offset_samples,
                     stop_offset_samples - window_size_samples + 1,
                     window_stride_samples)
    if not drop_last_window and idxs[
            -1] != stop_offset_samples - window_size_samples:
        idxs = np.append(idxs, stop_offset_samples - window_size_samples)

    assert len(idxs) == epochs_data.shape[0], (
        'Number of epochs different than expected')
    assert window_size_samples == epochs_data.shape[2], (
        'Window size different than expected')
    for j, idx in enumerate(idxs):
        np.testing.assert_allclose(base_ds.raw.get_data()[:, idx:idx +
                                                          window_size_samples],
                                   epochs_data[j, :],
                                   err_msg=f'Epochs different for epoch {j}')
Esempio n. 12
0
def test_fixed_length_windower():
    rng = np.random.RandomState(42)
    info = mne.create_info(ch_names=['0', '1'], sfreq=50, ch_types='eeg')
    data = rng.randn(2, 1000)
    raw = mne.io.RawArray(data=data, info=info)
    df = pd.DataFrame(zip([True], ["M"], [48]),
                      columns=["pathological", "gender", "age"])
    base_ds = BaseDataset(raw, df, target="age")

    # test case:
    # (window_size_samples, overlap_size_samples, drop_last_samples,
    # trial_start_offset_samples, n_windows)
    test_cases = [(100, 90, True, 0., 11), (100, 50, True, 0., 19),
                  (None, 50, True, 0., 1)]

    for i, test_case in enumerate(test_cases):
        (window_size, stride_size, drop_last_samples,
         trial_start_offset_samples, n_windows) = test_case
        if window_size is None:
            window_size = base_ds.raw.n_times
        windower = FixedLengthWindower(
            supercrop_size_samples=window_size,
            supercrop_stride_samples=stride_size,
            drop_samples=drop_last_samples,
            trial_start_offset_samples=trial_start_offset_samples,
            trial_stop_offset_samples=-trial_start_offset_samples +
            window_size)

        epochs = windower(base_ds)
        epochs_data = epochs.get_data()
        if window_size is None:
            window_size = base_ds.raw.get_data().shape[1]
        idxs = np.arange(0,
                         base_ds.raw.get_data().shape[1] - window_size + 1,
                         stride_size)

        assert len(idxs) == epochs_data.shape[0], \
            f"Number of epochs different than expected for test case {i}"
        assert window_size == epochs_data.shape[2], \
            f"Window size different than expected for test case {i}"
        for j, idx in enumerate(idxs):
            np.testing.assert_allclose(
                base_ds.raw.get_data()[:, idx:idx + window_size],
                epochs_data[j, :],
                err_msg=f"Epochs different for test case {i} for epoch {j}")
Esempio n. 13
0
def dataset_target_time_series():
    rng = np.random.RandomState(42)
    signal_sfreq = 50
    info = mne.create_info(ch_names=['0', '1', 'target_0', 'target_1'],
                           sfreq=signal_sfreq,
                           ch_types=['eeg', 'eeg', 'misc', 'misc'])
    signal = rng.randn(2, 1000)
    targets = np.full((2, 1000), np.nan)
    targets_sfreq = 10
    targets_stride = int(signal_sfreq / targets_sfreq)
    targets[:, ::targets_stride] = rng.randn(2, int(targets.shape[1] / targets_stride))

    raw = mne.io.RawArray(np.concatenate([signal, targets]), info=info)
    desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48})
    base_dataset = BaseDataset(raw, desc, target_name=None)
    concat_ds = BaseConcatDataset([base_dataset])
    windows_dataset = create_windows_from_target_channels(
        concat_ds,
        window_size_samples=100,
    )
    return concat_ds, windows_dataset, targets, signal
Esempio n. 14
0
def create_from_X_y(X,
                    y,
                    drop_last_window,
                    sfreq=None,
                    ch_names=None,
                    window_size_samples=None,
                    window_stride_samples=None):
    """Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
    decoding with skorch and braindecode, where X is a list of pre-cut trials
    and y are corresponding targets.

    Parameters
    ----------
    X: array-like
        list of pre-cut trials as n_trials x n_channels x n_times
    y: array-like
        targets corresponding to the trials
    sfreq: common sampling frequency of all trials
    ch_names: array-like
        channel names of the trials
    drop_last_window: bool
        whether or not have a last overlapping window, when
        windows/windows do not equally divide the continuous signal
    window_size_samples: int
        window size
    window_stride_samples: int
        stride between windows

    Returns
    -------
    windows_datasets: BaseConcatDataset
        X and y transformed to a dataset format that is compatible with skorch
        and braindecode
    """
    n_samples_per_x = []
    base_datasets = []
    if sfreq is None:
        sfreq = 100
        log.info("No sampling frequency given, set to 100 Hz.")
    if ch_names is None:
        ch_names = [str(i) for i in range(X.shape[1])]
        log.info(f"No channel names given, set to 0-{X.shape[1]}).")

    for x, target in zip(X, y):
        n_samples_per_x.append(x.shape[1])
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
        raw = mne.io.RawArray(x, info)
        base_dataset = BaseDataset(raw,
                                   pd.Series({"target": target}),
                                   target_name="target")
        base_datasets.append(base_dataset)
    base_datasets = BaseConcatDataset(base_datasets)

    if window_size_samples is None and window_stride_samples is None:
        if not len(np.unique(n_samples_per_x)) == 1:
            raise ValueError(f"if 'window_size_samples' and "
                             f"'window_stride_samples' are None, "
                             f"all trials have to have the same length")
        window_size_samples = n_samples_per_x[0]
        window_stride_samples = n_samples_per_x[0]
    windows_datasets = create_fixed_length_windows(
        base_datasets,
        start_offset_samples=0,
        stop_offset_samples=0,
        window_size_samples=window_size_samples,
        window_stride_samples=window_stride_samples,
        drop_last_window=drop_last_window)
    return windows_datasets