def run_until_second_stop(self):
        """
        Run training and evaluation using combined training + validation sets 
        for training on both datasets. 
        
        Runs until loss on validation set decreases below loss on training set 
        of best epoch or  until as many epochs trained after as before 
        first stop.
        """
        datasets = self.datasets
        datasets['train_1'] = concatenate_sets(
            [datasets['train_1'], datasets['valid_1']])
        datasets['train_2'] = concatenate_sets(
            [datasets['train_2'], datasets['valid_2']])

        self.run_until_stop(datasets, remember_best=True)
Exemple #2
0
 def get_single_subj_dataset(self, subject=None, final_evaluation=False):
     if subject not in self.datasets['train'].keys():
         self.datasets['train'][subject], self.datasets['valid'][subject], self.datasets['test'][subject] = \
             get_train_val_test(global_vars.get('data_folder'), subject)
     single_subj_dataset = OrderedDict(
         (('train', self.datasets['train'][subject]),
          ('valid', self.datasets['valid'][subject]),
          ('test', self.datasets['test'][subject])))
     if final_evaluation:
         single_subj_dataset['train'] = concatenate_sets(
             [single_subj_dataset['train'], single_subj_dataset['valid']])
     return single_subj_dataset
Exemple #3
0
def _load_and_merge_data(file_paths):
    """
    Load multiple HGD subjects and merged them into a new dataset
    :param file_paths:
    :return:
    """
    if len(file_paths) == 1:  # if just 1 subject's data
        return _load_h5_data(file_paths[0])

    signal_and_target_data_list = []
    for path in file_paths:
        signal_and_target_data_list.append(_load_h5_data(path))

    return concatenate_sets(signal_and_target_data_list)
Exemple #4
0
 def get_single_subj_dataset(self, subject=None, final_evaluation=False):
     if subject not in self.datasets['train'].keys():
         self.datasets['train'][subject], self.datasets['valid'][subject], self.datasets['test'][subject] = \
             get_train_val_test(global_vars.get('data_folder'), subject)
     single_subj_dataset = OrderedDict((('train', self.datasets['train'][subject]),
                                        ('valid', self.datasets['valid'][subject]),
                                        ('test', self.datasets['test'][subject])))
     if final_evaluation and global_vars.get('ensemble_iterations'):
         single_subj_dataset['train'] = concatenate_sets(
             [single_subj_dataset['train'], single_subj_dataset['valid']])
     if global_vars.get('time_frequency'):
         EEG_to_TF_mne(single_subj_dataset)
         set_global_vars_by_dataset(single_subj_dataset['train'])
     return single_subj_dataset
Exemple #5
0
    def run_until_second_stop(self):
        """
        Run training and evaluation using combined training + validation set 
        for training. 
        
        Runs until loss on validation  set decreases below loss on training set 
        of best epoch or  until as many epochs trained after as before 
        first stop.
        """
        datasets = self.datasets
        datasets['train'] = concatenate_sets(
            [datasets['train'], datasets['valid']])

        # Todo: actually keep remembering and in case of twice number of epochs
        # reset to best model again (check if valid loss not below train loss)
        self.run_until_stop(datasets, remember_best=False)
Exemple #6
0
def load_data_and_model(n_job):
    fileName = file_for_number(n_job)
    print("file = {:s}".format(fileName))
    # %% Load data: matlab cell array
    import h5py
    log.info("Loading data...")
    with h5py.File(dir_sourceData + '/' + fileName + '.mat', 'r') as h5file:
        sessions = [h5file[obj_ref] for obj_ref in h5file['D'][0]]
        Xs = [session['ieeg'][:] for session in sessions]
        ys = [session['traj'][0] for session in sessions]
        srates = [session['srate'][0, 0] for session in sessions]

    # %% create datasets
    from braindecode.datautil.signal_target import SignalAndTarget

    # Outer added axis is the trial axis (size one always...)
    datasets = [
        SignalAndTarget([X.astype(np.float32)], [y.astype(np.float32)])
        for X, y in zip(Xs, ys)
    ]

    from braindecode.datautil.splitters import concatenate_sets

    # only for allocation
    assert len(datasets) >= 4
    train_set = concatenate_sets(datasets[:-1])
    valid_set = datasets[-2]  # dummy variable, validation set is not used
    test_set = datasets[-1]

    log.info("Loading CNN model...")
    import torch
    model = torch.load(dir_outputData + '/models/' + fileName + '_model')
    # fix for new pytorch
    for m in model.modules():
        if m.__class__.__name__ == 'Conv2d':
            m.padding_mode = 'zeros'
    log.info("Loading done.")
    return train_set, valid_set, test_set, model
