コード例 #1
0
ファイル: base.py プロジェクト: jscastanoc/braindecode
 def get_monitors(self, input_time_length):
     monitors = [LossMonitor(), RuntimeMonitor()]
     if self.cropped:
         monitors.append(CroppedTrialMisclassMonitor(input_time_length))
     else:
         monitors.append(MisclassMonitor())
     return monitors
コード例 #2
0
def train(config):
    cuda = True
    model = config['model']
    if model == 'deep':
        model = Deep4Net(n_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length=2,
                         config=config).create_network()

    to_dense_prediction_model(model)
    if cuda:
        model.cuda()

    log.info("Model: \n{:s}".format(str(model)))
    dummy_input = np_to_var(train_set.X[:1, :, :, None])
    if cuda:
        dummy_input = dummy_input.cuda()
    out = model(dummy_input)

    n_preds_per_input = out.cpu().data.numpy().shape[2]

    optimizer = optim.Adam(model.parameters())

    iterator = CropsFromTrialsIterator(batch_size=60,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)

    stop_criterion = Or([MaxEpochs(20), NoDecrease('valid_misclass', 80)])

    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length=input_time_length),
        RuntimeMonitor()
    ]

    model_constraint = MaxNormDefaultConstraint()

    loss_function = lambda preds, targets: F.nll_loss(
        th.mean(preds, dim=2, keepdim=False), targets)

    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator=iterator,
                     loss_function=loss_function,
                     optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     cuda=cuda)
    exp.run()
    print(exp.rememberer)
    return exp.rememberer.lowest_val
コード例 #3
0
def run_exp(max_recording_mins, n_recordings, sec_to_cut,
            duration_recording_mins, max_abs_val, max_min_threshold,
            max_min_expected, shrink_val, max_min_remove, batch_set_zero_val,
            batch_set_zero_test, sampling_freq, low_cut_hz, high_cut_hz,
            exp_demean, exp_standardize, moving_demean, moving_standardize,
            channel_demean, channel_standardize, divisor, n_folds, i_test_fold,
            input_time_length, final_conv_length, pool_stride, n_blocks_to_add,
            sigmoid, model_constraint, batch_size, max_epochs,
            only_return_exp):
    cuda = True

    preproc_functions = []
    preproc_functions.append(lambda data, fs: (
        data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
    preproc_functions.append(lambda data, fs: (data[:, :int(
        duration_recording_mins * 60 * fs)], fs))
    if max_abs_val is not None:
        preproc_functions.append(
            lambda data, fs: (np.clip(data, -max_abs_val, max_abs_val), fs))
    if max_min_threshold is not None:
        preproc_functions.append(lambda data, fs: (clean_jumps(
            data, 200, max_min_threshold, max_min_expected, cuda), fs))
    if max_min_remove is not None:
        window_len = 200
        preproc_functions.append(lambda data, fs: (set_jumps_to_zero(
            data,
            window_len=window_len,
            threshold=max_min_remove,
            cuda=cuda,
            clip_min_max_to_zero=True), fs))

    if shrink_val is not None:
        preproc_functions.append(lambda data, fs: (shrink_spikes(
            data,
            shrink_val,
            1,
            9,
        ), fs))

    preproc_functions.append(lambda data, fs: (resampy.resample(
        data, fs, sampling_freq, axis=1, filter='kaiser_fast'), sampling_freq))
    preproc_functions.append(lambda data, fs: (bandpass_cnt(
        data, low_cut_hz, high_cut_hz, fs, filt_order=4, axis=1), fs))

    if exp_demean:
        preproc_functions.append(lambda data, fs: (exponential_running_demean(
            data.T, factor_new=0.001, init_block_size=100).T, fs))
    if exp_standardize:
        preproc_functions.append(
            lambda data, fs: (exponential_running_standardize(
                data.T, factor_new=0.001, init_block_size=100).T, fs))
    if moving_demean:
        preproc_functions.append(lambda data, fs: (padded_moving_demean(
            data, axis=1, n_window=201), fs))
    if moving_standardize:
        preproc_functions.append(lambda data, fs: (padded_moving_standardize(
            data, axis=1, n_window=201), fs))
    if channel_demean:
        preproc_functions.append(lambda data, fs: (demean(data, axis=1), fs))
    if channel_standardize:
        preproc_functions.append(lambda data, fs:
                                 (standardize(data, axis=1), fs))
    if divisor is not None:
        preproc_functions.append(lambda data, fs: (data / divisor, fs))

    dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions)
    if not only_return_exp:
        X, y = dataset.load()

    splitter = Splitter(
        n_folds,
        i_test_fold,
    )
    if not only_return_exp:
        train_set, valid_set, test_set = splitter.split(X, y)
        del X, y  # shouldn't be necessary, but just to make sure
    else:
        train_set = None
        valid_set = None
        test_set = None

    set_random_seeds(seed=20170629, cuda=cuda)
    if sigmoid:
        n_classes = 1
    else:
        n_classes = 2
    in_chans = 21

    net = Deep4Net(
        in_chans=in_chans,
        n_classes=n_classes,
        input_time_length=input_time_length,
        final_conv_length=final_conv_length,
        pool_time_length=pool_stride,
        pool_time_stride=pool_stride,
        n_filters_2=50,
        n_filters_3=80,
        n_filters_4=120,
    )
    model = net_with_more_layers(net, n_blocks_to_add, nn.MaxPool2d)
    if sigmoid:
        model = to_linear_plus_minus_net(model)
    optimizer = optim.Adam(model.parameters())
    to_dense_prediction_model(model)
    log.info("Model:\n{:s}".format(str(model)))
    if cuda:
        model.cuda()
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    log.info("{:d} predictions per input/trial".format(n_preds_per_input))
    iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)
    if sigmoid:
        loss_function = lambda preds, targets: binary_cross_entropy_with_logits(
            th.mean(preds, dim=2)[:, 1, 0], targets.type_as(preds))
    else:
        loss_function = lambda preds, targets: F.nll_loss(
            th.mean(preds, dim=2)[:, :, 0], targets)

    if model_constraint is not None:
        model_constraint = MaxNormDefaultConstraint()
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length),
        RuntimeMonitor(),
    ]
    stop_criterion = MaxEpochs(max_epochs)
    batch_modifier = None
    if batch_set_zero_val is not None:
        batch_modifier = RemoveMinMaxDiff(batch_set_zero_val,
                                          clip_max_abs=True,
                                          set_zero=True)
    if (batch_set_zero_val is not None) and (batch_set_zero_test == True):
        iterator = ModifiedIterator(
            iterator,
            batch_modifier,
        )
        batch_modifier = None
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     batch_modifier=batch_modifier,
                     cuda=cuda)
    if not only_return_exp:
        exp.run()
    else:
        exp.dataset = dataset
        exp.splitter = splitter

    return exp
