Exemplo n.º 1
0
def test_target_name_not_in_description(set_up):
    raw, _, _, _, _, _ = set_up
    with pytest.warns(UserWarning):
        base_dataset = BaseDataset(raw,
                                   target_name=('pathological', 'gender',
                                                'age'))
    with pytest.raises(TypeError):
        x, y = base_dataset[0]
    base_dataset.set_description({
        'pathological': True,
        'gender': 'M',
        'age': 48
    })
    x, y = base_dataset[0]
Exemplo n.º 2
0
def set_up():
    rng = np.random.RandomState(42)
    info = mne.create_info(ch_names=['0', '1'], sfreq=50, ch_types='eeg')
    raw = mne.io.RawArray(data=rng.randn(2, 1000), info=info)
    desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48})
    base_dataset = BaseDataset(raw, desc, target_name='age')

    events = np.array([[100, 0, 1],
                       [200, 0, 2],
                       [300, 0, 1],
                       [400, 0, 4],
                       [500, 0, 3]])
    window_idxs = [(0, 0, 100),
                   (0, 100, 200),
                   (1, 0, 100),
                   (2, 0, 100),
                   (2, 50, 150)]
    i_window_in_trial, i_start_in_trial, i_stop_in_trial = list(
        zip(*window_idxs))
    metadata = pd.DataFrame(
        {'sample': events[:, 0],
         'x': events[:, 1],
         'target': events[:, 2],
         'i_window_in_trial': i_window_in_trial,
         'i_start_in_trial': i_start_in_trial,
         'i_stop_in_trial': i_stop_in_trial})

    mne_epochs = mne.Epochs(raw=raw, events=events, metadata=metadata)
    windows_dataset = WindowsDataset(mne_epochs, desc)

    return raw, base_dataset, mne_epochs, windows_dataset, events, window_idxs
Exemplo n.º 3
0
def test_preprocessors_with_misc_channels():
    rng = np.random.RandomState(42)
    signal_sfreq = 50
    info = mne.create_info(ch_names=['0', '1', 'target_0', 'target_1'],
                           sfreq=signal_sfreq,
                           ch_types=['eeg', 'eeg', 'misc', 'misc'])
    signal = rng.randn(2, 1000)
    targets = rng.randn(2, 1000)
    raw = mne.io.RawArray(np.concatenate([signal, targets]), info=info)
    desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48})
    base_dataset = BaseDataset(raw, desc, target_name=None)
    concat_ds = BaseConcatDataset([base_dataset])
    preprocessors = [
        Preprocessor('pick_types', eeg=True, misc=True),
        Preprocessor(lambda x: x / 1e6),
    ]

    preprocess(concat_ds, preprocessors)

    # Check whether preprocessing has not affected the targets
    # This is only valid for preprocessors that use mne functions which do not modify
    # `misc` channels.
    np.testing.assert_array_equal(
        concat_ds.datasets[0].raw.get_data()[-2:, :],
        targets
    )
Exemplo n.º 4
0
def test_description_incorrect_type(set_up):
    raw, _, _, _, _, _ = set_up
    with pytest.raises(ValueError):
        BaseDataset(
            raw=raw,
            description=('test', 4),
        )
Exemplo n.º 5
0
def concat_ds_targets():
    raws, description = fetch_data_with_moabb(
        dataset_name="BNCI2014001", subject_ids=4)
    events, _ = mne.events_from_annotations(raws[0])
    targets = events[:, -1] - 1
    ds = [BaseDataset(raws[i], description.iloc[i]) for i in range(3)]
    concat_ds = BaseConcatDataset(ds)
    return concat_ds, targets
Exemplo n.º 6
0
def windows_ds():
    raws, description = fetch_data_with_moabb(
        dataset_name='BNCI2014001', subject_ids=4)
    ds = [BaseDataset(raws[i], description.iloc[i]) for i in range(3)]
    concat_ds = BaseConcatDataset(ds)

    windows_ds = create_fixed_length_windows(
        concat_ds=concat_ds, start_offset_samples=0, stop_offset_samples=None,
        window_size_samples=500, window_stride_samples=500,
        drop_last_window=False, preload=False)

    return windows_ds
Exemplo n.º 7
0
 def __init__(self, subject_ids=None):
     data_path = self.download()
     if isinstance(subject_ids, int):
         subject_ids = [subject_ids]
     if subject_ids is None:
         subject_ids = self.possible_subjects
     self._validate_subjects(subject_ids)
     files_list = [f'{data_path}/sub{i}_comp.mat' for i in subject_ids]
     datasets = []
     for file_path in files_list:
         raw_train, raw_test = self._load_data_to_mne(file_path)
         desc_train = dict(
             subject=file_path.split('/')[-1].split('sub')[1][0],
             file_name=file_path.split('/')[-1],
             session='train')
         desc_test = dict(
             subject=file_path.split('/')[-1].split('sub')[1][0],
             file_name=file_path.split('/')[-1],
             session='test')
         datasets.append(BaseDataset(raw_train, description=desc_train))
         datasets.append(BaseDataset(raw_test, description=desc_test))
     super().__init__(datasets)