def run_exp(max_recording_mins,
            n_recordings,
            sec_to_cut_at_start,
            sec_to_cut_at_end,
            duration_recording_mins,
            max_abs_val,
            clip_before_resample,
            sampling_freq,
            divisor,
            n_folds,
            i_test_fold,
            shuffle,
            merge_train_valid,
            model,
            input_time_length,
            optimizer,
            learning_rate,
            weight_decay,
            scheduler,
            model_constraint,
            batch_size,
            max_epochs,
            only_return_exp,
            time_cut_off_sec,
            start_time,
            test_on_eval,
            test_recording_mins,
            sensor_types,
            log_dir,
            np_th_seed,
            cuda=True):
    import torch.backends.cudnn as cudnn
    cudnn.benchmark = True
    if optimizer == 'adam':
        assert merge_train_valid == False
    else:
        assert optimizer == 'adamw'
        assert merge_train_valid == True

    preproc_functions = create_preproc_functions(
        sec_to_cut_at_start=sec_to_cut_at_start,
        sec_to_cut_at_end=sec_to_cut_at_end,
        duration_recording_mins=duration_recording_mins,
        max_abs_val=max_abs_val,
        clip_before_resample=clip_before_resample,
        sampling_freq=sampling_freq,
        divisor=divisor)

    dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           train_or_eval='train',
                           sensor_types=sensor_types)

    if test_on_eval:
        if test_recording_mins is None:
            test_recording_mins = duration_recording_mins

        test_preproc_functions = create_preproc_functions(
            sec_to_cut_at_start=sec_to_cut_at_start,
            sec_to_cut_at_end=sec_to_cut_at_end,
            duration_recording_mins=test_recording_mins,
            max_abs_val=max_abs_val,
            clip_before_resample=clip_before_resample,
            sampling_freq=sampling_freq,
            divisor=divisor)
        test_dataset = DiagnosisSet(n_recordings=n_recordings,
                                    max_recording_mins=None,
                                    preproc_functions=test_preproc_functions,
                                    train_or_eval='eval',
                                    sensor_types=sensor_types)
    if not only_return_exp:
        X, y = dataset.load()
        max_shape = np.max([list(x.shape) for x in X], axis=0)
        assert max_shape[1] == int(duration_recording_mins * sampling_freq *
                                   60)
        if test_on_eval:
            test_X, test_y = test_dataset.load()
            max_shape = np.max([list(x.shape) for x in test_X], axis=0)
            assert max_shape[1] == int(test_recording_mins * sampling_freq *
                                       60)
    if not test_on_eval:
        splitter = TrainValidTestSplitter(n_folds,
                                          i_test_fold,
                                          shuffle=shuffle)
    else:
        splitter = TrainValidSplitter(n_folds,
                                      i_valid_fold=i_test_fold,
                                      shuffle=shuffle)
    if not only_return_exp:
        if not test_on_eval:
            train_set, valid_set, test_set = splitter.split(X, y)
        else:

            train_set, valid_set = splitter.split(X, y)
            test_set = SignalAndTarget(test_X, test_y)
            del test_X, test_y
        del X, y  # shouldn't be necessary, but just to make sure
        if merge_train_valid:
            train_set = concatenate_sets([train_set, valid_set])
            # just reduce valid for faster computations
            valid_set.X = valid_set.X[:8]
            valid_set.y = valid_set.y[:8]
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/train_X.npy', train_set.X)
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/train_y.npy', train_set.y)
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/valid_X.npy', valid_set.X)
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/valid_y.npy', valid_set.y)
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/test_X.npy', test_set.X)
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/test_y.npy', test_set.y)
    else:
        train_set = None
        valid_set = None
        test_set = None

    log.info("Model:\n{:s}".format(str(model)))
    if cuda:
        model.cuda()
    model.eval()
    in_chans = 21
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    log.info("{:d} predictions per input/trial".format(n_preds_per_input))
    iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input,
                                       seed=np_th_seed)
    assert optimizer in ['adam', 'adamw'], ("Expect optimizer to be either "
                                            "adam or adamw")
    schedule_weight_decay = optimizer == 'adamw'
    if optimizer == 'adam':
        optim_class = optim.Adam
        assert schedule_weight_decay == False
        assert merge_train_valid == False
    else:
        optim_class = AdamW
        assert schedule_weight_decay == True
        assert merge_train_valid == True

    optimizer = optim_class(model.parameters(),
                            lr=learning_rate,
                            weight_decay=weight_decay)
    if scheduler is not None:
        assert scheduler == 'cosine'
        n_updates_per_epoch = sum(
            [1 for _ in iterator.get_batches(train_set, shuffle=True)])
        # Adapt if you have a different number of epochs
        n_updates_per_period = n_updates_per_epoch * max_epochs
        scheduler = CosineAnnealing(n_updates_per_period)
        optimizer = ScheduledOptimizer(
            scheduler, optimizer, schedule_weight_decay=schedule_weight_decay)
    loss_function = nll_loss_on_mean

    if model_constraint is not None:
        assert model_constraint == 'defaultnorm'
        model_constraint = MaxNormDefaultConstraint()
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedDiagnosisMonitor(input_time_length, n_preds_per_input),
        RuntimeMonitor(),
    ]

    stop_criterion = MaxEpochs(max_epochs)
    loggers = [Printer(), TensorboardWriter(log_dir)]
    batch_modifier = None
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     batch_modifier=batch_modifier,
                     cuda=cuda,
                     loggers=loggers)

    if not only_return_exp:
        # Until first stop
        exp.setup_training()
        exp.monitor_epoch(exp.datasets)
        exp.log_epoch()
        exp.rememberer.remember_epoch(exp.epochs_df, exp.model, exp.optimizer)

        exp.iterator.reset_rng()
        while not exp.stop_criterion.should_stop(exp.epochs_df):
            if (time.time() - start_time) > time_cut_off_sec:
                log.info(
                    "Ran out of time after {:.2f} sec.".format(time.time() -
                                                               start_time))
                return exp
            log.info("Still in time after {:.2f} sec.".format(time.time() -
                                                              start_time))
            exp.run_one_epoch(exp.datasets, remember_best=True)
        if (time.time() - start_time) > time_cut_off_sec:
            log.info("Ran out of time after {:.2f} sec.".format(time.time() -
                                                                start_time))
            return exp
        if not merge_train_valid:
            exp.setup_after_stop_training()
            # Run until second stop
            datasets = exp.datasets
            datasets['train'] = concatenate_sets(
                [datasets['train'], datasets['valid']])
            exp.monitor_epoch(datasets)
            exp.log_epoch()

            exp.iterator.reset_rng()
            while not exp.stop_criterion.should_stop(exp.epochs_df):
                if (time.time() - start_time) > time_cut_off_sec:
                    log.info("Ran out of time after {:.2f} sec.".format(
                        time.time() - start_time))
                    return exp
                log.info("Still in time after {:.2f} sec.".format(time.time() -
                                                                  start_time))
                exp.run_one_epoch(datasets, remember_best=False)

    else:
        exp.dataset = dataset
        exp.splitter = splitter
    if test_on_eval:
        exp.test_dataset = test_dataset

    return exp
