def processing_data(data_folder, subject_id, low_cut_hz, high_cut_hz, factor_new, init_block_size, ival, valid_set_fraction):
    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()

    # Preprocessing
    train_cnt = train_cnt.drop_channels(['STI 014', 'EOG-left',
                                         'EOG-central', 'EOG-right'])
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz,
                               train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
                                                  init_block_size=init_block_size,
                                                  eps=1e-4).T,
        train_cnt)

    test_cnt = test_cnt.drop_channels(['STI 014', 'EOG-left',
                                       'EOG-central', 'EOG-right'])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(
        lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz,
                               test_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), test_cnt)
    test_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
                                                  init_block_size=init_block_size,
                                                  eps=1e-4).T,
        test_cnt)

    marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
                              ('Foot', [3]), ('Tongue', [4])])

    train_set = create_signal_target_from_raw_mne(
        train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(
        test_cnt, marker_def, ival)

    train_set, valid_set = split_into_two_sets(
        train_set, first_set_fraction=1 - valid_set_fraction)

    return train_set, valid_set, test_set
Exemplo n.º 2
0
def preprocessing(data_folder, subject_id, low_cut_hz):
    global train_set, test_set, valid_set, n_classes, n_chans
    global n_iters, input_time_length
# def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()
    train_cnt = train_cnt.drop_channels(['STI 014', 'EOG-left',
                                         'EOG-central', 'EOG-right'])
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(a, low_cut_hz, 38, train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
                                                  init_block_size=1000,
                                                  eps=1e-4).T,
        train_cnt)

    test_cnt = test_cnt.drop_channels(['STI 014', 'EOG-left',
                                       'EOG-central', 'EOG-right'])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(
        lambda a: bandpass_cnt(a, low_cut_hz, 38, test_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), test_cnt)
    test_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
                                                  init_block_size=1000,
                                                  eps=1e-4).T,
        test_cnt)

    marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
                              ('Foot', [3]), ('Tongue', [4])])
    ival = [-500, 4000]
    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)

    train_set, valid_set = split_into_two_sets(train_set,
                                               first_set_fraction=0.8)
    set_random_seeds(seed=20190706, cuda=cuda)
    n_classes = 4
    n_chans = int(train_set.X.shape[1])
    input_time_length=1000
Exemplo n.º 3
0
def bandpass_mne(cnt, low_cut_hz, high_cut_hz, filt_order=3, axis=0):
    return mne_apply(
        lambda data: bandpass_cnt(data.T,
                                  low_cut_hz,
                                  high_cut_hz,
                                  fs=cnt.info['sfreq'],
                                  filt_order=filt_order,
                                  axis=axis).T, cnt)
Exemplo n.º 4
0
def preprocess(data):
    # data *= 1e6
    low_cut_hz = 4
    high_cut_hz = None
    return bandpass_cnt(data,
                        low_cut_hz,
                        high_cut_hz,
                        250,
                        filt_order=3,
                        axis=0)
Exemplo n.º 5
0
def standardize_cnt(cnt, standardize_mode=0):
    # computing frequencies
    sampling_freq = cnt.info['sfreq']
    init_freq = 0.1
    stop_freq = sampling_freq / 2 - 0.1
    filt_order = 3
    axis = 0
    filtfilt = False

    # filtering DC and frequencies higher than the nyquist one
    cnt = mne_apply(
        lambda x: bandpass_cnt(data=x,
                               low_cut_hz=init_freq,
                               high_cut_hz=stop_freq,
                               fs=sampling_freq,
                               filt_order=filt_order,
                               axis=axis,
                               filtfilt=filtfilt), cnt)

    # removing mean and normalizing in 3 different ways
    if standardize_mode == 0:
        # x - mean
        cnt = mne_apply(lambda x: x - mean(x, axis=0, keepdims=True), cnt)
    elif standardize_mode == 1:
        # (x - mean) / std
        cnt = mne_apply(
            lambda x: (x - mean(x, axis=0, keepdims=True)) / std(
                x, axis=0, keepdims=True), cnt)
    elif standardize_mode == 2:
        # parsing to milli volt for numerical stability of next operations
        cnt = mne_apply(lambda a: a * 1e6, cnt)

        # applying exponential_running_standardize (Schirrmeister)
        cnt = mne_apply(
            lambda x: exponential_running_standardize(
                x.T, factor_new=1e-3, init_block_size=1000, eps=1e-4).T, cnt)
    return cnt
Exemplo n.º 6
0
def process_bbci_data(filename, labels_filename, low_cut_hz):
    ival = [-500, 4000]
    high_cut_hz = 38
    factor_new = 1e-3
    init_block_size = 1000

    loader = BCICompetition4Set2A(filename, labels_filename=labels_filename)
    cnt = loader.load()

    # Preprocessing
    cnt = cnt.drop_channels(
        ['STI 014', 'EOG-left', 'EOG-central', 'EOG-right'])
    assert len(cnt.ch_names) == 22

    # lets convert to millvolt for numerical stability of next operations
    cnt = mne_apply(lambda a: a * 1e6, cnt)
    cnt = mne_apply(
        lambda a: bandpass_cnt(a,
                               low_cut_hz,
                               high_cut_hz,
                               cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), cnt)
    cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T,
                                                  factor_new=factor_new,
                                                  init_block_size=
                                                  init_block_size,
                                                  eps=1e-4).T, cnt)

    marker_def = OrderedDict([('Left Hand', [1]), (
        'Right Hand',
        [2],
    ), ('Foot', [3]), ('Tongue', [4])])

    dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival)
    return dataset
