def run_until_second_stop(self): """ Run training and evaluation using combined training + validation sets for training on both datasets. Runs until loss on validation set decreases below loss on training set of best epoch or until as many epochs trained after as before first stop. """ datasets = self.datasets datasets['train_1'] = concatenate_sets( [datasets['train_1'], datasets['valid_1']]) datasets['train_2'] = concatenate_sets( [datasets['train_2'], datasets['valid_2']]) self.run_until_stop(datasets, remember_best=True)
def get_single_subj_dataset(self, subject=None, final_evaluation=False): if subject not in self.datasets['train'].keys(): self.datasets['train'][subject], self.datasets['valid'][subject], self.datasets['test'][subject] = \ get_train_val_test(global_vars.get('data_folder'), subject) single_subj_dataset = OrderedDict( (('train', self.datasets['train'][subject]), ('valid', self.datasets['valid'][subject]), ('test', self.datasets['test'][subject]))) if final_evaluation: single_subj_dataset['train'] = concatenate_sets( [single_subj_dataset['train'], single_subj_dataset['valid']]) return single_subj_dataset
def _load_and_merge_data(file_paths): """ Load multiple HGD subjects and merged them into a new dataset :param file_paths: :return: """ if len(file_paths) == 1: # if just 1 subject's data return _load_h5_data(file_paths[0]) signal_and_target_data_list = [] for path in file_paths: signal_and_target_data_list.append(_load_h5_data(path)) return concatenate_sets(signal_and_target_data_list)
def get_single_subj_dataset(self, subject=None, final_evaluation=False): if subject not in self.datasets['train'].keys(): self.datasets['train'][subject], self.datasets['valid'][subject], self.datasets['test'][subject] = \ get_train_val_test(global_vars.get('data_folder'), subject) single_subj_dataset = OrderedDict((('train', self.datasets['train'][subject]), ('valid', self.datasets['valid'][subject]), ('test', self.datasets['test'][subject]))) if final_evaluation and global_vars.get('ensemble_iterations'): single_subj_dataset['train'] = concatenate_sets( [single_subj_dataset['train'], single_subj_dataset['valid']]) if global_vars.get('time_frequency'): EEG_to_TF_mne(single_subj_dataset) set_global_vars_by_dataset(single_subj_dataset['train']) return single_subj_dataset
def run_until_second_stop(self): """ Run training and evaluation using combined training + validation set for training. Runs until loss on validation set decreases below loss on training set of best epoch or until as many epochs trained after as before first stop. """ datasets = self.datasets datasets['train'] = concatenate_sets( [datasets['train'], datasets['valid']]) # Todo: actually keep remembering and in case of twice number of epochs # reset to best model again (check if valid loss not below train loss) self.run_until_stop(datasets, remember_best=False)
def load_data_and_model(n_job): fileName = file_for_number(n_job) print("file = {:s}".format(fileName)) # %% Load data: matlab cell array import h5py log.info("Loading data...") with h5py.File(dir_sourceData + '/' + fileName + '.mat', 'r') as h5file: sessions = [h5file[obj_ref] for obj_ref in h5file['D'][0]] Xs = [session['ieeg'][:] for session in sessions] ys = [session['traj'][0] for session in sessions] srates = [session['srate'][0, 0] for session in sessions] # %% create datasets from braindecode.datautil.signal_target import SignalAndTarget # Outer added axis is the trial axis (size one always...) datasets = [ SignalAndTarget([X.astype(np.float32)], [y.astype(np.float32)]) for X, y in zip(Xs, ys) ] from braindecode.datautil.splitters import concatenate_sets # only for allocation assert len(datasets) >= 4 train_set = concatenate_sets(datasets[:-1]) valid_set = datasets[-2] # dummy variable, validation set is not used test_set = datasets[-1] log.info("Loading CNN model...") import torch model = torch.load(dir_outputData + '/models/' + fileName + '_model') # fix for new pytorch for m in model.modules(): if m.__class__.__name__ == 'Conv2d': m.padding_mode = 'zeros' log.info("Loading done.") return train_set, valid_set, test_set, model
def run_exp(max_recording_mins, n_recordings, sec_to_cut_at_start, sec_to_cut_at_end, duration_recording_mins, max_abs_val, clip_before_resample, sampling_freq, divisor, n_folds, i_test_fold, shuffle, merge_train_valid, model, input_time_length, optimizer, learning_rate, weight_decay, scheduler, model_constraint, batch_size, max_epochs, only_return_exp, time_cut_off_sec, start_time, test_on_eval, test_recording_mins, sensor_types, log_dir, np_th_seed, cuda=True): import torch.backends.cudnn as cudnn cudnn.benchmark = True if optimizer == 'adam': assert merge_train_valid == False else: assert optimizer == 'adamw' assert merge_train_valid == True preproc_functions = create_preproc_functions( sec_to_cut_at_start=sec_to_cut_at_start, sec_to_cut_at_end=sec_to_cut_at_end, duration_recording_mins=duration_recording_mins, max_abs_val=max_abs_val, clip_before_resample=clip_before_resample, sampling_freq=sampling_freq, divisor=divisor) dataset = DiagnosisSet(n_recordings=n_recordings, max_recording_mins=max_recording_mins, preproc_functions=preproc_functions, train_or_eval='train', sensor_types=sensor_types) if test_on_eval: if test_recording_mins is None: test_recording_mins = duration_recording_mins test_preproc_functions = create_preproc_functions( sec_to_cut_at_start=sec_to_cut_at_start, sec_to_cut_at_end=sec_to_cut_at_end, duration_recording_mins=test_recording_mins, max_abs_val=max_abs_val, clip_before_resample=clip_before_resample, sampling_freq=sampling_freq, divisor=divisor) test_dataset = DiagnosisSet(n_recordings=n_recordings, max_recording_mins=None, preproc_functions=test_preproc_functions, train_or_eval='eval', sensor_types=sensor_types) if not only_return_exp: X, y = dataset.load() max_shape = np.max([list(x.shape) for x in X], axis=0) assert max_shape[1] == int(duration_recording_mins * sampling_freq * 60) if test_on_eval: test_X, test_y = test_dataset.load() max_shape = np.max([list(x.shape) for x in test_X], axis=0) assert max_shape[1] == int(test_recording_mins * sampling_freq * 60) if not test_on_eval: splitter = TrainValidTestSplitter(n_folds, i_test_fold, shuffle=shuffle) else: splitter = TrainValidSplitter(n_folds, i_valid_fold=i_test_fold, shuffle=shuffle) if not only_return_exp: if not test_on_eval: train_set, valid_set, test_set = splitter.split(X, y) else: train_set, valid_set = splitter.split(X, y) test_set = SignalAndTarget(test_X, test_y) del test_X, test_y del X, y # shouldn't be necessary, but just to make sure if merge_train_valid: train_set = concatenate_sets([train_set, valid_set]) # just reduce valid for faster computations valid_set.X = valid_set.X[:8] valid_set.y = valid_set.y[:8] # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/train_X.npy', train_set.X) # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/train_y.npy', train_set.y) # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/valid_X.npy', valid_set.X) # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/valid_y.npy', valid_set.y) # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/test_X.npy', test_set.X) # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/test_y.npy', test_set.y) else: train_set = None valid_set = None test_set = None log.info("Model:\n{:s}".format(str(model))) if cuda: model.cuda() model.eval() in_chans = 21 # determine output size test_input = np_to_var( np.ones((2, in_chans, input_time_length, 1), dtype=np.float32)) if cuda: test_input = test_input.cuda() out = model(test_input) n_preds_per_input = out.cpu().data.numpy().shape[2] log.info("{:d} predictions per input/trial".format(n_preds_per_input)) iterator = CropsFromTrialsIterator(batch_size=batch_size, input_time_length=input_time_length, n_preds_per_input=n_preds_per_input, seed=np_th_seed) assert optimizer in ['adam', 'adamw'], ("Expect optimizer to be either " "adam or adamw") schedule_weight_decay = optimizer == 'adamw' if optimizer == 'adam': optim_class = optim.Adam assert schedule_weight_decay == False assert merge_train_valid == False else: optim_class = AdamW assert schedule_weight_decay == True assert merge_train_valid == True optimizer = optim_class(model.parameters(), lr=learning_rate, weight_decay=weight_decay) if scheduler is not None: assert scheduler == 'cosine' n_updates_per_epoch = sum( [1 for _ in iterator.get_batches(train_set, shuffle=True)]) # Adapt if you have a different number of epochs n_updates_per_period = n_updates_per_epoch * max_epochs scheduler = CosineAnnealing(n_updates_per_period) optimizer = ScheduledOptimizer( scheduler, optimizer, schedule_weight_decay=schedule_weight_decay) loss_function = nll_loss_on_mean if model_constraint is not None: assert model_constraint == 'defaultnorm' model_constraint = MaxNormDefaultConstraint() monitors = [ LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'), CroppedDiagnosisMonitor(input_time_length, n_preds_per_input), RuntimeMonitor(), ] stop_criterion = MaxEpochs(max_epochs) loggers = [Printer(), TensorboardWriter(log_dir)] batch_modifier = None exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint, monitors, stop_criterion, remember_best_column='valid_misclass', run_after_early_stop=True, batch_modifier=batch_modifier, cuda=cuda, loggers=loggers) if not only_return_exp: # Until first stop exp.setup_training() exp.monitor_epoch(exp.datasets) exp.log_epoch() exp.rememberer.remember_epoch(exp.epochs_df, exp.model, exp.optimizer) exp.iterator.reset_rng() while not exp.stop_criterion.should_stop(exp.epochs_df): if (time.time() - start_time) > time_cut_off_sec: log.info( "Ran out of time after {:.2f} sec.".format(time.time() - start_time)) return exp log.info("Still in time after {:.2f} sec.".format(time.time() - start_time)) exp.run_one_epoch(exp.datasets, remember_best=True) if (time.time() - start_time) > time_cut_off_sec: log.info("Ran out of time after {:.2f} sec.".format(time.time() - start_time)) return exp if not merge_train_valid: exp.setup_after_stop_training() # Run until second stop datasets = exp.datasets datasets['train'] = concatenate_sets( [datasets['train'], datasets['valid']]) exp.monitor_epoch(datasets) exp.log_epoch() exp.iterator.reset_rng() while not exp.stop_criterion.should_stop(exp.epochs_df): if (time.time() - start_time) > time_cut_off_sec: log.info("Ran out of time after {:.2f} sec.".format( time.time() - start_time)) return exp log.info("Still in time after {:.2f} sec.".format(time.time() - start_time)) exp.run_one_epoch(datasets, remember_best=False) else: exp.dataset = dataset exp.splitter = splitter if test_on_eval: exp.test_dataset = test_dataset return exp
def train_model(self, model, dataset, state=None, final_evaluation=False, ensemble=False): if type(model) == list: model = AveragingEnsemble(model) if self.cuda: for mod in model.models: mod.cuda() if self.cuda: torch.cuda.empty_cache() if final_evaluation: self.stop_criterion = Or([ MaxEpochs(global_vars.get('final_max_epochs')), NoIncreaseDecrease( f'valid_{global_vars.get("nn_objective")}', global_vars.get('final_max_increase_epochs'), oper=get_oper_by_loss_function(self.loss_function)) ]) if global_vars.get('cropping'): self.set_cropping_for_model(model) self.epochs_df = pd.DataFrame() if global_vars.get('do_early_stop') or global_vars.get( 'remember_best'): self.rememberer = RememberBest( f"valid_{global_vars.get('nn_objective')}", oper=get_oper_by_loss_function(self.loss_function, equals=True)) self.optimizer = optim.Adam(model.parameters()) if self.cuda: assert torch.cuda.is_available(), "Cuda not available" if torch.cuda.device_count() > 1 and global_vars.get( 'parallel_gpu'): model.cuda() with torch.cuda.device(0): model = nn.DataParallel( model.cuda(), device_ids=[ int(s) for s in global_vars.get('gpu_select').split(',') ]) else: model.cuda() try: if global_vars.get('inherit_weights_normal') and state is not None: current_state = model.state_dict() for k, v in state.items(): if k in current_state and current_state[k].shape == v.shape: current_state.update({k: v}) model.load_state_dict(current_state) except Exception as e: print(f'failed weight inheritance\n,' f'state dict: {state.keys()}\n' f'current model state: {model.state_dict().keys()}') print('load state dict failed. Exception message: %s' % (str(e))) pdb.set_trace() self.monitor_epoch(dataset, model) if global_vars.get('log_epochs'): self.log_epoch() if global_vars.get('remember_best'): self.rememberer.remember_epoch(self.epochs_df, model, self.optimizer) self.iterator.reset_rng() start = time.time() num_epochs = self.run_until_stop(model, dataset) self.setup_after_stop_training(model, final_evaluation) if final_evaluation: dataset_train_backup = deepcopy(dataset['train']) if ensemble: self.run_one_epoch(dataset, model) self.rememberer.remember_epoch(self.epochs_df, model, self.optimizer, force=ensemble) num_epochs += 1 else: dataset['train'] = concatenate_sets( [dataset['train'], dataset['valid']]) num_epochs += self.run_until_stop(model, dataset) self.rememberer.reset_to_best_model(self.epochs_df, model, self.optimizer) dataset['train'] = dataset_train_backup end = time.time() self.final_time = end - start self.num_epochs = num_epochs return model
def load_train_valid_test(train_filename, test_filename, n_folds, i_test_fold, valid_set_fraction, use_validation_set, low_cut_hz, debug=False): # we loaded all sensors to always get same cleaning results independent of sensor selection # There is an inbuilt heuristic that tries to use only EEG channels and that definitely # works for datasets in our paper if test_filename is None: assert n_folds is not None assert i_test_fold is not None assert valid_set_fraction is None else: assert n_folds is None assert i_test_fold is None assert use_validation_set == (valid_set_fraction is not None) train_folder = '/home/schirrmr/data/BBCI-without-last-runs/' log.info("Loading train...") full_train_set = load_bbci_data(os.path.join(train_folder, train_filename), low_cut_hz=low_cut_hz, debug=debug) if test_filename is not None: test_folder = '/home/schirrmr/data/BBCI-only-last-runs/' log.info("Loading test...") test_set = load_bbci_data(os.path.join(test_folder, test_filename), low_cut_hz=low_cut_hz, debug=debug) if use_validation_set: assert valid_set_fraction is not None train_set, valid_set = split_into_two_sets(full_train_set, valid_set_fraction) else: train_set = full_train_set valid_set = None # Split data if n_folds is not None: fold_inds = get_balanced_batches(len(full_train_set.X), None, shuffle=False, n_batches=n_folds) fold_sets = [ select_examples(full_train_set, inds) for inds in fold_inds ] test_set = fold_sets[i_test_fold] train_folds = np.arange(n_folds) train_folds = np.setdiff1d(train_folds, [i_test_fold]) if use_validation_set: i_valid_fold = (i_test_fold - 1) % n_folds train_folds = np.setdiff1d(train_folds, [i_valid_fold]) valid_set = fold_sets[i_valid_fold] assert i_valid_fold not in train_folds assert i_test_fold != i_valid_fold else: valid_set = None assert i_test_fold not in train_folds train_fold_sets = [fold_sets[i] for i in train_folds] train_set = concatenate_sets(train_fold_sets) # Some checks if valid_set is None: assert len(train_set.X) + len(test_set.X) == len(full_train_set.X) else: assert len(train_set.X) + len(valid_set.X) + len( test_set.X) == len(full_train_set.X) log.info("Train set with {:4d} trials".format(len(train_set.X))) if valid_set is not None: log.info("Valid set with {:4d} trials".format(len(valid_set.X))) log.info("Test set with {:4d} trials".format(len(test_set.X))) return train_set, valid_set, test_set
def unify_dataset(dataset): return concatenate_sets([data for data in dataset.values()])
def concat_train_val_sets(dataset): dataset['train'] = concatenate_sets([dataset['train'], dataset['valid']]) del dataset['valid']
def data_gen(subject, high_cut_hz=38, low_cut_hz=0): data_sub = {} for i in range(len(subject)): subject_id = subject[i] data_folder = r'D:\li\=.=\eeg\hw\nn-STFT\dataset\BCICIV_2a_gdf' ival = [-500, 4000] factor_new = 1e-3 init_block_size = 1000 train_filename = 'A{:02d}T.gdf'.format(subject_id) test_filename = 'A{:02d}E.gdf'.format(subject_id) train_filepath = os.path.join(data_folder, train_filename) test_filepath = os.path.join(data_folder, test_filename) train_label_filepath = train_filepath.replace('.gdf', '.mat') test_label_filepath = test_filepath.replace('.gdf', '.mat') train_loader = BCICompetition4Set2A( train_filepath, labels_filename=train_label_filepath) test_loader = BCICompetition4Set2A(test_filepath, labels_filename=test_label_filepath) train_cnt = train_loader.load() test_cnt = test_loader.load() # train_loader = BCICompetition4Set2A( # train_filepath, labels_filename=train_label_filepath) # test_loader = BCICompetition4Set2A( # test_filepath, labels_filename=test_label_filepath) # train_cnt = train_loader.load() # test_cnt = test_loader.load() # train set process train_cnt = train_cnt.drop_channels( ['EOG-left', 'EOG-central', 'EOG-right']) assert len(train_cnt.ch_names) == 22 train_cnt = mne_apply(lambda a: a * 1e6, train_cnt) train_cnt = mne_apply( lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, train_cnt.info['sfreq'], filt_order=3, axis=1), train_cnt) train_cnt = mne_apply( lambda a: exponential_running_standardize(a.T, factor_new=factor_new, init_block_size= init_block_size, eps=1e-4).T, train_cnt) # test set process test_cnt = test_cnt.drop_channels( ['EOG-left', 'EOG-central', 'EOG-right']) assert len(test_cnt.ch_names) == 22 test_cnt = mne_apply(lambda a: a * 1e6, test_cnt) test_cnt = mne_apply( lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, test_cnt.info['sfreq'], filt_order=3, axis=1), test_cnt) test_cnt = mne_apply( lambda a: exponential_running_standardize(a.T, factor_new=factor_new, init_block_size= init_block_size, eps=1e-4).T, test_cnt) marker_def = OrderedDict([('Left Hand', [1]), ( 'Right Hand', [2], ), ('Foot', [3]), ('Tongue', [4])]) train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival) test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival) data_sub[str(subject_id)] = concatenate_sets([train_set, test_set]) if i == 0: dataset = data_sub[str(subject_id)] else: dataset = concatenate_sets([dataset, data_sub[str(subject_id)]]) assert len(data_sub) == len(subject) return dataset
def run_exp(max_epochs, only_return_exp): from collections import OrderedDict filenames = [ 'data/robot-hall/NiRiNBD6.ds_1-1_500Hz.BBCI.mat', 'data/robot-hall/NiRiNBD8.ds_1-1_500Hz.BBCI.mat', 'data/robot-hall/NiRiNBD9.ds_1-1_500Hz.BBCI.mat', 'data/robot-hall/NiRiNBD10.ds_1-1_500Hz.BBCI.mat', 'data/robot-hall/NiRiNBD12_cursor_250Hz.BBCI.mat', 'data/robot-hall/NiRiNBD13_cursorS000R01_onlyFullRuns_250Hz.BBCI.mat', 'data/robot-hall/NiRiNBD14_cursor_250Hz.BBCI.mat', 'data/robot-hall/NiRiNBD15_cursor_250Hz.BBCI.mat' ] sensor_names = [ 'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'M1', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8', 'M2', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'PO8', 'O1', 'Oz', 'O2' ] name_to_start_codes = OrderedDict([('Right Hand', [1]), ('Feet', [2]), ('Rotation', [3]), ('Words', [4])]) name_to_stop_codes = OrderedDict([('Right Hand', [10]), ('Feet', [20]), ('Rotation', [30]), ('Words', [40])]) trial_ival = [500, 0] min_break_length_ms = 6000 max_break_length_ms = 8000 break_ival = [1000, -500] input_time_length = 700 filename_to_extra_args = { 'data/robot-hall/NiRiNBD12_cursor_250Hz.BBCI.mat': dict( name_to_start_codes=OrderedDict([('Right Hand', [1]), ('Feet', [2]), ('Rotation', [3]), ('Words', [4]), ('Rest', [5])]), name_to_stop_codes=OrderedDict([('Right Hand', [10]), ('Feet', [20]), ('Rotation', [30]), ('Words', [40]), ('Rest', [50])]), min_break_length_ms=3700, max_break_length_ms=3900, ), 'data/robot-hall/NiRiNBD13_cursorS000R01_onlyFullRuns_250Hz.BBCI.mat': dict( name_to_start_codes=OrderedDict([('Right Hand', [1]), ('Feet', [2]), ('Rotation', [3]), ('Words', [4]), ('Rest', [5])]), name_to_stop_codes=OrderedDict([('Right Hand', [10]), ('Feet', [20]), ('Rotation', [30]), ('Words', [40]), ('Rest', [50])]), min_break_length_ms=3700, max_break_length_ms=3900, ), 'data/robot-hall/NiRiNBD14_cursor_250Hz.BBCI.mat': dict( name_to_start_codes=OrderedDict([('Right Hand', [1]), ('Feet', [2]), ('Rotation', [3]), ('Words', [4]), ('Rest', [5])]), name_to_stop_codes=OrderedDict([('Right Hand', [10]), ('Feet', [20]), ('Rotation', [30]), ('Words', [40]), ('Rest', [50])]), min_break_length_ms=3700, max_break_length_ms=3900, ), 'data/robot-hall/NiRiNBD15_cursor_250Hz.BBCI.mat': dict( name_to_start_codes=OrderedDict([('Right Hand', [1]), ('Feet', [2]), ('Rotation', [3]), ('Words', [4]), ('Rest', [5])]), name_to_stop_codes=OrderedDict([('Right Hand', [10]), ('Feet', [20]), ('Rotation', [30]), ('Words', [40]), ('Rest', [50])]), min_break_length_ms=3700, max_break_length_ms=3900, ), } from braindecode.datautil.trial_segment import \ create_signal_target_with_breaks_from_mne from copy import deepcopy def load_data(filenames, sensor_names, name_to_start_codes, name_to_stop_codes, trial_ival, break_ival, min_break_length_ms, max_break_length_ms, input_time_length, filename_to_extra_args): all_sets = [] original_args = locals() for filename in filenames: kwargs = deepcopy(original_args) if filename in filename_to_extra_args: kwargs.update(filename_to_extra_args[filename]) log.info("Loading {:s}...".format(filename)) cnt = BBCIDataset(filename, load_sensor_names=sensor_names).load() cnt = cnt.drop_channels(['STI 014']) log.info("Resampling...") cnt = resample_cnt(cnt, 100) log.info("Standardizing...") cnt = mne_apply( lambda a: exponential_running_standardize( a.T, init_block_size=50).T, cnt) log.info("Transform to set...") full_set = (create_signal_target_with_breaks_from_mne( cnt, kwargs['name_to_start_codes'], kwargs['trial_ival'], kwargs['name_to_stop_codes'], kwargs['min_break_length_ms'], kwargs['max_break_length_ms'], kwargs['break_ival'], prepad_trials_to_n_samples=kwargs['input_time_length'], )) all_sets.append(full_set) return all_sets sensor_names = [ 'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'M1', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8', 'M2', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'PO8', 'O1', 'Oz', 'O2' ] #sensor_names = ['C3', 'C4'] n_chans = len(sensor_names) if not only_return_exp: all_sets = load_data(filenames, sensor_names, name_to_start_codes, name_to_stop_codes, trial_ival, break_ival, min_break_length_ms, max_break_length_ms, input_time_length, filename_to_extra_args) from braindecode.datautil.signal_target import SignalAndTarget from braindecode.datautil.splitters import concatenate_sets train_set = concatenate_sets(all_sets[:6]) valid_set = all_sets[6] test_set = all_sets[7] else: train_set = None valid_set = None test_set = None set_random_seeds(seed=20171017, cuda=True) n_classes = 5 # final_conv_length determines the size of the receptive field of the ConvNet model = ShallowFBCSPNet(in_chans=n_chans, n_classes=n_classes, input_time_length=input_time_length, final_conv_length=30).create_network() to_dense_prediction_model(model) model.cuda() from torch import optim import numpy as np optimizer = optim.Adam(model.parameters()) from braindecode.torch_ext.util import np_to_var # determine output size test_input = np_to_var( np.ones((2, n_chans, input_time_length, 1), dtype=np.float32)) test_input = test_input.cuda() out = model(test_input) n_preds_per_input = out.cpu().data.numpy().shape[2] print("{:d} predictions per input/trial".format(n_preds_per_input)) from braindecode.datautil.iterators import CropsFromTrialsIterator iterator = CropsFromTrialsIterator(batch_size=32, input_time_length=input_time_length, n_preds_per_input=n_preds_per_input) from braindecode.experiments.experiment import Experiment from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, \ CroppedTrialMisclassMonitor, MisclassMonitor from braindecode.experiments.stopcriteria import MaxEpochs from braindecode.torch_ext.losses import log_categorical_crossentropy import torch.nn.functional as F import torch as th from braindecode.torch_ext.modules import Expression loss_function = log_categorical_crossentropy model_constraint = MaxNormDefaultConstraint() monitors = [ LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'), CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(), ] stop_criterion = MaxEpochs(max_epochs) exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint, monitors, stop_criterion, remember_best_column='valid_sample_misclass', run_after_early_stop=True, batch_modifier=None, cuda=True) if not only_return_exp: exp.run() return exp