Exemple #8
0
    def train_model(self,
                    model,
                    dataset,
                    state=None,
                    final_evaluation=False,
                    ensemble=False):
        if type(model) == list:
            model = AveragingEnsemble(model)
            if self.cuda:
                for mod in model.models:
                    mod.cuda()
        if self.cuda:
            torch.cuda.empty_cache()
        if final_evaluation:
            self.stop_criterion = Or([
                MaxEpochs(global_vars.get('final_max_epochs')),
                NoIncreaseDecrease(
                    f'valid_{global_vars.get("nn_objective")}',
                    global_vars.get('final_max_increase_epochs'),
                    oper=get_oper_by_loss_function(self.loss_function))
            ])
        if global_vars.get('cropping'):
            self.set_cropping_for_model(model)
        self.epochs_df = pd.DataFrame()
        if global_vars.get('do_early_stop') or global_vars.get(
                'remember_best'):
            self.rememberer = RememberBest(
                f"valid_{global_vars.get('nn_objective')}",
                oper=get_oper_by_loss_function(self.loss_function,
                                               equals=True))
        self.optimizer = optim.Adam(model.parameters())
        if self.cuda:
            assert torch.cuda.is_available(), "Cuda not available"
            if torch.cuda.device_count() > 1 and global_vars.get(
                    'parallel_gpu'):
                model.cuda()
                with torch.cuda.device(0):
                    model = nn.DataParallel(
                        model.cuda(),
                        device_ids=[
                            int(s)
                            for s in global_vars.get('gpu_select').split(',')
                        ])
            else:
                model.cuda()

        try:
            if global_vars.get('inherit_weights_normal') and state is not None:
                current_state = model.state_dict()
                for k, v in state.items():
                    if k in current_state and current_state[k].shape == v.shape:
                        current_state.update({k: v})
                model.load_state_dict(current_state)
        except Exception as e:
            print(f'failed weight inheritance\n,'
                  f'state dict: {state.keys()}\n'
                  f'current model state: {model.state_dict().keys()}')
            print('load state dict failed. Exception message: %s' % (str(e)))
            pdb.set_trace()
        self.monitor_epoch(dataset, model)
        if global_vars.get('log_epochs'):
            self.log_epoch()
        if global_vars.get('remember_best'):
            self.rememberer.remember_epoch(self.epochs_df, model,
                                           self.optimizer)
        self.iterator.reset_rng()
        start = time.time()
        num_epochs = self.run_until_stop(model, dataset)
        self.setup_after_stop_training(model, final_evaluation)
        if final_evaluation:
            dataset_train_backup = deepcopy(dataset['train'])
            if ensemble:
                self.run_one_epoch(dataset, model)
                self.rememberer.remember_epoch(self.epochs_df,
                                               model,
                                               self.optimizer,
                                               force=ensemble)
                num_epochs += 1
            else:
                dataset['train'] = concatenate_sets(
                    [dataset['train'], dataset['valid']])
            num_epochs += self.run_until_stop(model, dataset)
            self.rememberer.reset_to_best_model(self.epochs_df, model,
                                                self.optimizer)
            dataset['train'] = dataset_train_backup
        end = time.time()
        self.final_time = end - start
        self.num_epochs = num_epochs
        return model