Exemplo n.º 8
0
def test_target_name_list(set_up):
    raw, _, _, _, _, _ = set_up
    target_names = ['pathological', 'gender', 'age']
    base_dataset = BaseDataset(
        raw=raw,
        description={
            'pathological': True,
            'gender': 'M',
            'age': 48
        },
        target_name=target_names,
    )
    assert base_dataset.target_name == target_names
Exemplo n.º 9
0
def target_windows_ds():
    raws, description = fetch_data_with_moabb(dataset_name='BNCI2014001',
                                              subject_ids=4)
    ds = [BaseDataset(raws[i], description.iloc[i]) for i in range(3)]
    concat_ds = BaseConcatDataset(ds)

    windows_ds = create_windows_from_events(concat_ds,
                                            trial_start_offset_samples=0,
                                            trial_stop_offset_samples=0,
                                            window_size_samples=None,
                                            window_stride_samples=None,
                                            drop_last_window=False)

    return windows_ds
Exemplo n.º 10
0
def fake_regression_dataset(n_fake_recs, n_fake_chs, fake_sfreq, fake_duration_s):
    datasets = []
    for i in range(n_fake_recs):
        train_or_eval = "eval" if i == 0 else "train"
        raw, save_fname = create_mne_dummy_raw(
            n_channels=n_fake_chs, n_times=fake_duration_s*fake_sfreq,
            sfreq=fake_sfreq, savedir=None)
        target = np.random.randint(0, 100, n_classes)
        if n_classes == 1:
            target = target[0]
        fake_descrition = pd.Series(
            data=[target, train_or_eval],
            index=["target", "session"])
        base_ds = BaseDataset(raw, fake_descrition, target_name="target")
        datasets.append(base_ds)
    dataset = BaseConcatDataset(datasets)
    return dataset
Exemplo n.º 11
0
def test_target_in_subject_info(set_up):
    raw, _, _, _, _, _ = set_up
    desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48})
    with pytest.raises(ValueError, match="'does_not_exist' not in description"):
        BaseDataset(raw, desc, target_name='does_not_exist')