def run_exp(max_recording_mins, n_recordings, sec_to_cut,
            duration_recording_mins, max_abs_val, max_min_threshold,
            max_min_expected, shrink_val, max_min_remove, batch_set_zero_val,
            batch_set_zero_test, sampling_freq, low_cut_hz, high_cut_hz,
            exp_demean, exp_standardize, moving_demean, moving_standardize,
            channel_demean, channel_standardize, divisor, n_folds, i_test_fold,
            input_time_length, final_conv_length, pool_stride, n_blocks_to_add,
            sigmoid, model_constraint, batch_size, max_epochs,
            only_return_exp):
    cuda = True

    preproc_functions = []
    preproc_functions.append(lambda data, fs: (
        data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
    preproc_functions.append(lambda data, fs: (data[:, :int(
        duration_recording_mins * 60 * fs)], fs))
    if max_abs_val is not None:
        preproc_functions.append(
            lambda data, fs: (np.clip(data, -max_abs_val, max_abs_val), fs))
    if max_min_threshold is not None:
        preproc_functions.append(lambda data, fs: (clean_jumps(
            data, 200, max_min_threshold, max_min_expected, cuda), fs))
    if max_min_remove is not None:
        window_len = 200
        preproc_functions.append(lambda data, fs: (set_jumps_to_zero(
            data,
            window_len=window_len,
            threshold=max_min_remove,
            cuda=cuda,
            clip_min_max_to_zero=True), fs))

    if shrink_val is not None:
        preproc_functions.append(lambda data, fs: (shrink_spikes(
            data,
            shrink_val,
            1,
            9,
        ), fs))

    preproc_functions.append(lambda data, fs: (resampy.resample(
        data, fs, sampling_freq, axis=1, filter='kaiser_fast'), sampling_freq))
    preproc_functions.append(lambda data, fs: (bandpass_cnt(
        data, low_cut_hz, high_cut_hz, fs, filt_order=4, axis=1), fs))

    if exp_demean:
        preproc_functions.append(lambda data, fs: (exponential_running_demean(
            data.T, factor_new=0.001, init_block_size=100).T, fs))
    if exp_standardize:
        preproc_functions.append(
            lambda data, fs: (exponential_running_standardize(
                data.T, factor_new=0.001, init_block_size=100).T, fs))
    if moving_demean:
        preproc_functions.append(lambda data, fs: (padded_moving_demean(
            data, axis=1, n_window=201), fs))
    if moving_standardize:
        preproc_functions.append(lambda data, fs: (padded_moving_standardize(
            data, axis=1, n_window=201), fs))
    if channel_demean:
        preproc_functions.append(lambda data, fs: (demean(data, axis=1), fs))
    if channel_standardize:
        preproc_functions.append(lambda data, fs:
                                 (standardize(data, axis=1), fs))
    if divisor is not None:
        preproc_functions.append(lambda data, fs: (data / divisor, fs))

    dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions)
    if not only_return_exp:
        X, y = dataset.load()

    splitter = Splitter(
        n_folds,
        i_test_fold,
    )
    if not only_return_exp:
        train_set, valid_set, test_set = splitter.split(X, y)
        del X, y  # shouldn't be necessary, but just to make sure
    else:
        train_set = None
        valid_set = None
        test_set = None

    set_random_seeds(seed=20170629, cuda=cuda)
    if sigmoid:
        n_classes = 1
    else:
        n_classes = 2
    in_chans = 21

    net = Deep4Net(
        in_chans=in_chans,
        n_classes=n_classes,
        input_time_length=input_time_length,
        final_conv_length=final_conv_length,
        pool_time_length=pool_stride,
        pool_time_stride=pool_stride,
        n_filters_2=50,
        n_filters_3=80,
        n_filters_4=120,
    )
    model = net_with_more_layers(net, n_blocks_to_add, nn.MaxPool2d)
    if sigmoid:
        model = to_linear_plus_minus_net(model)
    optimizer = optim.Adam(model.parameters())
    to_dense_prediction_model(model)
    log.info("Model:\n{:s}".format(str(model)))
    if cuda:
        model.cuda()
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    log.info("{:d} predictions per input/trial".format(n_preds_per_input))
    iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)
    if sigmoid:
        loss_function = lambda preds, targets: binary_cross_entropy_with_logits(
            th.mean(preds, dim=2)[:, 1, 0], targets.type_as(preds))
    else:
        loss_function = lambda preds, targets: F.nll_loss(
            th.mean(preds, dim=2)[:, :, 0], targets)

    if model_constraint is not None:
        model_constraint = MaxNormDefaultConstraint()
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length),
        RuntimeMonitor(),
    ]
    stop_criterion = MaxEpochs(max_epochs)
    batch_modifier = None
    if batch_set_zero_val is not None:
        batch_modifier = RemoveMinMaxDiff(batch_set_zero_val,
                                          clip_max_abs=True,
                                          set_zero=True)
    if (batch_set_zero_val is not None) and (batch_set_zero_test == True):
        iterator = ModifiedIterator(
            iterator,
            batch_modifier,
        )
        batch_modifier = None
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     batch_modifier=batch_modifier,
                     cuda=cuda)
    if not only_return_exp:
        exp.run()
    else:
        exp.dataset = dataset
        exp.splitter = splitter

    return exp