コード例 #4
0
    def fit(
        self,
        train_X,
        train_y,
        epochs,
        batch_size,
        input_time_length=None,
        validation_data=None,
        model_constraint=None,
        remember_best_column=None,
        scheduler=None,
        log_0_epoch=True,
    ):
        """
        Fit the model using the given training data.
        
        Will set `epochs_df` variable with a pandas dataframe to the history
        of the training process.
        
        Parameters
        ----------
        train_X: ndarray
            Training input data
        train_y: 1darray
            Training labels
        epochs: int
            Number of epochs to train
        batch_size: int
        input_time_length: int, optional
            Super crop size, what temporal size is pushed forward through 
            the network, see cropped decoding tuturial.
        validation_data: (ndarray, 1darray), optional
            X and y for validation set if wanted
        model_constraint: object, optional
            You can supply :class:`.MaxNormDefaultConstraint` if wanted.
        remember_best_column: string, optional
            In case you want to do an early stopping/reset parameters to some
            "best" epoch, define here the monitored value whose minimum
            determines the best epoch.
        scheduler: 'cosine' or None, optional
            Whether to use cosine annealing (:class:`.CosineAnnealing`).
        log_0_epoch: bool
            Whether to compute the metrics once before training as well.

        Returns
        -------
        exp: 
            Underlying braindecode :class:`.Experiment`
        """
        if (not hasattr(self, "compiled")) or (not self.compiled):
            raise ValueError(
                "Compile the model first by calling model.compile(loss, optimizer, metrics)"
            )

        if self.cropped and input_time_length is None:
            raise ValueError(
                "In cropped mode, need to specify input_time_length,"
                "which is the number of timesteps that will be pushed through"
                "the network in a single pass.")

        train_X = _ensure_float32(train_X)
        if self.cropped:
            self.network.eval()
            test_input = np_to_var(
                np.ones(
                    (1, train_X[0].shape[0], input_time_length) +
                    train_X[0].shape[2:],
                    dtype=np.float32,
                ))
            while len(test_input.size()) < 4:
                test_input = test_input.unsqueeze(-1)
            if self.is_cuda:
                test_input = test_input.cuda()
            out = self.network(test_input)
            n_preds_per_input = out.cpu().data.numpy().shape[2]
            self.iterator = CropsFromTrialsIterator(
                batch_size=batch_size,
                input_time_length=input_time_length,
                n_preds_per_input=n_preds_per_input,
                seed=self.seed_rng.randint(0,
                                           np.iinfo(np.int32).max - 1),
            )
        else:
            self.iterator = BalancedBatchSizeIterator(
                batch_size=batch_size,
                seed=self.seed_rng.randint(0,
                                           np.iinfo(np.int32).max - 1),
            )
        if log_0_epoch:
            stop_criterion = MaxEpochs(epochs)
        else:
            stop_criterion = MaxEpochs(epochs - 1)
        train_set = SignalAndTarget(train_X, train_y)
        optimizer = self.optimizer
        if scheduler is not None:
            assert (scheduler == "cosine"
                    ), "Supply either 'cosine' or None as scheduler."
            n_updates_per_epoch = sum([
                1 for _ in self.iterator.get_batches(train_set, shuffle=True)
            ])
            n_updates_per_period = n_updates_per_epoch * epochs
            if scheduler == "cosine":
                scheduler = CosineAnnealing(n_updates_per_period)
            schedule_weight_decay = False
            if optimizer.__class__.__name__ == "AdamW":
                schedule_weight_decay = True
            optimizer = ScheduledOptimizer(
                scheduler,
                self.optimizer,
                schedule_weight_decay=schedule_weight_decay,
            )
        loss_function = self.loss
        if self.cropped:
            loss_function = lambda outputs, targets: self.loss(
                th.mean(outputs, dim=2), targets)
        if validation_data is not None:
            valid_X = _ensure_float32(validation_data[0])
            valid_y = validation_data[1]
            valid_set = SignalAndTarget(valid_X, valid_y)
        else:
            valid_set = None
        test_set = None
        self.monitors = [LossMonitor()]
        if self.cropped:
            self.monitors.append(
                CroppedTrialMisclassMonitor(input_time_length))
        else:
            self.monitors.append(MisclassMonitor())
        if self.extra_monitors is not None:
            self.monitors.extend(self.extra_monitors)
        self.monitors.append(RuntimeMonitor())
        exp = Experiment(
            self.network,
            train_set,
            valid_set,
            test_set,
            iterator=self.iterator,
            loss_function=loss_function,
            optimizer=optimizer,
            model_constraint=model_constraint,
            monitors=self.monitors,
            stop_criterion=stop_criterion,
            remember_best_column=remember_best_column,
            run_after_early_stop=False,
            cuda=self.is_cuda,
            log_0_epoch=log_0_epoch,
            do_early_stop=(remember_best_column is not None),
        )
        exp.run()
        self.epochs_df = exp.epochs_df
        return exp
