예제 #1
0
def add_padding(model, input_channels):
    """
    intended to add padding to the model, not a part of the thesis.
    :param model:
    :param input_channels:
    :return:
    """
    new_model = nn.Sequential()
    i = 0
    last_out = None
    for name, module in model.named_children():
        if hasattr(module, "dilation") and hasattr(
                module, 'kernel_size') and ('spat' not in name):
            dilation = module.dilation
            kernel_size = module.kernel_size
            right_padding = 0, 0, 0, (kernel_size[0] - 1) * dilation[0]
            new_model.add_module(name=f'{name}_pad',
                                 module=nn.ZeroPad2d(padding=right_padding))

            module.stride = (2, 1)
            new_model.add_module(name, module)
        else:
            new_model.add_module(name, module)
    n_preds_per_input = get_output_shape(new_model, input_channels, 1000)[1]
    new_model.add_module(name='last', module=nn.Linear(n_preds_per_input, 1))
    summary(new_model.cuda(device='cuda'), (85, 1000))
    print(new_model)
    summary(model, (85, 1200, 1))

    return new_model
예제 #2
0
def test_get_output_shape_1d_model():
    model = nn.Conv1d(1, 1, 3)
    out_shape = get_output_shape(model, in_chans=1, input_window_samples=5)
    assert out_shape == (
        1,
        1,
        3,
    )
    optimizer_weight_decay = 0.0005
else:
    raise ValueError(f'{model_name} unknown')

new_model = torch.nn.Sequential()
for name, module_ in model.named_children():
    if "softmax" in name:
        continue
    new_model.add_module(name, module_)
model = new_model

if cuda:
    model.cuda()

to_dense_prediction_model(model)
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]

train_set, valid_set = create_compatible_dataset('./data/BCICIV_4_mat/sub1_comp.mat')

# dataset = create_fixed_length_windows(
#     dataset,
#     start_offset_samples=0,
#     stop_offset_samples=0,
#     window_size_samples=input_window_samples,
#     window_stride_samples=n_preds_per_input,
#     drop_last_window=False,
#     drop_bad_windows=True,
# )

# splits = dataset.split("session")
# train_set = splits["train"]
예제 #4
0
def test_cropped_decoding():
    # 5,6,7,10,13,14 are codes for executed and imagined hands/feet
    subject_id = 1
    event_codes = [5, 6, 9, 10, 13, 14]

    # This will download the files if you don't have them yet,
    # and then return the paths to the files.
    physionet_paths = mne.datasets.eegbci.load_data(
        subject_id, event_codes, update_path=False
    )

    # Load each of the files
    parts = [
        mne.io.read_raw_edf(
            path, preload=True, stim_channel="auto", verbose="WARNING"
        )
        for path in physionet_paths
    ]

    # Concatenate them
    raw = concatenate_raws(parts)

    # Find the events in this dataset
    events, _ = mne.events_from_annotations(raw)
    # Use only EEG channels
    eeg_channel_inds = mne.pick_types(
        raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads"
    )

    # Extract trials, only using EEG channels
    epoched = mne.Epochs(
        raw,
        events,
        dict(hands=2, feet=3),
        tmin=1,
        tmax=4.1,
        proj=False,
        picks=eeg_channel_inds,
        baseline=None,
        preload=True,
    )
    # Convert data from volt to millivolt
    # Pytorch expects float32 for input and int64 for labels.
    X = (epoched.get_data() * 1e6).astype(np.float32)
    y = (epoched.events[:, 2] - 2).astype(np.int64)  # 2,3 -> 0,1

    # Set if you want to use GPU
    # You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
    cuda = False
    set_random_seeds(seed=20170629, cuda=cuda)

    # This will determine how many crops are processed in parallel
    input_time_length = 450
    n_classes = 2
    in_chans = X.shape[1]
    # final_conv_length determines the size of the receptive field of the ConvNet
    model = ShallowFBCSPNet(
        in_chans=in_chans,
        n_classes=n_classes,
        input_time_length=input_time_length,
        final_conv_length=12,
    )
    to_dense_prediction_model(model)

    if cuda:
        model.cuda()

    # Perform forward pass to determine how many outputs per input
    n_preds_per_input = get_output_shape(model, in_chans, input_time_length)[2]

    train_set = CroppedXyDataset(X[:60], y[:60],
                                 input_time_length=input_time_length,
                                 n_preds_per_input=n_preds_per_input)
    valid_set = CroppedXyDataset(X[60:], y=y[60:],
                                 input_time_length=input_time_length,
                                 n_preds_per_input=n_preds_per_input)
    train_split = predefined_split(valid_set)

    clf = EEGClassifier(
        model,
        cropped=True,
        criterion=CroppedLoss,
        criterion__loss_function=torch.nn.functional.nll_loss,
        optimizer=optim.Adam,
        train_split=train_split,
        batch_size=32,
        callbacks=['accuracy'],
    )

    clf.fit(train_set, y=None, epochs=4)

    np.testing.assert_allclose(
        clf.history[:, 'train_loss'],
        np.array(
            [
                1.455306,
                1.455934,
                1.210563,
                1.065806
            ]
        ),
        rtol=1e-4,
        atol=1e-5,
    )

    np.testing.assert_allclose(
        clf.history[:, 'valid_loss'],
        np.array(
            [
                2.547288,
                1.51785,
                1.394036,
                1.064355
            ]
        ),
        rtol=1e-4,
        atol=1e-4,
    )
    np.testing.assert_allclose(
        clf.history[:, 'train_accuracy'],
        np.array(
            [
                0.5,
                0.5,
                0.5,
                0.533333
            ]
        ),
        rtol=1e-4,
        atol=1e-5,
    )
    np.testing.assert_allclose(
        clf.history[:, 'valid_accuracy'],
        np.array(
            [
                0.533333,
                0.466667,
                0.533333,
                0.5
            ]
        ),
        rtol=1e-4,
        atol=1e-5,
    )