Exemplo n.º 8
0
def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(train_filepath,
                                        labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(test_filepath,
                                       labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()

    # Preprocessing

    train_cnt = train_cnt.drop_channels(
        ['STI 014', 'EOG-left', 'EOG-central', 'EOG-right'])
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(
            a, low_cut_hz, 38, train_cnt.info['sfreq'], filt_order=3, axis=1),
        train_cnt)
    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(
            a.T, factor_new=1e-3, init_block_size=1000, eps=1e-4).T, train_cnt)

    test_cnt = test_cnt.drop_channels(
        ['STI 014', 'EOG-left', 'EOG-central', 'EOG-right'])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(
        lambda a: bandpass_cnt(
            a, low_cut_hz, 38, test_cnt.info['sfreq'], filt_order=3, axis=1),
        test_cnt)
    test_cnt = mne_apply(
        lambda a: exponential_running_standardize(
            a.T, factor_new=1e-3, init_block_size=1000, eps=1e-4).T, test_cnt)

    marker_def = OrderedDict([('Left Hand', [1]), (
        'Right Hand',
        [2],
    ), ('Foot', [3]), ('Tongue', [4])])
    ival = [-500, 4000]

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)

    train_set, valid_set = split_into_two_sets(train_set,
                                               first_set_fraction=0.8)

    set_random_seeds(seed=20190706, cuda=cuda)

    n_classes = 4
    n_chans = int(train_set.X.shape[1])
    input_time_length = train_set.X.shape[2]
    if model == 'shallow':
        model = ShallowFBCSPNet(n_chans,
                                n_classes,
                                input_time_length=input_time_length,
                                final_conv_length='auto').create_network()
    elif model == 'deep':
        model = Deep4Net(n_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length='auto').create_network()
    if cuda:
        model.cuda()
    log.info("Model: \n{:s}".format(str(model)))

    optimizer = optim.Adam(model.parameters())

    iterator = BalancedBatchSizeIterator(batch_size=60)

    stop_criterion = Or([MaxEpochs(1600), NoDecrease('valid_misclass', 160)])

    monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]

    model_constraint = MaxNormDefaultConstraint()

    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator=iterator,
                     loss_function=F.nll_loss,
                     optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     cuda=cuda)
    exp.run()
    return exp
Exemplo n.º 9
0
def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
    ival = [-500, 4000]
    input_time_length = 1000
    max_epochs = 800
    max_increase_epochs = 80
    batch_size = 60
    high_cut_hz = 38
    factor_new = 1e-3
    init_block_size = 1000
    valid_set_fraction = 0.2

    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(train_filepath,
                                        labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(test_filepath,
                                       labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()

    # Preprocessing

    train_cnt = train_cnt.drop_channels(
        ['STI 014', 'EOG-left', 'EOG-central', 'EOG-right'])
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(a,
                               low_cut_hz,
                               high_cut_hz,
                               train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T,
                                                  factor_new=factor_new,
                                                  init_block_size=
                                                  init_block_size,
                                                  eps=1e-4).T, train_cnt)

    test_cnt = test_cnt.drop_channels(
        ['STI 014', 'EOG-left', 'EOG-central', 'EOG-right'])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(
        lambda a: bandpass_cnt(a,
                               low_cut_hz,
                               high_cut_hz,
                               test_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), test_cnt)
    test_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T,
                                                  factor_new=factor_new,
                                                  init_block_size=
                                                  init_block_size,
                                                  eps=1e-4).T, test_cnt)

    marker_def = OrderedDict([('Left Hand', [1]), (
        'Right Hand',
        [2],
    ), ('Foot', [3]), ('Tongue', [4])])

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)

    train_set, valid_set = split_into_two_sets(train_set,
                                               first_set_fraction=1 -
                                               valid_set_fraction)

    set_random_seeds(seed=20190706, cuda=cuda)

    n_classes = 4
    n_chans = int(train_set.X.shape[1])
    if model == 'shallow':
        model = ShallowFBCSPNet(n_chans,
                                n_classes,
                                input_time_length=input_time_length,
                                final_conv_length=30).create_network()
    elif model == 'deep':
        model = Deep4Net(n_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length=2).create_network()

    to_dense_prediction_model(model)
    if cuda:
        model.cuda()

    log.info("Model: \n{:s}".format(str(model)))
    dummy_input = np_to_var(train_set.X[:1, :, :, None])
    if cuda:
        dummy_input = dummy_input.cuda()
    out = model(dummy_input)

    n_preds_per_input = out.cpu().data.numpy().shape[2]

    optimizer = optim.Adam(model.parameters())

    iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)

    stop_criterion = Or([
        MaxEpochs(max_epochs),
        NoDecrease('valid_misclass', max_increase_epochs)
    ])

    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length=input_time_length),
        RuntimeMonitor()
    ]

    model_constraint = MaxNormDefaultConstraint()

    loss_function = lambda preds, targets: F.nll_loss(
        th.mean(preds, dim=2, keepdim=False), targets)

    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator=iterator,
                     loss_function=loss_function,
                     optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     cuda=cuda)
    exp.run()
    return exp