def load_train_valid_test(train_filename,
                          test_filename,
                          n_folds,
                          i_test_fold,
                          valid_set_fraction,
                          use_validation_set,
                          low_cut_hz,
                          debug=False):
    # we loaded all sensors to always get same cleaning results independent of sensor selection
    # There is an inbuilt heuristic that tries to use only EEG channels and that definitely
    # works for datasets in our paper
    if test_filename is None:
        assert n_folds is not None
        assert i_test_fold is not None
        assert valid_set_fraction is None
    else:
        assert n_folds is None
        assert i_test_fold is None
        assert use_validation_set == (valid_set_fraction is not None)

    train_folder = '/home/schirrmr/data/BBCI-without-last-runs/'
    log.info("Loading train...")
    full_train_set = load_bbci_data(os.path.join(train_folder, train_filename),
                                    low_cut_hz=low_cut_hz,
                                    debug=debug)

    if test_filename is not None:
        test_folder = '/home/schirrmr/data/BBCI-only-last-runs/'
        log.info("Loading test...")
        test_set = load_bbci_data(os.path.join(test_folder, test_filename),
                                  low_cut_hz=low_cut_hz,
                                  debug=debug)
        if use_validation_set:
            assert valid_set_fraction is not None
            train_set, valid_set = split_into_two_sets(full_train_set,
                                                       valid_set_fraction)
        else:
            train_set = full_train_set
            valid_set = None

    # Split data
    if n_folds is not None:
        fold_inds = get_balanced_batches(len(full_train_set.X),
                                         None,
                                         shuffle=False,
                                         n_batches=n_folds)

        fold_sets = [
            select_examples(full_train_set, inds) for inds in fold_inds
        ]

        test_set = fold_sets[i_test_fold]
        train_folds = np.arange(n_folds)
        train_folds = np.setdiff1d(train_folds, [i_test_fold])
        if use_validation_set:
            i_valid_fold = (i_test_fold - 1) % n_folds
            train_folds = np.setdiff1d(train_folds, [i_valid_fold])
            valid_set = fold_sets[i_valid_fold]
            assert i_valid_fold not in train_folds
            assert i_test_fold != i_valid_fold
        else:
            valid_set = None

        assert i_test_fold not in train_folds

        train_fold_sets = [fold_sets[i] for i in train_folds]
        train_set = concatenate_sets(train_fold_sets)
        # Some checks
        if valid_set is None:
            assert len(train_set.X) + len(test_set.X) == len(full_train_set.X)
        else:
            assert len(train_set.X) + len(valid_set.X) + len(
                test_set.X) == len(full_train_set.X)

    log.info("Train set with {:4d} trials".format(len(train_set.X)))
    if valid_set is not None:
        log.info("Valid set with {:4d} trials".format(len(valid_set.X)))
    log.info("Test set with  {:4d} trials".format(len(test_set.X)))

    return train_set, valid_set, test_set