예제 #5
0
def train(data,
          dilation,
          kernel_size,
          lr,
          patient_index,
          model_string,
          correlation_monitor,
          output_dir,
          max_train_epochs=300,
          split=None,
          cropped=True,
          padding=False):
    """
    Creates and fits a model with the specified parameters onto the specified data
    :param data: dataset on which the model is to be trained
    :param dilation: dilation parameters of the model max-pool layers
    :param kernel_size: kernel sizes of the model's max-pool layers
    :param lr: learning rate
    :param patient_index: index of the patient on whose data the model is trained
    :param model_string: string specifying the setting of the data
    :param correlation_monitor: correlation monitor object calculating the correlations while fitting
    :param output_dir: where the trained model should be saved
    :param max_train_epochs: number of epochs for which to train the model
    :param split: the fold from cross-validation for which we are currently trainig the model
    :param cropped: if the decoding is cropped, alwasy True in thesis experiments
    :param padding: if padding should be added, always False in thesis experiments
    :return:
    """
    model, changed_model, model_name = get_model(data.in_channels,
                                                 input_time_length,
                                                 dilations=dilation,
                                                 kernel_sizes=kernel_size,
                                                 padding=padding)
    if cuda:
        device = 'cuda'
        model.model = changed_model.cuda()

    else:
        model.model = changed_model
        device = 'cpu'
    if not padding:
        n_preds_per_input = get_output_shape(model.model, model.input_channels,
                                             model.input_time_length)[1]
    else:
        n_preds_per_input = 1
    Path(home + f'/models/saved_models/{output_dir}/').mkdir(parents=True,
                                                             exist_ok=True)
    # cutting the input into batches compatible with model
    # if data.num_of_folds != -1, then also pre-whitening or filtering takes place
    # as part of the cut_input method
    data.cut_input(input_time_length=input_time_length,
                   n_preds_per_input=n_preds_per_input,
                   shuffle=False)

    print(
        f'starting cv epoch {split} out of {data.num_of_folds} for model: {model_string}_{model_name}'
    )
    correlation_monitor.step_number = 0
    if split is not None:
        correlation_monitor.split = split

    monitor = 'validation_correlation_best'

    monitors = [
        ('correlation monitor', correlation_monitor),
        ('checkpoint',
         Checkpoint(
             monitor=monitor,
             f_history=home +
             f'/logs/model_{model_name}/histories/{model_string}_k_{model_name}_p_{patient_index}.json',
         )),
    ]
    # cropped=False
    print('cropped:', cropped)

    # object EEGRegressor from the braindecode library suited for fitting models for regression tasks
    regressor = EEGRegressor(cropped=cropped,
                             module=model.model,
                             criterion=model.loss_function,
                             optimizer=model.optimizer,
                             max_epochs=max_train_epochs,
                             verbose=1,
                             train_split=data.cv_split,
                             callbacks=monitors,
                             lr=lr,
                             device=device,
                             batch_size=32).initialize()

    torch.save(
        model.model, home +
        f'/models/saved_models/{output_dir}/initial_{model_string}_{model_name}_p_{patient_index}'
    )
    regressor.max_correlation = -1000

    if padding:
        regressor.fit(data.train_set[0], data.train_set[1])

    regressor.fit(np.stack(data.train_set.X), np.stack(data.train_set.y))

    # best_model = load_model(
    #     f'/models/saved_models/{output_dir}/best_model_split_0')
    torch.save(model.model,
               home + f'/models/saved_models/{output_dir}/last_model_{split}')
    if cuda:
        best_corr = get_corr_coef(correlation_monitor.validation_set,
                                  model.model.cuda(device=device))
    else:
        best_corr = get_corr_coef(correlation_monitor.validation_set,
                                  model.model)
    print(patient_index, best_corr)
    return best_corr