Exemplo n.º 10
0
def run_exp(data_folder, subject_id, low_cut_hz, test_model, model_PATH, model, cuda):
    ival = [-500, 4000]
    high_cut_hz = 38
    factor_new = 1e-3
    init_block_size = 1000

    test_filename = "A{:02d}E.gdf".format(subject_id)
    test_filepath = os.path.join(data_folder, test_filename)
    test_label_filepath = test_filepath.replace(".gdf", ".mat")

    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath
    )
    test_cnt = test_loader.load()
    # Preprocessing
    test_cnt = test_cnt.drop_channels(["EOG-left", "EOG-central", "EOG-right"])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(
        lambda a: bandpass_cnt(
            a,
            low_cut_hz,
            high_cut_hz,
            test_cnt.info["sfreq"],
            filt_order=3,
            axis=1,
        ),
        test_cnt,
    )
    test_cnt = mne_apply(
        lambda a: exponential_running_standardize(
            a.T,
            factor_new=factor_new,
            init_block_size=init_block_size,
            eps=1e-4,
        ).T,
        test_cnt,
    )
    marker_def = OrderedDict(
        [
            ("Left Hand", [1]),
            ("Right Hand", [2]),
            ("Foot", [3]),
            ("Tongue", [4]),
        ]
    )
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
    test_set = data_all_chan_cwtandraw(test_set)
    set_random_seeds(seed=20200104, cuda=cuda)

    model = SelfShallow()

    if test_model:
        model.load_state_dict(torch.load(model_PATH))

    if cuda:
        model.cuda()
        model.eval()
    log.info("Model: \n{:s}".format(str(model)))
    all_test_labels = test_set.y
    all_test_data = torch.from_numpy(test_set.X).cuda()
    preds,feature1,feature2,raw,guide  = model(all_test_data)

    preds = preds.cpu()
    preds = preds.detach().numpy()
    all_preds = np.argmax(preds, axis=1).squeeze()
    accury = np.mean(all_test_labels==all_preds)
    print('accury:',accury)
Exemplo n.º 11
0
def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
    ival = [-500, 4000]
    max_epochs = 1600
    max_increase_epochs = 160
    batch_size = 60
    high_cut_hz = 38
    factor_new = 1e-3
    init_block_size = 1000
    valid_set_fraction = 0.2

    train_filename = "A{:02d}T.gdf".format(subject_id)
    test_filename = "A{:02d}E.gdf".format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace(".gdf", ".mat")
    test_label_filepath = test_filepath.replace(".gdf", ".mat")

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath
    )
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath
    )
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()

    # Preprocessing

    train_cnt = train_cnt.drop_channels(
        ["EOG-left", "EOG-central", "EOG-right"]
    )
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(
            a,
            low_cut_hz,
            high_cut_hz,
            train_cnt.info["sfreq"],
            filt_order=3,
            axis=1,
        ),
        train_cnt,
    )
    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(
            a.T,
            factor_new=factor_new,
            init_block_size=init_block_size,
            eps=1e-4,
        ).T,
        train_cnt,
    )

    test_cnt = test_cnt.drop_channels(["EOG-left", "EOG-central", "EOG-right"])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(
        lambda a: bandpass_cnt(
            a,
            low_cut_hz,
            high_cut_hz,
            test_cnt.info["sfreq"],
            filt_order=3,
            axis=1,
        ),
        test_cnt,
    )
    test_cnt = mne_apply(
        lambda a: exponential_running_standardize(
            a.T,
            factor_new=factor_new,
            init_block_size=init_block_size,
            eps=1e-4,
        ).T,
        test_cnt,
    )

    marker_def = OrderedDict(
        [
            ("Left Hand", [1]),
            ("Right Hand", [2]),
            ("Foot", [3]),
            ("Tongue", [4]),
        ]
    )

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)

    train_set, valid_set = split_into_two_sets(
        train_set, first_set_fraction=1 - valid_set_fraction
    )

    set_random_seeds(seed=20190706, cuda=cuda)

    n_classes = 4
    n_chans = int(train_set.X.shape[1])
    input_time_length = train_set.X.shape[2]
    if model == "shallow":
        model = ShallowFBCSPNet(
            n_chans,
            n_classes,
            input_time_length=input_time_length,
            final_conv_length="auto",
        ).create_network()
    elif model == "deep":
        model = Deep4Net(
            n_chans,
            n_classes,
            input_time_length=input_time_length,
            final_conv_length="auto",
        ).create_network()
    if cuda:
        model.cuda()
    log.info("Model: \n{:s}".format(str(model)))

    optimizer = optim.Adam(model.parameters())

    iterator = BalancedBatchSizeIterator(batch_size=batch_size)

    stop_criterion = Or(
        [
            MaxEpochs(max_epochs),
            NoDecrease("valid_misclass", max_increase_epochs),
        ]
    )

    monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]

    model_constraint = MaxNormDefaultConstraint()

    exp = Experiment(
        model,
        train_set,
        valid_set,
        test_set,
        iterator=iterator,
        loss_function=F.nll_loss,
        optimizer=optimizer,
        model_constraint=model_constraint,
        monitors=monitors,
        stop_criterion=stop_criterion,
        remember_best_column="valid_misclass",
        run_after_early_stop=True,
        cuda=cuda,
    )
    exp.run()
    return exp