コード例 #5
0
def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
    train_filename = 'A{:01d}T.gdf'.format(subject_id)
    test_filename = 'A{:01d}E.gdf'.format(subject_id)

    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)

    train_label_filepath = train_filepath.replace('.gdf', '.mat')

    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath)
    #print(train_loader)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()

    # Preprocessing

    train_cnt = train_cnt.drop_channels(['STI 014', 'EOG-left',
                                         'EOG-central', 'EOG-right'])
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(a, low_cut_hz, 38, train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
                                                  init_block_size=1000,
                                                  eps=1e-4).T,
        train_cnt)

    test_cnt = test_cnt.drop_channels(['STI 014', 'EOG-left',
                                       'EOG-central', 'EOG-right'])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(
        lambda a: bandpass_cnt(a, low_cut_hz, 38, test_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), test_cnt)
    test_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
                                                  init_block_size=1000,
                                                  eps=1e-4).T,
        test_cnt)

    marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
                              ('Foot', [3]), ('Tongue', [4])])
    ival = [-500, 4000]

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)

    train_set, valid_set = split_into_two_sets(train_set,
                                               first_set_fraction=0.8)

    set_random_seeds(seed=20190706, cuda=cuda)

    n_classes = 4
    n_chans = int(train_set.X.shape[1])
    input_time_length=1000
    if model == 'shallow':
        model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length,
                            final_conv_length=30).create_network()
    elif model == 'deep':
        model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length,
                            final_conv_length=2).create_network()


    to_dense_prediction_model(model)
    if cuda:
        model.cuda()

    log.info("Model: \n{:s}".format(str(model)))
    dummy_input = np_to_var(train_set.X[:1, :, :, None])
    if cuda:
        dummy_input = dummy_input.cuda()
    out = model(dummy_input)

    n_preds_per_input = out.cpu().data.numpy().shape[2]

    optimizer = optim.Adam(model.parameters())

    iterator = CropsFromTrialsIterator(batch_size=60,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)

    stop_criterion = Or([MaxEpochs(800),
                         NoDecrease('valid_misclass', 80)])

    monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
                CroppedTrialMisclassMonitor(
                    input_time_length=input_time_length), RuntimeMonitor()]

    model_constraint = MaxNormDefaultConstraint()

    loss_function = lambda preds, targets: F.nll_loss(
        th.mean(preds, dim=2, keepdim=False), targets)

    exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
                     loss_function=loss_function, optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True, cuda=cuda)
    exp.run()
    return exp
コード例 #6
0
ファイル: example.py プロジェクト: ecmyhre/high-gamma-dataset
def run_exp_on_high_gamma_dataset(train_filename, test_filename, low_cut_hz,
                                  model_name, max_epochs, max_increase_epochs,
                                  np_th_seed, debug):
    train_set, valid_set, test_set = load_train_valid_test(
        train_filename=train_filename,
        test_filename=test_filename,
        low_cut_hz=low_cut_hz,
        debug=debug)
    if debug:
        max_epochs = 4

    set_random_seeds(np_th_seed, cuda=True)
    #torch.backends.cudnn.benchmark = True# sometimes crashes?
    n_classes = int(np.max(train_set.y) + 1)
    n_chans = int(train_set.X.shape[1])
    input_time_length = 1000
    if model_name == 'deep':
        model = Deep4Net(n_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length=2).create_network()
    elif model_name == 'shallow':
        model = ShallowFBCSPNet(n_chans,
                                n_classes,
                                input_time_length=input_time_length,
                                final_conv_length=30).create_network()

    to_dense_prediction_model(model)
    model.cuda()
    model.eval()

    out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())

    n_preds_per_input = out.cpu().data.numpy().shape[2]
    optimizer = optim.Adam(model.parameters(), weight_decay=0, lr=1e-3)

    iterator = CropsFromTrialsIterator(batch_size=60,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input,
                                       seed=np_th_seed)

    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length=input_time_length),
        RuntimeMonitor()
    ]

    model_constraint = MaxNormDefaultConstraint()

    loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2),
                                                      targets)

    run_after_early_stop = True
    do_early_stop = True
    remember_best_column = 'valid_misclass'
    stop_criterion = Or([
        MaxEpochs(max_epochs),
        NoDecrease('valid_misclass', max_increase_epochs)
    ])

    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator=iterator,
                     loss_function=loss_function,
                     optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column=remember_best_column,
                     run_after_early_stop=run_after_early_stop,
                     cuda=True,
                     do_early_stop=do_early_stop)
    exp.run()
    return exp
