Ejemplo n.º 1
0
def test_deep4net():
    rng = np.random.RandomState(42)
    n_channels = 18
    n_in_times = 600
    n_classes = 2
    n_samples = 7
    X = rng.randn(n_samples, n_channels, n_in_times, 1)
    X = torch.Tensor(X.astype(np.float32))
    model = Deep4Net(
        n_channels, n_classes, n_in_times, final_conv_length="auto"
    )
    y_pred = model(X)
    assert y_pred.shape == (n_samples, n_classes)
        in_chans=n_chans,
        n_classes=n_classes,
        input_window_samples=input_window_samples,
        n_filters_time=40,
        n_filters_spat=40,
        final_conv_length=35,
    )
    optimizer_lr = 0.000625
    optimizer_weight_decay = 0
elif model_name == "deep":
    model = Deep4Net(
        in_chans=n_chans,
        n_classes=n_classes,
        input_window_samples=input_window_samples,
        n_filters_time=25,
        n_filters_spat=25,
        stride_before_pool=True,
        n_filters_2=int(n_chans * 2),
        n_filters_3=int(n_chans * (2 ** 2.0)),
        n_filters_4=int(n_chans * (2 ** 3.0)),
        final_conv_length=1,
    )
    optimizer_lr = 0.01
    optimizer_weight_decay = 0.0005
else:
    raise ValueError(f'{model_name} unknown')

new_model = torch.nn.Sequential()
for name, module_ in model.named_children():
    if "softmax" in name:
        continue
    new_model.add_module(name, module_)
def exp(subject_id):
    import torch
    test_subj = np.r_[subject_id]
    print('test subj:' + str(test_subj))
    # train_subj = np.setdiff1d(np.r_[1:10], test_subj)
    train_subj = np.setdiff1d(np.r_[1, 3, 7, 8], test_subj)

    tr = []
    val = []
    for ids in train_subj:
        train_size = int(0.99 * len(splitted[ids]))
        test_size = len(splitted[ids]) - train_size
        tr_i, val_i = torch.utils.data.random_split(splitted[ids],
                                                    [train_size, test_size])
        tr.append(tr_i)
        val.append(val_i)

    train_set = torch.utils.data.ConcatDataset(tr)
    valid_set = torch.utils.data.ConcatDataset(val)
    valid_set = BaseConcatDataset([splitted[ids] for ids in test_subj])

    ######################################################################
    # Create model
    # ------------
    #

    ######################################################################
    # Now we create the deep learning model! Braindecode comes with some
    # predefined convolutional neural network architectures for raw
    # time-domain EEG. Here, we use the shallow ConvNet model from `Deep
    # learning with convolutional neural networks for EEG decoding and
    # visualization <https://arxiv.org/abs/1703.05051>`__. These models are
    # pure `PyTorch <https://pytorch.org>`__ deep learning models, therefore
    # to use your own model, it just has to be a normal PyTorch
    # `nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__.
    #

    import torch
    from braindecode.util import set_random_seeds
    from braindecode.models import ShallowFBCSPNet, Deep4Net

    cuda = torch.cuda.is_available(
    )  # check if GPU is available, if True chooses to use it
    device = 'cuda:0' if cuda else 'cpu'
    if cuda:
        torch.backends.cudnn.benchmark = True
    seed = 20200220  # random seed to make results reproducible
    # Set random seed to be able to reproduce results
    set_random_seeds(seed=seed, cuda=cuda)

    n_classes = 3
    # Extract number of chans and time steps from dataset
    n_chans = train_set[0][0].shape[0]
    input_window_samples = train_set[0][0].shape[1]
    #
    # model = ShallowFBCSPNet(
    #     n_chans,
    #     n_classes,
    #     input_window_samples=input_window_samples,
    #     final_conv_length='auto',
    # )

    from mynetworks import Deep4Net_origin, ConvClfNet, FcClfNet

    model = Deep4Net(
        n_chans,
        n_classes,
        input_window_samples=input_window_samples,
        final_conv_length="auto",
    )

    #
    # embedding_net = Deep4Net_origin(4, 22, input_window_samples)
    # model = FcClfNet(embedding_net)
    # #

    print(model)

    # Send model to GPU
    if cuda:
        model.cuda()

    ######################################################################
    # Training
    # --------
    #

    ######################################################################
    # Now we train the network! EEGClassifier is a Braindecode object
    # responsible for managing the training of neural networks. It inherits
    # from skorch.NeuralNetClassifier, so the training logic is the same as in
    # `Skorch <https://skorch.readthedocs.io/en/stable/>`__.
    #

    ######################################################################
    #    **Note**: In this tutorial, we use some default parameters that we
    #    have found to work well for motor decoding, however we strongly
    #    encourage you to perform your own hyperparameter optimization using
    #    cross validation on your training data.
    #

    from skorch.callbacks import LRScheduler
    from skorch.helper import predefined_split

    from braindecode import EEGClassifier
    # # These values we found good for shallow network:
    lr = 0.0625 * 0.01
    weight_decay = 0

    # For deep4 they should be:
    # lr = 1 * 0.01
    # weight_decay = 0.5 * 0.001

    batch_size = 8
    n_epochs = 100

    clf = EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(
            valid_set),  # using valid_set for validation
        optimizer__lr=lr,
        optimizer__weight_decay=weight_decay,
        batch_size=batch_size,
        callbacks=[
            "accuracy",
            ("lr_scheduler",
             LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
        ],
        device=device,
    )
    # Model training for a specified number of epochs. `y` is None as it is already supplied
    # in the dataset.
    clf.fit(train_set, y=None, epochs=n_epochs)

    ######################################################################
    # Plot Results
    # ------------
    #

    ######################################################################
    # Now we use the history stored by Skorch throughout training to plot
    # accuracy and loss curves.
    #

    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D
    import pandas as pd
    # Extract loss and accuracy values for plotting from history object
    results_columns = [
        'train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy'
    ]
    df = pd.DataFrame(clf.history[:, results_columns],
                      columns=results_columns,
                      index=clf.history[:, 'epoch'])

    # get percent of misclass for better visual comparison to loss
    df = df.assign(train_misclass=100 - 100 * df.train_accuracy,
                   valid_misclass=100 - 100 * df.valid_accuracy)

    plt.style.use('seaborn')
    fig, ax1 = plt.subplots(figsize=(8, 3))
    df.loc[:, ['train_loss', 'valid_loss']].plot(ax=ax1,
                                                 style=['-', ':'],
                                                 marker='o',
                                                 color='tab:blue',
                                                 legend=False,
                                                 fontsize=14)

    ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)
    ax1.set_ylabel("Loss", color='tab:blue', fontsize=14)

    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

    df.loc[:, ['train_misclass', 'valid_misclass']].plot(ax=ax2,
                                                         style=['-', ':'],
                                                         marker='o',
                                                         color='tab:red',
                                                         legend=False)
    ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)
    ax2.set_ylabel("Misclassification Rate [%]", color='tab:red', fontsize=14)
    ax2.set_ylim(ax2.get_ylim()[0], 85)  # make some room for legend
    ax1.set_xlabel("Epoch", fontsize=14)

    # where some data has already been plotted to ax
    handles = []
    handles.append(
        Line2D([0], [0],
               color='black',
               linewidth=1,
               linestyle='-',
               label='Train'))
    handles.append(
        Line2D([0], [0],
               color='black',
               linewidth=1,
               linestyle=':',
               label='Valid'))
    plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
    plt.tight_layout()

    # plt.show()

    return df