Exemplo n.º 12
0
    test_filepath, labels_filename=test_label_filepath
)
train_cnt = train_loader.load()
test_cnt = test_loader.load()

# Preprocessing

train_cnt = train_cnt.drop_channels(["EOG-left", "EOG-central", "EOG-right"])
assert len(train_cnt.ch_names) == 22
# lets convert to millvolt for numerical stability of next operations
train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
train_cnt = mne_apply(
    lambda a: bandpass_cnt(
        a,
        low_cut_hz,
        high_cut_hz,
        train_cnt.info["sfreq"],
        filt_order=3,
        axis=1,
    ),
    train_cnt,
)
train_cnt = mne_apply(
    lambda a: exponential_running_standardize(
        a.T, factor_new=factor_new, init_block_size=init_block_size, eps=1e-4
    ).T,
    train_cnt,
)

test_cnt = test_cnt.drop_channels(["EOG-left", "EOG-central", "EOG-right"])
assert len(test_cnt.ch_names) == 22
test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
Exemplo n.º 13
0
def get_bci_iv_2a_train_val_test(data_folder, subject_id, low_cut_hz):
    ival = [
        -500, 4000
    ]  # this is the window around the event from which we will take data to feed to the classifier
    high_cut_hz = 38  # cut off parts of signal higher than 38 hz
    factor_new = 1e-3  # ??? has to do with exponential running standardize
    init_block_size = 1000  # ???

    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(train_filepath,
                                        labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(test_filepath,
                                       labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()

    train_cnt = train_cnt.drop_channels(
        ['EOG-left', 'EOG-central', 'EOG-right'])
    if len(train_cnt.ch_names) > 22:
        train_cnt = train_cnt.drop_channels(['STI 014'])
    assert len(train_cnt.ch_names) == 22

    # convert measurements to millivolt
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(  # signal processing procedure that I don't understand
        lambda a: bandpass_cnt(a,
                               low_cut_hz,
                               high_cut_hz,
                               train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = mne_apply(  # signal processing procedure that I don't understand
        lambda a: exponential_running_standardize(a.T,
                                                  factor_new=factor_new,
                                                  init_block_size=
                                                  init_block_size,
                                                  eps=1e-4).T, train_cnt)

    test_cnt = test_cnt.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])
    if len(test_cnt.ch_names) > 22:
        test_cnt = test_cnt.drop_channels(['STI 014'])
    assert len(test_cnt.ch_names) == 22

    # convert measurements to millivolt
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(
        lambda a: bandpass_cnt(a,
                               low_cut_hz,
                               high_cut_hz,
                               test_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), test_cnt)
    test_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T,
                                                  factor_new=factor_new,
                                                  init_block_size=
                                                  init_block_size,
                                                  eps=1e-4).T, test_cnt)

    marker_def = OrderedDict([('Left Hand', [1]), (
        'Right Hand',
        [2],
    ), ('Foot', [3]), ('Tongue', [4])])

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
    train_set, valid_set = split_into_two_sets(
        train_set,
        first_set_fraction=1 - global_vars.get('valid_set_fraction'))

    return train_set, valid_set, test_set