コード例 #7
0
def test_experiment_class():
    import mne
    from mne.io import concatenate_raws

    # 5,6,7,10,13,14 are codes for executed and imagined hands/feet
    subject_id = 1
    event_codes = [5, 6, 9, 10, 13, 14]

    # This will download the files if you don't have them yet,
    # and then return the paths to the files.
    physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

    # Load each of the files
    parts = [mne.io.read_raw_edf(path, preload=True, stim_channel='auto',
                                 verbose='WARNING')
             for path in physionet_paths]

    # Concatenate them
    raw = concatenate_raws(parts)

    # Find the events in this dataset
    events, _ = mne.events_from_annotations(raw)

    # Use only EEG channels
    eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False,
                                      eog=False,
                                      exclude='bads')

    # Extract trials, only using EEG channels
    epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1,
                         proj=False, picks=eeg_channel_inds,
                         baseline=None, preload=True)
    import numpy as np
    from braindecode.datautil.signal_target import SignalAndTarget
    from braindecode.datautil.splitters import split_into_two_sets
    # Convert data from volt to millivolt
    # Pytorch expects float32 for input and int64 for labels.
    X = (epoched.get_data() * 1e6).astype(np.float32)
    y = (epoched.events[:, 2] - 2).astype(np.int64)  # 2,3 -> 0,1

    train_set = SignalAndTarget(X[:60], y=y[:60])
    test_set = SignalAndTarget(X[60:], y=y[60:])

    train_set, valid_set = split_into_two_sets(train_set,
                                               first_set_fraction=0.8)
    from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
    from torch import nn
    from braindecode.torch_ext.util import set_random_seeds
    from braindecode.models.util import to_dense_prediction_model

    # Set if you want to use GPU
    # You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
    cuda = False
    set_random_seeds(seed=20170629, cuda=cuda)

    # This will determine how many crops are processed in parallel
    input_time_length = 450
    n_classes = 2
    in_chans = train_set.X.shape[1]
    # final_conv_length determines the size of the receptive field of the ConvNet
    model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,
                            input_time_length=input_time_length,
                            final_conv_length=12).create_network()
    to_dense_prediction_model(model)

    if cuda:
        model.cuda()

    from torch import optim

    optimizer = optim.Adam(model.parameters())

    from braindecode.torch_ext.util import np_to_var
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    print("{:d} predictions per input/trial".format(n_preds_per_input))

    from braindecode.experiments.experiment import Experiment
    from braindecode.datautil.iterators import CropsFromTrialsIterator
    from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, \
        CroppedTrialMisclassMonitor, MisclassMonitor
    from braindecode.experiments.stopcriteria import MaxEpochs
    import torch.nn.functional as F
    import torch as th
    from braindecode.torch_ext.modules import Expression
    # Iterator is used to iterate over datasets both for training
    # and evaluation
    iterator = CropsFromTrialsIterator(batch_size=32,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)

    # Loss function takes predictions as they come out of the network and the targets
    # and returns a loss
    loss_function = lambda preds, targets: F.nll_loss(
        th.mean(preds, dim=2, keepdim=False), targets)

    # Could be used to apply some constraint on the models, then should be object
    # with apply method that accepts a module
    model_constraint = None
    # Monitors log the training progress
    monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
                CroppedTrialMisclassMonitor(input_time_length),
                RuntimeMonitor(), ]
    # Stop criterion determines when the first stop happens
    stop_criterion = MaxEpochs(4)
    exp = Experiment(model, train_set, valid_set, test_set, iterator,
                     loss_function, optimizer, model_constraint,
                     monitors, stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True, batch_modifier=None, cuda=cuda)

    # need to setup python logging before to be able to see anything
    import logging
    import sys
    logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                        level=logging.DEBUG, stream=sys.stdout)
    exp.run()

    import pandas as pd
    from io import StringIO
    compare_df = pd.read_csv(StringIO(
        'train_loss,valid_loss,test_loss,train_sample_misclass,valid_sample_misclass,'
        'test_sample_misclass,train_misclass,valid_misclass,test_misclass\n'
        '14.167170524597168,13.910758018493652,15.945781707763672,0.5,0.5,'
        '0.5333333333333333,0.5,0.5,0.5333333333333333\n'
        '1.1735659837722778,1.4342904090881348,1.8664429187774658,0.4629567736185384,'
        '0.5120320855614973,0.5336007130124778,0.5,0.5,0.5333333333333333\n'
        '1.3168460130691528,1.60431969165802,1.9181344509124756,0.49298128342245995,'
        '0.5109180035650625,0.531729055258467,0.5,0.5,0.5333333333333333\n'
        '0.8465543389320374,1.280307412147522,1.439755916595459,0.4413435828877005,'
        '0.5461229946524064,0.5283422459893048,0.47916666666666663,0.5,'
        '0.5333333333333333\n0.6977059841156006,1.1762590408325195,1.2779350280761719,'
        '0.40290775401069523,0.588903743315508,0.5307486631016043,0.5,0.5,0.5\n'
        '0.7934166193008423,1.1762590408325195,1.2779350280761719,0.4401069518716577,'
        '0.588903743315508,0.5307486631016043,0.5,0.5,0.5\n0.5982189178466797,'
        '0.8581563830375671,0.9598925113677979,0.32032085561497325,0.47660427807486627,'
        '0.4672905525846702,0.31666666666666665,0.5,0.4666666666666667\n0.5044312477111816,'
        '0.7133197784423828,0.8164243102073669,0.2591354723707665,0.45699643493761144,'
        '0.4393048128342246,0.16666666666666663,0.41666666666666663,0.43333333333333335\n'
        '0.4815250039100647,0.6736412644386292,0.8016976714134216,0.23413547237076648,'
        '0.39505347593582885,0.42932263814616756,0.15000000000000002,0.41666666666666663,0.5\n'))

    for col in compare_df:
        np.testing.assert_allclose(np.array(compare_df[col]),
                                   exp.epochs_df[col],
                                   rtol=1e-3, atol=1e-4)
コード例 #8
0
def run_experiment(train_set, valid_set, test_set, model_name, optimizer_name,
                   init_lr, scheduler_name, use_norm_constraint, weight_decay,
                   schedule_weight_decay, restarts, max_epochs,
                   max_increase_epochs, np_th_seed):
    set_random_seeds(np_th_seed, cuda=True)
    #torch.backends.cudnn.benchmark = True# sometimes crashes?
    if valid_set is not None:
        assert max_increase_epochs is not None
    assert (max_epochs is None) != (restarts is None)
    if max_epochs is None:
        max_epochs = np.sum(restarts)
    n_classes = int(np.max(train_set.y) + 1)
    n_chans = int(train_set.X.shape[1])
    input_time_length = 1000
    if model_name == 'deep':
        model = Deep4Net(n_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length=2).create_network()
    elif model_name == 'shallow':
        model = ShallowFBCSPNet(n_chans,
                                n_classes,
                                input_time_length=input_time_length,
                                final_conv_length=30).create_network()
    elif model_name in [
            'resnet-he-uniform', 'resnet-he-normal', 'resnet-xavier-normal',
            'resnet-xavier-uniform'
    ]:
        init_name = model_name.lstrip('resnet-')
        from torch.nn import init
        init_fn = {
            'he-uniform': lambda w: init.kaiming_uniform(w, a=0),
            'he-normal': lambda w: init.kaiming_normal(w, a=0),
            'xavier-uniform': lambda w: init.xavier_uniform(w, gain=1),
            'xavier-normal': lambda w: init.xavier_normal(w, gain=1)
        }[init_name]
        model = EEGResNet(in_chans=n_chans,
                          n_classes=n_classes,
                          input_time_length=input_time_length,
                          final_pool_length=10,
                          n_first_filters=48,
                          conv_weight_init_fn=init_fn).create_network()
    else:
        raise ValueError("Unknown model name {:s}".format(model_name))
    if 'resnet' not in model_name:
        to_dense_prediction_model(model)
    model.cuda()
    model.eval()

    out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())

    n_preds_per_input = out.cpu().data.numpy().shape[2]

    if optimizer_name == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               weight_decay=weight_decay,
                               lr=init_lr)
    elif optimizer_name == 'adamw':
        optimizer = AdamW(model.parameters(),
                          weight_decay=weight_decay,
                          lr=init_lr)

    iterator = CropsFromTrialsIterator(batch_size=60,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input,
                                       seed=np_th_seed)

    if scheduler_name is not None:
        assert schedule_weight_decay == (optimizer_name == 'adamw')
        if scheduler_name == 'cosine':
            n_updates_per_epoch = sum(
                [1 for _ in iterator.get_batches(train_set, shuffle=True)])
            if restarts is None:
                n_updates_per_period = n_updates_per_epoch * max_epochs
            else:
                n_updates_per_period = np.array(restarts) * n_updates_per_epoch
            scheduler = CosineAnnealing(n_updates_per_period)
            optimizer = ScheduledOptimizer(
                scheduler,
                optimizer,
                schedule_weight_decay=schedule_weight_decay)
        elif scheduler_name == 'cut_cosine':
            # TODO: integrate with if clause before, now just separate
            # to avoid messing with code
            n_updates_per_epoch = sum(
                [1 for _ in iterator.get_batches(train_set, shuffle=True)])
            if restarts is None:
                n_updates_per_period = n_updates_per_epoch * max_epochs
            else:
                n_updates_per_period = np.array(restarts) * n_updates_per_epoch
            scheduler = CutCosineAnnealing(n_updates_per_period)
            optimizer = ScheduledOptimizer(
                scheduler,
                optimizer,
                schedule_weight_decay=schedule_weight_decay)
        else:
            raise ValueError("Unknown scheduler")
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length=input_time_length),
        RuntimeMonitor()
    ]

    if use_norm_constraint:
        model_constraint = MaxNormDefaultConstraint()
    else:
        model_constraint = None
    # change here this cell
    loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2),
                                                      targets)

    if valid_set is not None:
        run_after_early_stop = True
        do_early_stop = True
        remember_best_column = 'valid_misclass'
        stop_criterion = Or([
            MaxEpochs(max_epochs),
            NoDecrease('valid_misclass', max_increase_epochs)
        ])
    else:
        run_after_early_stop = False
        do_early_stop = False
        remember_best_column = None
        stop_criterion = MaxEpochs(max_epochs)

    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator=iterator,
                     loss_function=loss_function,
                     optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column=remember_best_column,
                     run_after_early_stop=run_after_early_stop,
                     cuda=True,
                     do_early_stop=do_early_stop)
    exp.run()
    return exp