Exemplo n.º 12
0
def load_5f_halt(args):
	"""Loading and preprocessing the validation/traning data of the 5F or HaLT
	datasets.

	Parameters
	----------
	args : Namespace
		Input arguments.

	Returns
	----------
	dataset : BaseConcatDataset
		BaseConcatDataset of raw MNE arrays.

	"""

	import os
	from scipy import io
	import numpy as np
	import mne
	from sklearn.utils import resample
	from braindecode.datautil import exponential_moving_standardize
	from braindecode.datasets import BaseDataset, BaseConcatDataset

	### Channel types ###
	# Rejecting channels A1, A1, X5 (see paper)
	ch_names = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2',
		'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz', 'stim']
	ch_types = ['eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg',
		'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg',
		'eeg', 'stim']
	idx_chan = np.ones(22, dtype=bool)
	unused_chans = np.asarray((10, 11, 21))
	idx_chan[unused_chans] = False

	### Subjects ###
	dataset = []
	if args.dataset == '5f':
		data_dir = os.path.join(args.project_dir, 'datasets', '5f', 'data')
	elif args.dataset == 'halt':
		data_dir = os.path.join(args.project_dir, 'datasets', 'halt', 'data')
	files = os.listdir(data_dir)
	files.sort()
	# Loading only one subject for intra-subject analysis
	if args.inter_subject == False:
		used_files = []
		for file in files:
			if 'Subject'+args.test_sub in file: used_files.append(file)
	else:
		used_files = files

	### Loading and preprocessing the .mat data ###
	for file in used_files:
		print('\n\nData file --> '+file+'\n\n')
		current_sub = file.partition('Subject')[2][0]
		data = io.loadmat(os.path.join(data_dir, file),
			chars_as_strings=True)['o']
		sfreq = np.asarray(data[0][0]['sampFreq'][0])
		marker = np.transpose(np.asarray(data[0][0]['marker']))
		data = np.transpose(np.asarray(data[0][0]['data']))[idx_chan,:]
		data = exponential_moving_standardize(data)
		data = np.append(data, marker, 0)
		del marker

		### Converting to MNE format and downsample ###
		info = mne.create_info(ch_names, sfreq, ch_types)
		raw_train = mne.io.RawArray(data, info)
		raw_train.info['highpass'] = 0.53
		raw_train.info['lowpass'] = 70
		del data

		### Get events and downsample data ###
		events = mne.find_events(raw_train, stim_channel='stim', output='onset',
			consecutive='increasing')
		# Drop unused events
		idx = np.ones(events.shape[0], dtype=bool)
		for e in range(len(idx)):
			if events[e,2] > 6:
				idx[e] = False
		events = events[idx]
		# Drop stimuli channel
		raw_train.pick_types(eeg=True)
		# Downsampling the data
		raw_train.resample(args.sfreq)

		### Dividing events into training, validation and test ###
		# For intra-subject decoding, 10 trials per condition are used for
		# validation, 10 trials for testing, and the remaining trials are used
		# for training.
		# For inter-subject decoding 75 trials per condition of the subject of
		# interest are used for validation and 75 for testing. All the data
		# from the other subjects is used for training.
		idx_train = np.zeros((events.shape[0],len(np.unique(events[:,2]))),
			dtype=bool)
		idx_val = np.zeros((events.shape[0],len(np.unique(events[:,2]))),
			dtype=bool)
		idx_test = np.zeros((events.shape[0],len(np.unique(events[:,2]))),
			dtype=bool)
		for e in range(len(np.unique(events[:,2]))):
			if args.inter_subject == False:
				shuf = resample(np.where(events[:,2] == e+1)[0], replace=False)
				idx_val[shuf[:10],e] = True
				idx_test[shuf[10:20],e] = True
				idx_train[shuf[20:],e] = True
			else:
				if args.test_sub == current_sub:
					idx_val[np.where(events[:,2] == e+1)[0][0:75],e] = True
					idx_test[np.where(events[:,2] == e+1)[0][75:150],e] = True
				else:
					idx_train[np.where(events[:,2] == e+1)[0],e] = True
		idx_train = np.sum(idx_train, 1, dtype=bool)
		idx_val = np.sum(idx_val, 1, dtype=bool)
		idx_test = np.sum(idx_test, 1, dtype=bool)
		events_train = events[idx_train,:]
		events_val = events[idx_val,:]
		events_test = events[idx_test,:]

		### Creating the raw data annotations ###
		if args.dataset == '5f':
			event_desc = {1: 'thumb', 2: 'index_finger', 3: 'middle_finger',
				4: 'ring_finger', 5: 'pinkie_finger'}
		elif args.dataset == 'halt':
			event_desc = {1: 'left_hand', 2: 'right_hand', 3: 'passive_neutral',
				4: 'left_leg', 5: 'tongue', 6: 'right_leg'}
		if args.inter_subject == False:
			annotations_train = mne.annotations_from_events(events_train, sfreq,
				event_desc=event_desc)
			annotations_val = mne.annotations_from_events(events_val, sfreq,
				event_desc=event_desc)
			annotations_test = mne.annotations_from_events(events_test, sfreq,
				event_desc=event_desc)
			# Creating 1s trials
			annotations_train.duration = np.repeat(1., len(events_train))
			annotations_val.duration = np.repeat(1., len(events_val))
			annotations_test.duration = np.repeat(1., len(events_test))
			# Adding annotations to raw data
			raw_val = raw_train.copy()
			raw_test = raw_train.copy()
			raw_train.set_annotations(annotations_train)
			raw_val.set_annotations(annotations_val)
			raw_test.set_annotations(annotations_test)
		else:
			if args.test_sub == current_sub:
				annotations_val = mne.annotations_from_events(events_val, sfreq,
					event_desc=event_desc)
				annotations_test = mne.annotations_from_events(events_test,
					sfreq, event_desc=event_desc)
				# Creating 1s trials
				annotations_val.duration = np.repeat(1., len(events_val))
				annotations_test.duration = np.repeat(1., len(events_test))
				# Adding annotations to raw data
				raw_val = raw_train.copy()
				raw_test = raw_train.copy()
				raw_val.set_annotations(annotations_val)
				raw_test.set_annotations(annotations_test)
			else:
				annotations_train = mne.annotations_from_events(events_train,
					sfreq, event_desc=event_desc)
				# Creating 1s trials
				annotations_train.duration = np.repeat(1., len(events_train))
				# Adding annotations to raw data
				raw_train.set_annotations(annotations_train)

		### Converting to BaseConcatDataset format ###
		description_train = {'subject': current_sub, 'partition': 'training'}
		description_val = {'subject': current_sub, 'partition': 'validation'}
		description_test = {'subject': current_sub, 'partition': 'test'}
		if args.inter_subject == False:
			dataset.append(BaseDataset(raw_train, description_train))
			dataset.append(BaseDataset(raw_val, description_val))
			dataset.append(BaseDataset(raw_test, description_test))
		else:
			if args.test_sub == current_sub:
				dataset.append(BaseDataset(raw_val, description_val))
				dataset.append(BaseDataset(raw_test, description_test))
			else:
				dataset.append(BaseDataset(raw_train, description_train))
	dataset = BaseConcatDataset(dataset)

	### Output ###
	return dataset
Exemplo n.º 13
0
def test_target_name_incorrect_type(set_up):
    raw, _, _, _, _, _ = set_up
    with pytest.raises(ValueError,
                       match='target_name has to be None, str, tuple or list'):
        BaseDataset(raw, target_name={'target': 1})