def test_target_name_not_in_description(set_up): raw, _, _, _, _, _ = set_up with pytest.warns(UserWarning): base_dataset = BaseDataset(raw, target_name=('pathological', 'gender', 'age')) with pytest.raises(TypeError): x, y = base_dataset[0] base_dataset.set_description({ 'pathological': True, 'gender': 'M', 'age': 48 }) x, y = base_dataset[0]
def set_up(): rng = np.random.RandomState(42) info = mne.create_info(ch_names=['0', '1'], sfreq=50, ch_types='eeg') raw = mne.io.RawArray(data=rng.randn(2, 1000), info=info) desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48}) base_dataset = BaseDataset(raw, desc, target_name='age') events = np.array([[100, 0, 1], [200, 0, 2], [300, 0, 1], [400, 0, 4], [500, 0, 3]]) window_idxs = [(0, 0, 100), (0, 100, 200), (1, 0, 100), (2, 0, 100), (2, 50, 150)] i_window_in_trial, i_start_in_trial, i_stop_in_trial = list( zip(*window_idxs)) metadata = pd.DataFrame( {'sample': events[:, 0], 'x': events[:, 1], 'target': events[:, 2], 'i_window_in_trial': i_window_in_trial, 'i_start_in_trial': i_start_in_trial, 'i_stop_in_trial': i_stop_in_trial}) mne_epochs = mne.Epochs(raw=raw, events=events, metadata=metadata) windows_dataset = WindowsDataset(mne_epochs, desc) return raw, base_dataset, mne_epochs, windows_dataset, events, window_idxs
def test_preprocessors_with_misc_channels(): rng = np.random.RandomState(42) signal_sfreq = 50 info = mne.create_info(ch_names=['0', '1', 'target_0', 'target_1'], sfreq=signal_sfreq, ch_types=['eeg', 'eeg', 'misc', 'misc']) signal = rng.randn(2, 1000) targets = rng.randn(2, 1000) raw = mne.io.RawArray(np.concatenate([signal, targets]), info=info) desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48}) base_dataset = BaseDataset(raw, desc, target_name=None) concat_ds = BaseConcatDataset([base_dataset]) preprocessors = [ Preprocessor('pick_types', eeg=True, misc=True), Preprocessor(lambda x: x / 1e6), ] preprocess(concat_ds, preprocessors) # Check whether preprocessing has not affected the targets # This is only valid for preprocessors that use mne functions which do not modify # `misc` channels. np.testing.assert_array_equal( concat_ds.datasets[0].raw.get_data()[-2:, :], targets )
def test_description_incorrect_type(set_up): raw, _, _, _, _, _ = set_up with pytest.raises(ValueError): BaseDataset( raw=raw, description=('test', 4), )
def concat_ds_targets(): raws, description = fetch_data_with_moabb( dataset_name="BNCI2014001", subject_ids=4) events, _ = mne.events_from_annotations(raws[0]) targets = events[:, -1] - 1 ds = [BaseDataset(raws[i], description.iloc[i]) for i in range(3)] concat_ds = BaseConcatDataset(ds) return concat_ds, targets
def windows_ds(): raws, description = fetch_data_with_moabb( dataset_name='BNCI2014001', subject_ids=4) ds = [BaseDataset(raws[i], description.iloc[i]) for i in range(3)] concat_ds = BaseConcatDataset(ds) windows_ds = create_fixed_length_windows( concat_ds=concat_ds, start_offset_samples=0, stop_offset_samples=None, window_size_samples=500, window_stride_samples=500, drop_last_window=False, preload=False) return windows_ds
def __init__(self, subject_ids=None): data_path = self.download() if isinstance(subject_ids, int): subject_ids = [subject_ids] if subject_ids is None: subject_ids = self.possible_subjects self._validate_subjects(subject_ids) files_list = [f'{data_path}/sub{i}_comp.mat' for i in subject_ids] datasets = [] for file_path in files_list: raw_train, raw_test = self._load_data_to_mne(file_path) desc_train = dict( subject=file_path.split('/')[-1].split('sub')[1][0], file_name=file_path.split('/')[-1], session='train') desc_test = dict( subject=file_path.split('/')[-1].split('sub')[1][0], file_name=file_path.split('/')[-1], session='test') datasets.append(BaseDataset(raw_train, description=desc_train)) datasets.append(BaseDataset(raw_test, description=desc_test)) super().__init__(datasets)
def test_target_name_list(set_up): raw, _, _, _, _, _ = set_up target_names = ['pathological', 'gender', 'age'] base_dataset = BaseDataset( raw=raw, description={ 'pathological': True, 'gender': 'M', 'age': 48 }, target_name=target_names, ) assert base_dataset.target_name == target_names
def target_windows_ds(): raws, description = fetch_data_with_moabb(dataset_name='BNCI2014001', subject_ids=4) ds = [BaseDataset(raws[i], description.iloc[i]) for i in range(3)] concat_ds = BaseConcatDataset(ds) windows_ds = create_windows_from_events(concat_ds, trial_start_offset_samples=0, trial_stop_offset_samples=0, window_size_samples=None, window_stride_samples=None, drop_last_window=False) return windows_ds
def fake_regression_dataset(n_fake_recs, n_fake_chs, fake_sfreq, fake_duration_s): datasets = [] for i in range(n_fake_recs): train_or_eval = "eval" if i == 0 else "train" raw, save_fname = create_mne_dummy_raw( n_channels=n_fake_chs, n_times=fake_duration_s*fake_sfreq, sfreq=fake_sfreq, savedir=None) target = np.random.randint(0, 100, n_classes) if n_classes == 1: target = target[0] fake_descrition = pd.Series( data=[target, train_or_eval], index=["target", "session"]) base_ds = BaseDataset(raw, fake_descrition, target_name="target") datasets.append(base_ds) dataset = BaseConcatDataset(datasets) return dataset
def test_target_in_subject_info(set_up): raw, _, _, _, _, _ = set_up desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48}) with pytest.raises(ValueError, match="'does_not_exist' not in description"): BaseDataset(raw, desc, target_name='does_not_exist')
def load_5f_halt(args): """Loading and preprocessing the validation/traning data of the 5F or HaLT datasets. Parameters ---------- args : Namespace Input arguments. Returns ---------- dataset : BaseConcatDataset BaseConcatDataset of raw MNE arrays. """ import os from scipy import io import numpy as np import mne from sklearn.utils import resample from braindecode.datautil import exponential_moving_standardize from braindecode.datasets import BaseDataset, BaseConcatDataset ### Channel types ### # Rejecting channels A1, A1, X5 (see paper) ch_names = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz', 'stim'] ch_types = ['eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'stim'] idx_chan = np.ones(22, dtype=bool) unused_chans = np.asarray((10, 11, 21)) idx_chan[unused_chans] = False ### Subjects ### dataset = [] if args.dataset == '5f': data_dir = os.path.join(args.project_dir, 'datasets', '5f', 'data') elif args.dataset == 'halt': data_dir = os.path.join(args.project_dir, 'datasets', 'halt', 'data') files = os.listdir(data_dir) files.sort() # Loading only one subject for intra-subject analysis if args.inter_subject == False: used_files = [] for file in files: if 'Subject'+args.test_sub in file: used_files.append(file) else: used_files = files ### Loading and preprocessing the .mat data ### for file in used_files: print('\n\nData file --> '+file+'\n\n') current_sub = file.partition('Subject')[2][0] data = io.loadmat(os.path.join(data_dir, file), chars_as_strings=True)['o'] sfreq = np.asarray(data[0][0]['sampFreq'][0]) marker = np.transpose(np.asarray(data[0][0]['marker'])) data = np.transpose(np.asarray(data[0][0]['data']))[idx_chan,:] data = exponential_moving_standardize(data) data = np.append(data, marker, 0) del marker ### Converting to MNE format and downsample ### info = mne.create_info(ch_names, sfreq, ch_types) raw_train = mne.io.RawArray(data, info) raw_train.info['highpass'] = 0.53 raw_train.info['lowpass'] = 70 del data ### Get events and downsample data ### events = mne.find_events(raw_train, stim_channel='stim', output='onset', consecutive='increasing') # Drop unused events idx = np.ones(events.shape[0], dtype=bool) for e in range(len(idx)): if events[e,2] > 6: idx[e] = False events = events[idx] # Drop stimuli channel raw_train.pick_types(eeg=True) # Downsampling the data raw_train.resample(args.sfreq) ### Dividing events into training, validation and test ### # For intra-subject decoding, 10 trials per condition are used for # validation, 10 trials for testing, and the remaining trials are used # for training. # For inter-subject decoding 75 trials per condition of the subject of # interest are used for validation and 75 for testing. All the data # from the other subjects is used for training. idx_train = np.zeros((events.shape[0],len(np.unique(events[:,2]))), dtype=bool) idx_val = np.zeros((events.shape[0],len(np.unique(events[:,2]))), dtype=bool) idx_test = np.zeros((events.shape[0],len(np.unique(events[:,2]))), dtype=bool) for e in range(len(np.unique(events[:,2]))): if args.inter_subject == False: shuf = resample(np.where(events[:,2] == e+1)[0], replace=False) idx_val[shuf[:10],e] = True idx_test[shuf[10:20],e] = True idx_train[shuf[20:],e] = True else: if args.test_sub == current_sub: idx_val[np.where(events[:,2] == e+1)[0][0:75],e] = True idx_test[np.where(events[:,2] == e+1)[0][75:150],e] = True else: idx_train[np.where(events[:,2] == e+1)[0],e] = True idx_train = np.sum(idx_train, 1, dtype=bool) idx_val = np.sum(idx_val, 1, dtype=bool) idx_test = np.sum(idx_test, 1, dtype=bool) events_train = events[idx_train,:] events_val = events[idx_val,:] events_test = events[idx_test,:] ### Creating the raw data annotations ### if args.dataset == '5f': event_desc = {1: 'thumb', 2: 'index_finger', 3: 'middle_finger', 4: 'ring_finger', 5: 'pinkie_finger'} elif args.dataset == 'halt': event_desc = {1: 'left_hand', 2: 'right_hand', 3: 'passive_neutral', 4: 'left_leg', 5: 'tongue', 6: 'right_leg'} if args.inter_subject == False: annotations_train = mne.annotations_from_events(events_train, sfreq, event_desc=event_desc) annotations_val = mne.annotations_from_events(events_val, sfreq, event_desc=event_desc) annotations_test = mne.annotations_from_events(events_test, sfreq, event_desc=event_desc) # Creating 1s trials annotations_train.duration = np.repeat(1., len(events_train)) annotations_val.duration = np.repeat(1., len(events_val)) annotations_test.duration = np.repeat(1., len(events_test)) # Adding annotations to raw data raw_val = raw_train.copy() raw_test = raw_train.copy() raw_train.set_annotations(annotations_train) raw_val.set_annotations(annotations_val) raw_test.set_annotations(annotations_test) else: if args.test_sub == current_sub: annotations_val = mne.annotations_from_events(events_val, sfreq, event_desc=event_desc) annotations_test = mne.annotations_from_events(events_test, sfreq, event_desc=event_desc) # Creating 1s trials annotations_val.duration = np.repeat(1., len(events_val)) annotations_test.duration = np.repeat(1., len(events_test)) # Adding annotations to raw data raw_val = raw_train.copy() raw_test = raw_train.copy() raw_val.set_annotations(annotations_val) raw_test.set_annotations(annotations_test) else: annotations_train = mne.annotations_from_events(events_train, sfreq, event_desc=event_desc) # Creating 1s trials annotations_train.duration = np.repeat(1., len(events_train)) # Adding annotations to raw data raw_train.set_annotations(annotations_train) ### Converting to BaseConcatDataset format ### description_train = {'subject': current_sub, 'partition': 'training'} description_val = {'subject': current_sub, 'partition': 'validation'} description_test = {'subject': current_sub, 'partition': 'test'} if args.inter_subject == False: dataset.append(BaseDataset(raw_train, description_train)) dataset.append(BaseDataset(raw_val, description_val)) dataset.append(BaseDataset(raw_test, description_test)) else: if args.test_sub == current_sub: dataset.append(BaseDataset(raw_val, description_val)) dataset.append(BaseDataset(raw_test, description_test)) else: dataset.append(BaseDataset(raw_train, description_train)) dataset = BaseConcatDataset(dataset) ### Output ### return dataset
def test_target_name_incorrect_type(set_up): raw, _, _, _, _, _ = set_up with pytest.raises(ValueError, match='target_name has to be None, str, tuple or list'): BaseDataset(raw, target_name={'target': 1})