コード例 #9
0
ファイル: base.py プロジェクト: Shasvat-Desai/braindecode
    def fit(self,
            train_X,
            train_y,
            epochs,
            batch_size,
            input_time_length=None,
            validation_data=None,
            model_constraint=None,
            remember_best_column=None,
            scheduler=None):
        if not self.compiled:
            raise ValueError(
                "Compile the model first by calling model.compile(loss, optimizer, metrics)"
            )

        if self.cropped and input_time_length is None:
            raise ValueError(
                "In cropped mode, need to specify input_time_length,"
                "which is the number of timesteps that will be pushed through"
                "the network in a single pass.")
        if self.cropped:
            test_input = np_to_var(train_X[0:1], dtype=np.float32)
            while len(test_input.size()) < 4:
                test_input = test_input.unsqueeze(-1)
            if self.cuda:
                test_input = test_input.cuda()
            out = self.network(test_input)
            n_preds_per_input = out.cpu().data.numpy().shape[2]
            iterator = CropsFromTrialsIterator(
                batch_size=batch_size,
                input_time_length=input_time_length,
                n_preds_per_input=n_preds_per_input,
                seed=self.seed_rng.randint(0, 4294967295))
        else:
            iterator = BalancedBatchSizeIterator(batch_size=batch_size,
                                                 seed=self.seed_rng.randint(
                                                     0, 4294967295))
        stop_criterion = MaxEpochs(
            epochs - 1
        )  # -1 since we dont print 0 epoch, which matters for this stop criterion
        train_set = SignalAndTarget(train_X, train_y)
        optimizer = self.optimizer
        if scheduler is not None:
            assert scheduler == 'cosine'
            n_updates_per_epoch = sum(
                [1 for _ in iterator.get_batches(train_set, shuffle=True)])
            n_updates_per_period = n_updates_per_epoch * epochs
            if scheduler == 'cosine':
                scheduler = CosineAnnealing(n_updates_per_period)
            schedule_weight_decay = False
            if optimizer.__class__.__name__ == 'AdamW':
                schedule_weight_decay = True
            optimizer = ScheduledOptimizer(
                scheduler,
                self.optimizer,
                schedule_weight_decay=schedule_weight_decay)
        loss_function = self.loss
        if self.cropped:
            loss_function = lambda outputs, targets:\
                self.loss(th.mean(outputs, dim=2), targets)
        if validation_data is not None:
            valid_set = SignalAndTarget(validation_data[0], validation_data[1])
        else:
            valid_set = None
        test_set = None
        if self.cropped:
            monitor_dict = {
                'acc': lambda: CroppedTrialMisclassMonitor(input_time_length)
            }
        else:
            monitor_dict = {'acc': MisclassMonitor}
        self.monitors = [LossMonitor()]
        extra_monitors = [monitor_dict[m]() for m in self.metrics]
        self.monitors += extra_monitors
        self.monitors += [RuntimeMonitor()]
        exp = Experiment(self.network,
                         train_set,
                         valid_set,
                         test_set,
                         iterator=iterator,
                         loss_function=loss_function,
                         optimizer=optimizer,
                         model_constraint=model_constraint,
                         monitors=self.monitors,
                         stop_criterion=stop_criterion,
                         remember_best_column=remember_best_column,
                         run_after_early_stop=False,
                         cuda=self.cuda,
                         print_0_epoch=False,
                         do_early_stop=(remember_best_column is not None))
        exp.run()
        self.epochs_df = exp.epochs_df
        return exp