Ejemplo n.º 4
0
def test_deep4net(input_sizes):
    model = Deep4Net(input_sizes['n_channels'],
                     input_sizes['n_classes'],
                     input_sizes['n_in_times'],
                     final_conv_length="auto")
    check_forward_pass(model, input_sizes)
def create_example_model(n_channels,
                         n_classes,
                         window_len_samples,
                         kind='shallow',
                         cuda=False):
    """Create model, loss and optimizer.

    Parameters
    ----------
    n_channels : int
        Number of channels in the input
    n_times : int
        Window length in the input
    n_classes : int
        Number of classes in the output
    kind : str
        'shallow' or 'deep'
    cuda : bool
        If True, move the model to a CUDA device.

    Returns
    -------
    model : torch.nn.Module
        Model to train.
    loss :
        Loss function
    optimizer :
        Optimizer
    """
    if kind == 'shallow':
        model = ShallowFBCSPNet(n_channels,
                                n_classes,
                                input_window_samples=window_len_samples,
                                n_filters_time=40,
                                filter_time_length=25,
                                n_filters_spat=40,
                                pool_time_length=75,
                                pool_time_stride=15,
                                final_conv_length='auto',
                                split_first_layer=True,
                                batch_norm=True,
                                batch_norm_alpha=0.1,
                                drop_prob=0.5)
    elif kind == 'deep':
        model = Deep4Net(n_channels,
                         n_classes,
                         input_window_samples=window_len_samples,
                         final_conv_length='auto',
                         n_filters_time=25,
                         n_filters_spat=25,
                         filter_time_length=10,
                         pool_time_length=3,
                         pool_time_stride=3,
                         n_filters_2=50,
                         filter_length_2=10,
                         n_filters_3=100,
                         filter_length_3=10,
                         n_filters_4=200,
                         filter_length_4=10,
                         first_pool_mode="max",
                         later_pool_mode="max",
                         drop_prob=0.5,
                         double_time_convs=False,
                         split_first_layer=True,
                         batch_norm=True,
                         batch_norm_alpha=0.1,
                         stride_before_pool=False)
    else:
        raise ValueError

    if cuda:
        model.cuda()

    optimizer = optim.Adam(model.parameters())
    loss = nn.NLLLoss()

    return model, loss, optimizer
Ejemplo n.º 6
0
set_random_seeds(seed=seed, cuda=cuda)

if model_name == "shallow":
    model = ShallowFBCSPNet(
        n_chans,
        n_classes,
        input_time_length=input_time_length,
        final_conv_length='auto',
    )
    lr = 0.0625 * 0.01
    weight_decay = 0

elif model_name == "deep":
    model = Deep4Net(
        n_chans,
        n_classes,
        input_time_length=input_time_length,
        final_conv_length='auto',
    )
    lr = 1 * 0.01
    weight_decay = 0.5 * 0.001

if cuda:
    model.cuda()

dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])

standardize_func = partial(exponential_running_standardize,
                           factor_new=factor_new,
                           init_block_size=init_block_size)
raw_transform_dict = [
    ("pick_types", dict(eeg=True, meg=False, stim=False)),