Exemplo n.º 14
0
def data_gen(subject, high_cut_hz=38, low_cut_hz=0):
    data_sub = {}
    for i in range(len(subject)):
        subject_id = subject[i]
        data_folder = r'D:\li\=.=\eeg\hw\nn-STFT\dataset\BCICIV_2a_gdf'
        ival = [-500, 4000]
        factor_new = 1e-3
        init_block_size = 1000

        train_filename = 'A{:02d}T.gdf'.format(subject_id)
        test_filename = 'A{:02d}E.gdf'.format(subject_id)
        train_filepath = os.path.join(data_folder, train_filename)
        test_filepath = os.path.join(data_folder, test_filename)
        train_label_filepath = train_filepath.replace('.gdf', '.mat')
        test_label_filepath = test_filepath.replace('.gdf', '.mat')

        train_loader = BCICompetition4Set2A(
            train_filepath, labels_filename=train_label_filepath)
        test_loader = BCICompetition4Set2A(test_filepath,
                                           labels_filename=test_label_filepath)

        train_cnt = train_loader.load()
        test_cnt = test_loader.load()

        # train_loader = BCICompetition4Set2A(
        #     train_filepath, labels_filename=train_label_filepath)
        # test_loader = BCICompetition4Set2A(
        #     test_filepath, labels_filename=test_label_filepath)

        # train_cnt = train_loader.load()
        # test_cnt = test_loader.load()

        # train set process
        train_cnt = train_cnt.drop_channels(
            ['EOG-left', 'EOG-central', 'EOG-right'])
        assert len(train_cnt.ch_names) == 22

        train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
        train_cnt = mne_apply(
            lambda a: bandpass_cnt(a,
                                   low_cut_hz,
                                   high_cut_hz,
                                   train_cnt.info['sfreq'],
                                   filt_order=3,
                                   axis=1), train_cnt)

        train_cnt = mne_apply(
            lambda a: exponential_running_standardize(a.T,
                                                      factor_new=factor_new,
                                                      init_block_size=
                                                      init_block_size,
                                                      eps=1e-4).T, train_cnt)

        # test set process
        test_cnt = test_cnt.drop_channels(
            ['EOG-left', 'EOG-central', 'EOG-right'])
        assert len(test_cnt.ch_names) == 22
        test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
        test_cnt = mne_apply(
            lambda a: bandpass_cnt(a,
                                   low_cut_hz,
                                   high_cut_hz,
                                   test_cnt.info['sfreq'],
                                   filt_order=3,
                                   axis=1), test_cnt)
        test_cnt = mne_apply(
            lambda a: exponential_running_standardize(a.T,
                                                      factor_new=factor_new,
                                                      init_block_size=
                                                      init_block_size,
                                                      eps=1e-4).T, test_cnt)

        marker_def = OrderedDict([('Left Hand', [1]), (
            'Right Hand',
            [2],
        ), ('Foot', [3]), ('Tongue', [4])])

        train_set = create_signal_target_from_raw_mne(train_cnt, marker_def,
                                                      ival)
        test_set = create_signal_target_from_raw_mne(test_cnt, marker_def,
                                                     ival)

        data_sub[str(subject_id)] = concatenate_sets([train_set, test_set])
        if i == 0:
            dataset = data_sub[str(subject_id)]
        else:
            dataset = concatenate_sets([dataset, data_sub[str(subject_id)]])
    assert len(data_sub) == len(subject)

    return dataset