コード例 #10
0
def test_experiment_class():
    import mne
    from mne.io import concatenate_raws

    # 5,6,7,10,13,14 are codes for executed and imagined hands/feet
    subject_id = 1
    event_codes = [5, 6, 9, 10, 13, 14]

    # This will download the files if you don't have them yet,
    # and then return the paths to the files.
    physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

    # Load each of the files
    parts = [
        mne.io.read_raw_edf(path,
                            preload=True,
                            stim_channel='auto',
                            verbose='WARNING') for path in physionet_paths
    ]

    # Concatenate them
    raw = concatenate_raws(parts)

    # Find the events in this dataset
    events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')

    # Use only EEG channels
    eeg_channel_inds = mne.pick_types(raw.info,
                                      meg=False,
                                      eeg=True,
                                      stim=False,
                                      eog=False,
                                      exclude='bads')

    # Extract trials, only using EEG channels
    epoched = mne.Epochs(raw,
                         events,
                         dict(hands=2, feet=3),
                         tmin=1,
                         tmax=4.1,
                         proj=False,
                         picks=eeg_channel_inds,
                         baseline=None,
                         preload=True)
    import numpy as np
    from braindecode.datautil.signal_target import SignalAndTarget
    from braindecode.datautil.splitters import split_into_two_sets
    # Convert data from volt to millivolt
    # Pytorch expects float32 for input and int64 for labels.
    X = (epoched.get_data() * 1e6).astype(np.float32)
    y = (epoched.events[:, 2] - 2).astype(np.int64)  # 2,3 -> 0,1

    train_set = SignalAndTarget(X[:60], y=y[:60])
    test_set = SignalAndTarget(X[60:], y=y[60:])

    train_set, valid_set = split_into_two_sets(train_set,
                                               first_set_fraction=0.8)
    from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
    from torch import nn
    from braindecode.torch_ext.util import set_random_seeds
    from braindecode.models.util import to_dense_prediction_model

    # Set if you want to use GPU
    # You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
    cuda = False
    set_random_seeds(seed=20170629, cuda=cuda)

    # This will determine how many crops are processed in parallel
    input_time_length = 450
    n_classes = 2
    in_chans = train_set.X.shape[1]
    # final_conv_length determines the size of the receptive field of the ConvNet
    model = ShallowFBCSPNet(in_chans=in_chans,
                            n_classes=n_classes,
                            input_time_length=input_time_length,
                            final_conv_length=12).create_network()
    to_dense_prediction_model(model)

    if cuda:
        model.cuda()

    from torch import optim

    optimizer = optim.Adam(model.parameters())

    from braindecode.torch_ext.util import np_to_var
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    print("{:d} predictions per input/trial".format(n_preds_per_input))

    from braindecode.experiments.experiment import Experiment
    from braindecode.datautil.iterators import CropsFromTrialsIterator
    from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, \
        CroppedTrialMisclassMonitor, MisclassMonitor
    from braindecode.experiments.stopcriteria import MaxEpochs
    import torch.nn.functional as F
    import torch as th
    from braindecode.torch_ext.modules import Expression
    # Iterator is used to iterate over datasets both for training
    # and evaluation
    iterator = CropsFromTrialsIterator(batch_size=32,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)

    # Loss function takes predictions as they come out of the network and the targets
    # and returns a loss
    loss_function = lambda preds, targets: F.nll_loss(
        th.mean(preds, dim=2, keepdim=False), targets)

    # Could be used to apply some constraint on the models, then should be object
    # with apply method that accepts a module
    model_constraint = None
    # Monitors log the training progress
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length),
        RuntimeMonitor(),
    ]
    # Stop criterion determines when the first stop happens
    stop_criterion = MaxEpochs(4)
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     batch_modifier=None,
                     cuda=cuda)

    # need to setup python logging before to be able to see anything
    import logging
    import sys
    logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                        level=logging.DEBUG,
                        stream=sys.stdout)
    exp.run()

    import pandas as pd
    from io import StringIO
    compare_df = pd.read_csv(
        StringIO(
            u'train_loss,valid_loss,test_loss,train_sample_misclass,valid_sample_misclass,'
            'test_sample_misclass,train_misclass,valid_misclass,test_misclass\n'
            '0,0.8692976435025532,0.7483791708946228,0.6975634694099426,'
            '0.5389371657754011,0.47103386809269165,0.4425133689839572,'
            '0.6041666666666667,0.5,0.4\n1,2.3362590074539185,'
            '2.317707061767578,2.1407743096351624,0.4827874331550802,'
            '0.5,0.4666666666666667,0.5,0.5,0.4666666666666667\n'
            '2,0.5981490015983582,0.785034716129303,0.7005959153175354,'
            '0.3391822638146168,0.47994652406417115,0.41996434937611404,'
            '0.22916666666666663,0.41666666666666663,0.43333333333333335\n'
            '3,0.6355261653661728,0.785034716129303,'
            '0.7005959153175354,0.3673351158645276,0.47994652406417115,'
            '0.41996434937611404,0.2666666666666667,0.41666666666666663,'
            '0.43333333333333335\n4,0.625280424952507,'
            '0.802731990814209,0.7048938572406769,0.3367201426024955,'
            '0.43137254901960786,0.4229946524064171,0.3666666666666667,'
            '0.5833333333333333,0.33333333333333337\n'))

    for col in compare_df:
        np.testing.assert_allclose(np.array(compare_df[col]),
                                   exp.epochs_df[col],
                                   rtol=1e-4,
                                   atol=1e-5)