예제 #6
0
def train_nets(model_string,
               patient_indices,
               dilation,
               kernel_size,
               lr,
               num_of_folds,
               trajectory_index,
               low_pass,
               shift,
               variable,
               result_df,
               max_train_epochs,
               high_pass=False,
               high_pass_valid=False,
               padding=False,
               cropped=True,
               low_pass_train=False,
               shift_by=None,
               saved_model_dir=f'lr_0.001',
               whiten=False,
               indices=None,
               dummy_dataset=False):
    """
    Performs num_of_folds cross-validation on each of the patients
    :param model_string: specifies the setting in which the model was trained
    :param patient_indices: specifies the indices for patients for which a model should be trained
    :param dilation: dilation parameter of the max-pool layers in the network
    :param kernel_size: the kernel sizes of the max-pool layers in the network
    :param lr: learning rate
    :param num_of_folds: number of cross-validation folds. If -1, then only one 80-20 split is performed.
    :param trajectory_index: 0 for velocity, 1 for absolute velocity
    :param low_pass: specifies if validation data should be low-passed
    :param shift: specifies if predicted time-point should be shifted
    :param variable: 'vel' for velocity, 'absVel' for absolute velocity
    :param result_df: pandas.DataFrame where the results for the different patients are to be saved
    :param max_train_epochs: number of epochs for which to train the network
    :param high_pass: specifies if the train set and validation set should be high-passed
    :param high_pass_valid: specifies if the validation set should be high-passed
    :param padding: specifies if padding should be added to the network. Always False in this thesis.
    :param cropped: specifies if the input should be cropped. Always True in this thesis.
    :param low_pass_train: specifies if the training set should be low-passed
    :param shift_by: specifies by how much to shift the predicted time-point with across the receptive field
    :param saved_model_dir: specifies where the models should be saved
    :param whiten: specifies if the dataset should be whitened
    :param indices: specifies the indices for the different folds

    :return: None, only saves the learning statistics
    """

    best_valid_correlations = []
    # valid_indices = {}
    # train_indices = {}
    curr_patient_indices = None
    if dummy_dataset:
        patient_indices = [1]
    for patient_index in patient_indices:
        if indices is not None:
            curr_patient_indices = indices[f'P_{patient_index}']
        if dummy_dataset:
            data_file = f'{home}/data/dummy_dataset.mat'
        else:
            data_file = f'{home}/previous_work/P{patient_index}_data.mat'
        print('data_file', data_file)
        input_channels = get_num_of_channels(data_file,
                                             dummy_dataset=dummy_dataset)
        model, changed_model, model_name = get_model(input_channels,
                                                     input_time_length,
                                                     dilations=dilation,
                                                     kernel_sizes=kernel_size,
                                                     padding=padding)
        small_window = 522
        if padding:
            data = OnePredictionData(
                home + f'/previous_work/P{patient_index}_data.mat',
                num_of_folds=num_of_folds,
                low_pass=low_pass,
                input_time_length=input_time_length,
                trajectory_index=trajectory_index,
                high_pass=high_pass,
                valid_high_pass=high_pass_valid)
        else:
            n_preds_per_input = get_output_shape(changed_model, input_channels,
                                                 input_time_length)[1]
            small_window = input_time_length - n_preds_per_input + 1
            if shift_by is None:
                shift_index = int(small_window / 2)
            else:
                shift_index = int((small_window / 2) - shift_by)
                print('shift_index:', shift_index)
            print('dummy dataset', dummy_dataset)
            data = Data(data_file,
                        num_of_folds=num_of_folds,
                        low_pass=low_pass,
                        trajectory_index=trajectory_index,
                        shift_data=shift,
                        high_pass=high_pass,
                        shift_by=int(shift_index),
                        valid_high_pass=high_pass_valid,
                        low_pass_training=low_pass_train,
                        pre_whiten=whiten,
                        indices=curr_patient_indices,
                        dummy_dataset=dummy_dataset)
        # valid_indices[f'P{patient_index}'] = data.valid_indices
        # train_indices[f'P{patient_index}'] = data.train_indices
        output_dir = f'{saved_model_dir}/{model_string}_{model_name}/{model_string}_{model_name}_p_{patient_index}'
        correlation_monitor = CorrelationMonitor1D(
            input_time_length=input_time_length, output_dir=output_dir)
        if cuda:
            device = 'cuda'
            model.model = changed_model.cuda()

        else:
            model.model = changed_model
            device = 'cpu'

        if data.num_of_folds == -1:
            # only one 80-20 train-valiation split
            best_corr = train(data,
                              dilation,
                              kernel_size,
                              lr,
                              patient_index,
                              model_string,
                              correlation_monitor,
                              max_train_epochs=max_train_epochs,
                              output_dir=output_dir,
                              split=None,
                              cropped=cropped,
                              padding=padding)
            print('shift by:', shift_by)
            best_valid_correlations.append(best_corr)
            if len(best_valid_correlations) == 12:
                Path(
                    f'{home}/outputs/{saved_model_dir}/{model_string}_{model_name}/{model_string}_{model_name}'
                ).mkdir(parents=True, exist_ok=True)
                result_df[
                    f'{model_string}_{model_name}'] = best_valid_correlations
                best_valid_correlations = []
                result_df.to_csv(
                    f'{home}/outputs/{saved_model_dir}/{model_string}_{model_name}/{model_string}_{model_name}/results.csv',
                    sep=';')

        else:
            fold_corrs = []
            for i in range(data.num_of_folds):
                # data.num_of_folds cross-validation
                best_corr = train(data,
                                  dilation,
                                  kernel_size,
                                  lr,
                                  patient_index,
                                  model_string,
                                  correlation_monitor,
                                  output_dir,
                                  split=i,
                                  max_train_epochs=max_train_epochs,
                                  cropped=cropped)
                fold_corrs.append(best_corr)
            best_valid_correlations.append(fold_corrs)
            print('whole_patient:', patient_index, fold_corrs)
            patient_df = pandas.DataFrame()
            patient_df[f'P_{patient_index}'] = fold_corrs
            result_df = pandas.concat([result_df, patient_df], axis=1)
            result_df.to_csv(
                f'{home}/outputs/performances_{data.num_of_folds}/{model_string}_{model_name}/performances.csv',
                sep=';')
예제 #7
0
    plt.figure(figsize=(32, 12))
    t = np.arange(preds_per_trial.shape[0]) / srate
    plt.plot(t, preds_per_trial)
    plt.plot(t, targets_per_trial)
    plt.legend(('Predicted', 'Actual'), fontsize=14)
    plt.title('Fold = {:d}, CC = {:f}'.format(0, cc_folds[0]))
    plt.xlabel('time [s]')