Exemple #10
0
def unify_dataset(dataset):
    return concatenate_sets([data for data in dataset.values()])
Exemple #11
0
def concat_train_val_sets(dataset):
    dataset['train'] = concatenate_sets([dataset['train'], dataset['valid']])
    del dataset['valid']
Exemple #12
0
def data_gen(subject, high_cut_hz=38, low_cut_hz=0):
    data_sub = {}
    for i in range(len(subject)):
        subject_id = subject[i]
        data_folder = r'D:\li\=.=\eeg\hw\nn-STFT\dataset\BCICIV_2a_gdf'
        ival = [-500, 4000]
        factor_new = 1e-3
        init_block_size = 1000

        train_filename = 'A{:02d}T.gdf'.format(subject_id)
        test_filename = 'A{:02d}E.gdf'.format(subject_id)
        train_filepath = os.path.join(data_folder, train_filename)
        test_filepath = os.path.join(data_folder, test_filename)
        train_label_filepath = train_filepath.replace('.gdf', '.mat')
        test_label_filepath = test_filepath.replace('.gdf', '.mat')

        train_loader = BCICompetition4Set2A(
            train_filepath, labels_filename=train_label_filepath)
        test_loader = BCICompetition4Set2A(test_filepath,
                                           labels_filename=test_label_filepath)

        train_cnt = train_loader.load()
        test_cnt = test_loader.load()

        # train_loader = BCICompetition4Set2A(
        #     train_filepath, labels_filename=train_label_filepath)
        # test_loader = BCICompetition4Set2A(
        #     test_filepath, labels_filename=test_label_filepath)

        # train_cnt = train_loader.load()
        # test_cnt = test_loader.load()

        # train set process
        train_cnt = train_cnt.drop_channels(
            ['EOG-left', 'EOG-central', 'EOG-right'])
        assert len(train_cnt.ch_names) == 22

        train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
        train_cnt = mne_apply(
            lambda a: bandpass_cnt(a,
                                   low_cut_hz,
                                   high_cut_hz,
                                   train_cnt.info['sfreq'],
                                   filt_order=3,
                                   axis=1), train_cnt)

        train_cnt = mne_apply(
            lambda a: exponential_running_standardize(a.T,
                                                      factor_new=factor_new,
                                                      init_block_size=
                                                      init_block_size,
                                                      eps=1e-4).T, train_cnt)

        # test set process
        test_cnt = test_cnt.drop_channels(
            ['EOG-left', 'EOG-central', 'EOG-right'])
        assert len(test_cnt.ch_names) == 22
        test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
        test_cnt = mne_apply(
            lambda a: bandpass_cnt(a,
                                   low_cut_hz,
                                   high_cut_hz,
                                   test_cnt.info['sfreq'],
                                   filt_order=3,
                                   axis=1), test_cnt)
        test_cnt = mne_apply(
            lambda a: exponential_running_standardize(a.T,
                                                      factor_new=factor_new,
                                                      init_block_size=
                                                      init_block_size,
                                                      eps=1e-4).T, test_cnt)

        marker_def = OrderedDict([('Left Hand', [1]), (
            'Right Hand',
            [2],
        ), ('Foot', [3]), ('Tongue', [4])])

        train_set = create_signal_target_from_raw_mne(train_cnt, marker_def,
                                                      ival)
        test_set = create_signal_target_from_raw_mne(test_cnt, marker_def,
                                                     ival)

        data_sub[str(subject_id)] = concatenate_sets([train_set, test_set])
        if i == 0:
            dataset = data_sub[str(subject_id)]
        else:
            dataset = concatenate_sets([dataset, data_sub[str(subject_id)]])
    assert len(data_sub) == len(subject)

    return dataset