def concat_prepare_cnn(input_signal):
    # In this particular data set
    # it was required by the author of it,
    # that for preventing the algorithm
    # to pick on data of the eye movement,
    # a high band filter of Hz had to
    # be implimented.
    low_cut_hz = 1

    # The authors prove both configuration >38  an 38< frequenzy
    # in the current experiment, we see that band pass filter will take
    # Theta to part of Gamma frequenzy band
    # whihch what Filter Bank Commun spatial patters would do.
    # This value is a hiperpartemer that should be ajusted
    # per data set... In my opinion.
    high_cut_hz = 40

    # factor for exponential smothing
    # are this numbers usually setup used
    # on neuro sciencie?
    factor_new = 1e-3

    # initianlization values for the the mean and variance,
    # see prior discussion
    init_block_size = 1000

    # model = "shallow"  #'shallow' or 'deep'
    # GPU support
    # cuda = True

    # It was stated in the paper [1] that
    # "trial   window   for   later   experiments   with   convolutional
    # networks,  that  is,  from  0.5  to  4  s."
    # 0- 20s?
    # so "ival" variable simple states what milisecond interval to analize
    # per trial.
    ival = [0, 20000]

    # An epoch increase every time the whole training data point
    # had been input to the network. An epoch is not a batch
    # example, if we have 100 training data points
    # and we use batch_size 10, it will take 10 iterations of
    # batch_size to reach 1 epoch.
    # max_epochs = 1600

    # max_increase_epochs = 160

    # 60 data point per forward-backwards propagation
    # batch_size = 60

    # pertecentage of data to be used as test-set
    valid_set_fraction = 0.2

    gdf_events = mne.find_events(input_signal)

    input_signal = input_signal.drop_channels(["stim"])

    raw_training_signal = input_signal.get_data()

    print("data shape:", raw_training_signal.shape)

    for i_chan in range(raw_training_signal.shape[0]):
        # first set to nan, than replace nans by nanmean.
        this_chan = raw_training_signal[i_chan]
        raw_training_signal[i_chan] = np.where(this_chan == np.min(this_chan),
                                               np.nan, this_chan)
        mask = np.isnan(raw_training_signal[i_chan])
        chan_mean = np.nanmean(raw_training_signal[i_chan])
        raw_training_signal[i_chan, mask] = chan_mean

    # Reconstruct
    input_signal = mne.io.RawArray(raw_training_signal,
                                   input_signal.info,
                                   verbose="WARNING")
    # append the extracted events
    # raw_gdf_training_signal
    # raw_gdf_training_signal
    input_signal.info["events"] = gdf_events

    train_cnt = input_signal
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(
            a,
            low_cut_hz,
            high_cut_hz,
            train_cnt.info["sfreq"],
            filt_order=3,
            axis=1,
        ),
        train_cnt,
    )

    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(
            a.T,
            factor_new=factor_new,
            init_block_size=init_block_size,
            eps=1e-4,
        ).T,
        train_cnt,
    )

    marker_def = OrderedDict([("ec", [30])])

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    return train_set
Exemplo n.º 16
0
def run_exp(data_folder, session_id, subject_id, low_cut_hz, model, cuda):
    ival = [-500, 4000]
    max_epochs = 1600
    max_increase_epochs = 160
    batch_size = 10
    high_cut_hz = 38
    factor_new = 1e-3
    init_block_size = 1000
    valid_set_fraction = .2
    ''' # BCIcompetition
    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()
    '''

    # GIGAscience
    filename = 'sess{:02d}_subj{:02d}_EEG_MI.mat'.format(
        session_id, subject_id)
    filepath = os.path.join(data_folder, filename)
    train_variable = 'EEG_MI_train'
    test_variable = 'EEG_MI_test'

    train_loader = GIGAscience(filepath, train_variable)
    test_loader = GIGAscience(filepath, test_variable)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()

    # Preprocessing
    ''' channel
    ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC5', 'FC1', 'FC2', 'FC6', 'T7', 'C3', 'Cz', 'C4', 'T8', 'TP9', 'CP5',
     'CP1', 'CP2', 'CP6', 'TP10', 'P7', 'P3', 'Pz', 'P4', 'P8', 'PO9', 'O1', 'Oz', 'O2', 'PO10', 'FC3', 'FC4', 'C5',
     'C1', 'C2', 'C6', 'CP3', 'CPz', 'CP4', 'P1', 'P2', 'POz', 'FT9', 'FTT9h', 'TTP7h', 'TP7', 'TPP9h', 'FT10',
     'FTT10h', 'TPP8h', 'TP8', 'TPP10h', 'F9', 'F10', 'AF7', 'AF3', 'AF4', 'AF8', 'PO3', 'PO4']
    '''

    train_cnt = train_cnt.pick_channels([
        'FC5', 'FC3', 'FC1', 'Fz', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'Cz',
        'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'Pz'
    ])
    train_cnt, train_cnt.info['events'] = train_cnt.copy().resample(
        250, npad='auto', events=train_cnt.info['events'])

    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(a,
                               low_cut_hz,
                               high_cut_hz,
                               train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T,
                                                  factor_new=factor_new,
                                                  init_block_size=
                                                  init_block_size,
                                                  eps=1e-4).T, train_cnt)

    test_cnt = test_cnt.pick_channels([
        'FC5', 'FC3', 'FC1', 'Fz', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'Cz',
        'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'Pz'
    ])
    test_cnt, test_cnt.info['events'] = test_cnt.copy().resample(
        250, npad='auto', events=test_cnt.info['events'])

    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(
        lambda a: bandpass_cnt(a,
                               low_cut_hz,
                               high_cut_hz,
                               test_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), test_cnt)
    test_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T,
                                                  factor_new=factor_new,
                                                  init_block_size=
                                                  init_block_size,
                                                  eps=1e-4).T, test_cnt)

    marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2])])

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)

    train_set, valid_set = split_into_two_sets(train_set,
                                               first_set_fraction=1 -
                                               valid_set_fraction)

    set_random_seeds(seed=20190706, cuda=cuda)

    n_classes = 2
    n_chans = int(train_set.X.shape[1])
    input_time_length = train_set.X.shape[2]
    if model == 'shallow':
        model = ShallowFBCSPNet(n_chans,
                                n_classes,
                                input_time_length=input_time_length,
                                final_conv_length='auto').create_network()
    elif model == 'deep':
        model = Deep4Net(n_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length='auto').create_network()
    if cuda:
        model.cuda()
    log.info("Model: \n{:s}".format(str(model)))

    optimizer = optim.Adam(model.parameters())

    iterator = BalancedBatchSizeIterator(batch_size=batch_size)

    stop_criterion = Or([
        MaxEpochs(max_epochs),
        NoDecrease('valid_misclass', max_increase_epochs)
    ])

    monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]

    model_constraint = MaxNormDefaultConstraint()

    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator=iterator,
                     loss_function=F.nll_loss,
                     optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     cuda=cuda)
    exp.run()
    return exp