if __name__ == '__main__':
    model_file = '/models/saved_models/best_model_1'
    model = load_model(model_file)

    data_file = 'ALL_11_FR1_day1_absVel'
    data = Data(home + f'/previous_work/{data_file}.mat', -1)
    n_preds_per_input = get_output_shape(model, data.in_channels, 1200)[1]
    data.cut_input(input_time_length, n_preds_per_input, False)
    train_set, test_set = data.train_set, data.test_set

    select_modules = [
        'conv_spat', 'conv_2', 'conv_3', 'conv_4', 'conv_classifier'
    ]  # Specify intermediate outputs
    modules = list(model.named_children())  # Extract modules from model
    model_pert = SelectiveSequential(select_modules, modules)  # Wrap modules
    model_pert.eval()
    model_pert.double()
    model.eval()

    pred_fn = lambda x: [
        layer_out.data.numpy() for layer_out in model_pert.forward(
            torch.autograd.Variable(torch.from_numpy(x)).double())
예제 #8
0
def exp(subject_id):
    dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=subject_id)

    from braindecode.datautil.preprocess import exponential_moving_standardize
    from braindecode.datautil.preprocess import MNEPreproc, NumpyPreproc, preprocess

    low_cut_hz = 0.  # low cut frequency for filtering
    high_cut_hz = 49.  # high cut frequency for filtering
    # Parameters for exponential moving standardization
    factor_new = 1e-3
    init_block_size = 1000

    preprocessors = [
        # keep only EEG sensors
        MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False),
        # convert from volt to microvolt, directly modifying the numpy array
        NumpyPreproc(fn=lambda x: x * 1e6),
        # bandpass filter
        MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
        # exponential moving standardization
        # NumpyPreproc(fn=exponential_moving_standardize, factor_new=factor_new,
        #     init_block_size=init_block_size)
    ]

    # Transform the data
    preprocess(dataset, preprocessors)

    ######################################################################
    # Create model and compute windowing parameters
    # ---------------------------------------------
    #

    ######################################################################
    # In contrast to trialwise decoding, we first have to create the model
    # before we can cut the dataset into windows. This is because we need to
    # know the receptive field of the network to know how large the window
    # stride should be.
    #

    ######################################################################
    # We first choose the compute/input window size that will be fed to the
    # network during training This has to be larger than the networks
    # receptive field size and can otherwise be chosen for computational
    # efficiency (see explanations in the beginning of this tutorial). Here we
    # choose 1000 samples, which are 4 seconds for the 250 Hz sampling rate.
    #

    input_window_samples = 1000

    ######################################################################
    # Now we create the model. To enable it to be used in cropped decoding
    # efficiently, we manually set the length of the final convolution layer
    # to some length that makes the receptive field of the ConvNet smaller
    # than ``input_window_samples`` (see ``final_conv_length=30`` in the model
    # definition).
    #

    import torch
    from braindecode.util import set_random_seeds
    from braindecode.models import ShallowFBCSPNet, Deep4Net

    cuda = torch.cuda.is_available(
    )  # check if GPU is available, if True chooses to use it
    device = 'cuda:1' if cuda else 'cpu'
    if cuda:
        torch.backends.cudnn.benchmark = True
    seed = 20190706  # random seed to make results reproducible
    # Set random seed to be able to reproduce results
    set_random_seeds(seed=seed, cuda=cuda)

    n_classes = 4
    # Extract number of chans from dataset
    n_chans = dataset[0][0].shape[0]

    # model = Deep4Net(
    #     n_chans,
    #     n_classes,
    #     input_window_samples=input_window_samples,
    #     final_conv_length="auto",
    # )
    #
    #
    #
    # embedding_net = Deep4Net_origin(4, 22, input_window_samples)
    # model = FcClfNet(embedding_net)

    model = ShallowFBCSPNet(
        n_chans,
        n_classes,
        input_window_samples=input_window_samples,
        final_conv_length=30,
    )

    print(model)

    # Send model to GPU
    if cuda:
        model.cuda(device)

    ######################################################################
    # And now we transform model with strides to a model that outputs dense
    # prediction, so we can use it to obtain predictions for all
    # crops.
    #

    from braindecode.models.util import to_dense_prediction_model, get_output_shape
    to_dense_prediction_model(model)

    n_preds_per_input = get_output_shape(model, n_chans,
                                         input_window_samples)[2]
    print("n_preds_per_input : ", n_preds_per_input)
    print(model)

    ######################################################################
    # Cut the data into windows
    # -------------------------
    #

    ######################################################################
    # In contrast to trialwise decoding, we have to supply an explicit window size and window stride to the
    # ``create_windows_from_events`` function.
    #

    import numpy as np
    from braindecode.datautil.windowers import create_windows_from_events

    trial_start_offset_seconds = -0.5
    # Extract sampling frequency, check that they are same in all datasets
    sfreq = dataset.datasets[0].raw.info['sfreq']
    assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])

    # Calculate the trial start offset in samples.
    trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

    # Create windows using braindecode function for this. It needs parameters to define how
    # trials should be used.
    windows_dataset = create_windows_from_events(
        dataset,
        trial_start_offset_samples=trial_start_offset_samples,
        trial_stop_offset_samples=0,
        window_size_samples=input_window_samples,
        window_stride_samples=n_preds_per_input,
        drop_last_window=False,
        preload=True,
    )

    ######################################################################
    # Split the dataset
    # -----------------
    #
    # This code is the same as in trialwise decoding.
    #

    from braindecode.datasets.base import BaseConcatDataset
    splitted = windows_dataset.split('session')

    train_set = splitted['session_T']
    valid_set = splitted['session_E']

    lr = 0.0625 * 0.01
    weight_decay = 0
    batch_size = 8
    n_epochs = 100

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True)
    # valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(valid_set,
                                              batch_size=batch_size,
                                              shuffle=False)

    # Send model to GPU
    if cuda:
        model.cuda(device)

    from torch.optim import lr_scheduler
    import torch.optim as optim

    import argparse
    parser = argparse.ArgumentParser(
        description='cross subject domain adaptation')
    parser.add_argument('--batch-size',
                        type=int,
                        default=50,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=50,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()
    args.gpuidx = 0
    args.seed = 0
    args.use_tensorboard = False
    args.save_model = False

    optimizer = optim.AdamW(model.parameters(),
                            lr=lr,
                            weight_decay=weight_decay)
    # scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs - 1)

    import pandas as pd
    results_columns = ['test_loss', 'test_accuracy']
    df = pd.DataFrame(columns=results_columns)

    for epochidx in range(1, n_epochs):
        print(epochidx)
        train_crop(10, model, device, train_loader, optimizer, scheduler, cuda,
                   args.gpuidx)
        test_loss, test_score = eval_crop(model, device, test_loader)
        results = {'test_loss': test_loss, 'test_accuracy': test_score}
        df = df.append(results, ignore_index=True)
        print(results)

    return df
예제 #9
0
def exp(subject_id):
    import torch
    input_window_samples = 1000

    cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
    device = 'cuda:0' if cuda else 'cpu'
    if cuda:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    seed = 20190706  # random seed to make results reproducible
    # Set random seed to be able to reproduce results
    random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

    n_classes = 4

    PATH = '../datasets/'
    with open(PATH + 'bcic_datasets_[0,49].pkl', 'rb') as f:
        data = pickle.load(f)

    import torch

    print('subject:' + str(subject_id))


    #make train test
    tr = []
    val =[]
    test_train_split = 0.5

    dataset= data[subject_id]

    dataset_size = len(dataset)
    indices = list(range(dataset_size))

    test_split = int(np.floor(test_train_split * dataset_size))

    train_indices, test_indices = indices[:test_split], indices[test_split:]

    np.random.shuffle(train_indices)
    #분석
    sample_data = data[0].dataset
    sample_data.psd()
    from mne.viz import plot_epochs_image
    import mne
    plot_epochs_image(sample_data, picks=['C3','C4'])

    label = sample_data.read_label()

    sample_data.plot_projs_topomap()

    train_sampler = SubsetRandomSampler(train_indices)
    test_sampler = SubsetRandomSampler(test_indices)

    from braindecode.models import ShallowFBCSPNet
    model = ShallowFBCSPNet(
        22,
        n_classes,
        input_window_samples=input_window_samples,
        final_conv_length=30,
    )

    from braindecode.models.util import to_dense_prediction_model, get_output_shape
    to_dense_prediction_model(model)

    n_preds_per_input = get_output_shape(model, 22, input_window_samples)[2]
    print("n_preds_per_input : ", n_preds_per_input)
    print(model)


    # crop_size =1000
    #
    #
    #
    #
    # model = ShallowNet_dense(n_classes, 22, crop_size)
    #
    # print(model)

    epochs = 100

    # For deep4 they should be:
    lr = 1 * 0.01
    weight_decay = 0.5 * 0.001

    batch_size = 8

    train_set = torch.utils.data.Subset(dataset,indices= train_indices)
    test_set = torch.utils.data.Subset(dataset,indices= test_indices)

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    # Send model to GPU
    if cuda:
        model.cuda(device=device)

    from torch.optim import lr_scheduler
    import torch.optim as optim

    import argparse
    parser = argparse.ArgumentParser(description='cross subject domain adaptation')
    parser.add_argument('--batch-size', type=int, default=50, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=50, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()
    args.gpuidx = 0
    args.seed = 0
    args.use_tensorboard = False
    args.save_model = False

    lr = 0.0625 * 0.01
    weight_decay = 0
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs-1)
    #
    # #test lr
    # lr = []
    # for i in range(200):
    #     scheduler.step()
    #     lr.append(scheduler.get_lr())
    #
    # import matplotlib.pyplot as plt
    # plt.plot(lr)

    import pandas as pd
    results_columns = ['test_loss', 'test_accuracy']
    df = pd.DataFrame(columns=results_columns)

    for epochidx in range(1, epochs):
        print(epochidx)
        train_crop(10, model, device, train_loader,optimizer,scheduler,cuda, args.gpuidx)
        test_loss, test_score = eval_crop(model, device, test_loader)
        results = {'test_loss': test_loss, 'test_accuracy': test_score}
        df = df.append(results, ignore_index=True)
        print(results)

    return df
# Send model to GPU
if cuda:
	model.cuda()

# Transform the model with strides to a model that outputs dense prediction, so
# it can be used to obtain predictions for all crops
to_dense_prediction_model(model)


# =============================================================================
# Windowing and dividing the data into validation and training sets
# =============================================================================
# To know the models’ receptive field, we calculate the shape of model output
# for a dummy input. The model's receptive field size defines the crop size.
args.n_preds_per_input = get_output_shape(model, args.nchan,
	args.input_window_samples)[2]

valid_set, train_set = windowing_data(dataset, args)
del dataset


# =============================================================================
# Training the model
# =============================================================================
clf = EEGClassifier(
	model,
	cropped=True,
	criterion=CroppedLoss,
	criterion__loss_function=torch.nn.functional.nll_loss,
	optimizer=torch.optim.AdamW,
	train_split=predefined_split(valid_set),
예제 #11
0
def test_get_output_shape_2d_model():
    model = nn.Sequential(Expression(lambda x: x.unsqueeze(-1)),
                          nn.Conv2d(1, 1, (3, 1)))
    out_shape = get_output_shape(model, in_chans=1, input_window_samples=5)
    assert out_shape == (1, 1, 3, 1)
예제 #12
0
def exp(subject_id):
    import torch
    test_subj = np.r_[subject_id]

    print('test subj:' + str(test_subj))

    #20% validation
    train_size = int(0.9* len(splitted['session_T']))
    test_size = len(splitted['session_T']) - train_size



    # train_set, valid_set = torch.utils.data.random_split(splitted['session_T'], [train_size, test_size])
    train_set = splitted['session_T']
    test_set = splitted['session_E']



    # model = Deep4Net(
    #     n_chans,
    #     n_classes,
    #     input_window_samples=input_window_samples,
    #     final_conv_length="auto",
    # )

    from torch.utils.data import Dataset, ConcatDataset




    crop_size = 1000
    # embedding_net = Deep4Net_origin(n_classes, n_chans, crop_size)
    # model = FcClfNet(embedding_net)

    model = ShallowFBCSPNet(
        n_chans,
        n_classes,
        input_window_samples=input_window_samples,
        final_conv_length='auto',
    )

    from braindecode.models.util import to_dense_prediction_model, get_output_shape
    to_dense_prediction_model(model)

    n_preds_per_input = get_output_shape(model, 22, input_window_samples)[2]
    print("n_preds_per_input : ", n_preds_per_input)
    print(model)


    batch_size =8
    epochs = 200






    lr = 0.0625 * 0.01
    weight_decay = 0



    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    # valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)



    # Send model to GPU
    if cuda:
        model.cuda()

    from torch.optim import lr_scheduler
    import torch.optim as optim

    import argparse
    parser = argparse.ArgumentParser(description='cross subject domain adaptation')
    parser.add_argument('--batch-size', type=int, default=50, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=50, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()
    args.gpuidx = 0
    args.seed = 0
    args.use_tensorboard = False
    args.save_model = False

    optimizer = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.5 * 0.001)
    # scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs-1)



    import pandas as pd
    results_columns = ['test_loss',  'test_accuracy']
    df = pd.DataFrame(columns=results_columns)

    for epochidx in range(1, epochs):
        print(epochidx)
        train_crop(10, model, device, train_loader,optimizer,scheduler,cuda, args.gpuidx)
        test_loss, test_score = eval_crop(model, device, test_loader)
        results = { 'test_loss': test_loss, 'test_accuracy': test_score}
        df = df.append(results, ignore_index=True)
        print(results)

    return df
예제 #13
0
    all_ys = []
    all_zs = []
    corr_coefs_full = []
    corr_coefs_hp = []
    for patient_index in range(1, 13):
        input_channels = get_num_of_channels(
            home + f'/previous_work/P{patient_index}_data.mat')
        changed_model_full = load_model(
            f'/models/saved_models/{saved_model_dir}/sbp0_shuffled_sm_{variable}_k3_d3/sbp0_shuffled_sm_{variable}_k3_d3_p_{patient_index}//best_model_split_0'
        )
        changed_model_hp = load_model(
            f'/models/saved_models/{saved_model_dir}/sbp0_shuffled_hp_sm2_{variable}_k3_d3/sbp0_shuffled_hp_sm2_{variable}_k3_d3_p_{patient_index}//best_model_split_0'
        )
        model_name = 'k3_d3'
        n_preds_per_input = get_output_shape(changed_model_full,
                                             input_channels,
                                             input_time_length)[1]
        small_window = input_time_length - n_preds_per_input + 1
        if shift_by is None:
            shift_index = int(small_window / 2)
        else:
            shift_index = int((small_window / 2) - shift_by)
        train_file = open(f'{home}/models/indices/train.dict', 'rb')
        valid_file = open(f'{home}/models/indices/valid.dict', 'rb')

        train_indices = pickle.load(train_file)
        valid_indices = pickle.load(valid_file)

        data_full = Data(home + f'/previous_work/P{patient_index}_data.mat',
                         num_of_folds=-1,
                         low_pass=low_pass,
elif model_name == "deep":
    model = Deep4Net(
        n_chans,
        n_classes,
        input_time_length=input_time_length,
        final_conv_length=2,
    )
    lr = 1 * 0.01
    weight_decay = 0.5 * 0.001

if cuda:
    model.cuda()

to_dense_prediction_model(model)
n_preds_per_input = get_output_shape(model, n_chans, input_time_length)[2]

dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])

standardize_func = partial(exponential_running_standardize,
                           factor_new=factor_new,
                           init_block_size=init_block_size)
raw_transform_dict = [
    ("pick_types", dict(eeg=True, meg=False, stim=False)),
    ('apply_function', dict(fun=lambda x: x * 1e6, channel_wise=False)),
    ('filter', dict(l_freq=low_cut_hz, h_freq=high_cut_hz)),
    ('apply_function', dict(fun=standardize_func, channel_wise=False))
]
transform_concat_ds(dataset, raw_transform_dict)

sfreqs = [ds.raw.info['sfreq'] for ds in dataset.datasets]
예제 #15
0
def prepare_for_gradients(patient_index,
                          model_name,
                          trained_mode,
                          eval_mode,
                          saved_model_dir,
                          model_file=None,
                          shift=False,
                          high_pass=False,
                          trajectory_index=0,
                          multi_layer=False,
                          motor_channels=None,
                          low_pass=False,
                          shift_by=None,
                          whiten=False):
    """
    Puts together the variables necessary for gradient visualization.
    First the model is loaded based on its name. Then the data is prepared to be given to the loaded model on input.
    The model is set into eval mode (its weights are frozen) and it is ready for gradient calculation

    :param patient_index: specifies on which patient data the model was build
    :param model_name: name of the model to be loaded
    :param trained_mode: specifies if the model should be trained or untrained
    :param eval_mode: specifies if the set should be the train set or the validation set
    :param saved_model_dir: the directory where the model is saved
    :param model_file:
    :param shift: specifies if the data in the dataset should be shifted
    :param high_pass: specifies if the data in the datasets should be high-pass filtered
    :param trajectory_index: 0 for velocity, 1 for absolute velocity
    :param multi_layer: specifies if only the last layer should be returned or all layers will be inspected
    :param motor_channels: specifies if only gradients for motor channels should be returned
    :param low_pass: specifies if the data in the datasets should be low-pass filtered
    :param shift_by: specifies by how much the predicted time-point should be shifted with respect to he receptive field.
    :param whiten: specifies if data should be whitened
    """
    if shift_by is not None:
        shift_str = f'shift_{shift_by}'
        model_name_list = model_name.split('/')
        model_name_list = [model_name_list[0], shift_str, model_name_list[1]]
        model_name = '/'.join(model_name_list)
        index = 2
        random_valid = False
    else:
        index = 1
        shift_str = f''
        random_valid = True
    if model_file is None:
        if '/' in model_name:
            other_model_name = model_name.split(
                '/')[index] + f'_p_{patient_index}'
        else:
            other_model_name = f'{model_name}_p_{patient_index}'
        if trained_mode == 'untrained':

            model_file = f'/models/saved_models/{model_name}/{other_model_name}/initial_{other_model_name}'
        else:
            # model_file = f'/models/saved_models/{model_name}/{other_model_name}/last_model'
            if shift_str != '':
                model_file = f'/models/saved_models/{model_name}/{other_model_name}/best_model_split_0'
            else:
                model_file = f'/models/saved_models/{model_name}/{other_model_name}/last_model'
    output = f'{output_dir}/hp_graphs/{model_name}/{eval_mode}/{trained_mode}/'
    # Path(output).mkdir(parents=True, exist_ok=True)
    model = load_model(model_file)
    print(model_file)
    print('motor channels:', motor_channels)

    in_channels = get_num_of_channels(
        home + f'/previous_work/P{patient_index}_data.mat')
    n_preds_per_input = get_output_shape(model, in_channels, 1200)[1]
    shift_window = input_time_length - n_preds_per_input + 1
    small_window = min((input_time_length - n_preds_per_input) * 2, 1200)
    print('small window:', small_window, model_name)
    print('shift window:', shift_window, shift)
    if shift_by is None:
        shift_index = int(shift_window / 2)
    else:
        shift_index = int((shift_window / 2) - shift_by)

    data = Data(home + f'/previous_work/P{patient_index}_data.mat',
                -1,
                low_pass=low_pass,
                trajectory_index=trajectory_index,
                shift_data=shift,
                high_pass=high_pass,
                shift_by=shift_index,
                pre_whiten=whiten,
                random_valid=random_valid)

    data.cut_input(input_time_length, n_preds_per_input, False)
    train_set, test_set = data.train_set, data.test_set
    corrcoef = get_corr_coef(train_set, model)
    num_channels = None

    if eval_mode == 'validation':
        train_set = test_set

    X_reshaped = np.asarray(train_set.X)
    print(X_reshaped.shape)
    X_reshaped = reshape_Xs(input_time_length, X_reshaped)
    # summary(model.float(), input_size=(data.in_channels, 683, 1))
    if not multi_layer:
        new_model = create_new_model(model,
                                     'conv_classifier',
                                     input_channels=num_channels)
    else:
        new_model = model
    # with torch.no_grad():
    #     test_out = new_model(np_to_var(X_reshaped[:2]).double())
    new_model.eval()
    # n_filters = test_out.shape[1]

    return corrcoef, new_model, X_reshaped, small_window, output, data.motor_channels, data.non_motor_channels
예제 #16
0
def exp(subject_id):

    dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=subject_id)

    from braindecode.datautil.preprocess import exponential_moving_standardize
    from braindecode.datautil.preprocess import MNEPreproc, NumpyPreproc, preprocess

    low_cut_hz = 0.  # low cut frequency for filtering
    high_cut_hz = 38.  # high cut frequency for filtering
    # Parameters for exponential moving standardization
    factor_new = 1e-3
    init_block_size = 1000

    preprocessors = [
        # keep only EEG sensors
        MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False),
        # convert from volt to microvolt, directly modifying the numpy array
        NumpyPreproc(fn=lambda x: x * 1e6),
        # bandpass filter
        MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
        # exponential moving standardization
        # NumpyPreproc(fn=exponential_moving_standardize, factor_new=factor_new,
        #     init_block_size=init_block_size)
    ]

    # Transform the data
    preprocess(dataset, preprocessors)

    ######################################################################
    # Create model and compute windowing parameters
    # ---------------------------------------------
    #

    ######################################################################
    # In contrast to trialwise decoding, we first have to create the model
    # before we can cut the dataset into windows. This is because we need to
    # know the receptive field of the network to know how large the window
    # stride should be.
    #

    ######################################################################
    # We first choose the compute/input window size that will be fed to the
    # network during training This has to be larger than the networks
    # receptive field size and can otherwise be chosen for computational
    # efficiency (see explanations in the beginning of this tutorial). Here we
    # choose 1000 samples, which are 4 seconds for the 250 Hz sampling rate.
    #

    input_window_samples = 1000

    ######################################################################
    # Now we create the model. To enable it to be used in cropped decoding
    # efficiently, we manually set the length of the final convolution layer
    # to some length that makes the receptive field of the ConvNet smaller
    # than ``input_window_samples`` (see ``final_conv_length=30`` in the model
    # definition).
    #

    import torch
    from braindecode.util import set_random_seeds
    from braindecode.models import ShallowFBCSPNet, Deep4Net

    cuda = torch.cuda.is_available(
    )  # check if GPU is available, if True chooses to use it
    device = 'cuda:1' if cuda else 'cpu'
    if cuda:
        torch.backends.cudnn.benchmark = True
    seed = 20190706  # random seed to make results reproducible
    # Set random seed to be able to reproduce results
    set_random_seeds(seed=seed, cuda=cuda)

    n_classes = 4
    # Extract number of chans from dataset
    n_chans = dataset[0][0].shape[0]

    # model = Deep4Net(
    #     n_chans,
    #     n_classes,
    #     input_window_samples=input_window_samples,
    #     final_conv_length="auto",
    # )
    #
    #
    #
    # embedding_net = Deep4Net_origin(4, 22, input_window_samples)
    # model = FcClfNet(embedding_net)

    model = ShallowFBCSPNet(
        n_chans,
        n_classes,
        input_window_samples=input_window_samples,
        final_conv_length=30,
    )

    print(model)

    # Send model to GPU
    if cuda:
        model.cuda(device)

    ######################################################################
    # And now we transform model with strides to a model that outputs dense
    # prediction, so we can use it to obtain predictions for all
    # crops.
    #

    from braindecode.models.util import to_dense_prediction_model, get_output_shape
    to_dense_prediction_model(model)

    n_preds_per_input = get_output_shape(model, n_chans,
                                         input_window_samples)[2]
    print("n_preds_per_input : ", n_preds_per_input)
    print(model)

    ######################################################################
    # Cut the data into windows
    # -------------------------
    #

    ######################################################################
    # In contrast to trialwise decoding, we have to supply an explicit window size and window stride to the
    # ``create_windows_from_events`` function.
    #

    import numpy as np
    from braindecode.datautil.windowers import create_windows_from_events

    trial_start_offset_seconds = -0.5
    # Extract sampling frequency, check that they are same in all datasets
    sfreq = dataset.datasets[0].raw.info['sfreq']
    assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])

    # Calculate the trial start offset in samples.
    trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

    # Create windows using braindecode function for this. It needs parameters to define how
    # trials should be used.
    windows_dataset = create_windows_from_events(
        dataset,
        trial_start_offset_samples=trial_start_offset_samples,
        trial_stop_offset_samples=0,
        window_size_samples=input_window_samples,
        window_stride_samples=n_preds_per_input,
        drop_last_window=False,
        preload=True,
    )

    ######################################################################
    # Split the dataset
    # -----------------
    #
    # This code is the same as in trialwise decoding.
    #

    from braindecode.datasets.base import BaseConcatDataset
    splitted = windows_dataset.split('session')

    train_set = splitted['session_T']
    valid_set = splitted['session_E']

    ######################################################################
    # In difference to trialwise decoding, we now should supply
    # ``cropped=True`` to the EEGClassifier, and ``CroppedLoss`` as the
    # criterion, as well as ``criterion__loss_function`` as the loss function
    # applied to the meaned predictions.
    #

    ######################################################################
    # .. note::
    #    In this tutorial, we use some default parameters that we
    #    have found to work well for motor decoding, however we strongly
    #    encourage you to perform your own hyperparameter optimization using
    #    cross validation on your training data.
    #

    from skorch.callbacks import LRScheduler
    from skorch.helper import predefined_split

    from braindecode import EEGClassifier
    from braindecode.training.losses import CroppedLoss
    from braindecode.training.scoring import trial_preds_from_window_preds

    # # These values we found good for shallow network:
    lr = 0.0625 * 0.01
    weight_decay = 0

    # # For deep4 they should be:
    # lr = 1 * 0.01
    # weight_decay = 0.5 * 0.001
    #
    batch_size = 8
    n_epochs = 100

    clf = EEGClassifier(
        model,
        cropped=True,
        criterion=CroppedLoss,
        criterion__loss_function=torch.nn.functional.nll_loss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(valid_set),
        optimizer__lr=lr,
        optimizer__weight_decay=weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        callbacks=[
            "accuracy",
            ("lr_scheduler",
             LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
        ],
        device=device,
    )
    # Model training for a specified number of epochs. `y` is None as it is already supplied
    # in the dataset.
    clf.fit(train_set, y=None, epochs=n_epochs)

    ######################################################################
    # Plot Results
    # ------------
    #

    ######################################################################
    # This is again the same code as in trialwise decoding.
    #
    # .. note::
    #     Note that we drop further in the classification error and
    #     loss as in the trialwise decoding tutorial.
    #

    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D
    import pandas as pd
    # Extract loss and accuracy values for plotting from history object
    results_columns = [
        'train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy'
    ]
    df = pd.DataFrame(clf.history[:, results_columns],
                      columns=results_columns,
                      index=clf.history[:, 'epoch'])

    # get percent of misclass for better visual comparison to loss
    df = df.assign(train_misclass=100 - 100 * df.train_accuracy,
                   valid_misclass=100 - 100 * df.valid_accuracy)

    plt.style.use('seaborn')
    fig, ax1 = plt.subplots(figsize=(8, 3))
    df.loc[:, ['train_loss', 'valid_loss']].plot(ax=ax1,
                                                 style=['-', ':'],
                                                 marker='o',
                                                 color='tab:blue',
                                                 legend=False,
                                                 fontsize=14)

    ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)
    ax1.set_ylabel("Loss", color='tab:blue', fontsize=14)

    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

    df.loc[:, ['train_misclass', 'valid_misclass']].plot(ax=ax2,
                                                         style=['-', ':'],
                                                         marker='o',
                                                         color='tab:red',
                                                         legend=False)
    ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)
    ax2.set_ylabel("Misclassification Rate [%]", color='tab:red', fontsize=14)
    ax2.set_ylim(ax2.get_ylim()[0], 85)  # make some room for legend
    ax1.set_xlabel("Epoch", fontsize=14)

    # where some data has already been plotted to ax
    handles = []
    handles.append(
        Line2D([0], [0],
               color='black',
               linewidth=1,
               linestyle='-',
               label='Train'))
    handles.append(
        Line2D([0], [0],
               color='black',
               linewidth=1,
               linestyle=':',
               label='Valid'))
    plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
    plt.tight_layout()

    return df