def run_exp(max_epochs, only_return_exp):
    from collections import OrderedDict
    filenames = [
        'data/robot-hall/NiRiNBD6.ds_1-1_500Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD8.ds_1-1_500Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD9.ds_1-1_500Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD10.ds_1-1_500Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD12_cursor_250Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD13_cursorS000R01_onlyFullRuns_250Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD14_cursor_250Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD15_cursor_250Hz.BBCI.mat'
    ]
    sensor_names = [
        'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3',
        'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz',
        'FC2', 'FC4', 'FC6', 'FT8', 'M1', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2',
        'C4', 'C6', 'T8', 'M2', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2',
        'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6',
        'P8', 'PO7', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'PO8', 'O1', 'Oz', 'O2'
    ]
    name_to_start_codes = OrderedDict([('Right Hand', [1]), ('Feet', [2]),
                                       ('Rotation', [3]), ('Words', [4])])
    name_to_stop_codes = OrderedDict([('Right Hand', [10]), ('Feet', [20]),
                                      ('Rotation', [30]), ('Words', [40])])

    trial_ival = [500, 0]
    min_break_length_ms = 6000
    max_break_length_ms = 8000
    break_ival = [1000, -500]

    input_time_length = 700

    filename_to_extra_args = {
        'data/robot-hall/NiRiNBD12_cursor_250Hz.BBCI.mat':
        dict(
            name_to_start_codes=OrderedDict([('Right Hand', [1]),
                                             ('Feet', [2]), ('Rotation', [3]),
                                             ('Words', [4]), ('Rest', [5])]),
            name_to_stop_codes=OrderedDict([('Right Hand', [10]),
                                            ('Feet', [20]), ('Rotation', [30]),
                                            ('Words', [40]), ('Rest', [50])]),
            min_break_length_ms=3700,
            max_break_length_ms=3900,
        ),
        'data/robot-hall/NiRiNBD13_cursorS000R01_onlyFullRuns_250Hz.BBCI.mat':
        dict(
            name_to_start_codes=OrderedDict([('Right Hand', [1]),
                                             ('Feet', [2]), ('Rotation', [3]),
                                             ('Words', [4]), ('Rest', [5])]),
            name_to_stop_codes=OrderedDict([('Right Hand', [10]),
                                            ('Feet', [20]), ('Rotation', [30]),
                                            ('Words', [40]), ('Rest', [50])]),
            min_break_length_ms=3700,
            max_break_length_ms=3900,
        ),
        'data/robot-hall/NiRiNBD14_cursor_250Hz.BBCI.mat':
        dict(
            name_to_start_codes=OrderedDict([('Right Hand', [1]),
                                             ('Feet', [2]), ('Rotation', [3]),
                                             ('Words', [4]), ('Rest', [5])]),
            name_to_stop_codes=OrderedDict([('Right Hand', [10]),
                                            ('Feet', [20]), ('Rotation', [30]),
                                            ('Words', [40]), ('Rest', [50])]),
            min_break_length_ms=3700,
            max_break_length_ms=3900,
        ),
        'data/robot-hall/NiRiNBD15_cursor_250Hz.BBCI.mat':
        dict(
            name_to_start_codes=OrderedDict([('Right Hand', [1]),
                                             ('Feet', [2]), ('Rotation', [3]),
                                             ('Words', [4]), ('Rest', [5])]),
            name_to_stop_codes=OrderedDict([('Right Hand', [10]),
                                            ('Feet', [20]), ('Rotation', [30]),
                                            ('Words', [40]), ('Rest', [50])]),
            min_break_length_ms=3700,
            max_break_length_ms=3900,
        ),
    }
    from braindecode.datautil.trial_segment import \
        create_signal_target_with_breaks_from_mne
    from copy import deepcopy

    def load_data(filenames, sensor_names, name_to_start_codes,
                  name_to_stop_codes, trial_ival, break_ival,
                  min_break_length_ms, max_break_length_ms, input_time_length,
                  filename_to_extra_args):
        all_sets = []
        original_args = locals()
        for filename in filenames:
            kwargs = deepcopy(original_args)
            if filename in filename_to_extra_args:
                kwargs.update(filename_to_extra_args[filename])
            log.info("Loading {:s}...".format(filename))
            cnt = BBCIDataset(filename, load_sensor_names=sensor_names).load()
            cnt = cnt.drop_channels(['STI 014'])
            log.info("Resampling...")
            cnt = resample_cnt(cnt, 100)
            log.info("Standardizing...")
            cnt = mne_apply(
                lambda a: exponential_running_standardize(
                    a.T, init_block_size=50).T, cnt)

            log.info("Transform to set...")
            full_set = (create_signal_target_with_breaks_from_mne(
                cnt,
                kwargs['name_to_start_codes'],
                kwargs['trial_ival'],
                kwargs['name_to_stop_codes'],
                kwargs['min_break_length_ms'],
                kwargs['max_break_length_ms'],
                kwargs['break_ival'],
                prepad_trials_to_n_samples=kwargs['input_time_length'],
            ))
            all_sets.append(full_set)
        return all_sets

    sensor_names = [
        'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3',
        'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz',
        'FC2', 'FC4', 'FC6', 'FT8', 'M1', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2',
        'C4', 'C6', 'T8', 'M2', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2',
        'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6',
        'P8', 'PO7', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'PO8', 'O1', 'Oz', 'O2'
    ]
    #sensor_names = ['C3', 'C4']
    n_chans = len(sensor_names)
    if not only_return_exp:
        all_sets = load_data(filenames, sensor_names, name_to_start_codes,
                             name_to_stop_codes, trial_ival, break_ival,
                             min_break_length_ms, max_break_length_ms,
                             input_time_length, filename_to_extra_args)
        from braindecode.datautil.signal_target import SignalAndTarget
        from braindecode.datautil.splitters import concatenate_sets

        train_set = concatenate_sets(all_sets[:6])
        valid_set = all_sets[6]
        test_set = all_sets[7]
    else:
        train_set = None
        valid_set = None
        test_set = None
    set_random_seeds(seed=20171017, cuda=True)
    n_classes = 5
    # final_conv_length determines the size of the receptive field of the ConvNet
    model = ShallowFBCSPNet(in_chans=n_chans,
                            n_classes=n_classes,
                            input_time_length=input_time_length,
                            final_conv_length=30).create_network()
    to_dense_prediction_model(model)

    model.cuda()

    from torch import optim
    import numpy as np

    optimizer = optim.Adam(model.parameters())

    from braindecode.torch_ext.util import np_to_var
    # determine output size
    test_input = np_to_var(
        np.ones((2, n_chans, input_time_length, 1), dtype=np.float32))
    test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    print("{:d} predictions per input/trial".format(n_preds_per_input))

    from braindecode.datautil.iterators import CropsFromTrialsIterator
    iterator = CropsFromTrialsIterator(batch_size=32,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)

    from braindecode.experiments.experiment import Experiment
    from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, \
        CroppedTrialMisclassMonitor, MisclassMonitor
    from braindecode.experiments.stopcriteria import MaxEpochs
    from braindecode.torch_ext.losses import log_categorical_crossentropy
    import torch.nn.functional as F
    import torch as th
    from braindecode.torch_ext.modules import Expression

    loss_function = log_categorical_crossentropy

    model_constraint = MaxNormDefaultConstraint()
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length),
        RuntimeMonitor(),
    ]
    stop_criterion = MaxEpochs(max_epochs)
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_sample_misclass',
                     run_after_early_stop=True,
                     batch_modifier=None,
                     cuda=True)
    if not only_return_exp:
        exp.run()

    return exp