Exemplo n.º 17
0
def run_exp(max_recording_mins, n_recordings, sec_to_cut,
            duration_recording_mins, max_abs_val, max_min_threshold,
            max_min_expected, shrink_val, max_min_remove, batch_set_zero_val,
            batch_set_zero_test, sampling_freq, low_cut_hz, high_cut_hz,
            exp_demean, exp_standardize, moving_demean, moving_standardize,
            channel_demean, channel_standardize, divisor, n_folds, i_test_fold,
            model_name, input_time_length, final_conv_length, batch_size,
            max_epochs, only_return_exp):
    cuda = True

    preproc_functions = []
    preproc_functions.append(lambda data, fs: (
        data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
    preproc_functions.append(lambda data, fs: (data[:, :int(
        duration_recording_mins * 60 * fs)], fs))
    if max_abs_val is not None:
        preproc_functions.append(
            lambda data, fs: (np.clip(data, -max_abs_val, max_abs_val), fs))
    if max_min_threshold is not None:
        preproc_functions.append(lambda data, fs: (clean_jumps(
            data, 200, max_min_threshold, max_min_expected, cuda), fs))
    if max_min_remove is not None:
        window_len = 200
        preproc_functions.append(lambda data, fs: (set_jumps_to_zero(
            data,
            window_len=window_len,
            threshold=max_min_remove,
            cuda=cuda,
            clip_min_max_to_zero=True), fs))

    if shrink_val is not None:
        preproc_functions.append(lambda data, fs: (shrink_spikes(
            data,
            shrink_val,
            1,
            9,
        ), fs))

    preproc_functions.append(lambda data, fs: (resampy.resample(
        data, fs, sampling_freq, axis=1, filter='kaiser_fast'), sampling_freq))
    preproc_functions.append(lambda data, fs: (bandpass_cnt(
        data, low_cut_hz, high_cut_hz, fs, filt_order=4, axis=1), fs))

    if exp_demean:
        preproc_functions.append(lambda data, fs: (exponential_running_demean(
            data.T, factor_new=0.001, init_block_size=100).T, fs))
    if exp_standardize:
        preproc_functions.append(
            lambda data, fs: (exponential_running_standardize(
                data.T, factor_new=0.001, init_block_size=100).T, fs))
    if moving_demean:
        preproc_functions.append(lambda data, fs: (padded_moving_demean(
            data, axis=1, n_window=201), fs))
    if moving_standardize:
        preproc_functions.append(lambda data, fs: (padded_moving_standardize(
            data, axis=1, n_window=201), fs))
    if channel_demean:
        preproc_functions.append(lambda data, fs: (demean(data, axis=1), fs))
    if channel_standardize:
        preproc_functions.append(lambda data, fs:
                                 (standardize(data, axis=1), fs))
    if divisor is not None:
        preproc_functions.append(lambda data, fs: (data / divisor, fs))

    all_file_names, labels = get_all_sorted_file_names_and_labels()
    lengths = np.load(
        '/home/schirrmr/code/auto-diagnosis/sorted-recording-lengths.npy')
    mask = lengths < max_recording_mins * 60
    cleaned_file_names = np.array(all_file_names)[mask]
    cleaned_labels = labels[mask]

    diffs_per_rec = np.load(
        '/home/schirrmr/code/auto-diagnosis/diffs_per_recording.npy')

    def create_set(inds):
        X = []
        for i in inds:
            log.info("Load {:s}".format(cleaned_file_names[i]))
            x = load_data(cleaned_file_names[i], preproc_functions)
            X.append(x)
        y = cleaned_labels[inds].astype(np.int64)
        return SignalAndTarget(X, y)

    if not only_return_exp:
        folds = get_balanced_batches(n_recordings,
                                     None,
                                     False,
                                     n_batches=n_folds)
        test_inds = folds[i_test_fold]
        valid_inds = folds[i_test_fold - 1]
        all_inds = list(range(n_recordings))
        train_inds = np.setdiff1d(all_inds, np.union1d(test_inds, valid_inds))

        rec_nr_sorted_by_diff = np.argsort(diffs_per_rec)[::-1]
        train_inds = rec_nr_sorted_by_diff[train_inds]
        valid_inds = rec_nr_sorted_by_diff[valid_inds]
        test_inds = rec_nr_sorted_by_diff[test_inds]

        train_set = create_set(train_inds)
        valid_set = create_set(valid_inds)
        test_set = create_set(test_inds)
    else:
        train_set = None
        valid_set = None
        test_set = None

    set_random_seeds(seed=20170629, cuda=cuda)
    # This will determine how many crops are processed in parallel
    n_classes = 2
    in_chans = 21
    if model_name == 'shallow':
        model = ShallowFBCSPNet(
            in_chans=in_chans,
            n_classes=n_classes,
            input_time_length=input_time_length,
            final_conv_length=final_conv_length).create_network()
    elif model_name == 'deep':
        model = Deep4Net(in_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length=final_conv_length).create_network()

    optimizer = optim.Adam(model.parameters())
    to_dense_prediction_model(model)
    log.info("Model:\n{:s}".format(str(model)))
    if cuda:
        model.cuda()
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    log.info("{:d} predictions per input/trial".format(n_preds_per_input))
    iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)
    loss_function = lambda preds, targets: F.nll_loss(
        th.mean(preds, dim=2)[:, :, 0], targets)
    model_constraint = None
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length),
        RuntimeMonitor(),
    ]
    stop_criterion = MaxEpochs(max_epochs)
    batch_modifier = None
    if batch_set_zero_val is not None:
        batch_modifier = RemoveMinMaxDiff(batch_set_zero_val,
                                          clip_max_abs=True,
                                          set_zero=True)
    if (batch_set_zero_val is not None) and (batch_set_zero_test == True):
        iterator = ModifiedIterator(
            iterator,
            batch_modifier,
        )
        batch_modifier = None
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     batch_modifier=batch_modifier,
                     cuda=cuda)
    if not only_return_exp:
        exp.run()
    else:
        exp.dataset = None
        exp.splitter = None

    return exp