コード例 #11
0
def run_exp(max_recording_mins, n_recordings, sec_to_cut,
            duration_recording_mins, max_abs_val, max_min_threshold,
            max_min_expected, shrink_val, max_min_remove, batch_set_zero_val,
            batch_set_zero_test, sampling_freq, low_cut_hz, high_cut_hz,
            exp_demean, exp_standardize, moving_demean, moving_standardize,
            channel_demean, channel_standardize, divisor, n_folds, i_test_fold,
            model_name, input_time_length, final_conv_length, batch_size,
            max_epochs, only_return_exp):
    cuda = True

    preproc_functions = []
    preproc_functions.append(lambda data, fs: (
        data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
    preproc_functions.append(lambda data, fs: (data[:, :int(
        duration_recording_mins * 60 * fs)], fs))
    if max_abs_val is not None:
        preproc_functions.append(
            lambda data, fs: (np.clip(data, -max_abs_val, max_abs_val), fs))
    if max_min_threshold is not None:
        preproc_functions.append(lambda data, fs: (clean_jumps(
            data, 200, max_min_threshold, max_min_expected, cuda), fs))
    if max_min_remove is not None:
        window_len = 200
        preproc_functions.append(lambda data, fs: (set_jumps_to_zero(
            data,
            window_len=window_len,
            threshold=max_min_remove,
            cuda=cuda,
            clip_min_max_to_zero=True), fs))

    if shrink_val is not None:
        preproc_functions.append(lambda data, fs: (shrink_spikes(
            data,
            shrink_val,
            1,
            9,
        ), fs))

    preproc_functions.append(lambda data, fs: (resampy.resample(
        data, fs, sampling_freq, axis=1, filter='kaiser_fast'), sampling_freq))
    preproc_functions.append(lambda data, fs: (bandpass_cnt(
        data, low_cut_hz, high_cut_hz, fs, filt_order=4, axis=1), fs))

    if exp_demean:
        preproc_functions.append(lambda data, fs: (exponential_running_demean(
            data.T, factor_new=0.001, init_block_size=100).T, fs))
    if exp_standardize:
        preproc_functions.append(
            lambda data, fs: (exponential_running_standardize(
                data.T, factor_new=0.001, init_block_size=100).T, fs))
    if moving_demean:
        preproc_functions.append(lambda data, fs: (padded_moving_demean(
            data, axis=1, n_window=201), fs))
    if moving_standardize:
        preproc_functions.append(lambda data, fs: (padded_moving_standardize(
            data, axis=1, n_window=201), fs))
    if channel_demean:
        preproc_functions.append(lambda data, fs: (demean(data, axis=1), fs))
    if channel_standardize:
        preproc_functions.append(lambda data, fs:
                                 (standardize(data, axis=1), fs))
    if divisor is not None:
        preproc_functions.append(lambda data, fs: (data / divisor, fs))

    all_file_names, labels = get_all_sorted_file_names_and_labels()
    lengths = np.load(
        '/home/schirrmr/code/auto-diagnosis/sorted-recording-lengths.npy')
    mask = lengths < max_recording_mins * 60
    cleaned_file_names = np.array(all_file_names)[mask]
    cleaned_labels = labels[mask]

    diffs_per_rec = np.load(
        '/home/schirrmr/code/auto-diagnosis/diffs_per_recording.npy')

    def create_set(inds):
        X = []
        for i in inds:
            log.info("Load {:s}".format(cleaned_file_names[i]))
            x = load_data(cleaned_file_names[i], preproc_functions)
            X.append(x)
        y = cleaned_labels[inds].astype(np.int64)
        return SignalAndTarget(X, y)

    if not only_return_exp:
        folds = get_balanced_batches(n_recordings,
                                     None,
                                     False,
                                     n_batches=n_folds)
        test_inds = folds[i_test_fold]
        valid_inds = folds[i_test_fold - 1]
        all_inds = list(range(n_recordings))
        train_inds = np.setdiff1d(all_inds, np.union1d(test_inds, valid_inds))

        rec_nr_sorted_by_diff = np.argsort(diffs_per_rec)[::-1]
        train_inds = rec_nr_sorted_by_diff[train_inds]
        valid_inds = rec_nr_sorted_by_diff[valid_inds]
        test_inds = rec_nr_sorted_by_diff[test_inds]

        train_set = create_set(train_inds)
        valid_set = create_set(valid_inds)
        test_set = create_set(test_inds)
    else:
        train_set = None
        valid_set = None
        test_set = None

    set_random_seeds(seed=20170629, cuda=cuda)
    # This will determine how many crops are processed in parallel
    n_classes = 2
    in_chans = 21
    if model_name == 'shallow':
        model = ShallowFBCSPNet(
            in_chans=in_chans,
            n_classes=n_classes,
            input_time_length=input_time_length,
            final_conv_length=final_conv_length).create_network()
    elif model_name == 'deep':
        model = Deep4Net(in_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length=final_conv_length).create_network()

    optimizer = optim.Adam(model.parameters())
    to_dense_prediction_model(model)
    log.info("Model:\n{:s}".format(str(model)))
    if cuda:
        model.cuda()
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    log.info("{:d} predictions per input/trial".format(n_preds_per_input))
    iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)
    loss_function = lambda preds, targets: F.nll_loss(
        th.mean(preds, dim=2)[:, :, 0], targets)
    model_constraint = None
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length),
        RuntimeMonitor(),
    ]
    stop_criterion = MaxEpochs(max_epochs)
    batch_modifier = None
    if batch_set_zero_val is not None:
        batch_modifier = RemoveMinMaxDiff(batch_set_zero_val,
                                          clip_max_abs=True,
                                          set_zero=True)
    if (batch_set_zero_val is not None) and (batch_set_zero_test == True):
        iterator = ModifiedIterator(
            iterator,
            batch_modifier,
        )
        batch_modifier = None
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     batch_modifier=batch_modifier,
                     cuda=cuda)
    if not only_return_exp:
        exp.run()
    else:
        exp.dataset = None
        exp.splitter = None

    return exp
