def test_windows_from_events_different_events(tmpdir_factory): description_expected = 5 * ['T0', 'T1'] + 4 * ['T2', 'T3'] + 2 * ['T1'] raw = _get_raw(tmpdir_factory, description_expected[:10]) base_ds = BaseDataset(raw, description=pd.Series({'file_id': 1})) raw_1 = _get_raw(tmpdir_factory, description_expected[10:]) base_ds_1 = BaseDataset(raw_1, description=pd.Series({'file_id': 2})) concat_ds = BaseConcatDataset([base_ds, base_ds_1]) windows = create_windows_from_events( concat_ds=concat_ds, trial_start_offset_samples=0, trial_stop_offset_samples=0, window_size_samples=100, window_stride_samples=100, drop_last_window=False) description = [] events = [] for ds in windows.datasets: description += ds.windows.metadata['target'].to_list() events += ds.windows.events[:, 0].tolist() assert len(description) == 20 np.testing.assert_array_equal(description, 5 * [0, 1] + 4 * [2, 3] + 2 * [1]) np.testing.assert_array_equal( np.concatenate( [raw.time_as_index(raw.annotations.onset, use_rounding=True), raw_1.time_as_index(raw.annotations.onset, use_rounding=True)]), events)
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None): """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from files. Parameters ---------- path: str Path to the directory of the .fif / -epo.fif and .json files. preload: bool Whether to preload the data. ids_to_load: None | list(int) Ids of specific files to load. target_name: None or str Load specific description column as target. If not given, take saved target name. Returns ------- concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets """ # assume we have a single concat dataset to load concat_of_raws = os.path.isfile(os.path.join(path, '0-raw.fif')) assert not (not concat_of_raws and target_name is not None), ( 'Setting a new target is only supported for raws.') concat_of_epochs = os.path.isfile(os.path.join(path, '0-epo.fif')) paths = [path] # assume we have multiple concat datasets to load if not (concat_of_raws or concat_of_epochs): concat_of_raws = os.path.isfile(os.path.join(path, '0', '0-raw.fif')) concat_of_epochs = os.path.isfile(os.path.join(path, '0', '0-epo.fif')) path = os.path.join(path, '*', '') paths = glob(path) paths = sorted(paths, key=lambda p: int(p.split(os.sep)[-2])) if ids_to_load is not None: paths = [paths[i] for i in ids_to_load] ids_to_load = None # if we have neither a single nor multiple datasets, something went wrong assert concat_of_raws or concat_of_epochs, ( f'Expect either raw or epo to exist in {path} or in ' f'{os.path.join(path, "0")}') datasets = [] for path in paths: if concat_of_raws and target_name is None: target_file_name = os.path.join(path, 'target_name.json') target_name = json.load(open(target_file_name, "r"))['target_name'] all_signals, description = _load_signals_and_description( path=path, preload=preload, raws=concat_of_raws, ids_to_load=ids_to_load ) for i_signal, signal in enumerate(all_signals): if concat_of_raws: datasets.append( BaseDataset(signal, description.iloc[i_signal], target_name=target_name)) else: datasets.append( WindowsDataset(signal, description.iloc[i_signal]) ) return BaseConcatDataset(datasets)
def lazy_loadable_dataset(tmpdir_factory): """Make a dataset of fif files that can be loaded lazily. """ raw = _get_raw(tmpdir_factory) base_ds = BaseDataset(raw, description=pd.Series({'file_id': 1})) concat_ds = BaseConcatDataset([base_ds]) return concat_ds
def concat_ds_targets(): raws, description = fetch_data_with_moabb( dataset_name="BNCI2014001", subject_ids=4) events, _ = mne.events_from_annotations(raws[0]) targets = events[:, -1] - 1 ds = BaseDataset(raws[0], description.iloc[0]) concat_ds = BaseConcatDataset([ds]) return concat_ds, targets
def test_overlapping_trial_offsets(raw_info_targets): raw, info, targets = raw_info_targets base_ds = BaseDataset(raw, info) windower = EventWindower(trial_start_offset_samples=-2000, trial_stop_offset_samples=1000, supercrop_size_samples=1000, supercrop_stride_samples=1000) with pytest.raises(AssertionError, match='trials overlap not implemented'): windower(base_ds)
def test_stride_has_no_effect(raw_info_targets): raw, info, targets = raw_info_targets base_ds = BaseDataset(raw, info) windower = EventWindower(trial_start_offset_samples=0, trial_stop_offset_samples=1000, supercrop_size_samples=1000, supercrop_stride_samples=1000) windows = windower(base_ds) description = windows.events[:, -1] assert len(description) == len(targets) np.testing.assert_array_equal(description, targets)
def test_dropping_last_incomplete_supercrop(raw_info_targets): raw, info, targets = raw_info_targets base_ds = BaseDataset(raw, info) windower = EventWindower(trial_start_offset_samples=-250, trial_stop_offset_samples=250, supercrop_size_samples=250, supercrop_stride_samples=300, drop_samples=True) windows = windower(base_ds) description = windows.events[:, -1] assert len(description) == len(targets) np.testing.assert_array_equal(description, targets)
def test_shifting_last_supercrop_back_in(raw_info_targets): raw, info, targets = raw_info_targets base_ds = BaseDataset(raw, info) windower = EventWindower(trial_start_offset_samples=-250, trial_stop_offset_samples=250, supercrop_size_samples=250, supercrop_stride_samples=300) windows = windower(base_ds) description = windows.events[:, -1] assert len(description) == len(targets) * 2 np.testing.assert_array_equal(description[0::2], targets) np.testing.assert_array_equal(description[1::2], targets)
def test_maximally_overlapping_supercrops(raw_info_targets): raw, info, targets = raw_info_targets base_ds = BaseDataset(raw, info) windower = EventWindower(trial_start_offset_samples=-2, trial_stop_offset_samples=1000, supercrop_size_samples=1000, supercrop_stride_samples=1) windows = windower(base_ds) description = windows.events[:, -1] assert len(description) == len(targets) * 3 np.testing.assert_array_equal(description[0::3], targets) np.testing.assert_array_equal(description[1::3], targets) np.testing.assert_array_equal(description[2::3], targets)
def test_windows_from_events_mapping_filter(tmpdir_factory): raw = _get_raw(tmpdir_factory, 5 * ['T0', 'T1']) base_ds = BaseDataset(raw, description=pd.Series({'file_id': 1})) concat_ds = BaseConcatDataset([base_ds]) windows = create_windows_from_events( concat_ds=concat_ds, trial_start_offset_samples=0, trial_stop_offset_samples=0, window_size_samples=100, window_stride_samples=100, drop_last_window=False, mapping={'T1': 0}) description = windows.datasets[0].windows.metadata['target'].to_list() assert len(description) == 5 np.testing.assert_array_equal(description, np.zeros(5)) # dataset should contain only 'T1' events np.testing.assert_array_equal( (raw.time_as_index(raw.annotations.onset[1::2], use_rounding=True)), windows.datasets[0].windows.events[:, 0])
def test_fixed_length_windower(start_offset_samples, window_size_samples, window_stride_samples, drop_last_window, mapping): rng = np.random.RandomState(42) info = mne.create_info(ch_names=['0', '1'], sfreq=50, ch_types='eeg') data = rng.randn(2, 1000) raw = mne.io.RawArray(data=data, info=info) desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48}) base_ds = BaseDataset(raw, desc, target_name="age") concat_ds = BaseConcatDataset([base_ds]) if window_size_samples is None: window_size_samples = base_ds.raw.n_times stop_offset_samples = data.shape[1] - start_offset_samples epochs_ds = create_fixed_length_windows( concat_ds, start_offset_samples=start_offset_samples, stop_offset_samples=stop_offset_samples, window_size_samples=window_size_samples, window_stride_samples=window_stride_samples, drop_last_window=drop_last_window, mapping=mapping) if mapping is not None: assert base_ds.target == 48 assert all(epochs_ds.datasets[0].windows.metadata['target'] == 0) epochs_data = epochs_ds.datasets[0].windows.get_data() idxs = np.arange(start_offset_samples, stop_offset_samples - window_size_samples + 1, window_stride_samples) if not drop_last_window and idxs[ -1] != stop_offset_samples - window_size_samples: idxs = np.append(idxs, stop_offset_samples - window_size_samples) assert len(idxs) == epochs_data.shape[0], ( 'Number of epochs different than expected') assert window_size_samples == epochs_data.shape[2], ( 'Window size different than expected') for j, idx in enumerate(idxs): np.testing.assert_allclose(base_ds.raw.get_data()[:, idx:idx + window_size_samples], epochs_data[j, :], err_msg=f'Epochs different for epoch {j}')
def test_fixed_length_windower(): rng = np.random.RandomState(42) info = mne.create_info(ch_names=['0', '1'], sfreq=50, ch_types='eeg') data = rng.randn(2, 1000) raw = mne.io.RawArray(data=data, info=info) df = pd.DataFrame(zip([True], ["M"], [48]), columns=["pathological", "gender", "age"]) base_ds = BaseDataset(raw, df, target="age") # test case: # (window_size_samples, overlap_size_samples, drop_last_samples, # trial_start_offset_samples, n_windows) test_cases = [(100, 90, True, 0., 11), (100, 50, True, 0., 19), (None, 50, True, 0., 1)] for i, test_case in enumerate(test_cases): (window_size, stride_size, drop_last_samples, trial_start_offset_samples, n_windows) = test_case if window_size is None: window_size = base_ds.raw.n_times windower = FixedLengthWindower( supercrop_size_samples=window_size, supercrop_stride_samples=stride_size, drop_samples=drop_last_samples, trial_start_offset_samples=trial_start_offset_samples, trial_stop_offset_samples=-trial_start_offset_samples + window_size) epochs = windower(base_ds) epochs_data = epochs.get_data() if window_size is None: window_size = base_ds.raw.get_data().shape[1] idxs = np.arange(0, base_ds.raw.get_data().shape[1] - window_size + 1, stride_size) assert len(idxs) == epochs_data.shape[0], \ f"Number of epochs different than expected for test case {i}" assert window_size == epochs_data.shape[2], \ f"Window size different than expected for test case {i}" for j, idx in enumerate(idxs): np.testing.assert_allclose( base_ds.raw.get_data()[:, idx:idx + window_size], epochs_data[j, :], err_msg=f"Epochs different for test case {i} for epoch {j}")
def dataset_target_time_series(): rng = np.random.RandomState(42) signal_sfreq = 50 info = mne.create_info(ch_names=['0', '1', 'target_0', 'target_1'], sfreq=signal_sfreq, ch_types=['eeg', 'eeg', 'misc', 'misc']) signal = rng.randn(2, 1000) targets = np.full((2, 1000), np.nan) targets_sfreq = 10 targets_stride = int(signal_sfreq / targets_sfreq) targets[:, ::targets_stride] = rng.randn(2, int(targets.shape[1] / targets_stride)) raw = mne.io.RawArray(np.concatenate([signal, targets]), info=info) desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48}) base_dataset = BaseDataset(raw, desc, target_name=None) concat_ds = BaseConcatDataset([base_dataset]) windows_dataset = create_windows_from_target_channels( concat_ds, window_size_samples=100, ) return concat_ds, windows_dataset, targets, signal
def create_from_X_y(X, y, drop_last_window, sfreq=None, ch_names=None, window_size_samples=None, window_stride_samples=None): """Create a BaseConcatDataset of WindowsDatasets from X and y to be used for decoding with skorch and braindecode, where X is a list of pre-cut trials and y are corresponding targets. Parameters ---------- X: array-like list of pre-cut trials as n_trials x n_channels x n_times y: array-like targets corresponding to the trials sfreq: common sampling frequency of all trials ch_names: array-like channel names of the trials drop_last_window: bool whether or not have a last overlapping window, when windows/windows do not equally divide the continuous signal window_size_samples: int window size window_stride_samples: int stride between windows Returns ------- windows_datasets: BaseConcatDataset X and y transformed to a dataset format that is compatible with skorch and braindecode """ n_samples_per_x = [] base_datasets = [] if sfreq is None: sfreq = 100 log.info("No sampling frequency given, set to 100 Hz.") if ch_names is None: ch_names = [str(i) for i in range(X.shape[1])] log.info(f"No channel names given, set to 0-{X.shape[1]}).") for x, target in zip(X, y): n_samples_per_x.append(x.shape[1]) info = mne.create_info(ch_names=ch_names, sfreq=sfreq) raw = mne.io.RawArray(x, info) base_dataset = BaseDataset(raw, pd.Series({"target": target}), target_name="target") base_datasets.append(base_dataset) base_datasets = BaseConcatDataset(base_datasets) if window_size_samples is None and window_stride_samples is None: if not len(np.unique(n_samples_per_x)) == 1: raise ValueError(f"if 'window_size_samples' and " f"'window_stride_samples' are None, " f"all trials have to have the same length") window_size_samples = n_samples_per_x[0] window_stride_samples = n_samples_per_x[0] windows_datasets = create_fixed_length_windows( base_datasets, start_offset_samples=0, stop_offset_samples=0, window_size_samples=window_size_samples, window_stride_samples=window_stride_samples, drop_last_window=drop_last_window) return windows_datasets