コード例 #12
0
def run_exp(max_epochs, only_return_exp):
    from collections import OrderedDict
    filenames = [
        'data/robot-hall/NiRiNBD6.ds_1-1_500Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD8.ds_1-1_500Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD9.ds_1-1_500Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD10.ds_1-1_500Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD12_cursor_250Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD13_cursorS000R01_onlyFullRuns_250Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD14_cursor_250Hz.BBCI.mat',
        'data/robot-hall/NiRiNBD15_cursor_250Hz.BBCI.mat'
    ]
    sensor_names = [
        'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3',
        'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz',
        'FC2', 'FC4', 'FC6', 'FT8', 'M1', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2',
        'C4', 'C6', 'T8', 'M2', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2',
        'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6',
        'P8', 'PO7', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'PO8', 'O1', 'Oz', 'O2'
    ]
    name_to_start_codes = OrderedDict([('Right Hand', [1]), ('Feet', [2]),
                                       ('Rotation', [3]), ('Words', [4])])
    name_to_stop_codes = OrderedDict([('Right Hand', [10]), ('Feet', [20]),
                                      ('Rotation', [30]), ('Words', [40])])

    trial_ival = [500, 0]
    min_break_length_ms = 6000
    max_break_length_ms = 8000
    break_ival = [1000, -500]

    input_time_length = 700

    filename_to_extra_args = {
        'data/robot-hall/NiRiNBD12_cursor_250Hz.BBCI.mat':
        dict(
            name_to_start_codes=OrderedDict([('Right Hand', [1]),
                                             ('Feet', [2]), ('Rotation', [3]),
                                             ('Words', [4]), ('Rest', [5])]),
            name_to_stop_codes=OrderedDict([('Right Hand', [10]),
                                            ('Feet', [20]), ('Rotation', [30]),
                                            ('Words', [40]), ('Rest', [50])]),
            min_break_length_ms=3700,
            max_break_length_ms=3900,
        ),
        'data/robot-hall/NiRiNBD13_cursorS000R01_onlyFullRuns_250Hz.BBCI.mat':
        dict(
            name_to_start_codes=OrderedDict([('Right Hand', [1]),
                                             ('Feet', [2]), ('Rotation', [3]),
                                             ('Words', [4]), ('Rest', [5])]),
            name_to_stop_codes=OrderedDict([('Right Hand', [10]),
                                            ('Feet', [20]), ('Rotation', [30]),
                                            ('Words', [40]), ('Rest', [50])]),
            min_break_length_ms=3700,
            max_break_length_ms=3900,
        ),
        'data/robot-hall/NiRiNBD14_cursor_250Hz.BBCI.mat':
        dict(
            name_to_start_codes=OrderedDict([('Right Hand', [1]),
                                             ('Feet', [2]), ('Rotation', [3]),
                                             ('Words', [4]), ('Rest', [5])]),
            name_to_stop_codes=OrderedDict([('Right Hand', [10]),
                                            ('Feet', [20]), ('Rotation', [30]),
                                            ('Words', [40]), ('Rest', [50])]),
            min_break_length_ms=3700,
            max_break_length_ms=3900,
        ),
        'data/robot-hall/NiRiNBD15_cursor_250Hz.BBCI.mat':
        dict(
            name_to_start_codes=OrderedDict([('Right Hand', [1]),
                                             ('Feet', [2]), ('Rotation', [3]),
                                             ('Words', [4]), ('Rest', [5])]),
            name_to_stop_codes=OrderedDict([('Right Hand', [10]),
                                            ('Feet', [20]), ('Rotation', [30]),
                                            ('Words', [40]), ('Rest', [50])]),
            min_break_length_ms=3700,
            max_break_length_ms=3900,
        ),
    }
    from braindecode.datautil.trial_segment import \
        create_signal_target_with_breaks_from_mne
    from copy import deepcopy

    def load_data(filenames, sensor_names, name_to_start_codes,
                  name_to_stop_codes, trial_ival, break_ival,
                  min_break_length_ms, max_break_length_ms, input_time_length,
                  filename_to_extra_args):
        all_sets = []
        original_args = locals()
        for filename in filenames:
            kwargs = deepcopy(original_args)
            if filename in filename_to_extra_args:
                kwargs.update(filename_to_extra_args[filename])
            log.info("Loading {:s}...".format(filename))
            cnt = BBCIDataset(filename, load_sensor_names=sensor_names).load()
            cnt = cnt.drop_channels(['STI 014'])
            log.info("Resampling...")
            cnt = resample_cnt(cnt, 100)
            log.info("Standardizing...")
            cnt = mne_apply(
                lambda a: exponential_running_standardize(
                    a.T, init_block_size=50).T, cnt)

            log.info("Transform to set...")
            full_set = (create_signal_target_with_breaks_from_mne(
                cnt,
                kwargs['name_to_start_codes'],
                kwargs['trial_ival'],
                kwargs['name_to_stop_codes'],
                kwargs['min_break_length_ms'],
                kwargs['max_break_length_ms'],
                kwargs['break_ival'],
                prepad_trials_to_n_samples=kwargs['input_time_length'],
            ))
            all_sets.append(full_set)
        return all_sets

    sensor_names = [
        'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3',
        'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz',
        'FC2', 'FC4', 'FC6', 'FT8', 'M1', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2',
        'C4', 'C6', 'T8', 'M2', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2',
        'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6',
        'P8', 'PO7', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'PO8', 'O1', 'Oz', 'O2'
    ]
    #sensor_names = ['C3', 'C4']
    n_chans = len(sensor_names)
    if not only_return_exp:
        all_sets = load_data(filenames, sensor_names, name_to_start_codes,
                             name_to_stop_codes, trial_ival, break_ival,
                             min_break_length_ms, max_break_length_ms,
                             input_time_length, filename_to_extra_args)
        from braindecode.datautil.signal_target import SignalAndTarget
        from braindecode.datautil.splitters import concatenate_sets

        train_set = concatenate_sets(all_sets[:6])
        valid_set = all_sets[6]
        test_set = all_sets[7]
    else:
        train_set = None
        valid_set = None
        test_set = None
    set_random_seeds(seed=20171017, cuda=True)
    n_classes = 5
    # final_conv_length determines the size of the receptive field of the ConvNet
    model = ShallowFBCSPNet(in_chans=n_chans,
                            n_classes=n_classes,
                            input_time_length=input_time_length,
                            final_conv_length=30).create_network()
    to_dense_prediction_model(model)

    model.cuda()

    from torch import optim
    import numpy as np

    optimizer = optim.Adam(model.parameters())

    from braindecode.torch_ext.util import np_to_var
    # determine output size
    test_input = np_to_var(
        np.ones((2, n_chans, input_time_length, 1), dtype=np.float32))
    test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    print("{:d} predictions per input/trial".format(n_preds_per_input))

    from braindecode.datautil.iterators import CropsFromTrialsIterator
    iterator = CropsFromTrialsIterator(batch_size=32,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)

    from braindecode.experiments.experiment import Experiment
    from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, \
        CroppedTrialMisclassMonitor, MisclassMonitor
    from braindecode.experiments.stopcriteria import MaxEpochs
    from braindecode.torch_ext.losses import log_categorical_crossentropy
    import torch.nn.functional as F
    import torch as th
    from braindecode.torch_ext.modules import Expression

    loss_function = log_categorical_crossentropy

    model_constraint = MaxNormDefaultConstraint()
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length),
        RuntimeMonitor(),
    ]
    stop_criterion = MaxEpochs(max_epochs)
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_sample_misclass',
                     run_after_early_stop=True,
                     batch_modifier=None,
                     cuda=True)
    if not only_return_exp:
        exp.run()

    return exp