def test_trialwise_decoding():
    # 5,6,7,10,13,14 are codes for executed and imagined hands/feet
    subject_id = 1
    event_codes = [5, 6, 9, 10, 13, 14]
    # event_codes = [6]

    # This will download the files if you don't have them yet,
    # and then return the paths to the files.
    physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

    # Load each of the files
    parts = [
        mne.io.read_raw_edf(path,
                            preload=True,
                            stim_channel='auto',
                            verbose='WARNING') for path in physionet_paths
    ]

    # Concatenate them
    raw = concatenate_raws(parts)

    # Find the events in this dataset
    # events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')
    events, _ = mne.events_from_annotations(raw)

    # Extract trials, only using EEG channels
    eeg_channel_inds = mne.pick_types(raw.info,
                                      meg=False,
                                      eeg=True,
                                      stim=False,
                                      eog=False,
                                      exclude='bads')

    # Extract trials, only using EEG channels
    epoched = mne.Epochs(raw,
                         events,
                         dict(hands=2, feet=3),
                         tmin=1,
                         tmax=4.1,
                         proj=False,
                         picks=eeg_channel_inds,
                         baseline=None,
                         preload=True)

    # Convert data from volt to millivolt
    # Pytorch expects float32 for input and int64 for labels.
    # X:[90,64,497]
    X = (epoched.get_data() * 1e6).astype(np.float32)
    # y:[90]
    y = (epoched.events[:, 2] - 2).astype(np.int64)  # 2,3 -> 0,1

    # X_train:[60,64,497], y_train:[60]
    train_set = SignalAndTarget(X[:60], y=y[:60])
    # X_test:[30,64,497], y_test:[30]
    test_set = SignalAndTarget(X[60:], y=y[60:])

    # Set if you want to use GPU
    # You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
    cuda = False
    set_random_seeds(seed=20170629, cuda=cuda)
    n_classes = 2
    in_chans = train_set.X.shape[1]
    # final_conv_length = auto ensures we only get a single output in the time dimension
    # def __init__(self, in_chans=64, n_classes=2, input_time_length=497, n_filters_time=40, filter_time_length=25, n_filters_spat=40, pool_time_length=75, pool_time_stride=15, final_conv_length='auto, conv_nonlin=square, pool_mode="mean", pool_nonlin=safe_log, split_first_layer=True, batch_norm=True, batch_norm_alpha=0.1, drop_prob=0.5, ):
    # 感觉create_network()就是__init__的一部分, 现在改成用self.model调用了, 还是感觉不优雅, 主要是forward集成在nn.Sequential里面了
    # 然后这个model的实际__init__不是ShallowFBCSPNet, 而是nn.Sequential, 感觉我更喜欢原来的定义方式, 这种方式看不到中间输出
    # model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=train_set.X.shape[2], final_conv_length='auto').create_network() #原来的
    model = ShallowFBCSPNet(in_chans=in_chans,
                            n_classes=n_classes,
                            input_time_length=train_set.X.shape[2],
                            final_conv_length='auto').model
    if cuda:
        model.cuda()

    optimizer = optim.Adam(model.parameters())

    rng = RandomState((2017, 6, 30))
    losses = []
    accuracies = []
    for i_epoch in range(6):
        i_trials_in_batch = get_balanced_batches(len(train_set.X),
                                                 rng,
                                                 shuffle=True,
                                                 batch_size=10)
        # Set model to training mode
        model.train()
        for i_trials in i_trials_in_batch:
            # Have to add empty fourth dimension to X
            batch_X = train_set.X[i_trials][:, :, :, None]
            batch_y = train_set.y[i_trials]
            net_in = np_to_var(batch_X)
            if cuda:
                net_in = net_in.cuda()
            net_target = np_to_var(batch_y)
            if cuda:
                net_target = net_target.cuda()
            # Remove gradients of last backward pass from all parameters
            optimizer.zero_grad()
            # Compute outputs of the network
            #net_in: [10, 64, 497, 1]=[bsz, H_im, W_im, C_im]
            #
            outputs = model.forward(net_in)
            # model=Sequential(
            #                   (dimshuffle): Expression(expression=_transpose_time_to_spat)
            #                   (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
            #                   (conv_spat): Conv2d(40, 40, kernel_size=(1, 64), stride=(1, 1), bias=False)
            #                   (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            #                   (conv_nonlin): Expression(expression=square)
            #                   (pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0)
            #                   (pool_nonlin): Expression(expression=safe_log)
            #                   (drop): Dropout(p=0.5)
            #                   (conv_classifier): Conv2d(40, 2, kernel_size=(27, 1), stride=(1, 1))
            #                   (softmax): LogSoftmax()
            #                   (squeeze): Expression(expression=_squeeze_final_output)
            #                 )
            # Compute the loss
            loss = F.nll_loss(outputs, net_target)
            # Do the backpropagation
            loss.backward()
            # Update parameters with the optimizer
            optimizer.step()

        # Print some statistics each epoch
        model.eval()
        print("Epoch {:d}".format(i_epoch))
        for setname, dataset in (('Train', train_set), ('Test', test_set)):
            # Here, we will use the entire dataset at once, which is still possible
            # for such smaller datasets. Otherwise we would have to use batches.
            net_in = np_to_var(dataset.X[:, :, :, None])
            if cuda:
                net_in = net_in.cuda()
            net_target = np_to_var(dataset.y)
            if cuda:
                net_target = net_target.cuda()
            outputs = model(net_in)
            loss = F.nll_loss(outputs, net_target)
            losses.append(float(var_to_np(loss)))
            print("{:6s} Loss: {:.5f}".format(setname, float(var_to_np(loss))))
            predicted_labels = np.argmax(var_to_np(outputs), axis=1)
            accuracy = np.mean(dataset.y == predicted_labels)
            accuracies.append(accuracy * 100)
            print("{:6s} Accuracy: {:.1f}%".format(setname, accuracy * 100))

    np.testing.assert_allclose(np.array(losses),
                               np.array([
                                   1.1775966882705688, 1.2602351903915405,
                                   0.7068756818771362, 0.9367912411689758,
                                   0.394258975982666, 0.6598362326622009,
                                   0.3359280526638031, 0.656258761882782,
                                   0.2790488004684448, 0.6104397177696228,
                                   0.27319177985191345, 0.5949864983558655
                               ]),
                               rtol=1e-4,
                               atol=1e-5)

    np.testing.assert_allclose(np.array(accuracies),
                               np.array([
                                   51.666666666666671, 53.333333333333336,
                                   63.333333333333329, 56.666666666666664,
                                   86.666666666666671, 66.666666666666657,
                                   90.0, 63.333333333333329,
                                   96.666666666666671, 56.666666666666664,
                                   96.666666666666671, 66.666666666666657
                               ]),
                               rtol=1e-4,
                               atol=1e-5)
Пример #2
0
def runModel(mode):
    cudnn.benchmark = True

    start = time.time()

    #mode = str(sys.argv[1])
    #X,y,test_X,test_y = loadSubNormData(mode='all')
    #X,y,test_X,test_y = loadNEDCdata(mode=mode)

    #data = np.load('sessionsData/data%s-sessions.npy'%mode[:3])
    #labels = np.load('sessionsData/labels%s-sessions.npy'%mode[:3])

    data = np.load('data%s.npy' % mode[:3])
    labels = np.load('labels%s.npy' % mode[:3])

    X, y, test_X, test_y = splitDataRandom_Loaded(data, labels, mode)

    print('Mode - %s Total n: %d, Test n: %d' %
          (mode, len(y) + len(test_y), len(test_y)))
    #return 0

    #X = addDataNoise(X,band=[1,4])
    #test_X = addDataNoise(test_X,band=[1,4])

    max_shape = np.max([list(x.shape) for x in X], axis=0)

    assert max_shape[1] == int(config.duration_recording_mins *
                               config.sampling_freq * 60)

    n_classes = 2
    n_recordings = None  # set to an integer, if you want to restrict the set size
    sensor_types = ["EEG"]
    n_chans = 19  #21
    max_recording_mins = 35  # exclude larger recordings from training set
    sec_to_cut = 60  # cut away at start of each recording
    duration_recording_mins = 5  #20  # how many minutes to use per recording
    test_recording_mins = 5  #20
    max_abs_val = 800  # for clipping
    sampling_freq = 100
    divisor = 10  # divide signal by this
    test_on_eval = True  # teston evaluation set or on training set
    # in case of test on eval, n_folds and i_testfold determine
    # validation fold in training set for training until first stop
    n_folds = 10
    i_test_fold = 9
    shuffle = True
    model_name = 'linear'  #'deep'#'shallow' 'linear'
    n_start_chans = 25
    n_chan_factor = 2  # relevant for deep model only
    input_time_length = 6000
    final_conv_length = 1
    model_constraint = 'defaultnorm'
    init_lr = 1e-3
    batch_size = 64
    max_epochs = 35  # until first stop, the continue train on train+valid
    cuda = True  # False

    if model_name == 'shallow':
        model = ShallowFBCSPNet(
            in_chans=n_chans,
            n_classes=n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            input_time_length=input_time_length,
            final_conv_length=final_conv_length).create_network()
    elif model_name == 'deep':
        model = Deep4Net(n_chans,
                         n_classes,
                         n_filters_time=n_start_chans,
                         n_filters_spat=n_start_chans,
                         input_time_length=input_time_length,
                         n_filters_2=int(n_start_chans * n_chan_factor),
                         n_filters_3=int(n_start_chans * (n_chan_factor**2.0)),
                         n_filters_4=int(n_start_chans * (n_chan_factor**3.0)),
                         final_conv_length=final_conv_length,
                         stride_before_pool=True).create_network()
    elif (model_name == 'deep_smac'):
        if model_name == 'deep_smac':
            do_batch_norm = False
        else:
            assert model_name == 'deep_smac_bnorm'
            do_batch_norm = True
        double_time_convs = False
        drop_prob = 0.244445
        filter_length_2 = 12
        filter_length_3 = 14
        filter_length_4 = 12
        filter_time_length = 21
        final_conv_length = 1
        first_nonlin = elu
        first_pool_mode = 'mean'
        first_pool_nonlin = identity
        later_nonlin = elu
        later_pool_mode = 'mean'
        later_pool_nonlin = identity
        n_filters_factor = 1.679066
        n_filters_start = 32
        pool_time_length = 1
        pool_time_stride = 2
        split_first_layer = True
        n_chan_factor = n_filters_factor
        n_start_chans = n_filters_start
        model = Deep4Net(n_chans,
                         n_classes,
                         n_filters_time=n_start_chans,
                         n_filters_spat=n_start_chans,
                         input_time_length=input_time_length,
                         n_filters_2=int(n_start_chans * n_chan_factor),
                         n_filters_3=int(n_start_chans * (n_chan_factor**2.0)),
                         n_filters_4=int(n_start_chans * (n_chan_factor**3.0)),
                         final_conv_length=final_conv_length,
                         batch_norm=do_batch_norm,
                         double_time_convs=double_time_convs,
                         drop_prob=drop_prob,
                         filter_length_2=filter_length_2,
                         filter_length_3=filter_length_3,
                         filter_length_4=filter_length_4,
                         filter_time_length=filter_time_length,
                         first_nonlin=first_nonlin,
                         first_pool_mode=first_pool_mode,
                         first_pool_nonlin=first_pool_nonlin,
                         later_nonlin=later_nonlin,
                         later_pool_mode=later_pool_mode,
                         later_pool_nonlin=later_pool_nonlin,
                         pool_time_length=pool_time_length,
                         pool_time_stride=pool_time_stride,
                         split_first_layer=split_first_layer,
                         stride_before_pool=True).create_network()
    elif model_name == 'shallow_smac':
        conv_nonlin = identity
        do_batch_norm = True
        drop_prob = 0.328794
        filter_time_length = 56
        final_conv_length = 22
        n_filters_spat = 73
        n_filters_time = 24
        pool_mode = 'max'
        pool_nonlin = identity
        pool_time_length = 84
        pool_time_stride = 3
        split_first_layer = True
        model = ShallowFBCSPNet(
            in_chans=n_chans,
            n_classes=n_classes,
            n_filters_time=n_filters_time,
            n_filters_spat=n_filters_spat,
            input_time_length=input_time_length,
            final_conv_length=final_conv_length,
            conv_nonlin=conv_nonlin,
            batch_norm=do_batch_norm,
            drop_prob=drop_prob,
            filter_time_length=filter_time_length,
            pool_mode=pool_mode,
            pool_nonlin=pool_nonlin,
            pool_time_length=pool_time_length,
            pool_time_stride=pool_time_stride,
            split_first_layer=split_first_layer,
        ).create_network()
    elif model_name == 'linear':
        model = nn.Sequential()
        model.add_module("conv_classifier",
                         nn.Conv2d(n_chans, n_classes, (600, 1)))
        model.add_module('softmax', nn.LogSoftmax(dim=1))
        model.add_module('squeeze', Expression(lambda x: x.squeeze(3)))
    else:
        assert False, "unknown model name {:s}".format(model_name)

    to_dense_prediction_model(model)

    if config.cuda:
        model.cuda()
    test_input = np_to_var(
        np.ones((2, config.n_chans, config.input_time_length, 1),
                dtype=np.float32))
    if config.cuda:
        test_input = test_input.cuda()

    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    iterator = CropsFromTrialsIterator(
        batch_size=config.batch_size,
        input_time_length=config.input_time_length,
        n_preds_per_input=n_preds_per_input)

    #model.add_module('softmax', nn.LogSoftmax(dim=1))

    model.eval()

    mode[2] = str(mode[2])
    mode[3] = str(mode[3])
    modelName = '-'.join(mode[:4])

    #params = th.load('sessionsData/%sModel%s-sessions.pt'%(modelName,mode[4]))
    #params = th.load('%sModel%s.pt'%(modelName,mode[4]))
    params = th.load('linear/%sModel%s.pt' % (modelName, mode[4]))

    model.load_state_dict(params)

    if config.test_on_eval:
        #test_X, test_y = test_dataset.load()
        #test_X, test_y = loadNEDCdata(mode='eval')
        max_shape = np.max([list(x.shape) for x in test_X], axis=0)
        assert max_shape[1] == int(config.test_recording_mins *
                                   config.sampling_freq * 60)
    if not config.test_on_eval:
        splitter = TrainValidTestSplitter(config.n_folds,
                                          config.i_test_fold,
                                          shuffle=config.shuffle)
        train_set, valid_set, test_set = splitter.split(X, y)
    else:
        splitter = TrainValidSplitter(config.n_folds,
                                      i_valid_fold=config.i_test_fold,
                                      shuffle=config.shuffle)
        train_set, valid_set = splitter.split(X, y)
        test_set = SignalAndTarget(test_X, test_y)
        del test_X, test_y
    del X, y  # shouldn't be necessary, but just to make sure

    datasets = OrderedDict(
        (('train', train_set), ('valid', valid_set), ('test', test_set)))

    for setname in ('train', 'valid', 'test'):
        #setname = 'test'
        #print("Compute predictions for {:s}...".format(setname))
        dataset = datasets[setname]
        if config.cuda:
            preds_per_batch = [
                var_to_np(model(np_to_var(b[0]).cuda()))
                for b in iterator.get_batches(dataset, shuffle=False)
            ]
        else:
            preds_per_batch = [
                var_to_np(model(np_to_var(b[0])))
                for b in iterator.get_batches(dataset, shuffle=False)
            ]
        preds_per_trial = compute_preds_per_trial(
            preds_per_batch,
            dataset,
            input_time_length=iterator.input_time_length,
            n_stride=iterator.n_preds_per_input)
        mean_preds_per_trial = [
            np.mean(preds, axis=(0, 2)) for preds in preds_per_trial
        ]
        mean_preds_per_trial = np.array(mean_preds_per_trial)

        all_pred_labels = np.argmax(mean_preds_per_trial, axis=1).squeeze()
        all_target_labels = dataset.y
        acc_per_class = []
        for i_class in range(n_classes):
            mask = all_target_labels == i_class
            acc = np.mean(all_pred_labels[mask] == all_target_labels[mask])
            acc_per_class.append(acc)
        misclass = 1 - np.mean(acc_per_class)
        #print('Acc:{}, Class 0:{}, Class 1:{}'.format(np.mean(acc_per_class),acc_per_class[0],acc_per_class[1]))

        if setname == 'test':
            testResult = np.mean(acc_per_class)

    return testResult
def test_cropped_decoding():
    import mne
    from mne.io import concatenate_raws

    # 5,6,7,10,13,14 are codes for executed and imagined hands/feet
    subject_id = 1
    event_codes = [5, 6, 9, 10, 13, 14]

    # This will download the files if you don't have them yet,
    # and then return the paths to the files.
    physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

    # Load each of the files
    parts = [mne.io.read_raw_edf(path, preload=True, stim_channel='auto',
                                 verbose='WARNING')
             for path in physionet_paths]

    # Concatenate them
    raw = concatenate_raws(parts)

    # Find the events in this dataset
    events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')

    # Use only EEG channels
    eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False,
                                      eog=False,
                                      exclude='bads')

    # Extract trials, only using EEG channels
    epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1,
                         proj=False, picks=eeg_channel_inds,
                         baseline=None, preload=True)
    import numpy as np
    from braindecode.datautil.signal_target import SignalAndTarget
    # Convert data from volt to millivolt
    # Pytorch expects float32 for input and int64 for labels.
    X = (epoched.get_data() * 1e6).astype(np.float32)
    y = (epoched.events[:, 2] - 2).astype(np.int64)  # 2,3 -> 0,1

    train_set = SignalAndTarget(X[:60], y=y[:60])
    test_set = SignalAndTarget(X[60:], y=y[60:])
    from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
    from torch import nn
    from braindecode.torch_ext.util import set_random_seeds
    from braindecode.models.util import to_dense_prediction_model

    # Set if you want to use GPU
    # You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
    cuda = False
    set_random_seeds(seed=20170629, cuda=cuda)

    # This will determine how many crops are processed in parallel
    input_time_length = 450
    n_classes = 2
    in_chans = train_set.X.shape[1]
    # final_conv_length determines the size of the receptive field of the ConvNet
    model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,
                            input_time_length=input_time_length,
                            final_conv_length=12).create_network()
    to_dense_prediction_model(model)

    if cuda:
        model.cuda()

    from torch import optim

    optimizer = optim.Adam(model.parameters())
    from braindecode.torch_ext.util import np_to_var
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    print("{:d} predictions per input/trial".format(n_preds_per_input))
    from braindecode.datautil.iterators import CropsFromTrialsIterator
    iterator = CropsFromTrialsIterator(batch_size=32,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)
    from braindecode.torch_ext.util import np_to_var, var_to_np
    import torch.nn.functional as F
    from numpy.random import RandomState
    import torch as th
    from braindecode.experiments.monitors import compute_preds_per_trial_from_crops
    rng = RandomState((2017, 6, 30))
    losses = []
    accuracies = []
    for i_epoch in range(4):
        # Set model to training mode
        model.train()
        for batch_X, batch_y in iterator.get_batches(train_set, shuffle=False):
            net_in = np_to_var(batch_X)
            if cuda:
                net_in = net_in.cuda()
            net_target = np_to_var(batch_y)
            if cuda:
                net_target = net_target.cuda()
            # Remove gradients of last backward pass from all parameters
            optimizer.zero_grad()
            outputs = model(net_in)
            # Mean predictions across trial
            # Note that this will give identical gradients to computing
            # a per-prediction loss (at least for the combination of log softmax activation
            # and negative log likelihood loss which we are using here)
            outputs = th.mean(outputs, dim=2, keepdim=False)
            loss = F.nll_loss(outputs, net_target)
            loss.backward()
            optimizer.step()

        # Print some statistics each epoch
        model.eval()
        print("Epoch {:d}".format(i_epoch))
        for setname, dataset in (('Train', train_set), ('Test', test_set)):
            # Collect all predictions and losses
            all_preds = []
            all_losses = []
            batch_sizes = []
            for batch_X, batch_y in iterator.get_batches(dataset,
                                                         shuffle=False):
                net_in = np_to_var(batch_X)
                if cuda:
                    net_in = net_in.cuda()
                net_target = np_to_var(batch_y)
                if cuda:
                    net_target = net_target.cuda()
                outputs = model(net_in)
                all_preds.append(var_to_np(outputs))
                outputs = th.mean(outputs, dim=2, keepdim=False)
                loss = F.nll_loss(outputs, net_target)
                loss = float(var_to_np(loss))
                all_losses.append(loss)
                batch_sizes.append(len(batch_X))
            # Compute mean per-input loss
            loss = np.mean(np.array(all_losses) * np.array(batch_sizes) /
                           np.mean(batch_sizes))
            print("{:6s} Loss: {:.5f}".format(setname, loss))
            losses.append(loss)
            # Assign the predictions to the trials
            preds_per_trial = compute_preds_per_trial_from_crops(all_preds,
                                                              input_time_length,
                                                              dataset.X)
            # preds per trial are now trials x classes x timesteps/predictions
            # Now mean across timesteps for each trial to get per-trial predictions
            meaned_preds_per_trial = np.array(
                [np.mean(p, axis=1) for p in preds_per_trial])
            predicted_labels = np.argmax(meaned_preds_per_trial, axis=1)
            accuracy = np.mean(predicted_labels == dataset.y)
            accuracies.append(accuracy * 100)
            print("{:6s} Accuracy: {:.1f}%".format(
                setname, accuracy * 100))
    np.testing.assert_allclose(
        np.array(losses),
        np.array([1.703004002571106,
                  1.6295261979103088,
                  0.71168938279151917,
                  0.70825588703155518,
                  0.58231228590011597,
                  0.60176041722297668,
                  0.46629951894283295,
                  0.51184913516044617]),
        rtol=1e-4, atol=1e-5)
    np.testing.assert_allclose(
        np.array(accuracies),
        np.array(
            [50.0,
             46.666666666666664,
             60.0,
             53.333333333333336,
             68.333333333333329,
             66.666666666666657,
             88.333333333333329,
             83.333333333333343]),
        rtol=1e-4, atol=1e-5)
Пример #4
0
            net_target = net_target.cuda()
        # Remove gradients of last backward pass from all parameters
        optimizer.zero_grad()
        outputs = model(net_in)

        # Mean predictions across trial
        # Note that this will give identical gradients to computing
        # a per-prediction loss (at least for the combination of log softmax activation
        # and negative log likelihood loss which we are using here)
        outputs = th.mean(outputs, dim=2, keepdim=False)
        loss = F.nll_loss(outputs, net_target)
        loss.backward()
        optimizer.step()

    # Print some statistics each epoch
    model.eval()
    print("Epoch {:d}".format(i_epoch))
    for setname, dataset in (('Train', train_set), ('Test', test_set)):
        # Collect all predictions and losses
        all_preds = []
        all_losses = []
        batch_sizes = []
        for batch_X, batch_y in iterator.get_batches(dataset, shuffle=False):
            net_in = np_to_var(batch_X)
            if cuda:
                net_in = net_in.cuda()
            net_target = np_to_var(batch_y)
            if cuda:
                net_target = net_target.cuda()
            outputs = model(net_in)
            all_preds.append(var_to_np(outputs))
class ShallowFBCSPNet_GeneralTrainer(BaseEstimator, ClassifierMixin):
    """
        Initialize the parameters of the network
        Full list of parameters described in 
        ref: https://robintibor.github.io/braindecode/source/braindecode.models.html
    """
    def __init__(self,
                 n_filters_time=10,
                 filter_time_length=75,
                 n_filters_spat=5,
                 pool_time_length=60,
                 pool_time_stride=30,
                 nb_epoch=160):

        # init meta info
        self.cuda = torch.cuda.is_available()
        #set_random_seeds(seed=20180505, cuda=self.cuda)  # TODO: Fix random seed
        set_random_seeds(seed=randint(1, 20180505),
                         cuda=self.cuda)  # TODO: Fix random seed

        # copy all network parameters
        self.n_filters_time = n_filters_time
        self.filter_time_length = filter_time_length
        self.n_filters_spat = n_filters_spat
        self.pool_time_length = pool_time_length
        self.pool_time_stride = pool_time_stride
        self.nb_epoch = nb_epoch

        return

    """
        Fit the network
        Params:
            X, data array in the format (...)
            y, labels
        ref: http://danielhnyk.cz/creating-your-own-estimator-scikit-learn/
    """

    def fit(self, X, y):

        # define a number of train/test trials
        nb_train_trials = int(np.floor(7 / 8 * X.shape[0]))

        # split the dataset
        train_set = SignalAndTarget(X[:nb_train_trials], y=y[:nb_train_trials])
        test_set = SignalAndTarget(X[nb_train_trials:], y=y[nb_train_trials:])

        # number of classes and input channels
        n_classes = np.unique(y).size
        in_chans = train_set.X.shape[1]

        # final_conv_length = auto ensures we only get a single output in the time dimension
        self.model = ShallowFBCSPNet(
            in_chans=in_chans,
            n_classes=n_classes,
            input_time_length=train_set.X.shape[2],
            n_filters_time=self.n_filters_time,
            filter_time_length=self.filter_time_length,
            n_filters_spat=self.n_filters_spat,
            pool_time_length=self.pool_time_length,
            pool_time_stride=self.pool_time_stride,
            final_conv_length='auto').create_network()

        # setup model for cuda
        if self.cuda:
            self.model.cuda()

        # setup optimizer
        self.optimizer = optim.Adam(self.model.parameters())

        # random generator
        self.rng = RandomState(None)

        # array that tracks results
        self.loss_rec = np.zeros((self.nb_epoch, 2))
        self.accuracy_rec = np.zeros((self.nb_epoch, 2))

        # run all epoch
        for i_epoch in range(self.nb_epoch):

            self._batchTrain(i_epoch, train_set)
            self._evalTraining(i_epoch, train_set, test_set)

        return self

    """
        Training iteration, train the network on the train_set
        Params:
            i_epoch, current epoch iteration
            train_set, training set
    """

    def _batchTrain(self, i_epoch, train_set):

        # get a set of balanced batches
        i_trials_in_batch = get_balanced_batches(len(train_set.X),
                                                 self.rng,
                                                 shuffle=True,
                                                 batch_size=32)

        # Set model to training mode
        self.model.train()

        # go through all batches
        for i_trials in i_trials_in_batch:

            # Have to add empty fourth dimension to X
            batch_X = train_set.X[i_trials][:, :, :, None]
            batch_y = train_set.y[i_trials]

            net_in = np_to_var(batch_X)
            net_target = np_to_var(batch_y)

            # if cuda, copy to cuda memory
            if self.cuda:
                net_in = net_in.cuda()
                net_target = net_target.cuda()

            # Remove gradients of last backward pass from all parameters
            self.optimizer.zero_grad()
            # Compute outputs of the network
            outputs = self.model(net_in)
            # Compute the loss
            loss = F.nll_loss(outputs, net_target)
            # Do the backpropagation
            loss.backward()
            # Update parameters with the optimizer
            self.optimizer.step()

        return

    """
        Evaluation iteration, computes the performance the network
        Params:
            i_epoch, current epoch iteration
            train_set, training set
    """

    def _evalTraining(self, i_epoch, train_set, test_set):

        # Print some statistics each epoch
        self.model.eval()
        print("Epoch {:d}".format(i_epoch))

        sets = {'Train': 0, 'Test': 1}

        # run evaluation on both train and test sets
        for setname, dataset in (('Train', train_set), ('Test', test_set)):

            # get balanced sets
            i_trials_in_batch = get_balanced_batches(len(dataset.X),
                                                     self.rng,
                                                     batch_size=32,
                                                     shuffle=False)

            outputs = []
            net_targets = []

            # for all trials in set
            for i_trials in i_trials_in_batch:

                # adapt datasets
                batch_X = dataset.X[i_trials][:, :, :, None]
                batch_y = dataset.y[i_trials]

                # apply some conversion
                net_in = np_to_var(batch_X)
                net_target = np_to_var(batch_y)

                # convert
                if self.cuda:
                    net_in = net_in.cuda()
                    net_target = net_target.cuda()

                net_target = var_to_np(net_target)
                output = var_to_np(self.model(net_in))
                outputs.append(output)
                net_targets.append(net_target)

            net_targets = np_to_var(np.concatenate(net_targets))
            outputs = np_to_var(np.concatenate(outputs))
            loss = F.nll_loss(outputs, net_targets)

            print("{:6s} Loss: {:.5f}".format(setname, float(var_to_np(loss))))

            self.loss_rec[i_epoch, sets[setname]] = var_to_np(loss)
            predicted_labels = np.argmax(var_to_np(outputs), axis=1)
            accuracy = np.mean(dataset.y == predicted_labels)

            print("{:6s} Accuracy: {:.1f}%".format(setname, accuracy * 100))
            self.accuracy_rec[i_epoch, sets[setname]] = accuracy

        return

    def predict(self, X):
        self.model.eval()

        #i_trials_in_batch = get_balanced_batches(len(X), self.rng, batch_size=32, shuffle=False)

        outputs = []

        for i_trials in i_trials_in_batch:
            batch_X = dataset.X[i_trials][:, :, :, None]

            net_in = np_to_var(batch_X)

            if self.cuda:
                net_in = net_in.cuda()

            output = var_to_np(self.model(net_in))
            outputs.append(output)

        return outputs
Пример #6
0
def run_exp_on_high_gamma_dataset(train_filename, test_filename, low_cut_hz,
                                  model_name, max_epochs, max_increase_epochs,
                                  np_th_seed, debug):
    train_set, valid_set, test_set = load_train_valid_test(
        train_filename=train_filename,
        test_filename=test_filename,
        low_cut_hz=low_cut_hz,
        debug=debug)
    if debug:
        max_epochs = 4

    set_random_seeds(np_th_seed, cuda=True)
    #torch.backends.cudnn.benchmark = True# sometimes crashes?
    n_classes = int(np.max(train_set.y) + 1)
    n_chans = int(train_set.X.shape[1])
    input_time_length = 1000
    if model_name == 'deep':
        model = Deep4Net(n_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length=2).create_network()
    elif model_name == 'shallow':
        model = ShallowFBCSPNet(n_chans,
                                n_classes,
                                input_time_length=input_time_length,
                                final_conv_length=30).create_network()

    to_dense_prediction_model(model)
    model.cuda()
    model.eval()

    out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())

    n_preds_per_input = out.cpu().data.numpy().shape[2]
    optimizer = optim.Adam(model.parameters(), weight_decay=0, lr=1e-3)

    iterator = CropsFromTrialsIterator(batch_size=60,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input,
                                       seed=np_th_seed)

    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length=input_time_length),
        RuntimeMonitor()
    ]

    model_constraint = MaxNormDefaultConstraint()

    loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2),
                                                      targets)

    run_after_early_stop = True
    do_early_stop = True
    remember_best_column = 'valid_misclass'
    stop_criterion = Or([
        MaxEpochs(max_epochs),
        NoDecrease('valid_misclass', max_increase_epochs)
    ])

    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator=iterator,
                     loss_function=loss_function,
                     optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column=remember_best_column,
                     run_after_early_stop=run_after_early_stop,
                     cuda=True,
                     do_early_stop=do_early_stop)
    exp.run()
    return exp
Пример #7
0
def test_trialwise_decoding():
    import mne
    from mne.io import concatenate_raws

    # 5,6,7,10,13,14 are codes for executed and imagined hands/feet
    subject_id = 1
    event_codes = [5, 6, 9, 10, 13, 14]

    # This will download the files if you don't have them yet,
    # and then return the paths to the files.
    physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

    # Load each of the files
    parts = [
        mne.io.read_raw_edf(path,
                            preload=True,
                            stim_channel='auto',
                            verbose='WARNING') for path in physionet_paths
    ]

    # Concatenate them
    raw = concatenate_raws(parts)

    # Find the events in this dataset
    events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')

    # Use only EEG channels
    eeg_channel_inds = mne.pick_types(raw.info,
                                      meg=False,
                                      eeg=True,
                                      stim=False,
                                      eog=False,
                                      exclude='bads')

    # Extract trials, only using EEG channels
    epoched = mne.Epochs(raw,
                         events,
                         dict(hands=2, feet=3),
                         tmin=1,
                         tmax=4.1,
                         proj=False,
                         picks=eeg_channel_inds,
                         baseline=None,
                         preload=True)

    import numpy as np

    # Convert data from volt to millivolt
    # Pytorch expects float32 for input and int64 for labels.
    X = (epoched.get_data() * 1e6).astype(np.float32)
    y = (epoched.events[:, 2] - 2).astype(np.int64)  # 2,3 -> 0,1

    from braindecode.datautil.signal_target import SignalAndTarget

    train_set = SignalAndTarget(X[:60], y=y[:60])
    test_set = SignalAndTarget(X[60:], y=y[60:])

    from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
    from torch import nn
    from braindecode.torch_ext.util import set_random_seeds

    # Set if you want to use GPU
    # You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
    cuda = False
    set_random_seeds(seed=20170629, cuda=cuda)
    n_classes = 2
    in_chans = train_set.X.shape[1]
    # final_conv_length = auto ensures we only get a single output in the time dimension
    model = ShallowFBCSPNet(in_chans=in_chans,
                            n_classes=n_classes,
                            input_time_length=train_set.X.shape[2],
                            final_conv_length='auto').create_network()
    if cuda:
        model.cuda()

    from torch import optim

    optimizer = optim.Adam(model.parameters())

    from braindecode.torch_ext.util import np_to_var, var_to_np
    from braindecode.datautil.iterators import get_balanced_batches
    import torch.nn.functional as F
    from numpy.random import RandomState

    rng = RandomState((2017, 6, 30))
    losses = []
    accuracies = []
    for i_epoch in range(6):
        i_trials_in_batch = get_balanced_batches(len(train_set.X),
                                                 rng,
                                                 shuffle=True,
                                                 batch_size=30)
        # Set model to training mode
        model.train()
        for i_trials in i_trials_in_batch:
            # Have to add empty fourth dimension to X
            batch_X = train_set.X[i_trials][:, :, :, None]
            batch_y = train_set.y[i_trials]
            net_in = np_to_var(batch_X)
            if cuda:
                net_in = net_in.cuda()
            net_target = np_to_var(batch_y)
            if cuda:
                net_target = net_target.cuda()
            # Remove gradients of last backward pass from all parameters
            optimizer.zero_grad()
            # Compute outputs of the network
            outputs = model(net_in)
            # Compute the loss
            loss = F.nll_loss(outputs, net_target)
            # Do the backpropagation
            loss.backward()
            # Update parameters with the optimizer
            optimizer.step()

        # Print some statistics each epoch
        model.eval()
        print("Epoch {:d}".format(i_epoch))
        for setname, dataset in (('Train', train_set), ('Test', test_set)):
            # Here, we will use the entire dataset at once, which is still possible
            # for such smaller datasets. Otherwise we would have to use batches.
            net_in = np_to_var(dataset.X[:, :, :, None])
            if cuda:
                net_in = net_in.cuda()
            net_target = np_to_var(dataset.y)
            if cuda:
                net_target = net_target.cuda()
            outputs = model(net_in)
            loss = F.nll_loss(outputs, net_target)
            losses.append(float(var_to_np(loss)))
            print("{:6s} Loss: {:.5f}".format(setname, float(var_to_np(loss))))
            predicted_labels = np.argmax(var_to_np(outputs), axis=1)
            accuracy = np.mean(dataset.y == predicted_labels)
            accuracies.append(accuracy * 100)
            print("{:6s} Accuracy: {:.1f}%".format(setname, accuracy * 100))

    np.testing.assert_allclose(np.array(losses),
                               np.array([
                                   1.1775966882705688, 1.2602351903915405,
                                   0.7068756818771362, 0.9367912411689758,
                                   0.394258975982666, 0.6598362326622009,
                                   0.3359280526638031, 0.656258761882782,
                                   0.2790488004684448, 0.6104397177696228,
                                   0.27319177985191345, 0.5949864983558655
                               ]),
                               rtol=1e-4,
                               atol=1e-5)

    np.testing.assert_allclose(np.array(accuracies),
                               np.array([
                                   51.666666666666671, 53.333333333333336,
                                   63.333333333333329, 56.666666666666664,
                                   86.666666666666671, 66.666666666666657,
                                   90.0, 63.333333333333329,
                                   96.666666666666671, 56.666666666666664,
                                   96.666666666666671, 66.666666666666657
                               ]),
                               rtol=1e-4,
                               atol=1e-5)
Пример #8
0
def run_exp(max_recording_mins, n_recordings, sec_to_cut,
            duration_recording_mins, max_abs_val, shrink_val, sampling_freq,
            divisor, n_folds, i_test_fold, final_conv_length, model_constraint,
            batch_size, max_epochs, n_filters_time, n_filters_spat,
            filter_time_length, conv_nonlin, pool_time_length,
            pool_time_stride, pool_mode, pool_nonlin, split_first_layer,
            do_batch_norm, drop_prob, time_cut_off_sec, start_time,
            input_time_length, only_return_exp):
    kwargs = locals()
    for model_param in [
            'final_conv_length',
            'n_filters_time',
            'n_filters_spat',
            'filter_time_length',
            'conv_nonlin',
            'pool_time_length',
            'pool_time_stride',
            'pool_mode',
            'pool_nonlin',
            'split_first_layer',
            'do_batch_norm',
            'drop_prob',
    ]:
        kwargs.pop(model_param)
    nonlin_dict = {
        'elu': elu,
        'relu': relu,
        'relu6': relu6,
        'tanh': tanh,
        'square': square,
        'identity': identity,
        'log': safe_log,
    }
    assert input_time_length == 6000
    # copy over from early seizure
    # make proper
    n_classes = 2
    in_chans = 21
    cuda = True
    set_random_seeds(seed=20170629, cuda=cuda)
    model = ShallowFBCSPNet(in_chans=in_chans,
                            n_classes=n_classes,
                            input_time_length=input_time_length,
                            final_conv_length=final_conv_length,
                            n_filters_time=n_filters_time,
                            filter_time_length=filter_time_length,
                            n_filters_spat=n_filters_spat,
                            pool_time_length=pool_time_length,
                            pool_time_stride=pool_time_stride,
                            conv_nonlin=nonlin_dict[conv_nonlin],
                            pool_mode=pool_mode,
                            pool_nonlin=nonlin_dict[pool_nonlin],
                            split_first_layer=split_first_layer,
                            batch_norm=do_batch_norm,
                            batch_norm_alpha=0.1,
                            drop_prob=drop_prob).create_network()

    to_dense_prediction_model(model)
    if cuda:
        model.cuda()
    model.eval()
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()

    try:
        out = model(test_input)
    except RuntimeError:
        raise ValueError("Model receptive field too large...")
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    n_receptive_field = input_time_length - n_preds_per_input

    if n_receptive_field > 6000:
        raise ValueError("Model receptive field ({:d}) too large...".format(
            n_receptive_field))
        # For future, here optionally add input time length instead

    model = ShallowFBCSPNet(in_chans=in_chans,
                            n_classes=n_classes,
                            input_time_length=input_time_length,
                            final_conv_length=final_conv_length,
                            n_filters_time=n_filters_time,
                            filter_time_length=filter_time_length,
                            n_filters_spat=n_filters_spat,
                            pool_time_length=pool_time_length,
                            pool_time_stride=pool_time_stride,
                            conv_nonlin=nonlin_dict[conv_nonlin],
                            pool_mode=pool_mode,
                            pool_nonlin=nonlin_dict[pool_nonlin],
                            split_first_layer=split_first_layer,
                            batch_norm=do_batch_norm,
                            batch_norm_alpha=0.1,
                            drop_prob=drop_prob).create_network()
    return common.run_exp(model=model, **kwargs)
Пример #9
0
def run_experiment(train_set, valid_set, test_set, model_name, optimizer_name,
                   init_lr, scheduler_name, use_norm_constraint, weight_decay,
                   schedule_weight_decay, restarts, max_epochs,
                   max_increase_epochs, np_th_seed):
    set_random_seeds(np_th_seed, cuda=True)
    #torch.backends.cudnn.benchmark = True# sometimes crashes?
    if valid_set is not None:
        assert max_increase_epochs is not None
    assert (max_epochs is None) != (restarts is None)
    if max_epochs is None:
        max_epochs = np.sum(restarts)
    n_classes = int(np.max(train_set.y) + 1)
    n_chans = int(train_set.X.shape[1])
    input_time_length = 1000
    if model_name == 'deep':
        model = Deep4Net(n_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length=2).create_network()
    elif model_name == 'shallow':
        model = ShallowFBCSPNet(n_chans,
                                n_classes,
                                input_time_length=input_time_length,
                                final_conv_length=30).create_network()
    elif model_name in [
            'resnet-he-uniform', 'resnet-he-normal', 'resnet-xavier-normal',
            'resnet-xavier-uniform'
    ]:
        init_name = model_name.lstrip('resnet-')
        from torch.nn import init
        init_fn = {
            'he-uniform': lambda w: init.kaiming_uniform(w, a=0),
            'he-normal': lambda w: init.kaiming_normal(w, a=0),
            'xavier-uniform': lambda w: init.xavier_uniform(w, gain=1),
            'xavier-normal': lambda w: init.xavier_normal(w, gain=1)
        }[init_name]
        model = EEGResNet(in_chans=n_chans,
                          n_classes=n_classes,
                          input_time_length=input_time_length,
                          final_pool_length=10,
                          n_first_filters=48,
                          conv_weight_init_fn=init_fn).create_network()
    else:
        raise ValueError("Unknown model name {:s}".format(model_name))
    if 'resnet' not in model_name:
        to_dense_prediction_model(model)
    model.cuda()
    model.eval()

    out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())

    n_preds_per_input = out.cpu().data.numpy().shape[2]

    if optimizer_name == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               weight_decay=weight_decay,
                               lr=init_lr)
    elif optimizer_name == 'adamw':
        optimizer = AdamW(model.parameters(),
                          weight_decay=weight_decay,
                          lr=init_lr)

    iterator = CropsFromTrialsIterator(batch_size=60,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input,
                                       seed=np_th_seed)

    if scheduler_name is not None:
        assert schedule_weight_decay == (optimizer_name == 'adamw')
        if scheduler_name == 'cosine':
            n_updates_per_epoch = sum(
                [1 for _ in iterator.get_batches(train_set, shuffle=True)])
            if restarts is None:
                n_updates_per_period = n_updates_per_epoch * max_epochs
            else:
                n_updates_per_period = np.array(restarts) * n_updates_per_epoch
            scheduler = CosineAnnealing(n_updates_per_period)
            optimizer = ScheduledOptimizer(
                scheduler,
                optimizer,
                schedule_weight_decay=schedule_weight_decay)
        elif scheduler_name == 'cut_cosine':
            # TODO: integrate with if clause before, now just separate
            # to avoid messing with code
            n_updates_per_epoch = sum(
                [1 for _ in iterator.get_batches(train_set, shuffle=True)])
            if restarts is None:
                n_updates_per_period = n_updates_per_epoch * max_epochs
            else:
                n_updates_per_period = np.array(restarts) * n_updates_per_epoch
            scheduler = CutCosineAnnealing(n_updates_per_period)
            optimizer = ScheduledOptimizer(
                scheduler,
                optimizer,
                schedule_weight_decay=schedule_weight_decay)
        else:
            raise ValueError("Unknown scheduler")
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length=input_time_length),
        RuntimeMonitor()
    ]

    if use_norm_constraint:
        model_constraint = MaxNormDefaultConstraint()
    else:
        model_constraint = None
    # change here this cell
    loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2),
                                                      targets)

    if valid_set is not None:
        run_after_early_stop = True
        do_early_stop = True
        remember_best_column = 'valid_misclass'
        stop_criterion = Or([
            MaxEpochs(max_epochs),
            NoDecrease('valid_misclass', max_increase_epochs)
        ])
    else:
        run_after_early_stop = False
        do_early_stop = False
        remember_best_column = None
        stop_criterion = MaxEpochs(max_epochs)

    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator=iterator,
                     loss_function=loss_function,
                     optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column=remember_best_column,
                     run_after_early_stop=run_after_early_stop,
                     cuda=True,
                     do_early_stop=do_early_stop)
    exp.run()
    return exp
class ShallowFBCSPNet_SpecializedTrainer(BaseEstimator, ClassifierMixin):

    model = None

    def __init__(self, network=None, filename=None):
        self.cuda = True
        if network is not None:
            self._decorateNetwork(network)
        elif filename is not None:
            self._loadFromFile(filename)
        else:
            print("unsupported option")
            sys.exit(-1)

        # set default parameters
        self.configure()

    def configure(self,
                  nb_epoch=160,
                  initial_lr=0.00006,
                  trainTestRatio=(7 / 8)):
        self.nb_epoch = nb_epoch
        self.lr = initial_lr
        self.trainTestRatio = trainTestRatio

    def _decorateNetwork(self, network):

        self.model = network  # TODO make a deep copy

        # deactivate training for all layers
        #for param in network.conv_classifier.parameters():
        #    param.requires_grad = False

        # replace last layer with a brand new one (for which training is true by default)
        self.model.conv_classifier = nn.Conv2d(5, 2, (116, 1),
                                               bias=True).cuda()

        # save/load only the model parameters(prefered solution) TODO: ask yannick
        torch.save(self.model.state_dict(), "myModel.pth")

        return

    def _loadFromFile(self, filename):

        # TODO: integrate this in saved file parameters somehow
        #n_filters_time=10
        #filter_time_length=75
        #n_filters_spat=5
        #pool_time_length=60
        #pool_time_stride=30
        #in_chans = 15
        #input_time_length = 3584

        # final_conv_length = auto ensures we only get a single output in the time dimension
        self.model = ShallowFBCSPNet(
            in_chans=15,
            n_classes=2,
            input_time_length=3584,
            n_filters_time=10,
            filter_time_length=75,
            n_filters_spat=5,
            pool_time_length=60,
            pool_time_stride=30,
            final_conv_length='auto').create_network()

        # setup model for cuda
        if self.cuda:
            print("That's the new one")
            self.model.cuda()

        # load the saved network (makes it possible to run bottom form same starting point
        self.model.load_state_dict(torch.load("myModel.pth"))
        return

    """
        Fit the network
        Params:
            X, data array in the format (...)
            y, labels
        ref: http://danielhnyk.cz/creating-your-own-estimator-scikit-learn/
    """

    def fit(self, X, y):

        self.nb_epoch = 160

        # prepare an optimizer
        self.optimizer = optim.Adam(self.model.conv_classifier.parameters(),
                                    lr=self.lr)

        # define a number of train/test trials
        nb_train_trials = int(np.floor(self.trainTestRatio * X.shape[0]))

        # split the dataset
        train_set = SignalAndTarget(X[:nb_train_trials], y=y[:nb_train_trials])
        test_set = SignalAndTarget(X[nb_train_trials:], y=y[nb_train_trials:])

        # random generator
        self.rng = RandomState(None)

        # array that tracks results
        self.loss_rec = np.zeros((self.nb_epoch, 2))
        self.accuracy_rec = np.zeros((self.nb_epoch, 2))

        # run all epoch
        for i_epoch in range(self.nb_epoch):

            self._batchTrain(i_epoch, train_set)
            self._evalTraining(i_epoch, train_set, test_set)

        return self

    """
        Training iteration, train the network on the train_set
        Params:
            i_epoch, current epoch iteration
            train_set, training set
    """

    def _batchTrain(self, i_epoch, train_set):

        # get a set of balanced batches
        i_trials_in_batch = get_balanced_batches(len(train_set.X),
                                                 self.rng,
                                                 shuffle=True,
                                                 batch_size=32)

        self.adjust_learning_rate(self.optimizer, i_epoch)

        # Set model to training mode
        self.model.train()

        # go through all batches
        for i_trials in i_trials_in_batch:

            # Have to add empty fourth dimension to X
            batch_X = train_set.X[i_trials][:, :, :, None]
            batch_y = train_set.y[i_trials]

            net_in = np_to_var(batch_X)
            net_target = np_to_var(batch_y)

            # if cuda, copy to cuda memory
            if self.cuda:
                net_in = net_in.cuda()
                net_target = net_target.cuda()

            # Remove gradients of last backward pass from all parameters
            self.optimizer.zero_grad()
            # Compute outputs of the network
            outputs = self.model(net_in)
            # Compute the loss
            loss = F.nll_loss(outputs, net_target)
            # Do the backpropagation
            loss.backward()
            # Update parameters with the optimizer
            self.optimizer.step()

        return

    """
        Evaluation iteration, computes the performance the network
        Params:
            i_epoch, current epoch iteration
            train_set, training set
    """

    def _evalTraining(self, i_epoch, train_set, test_set):

        # Print some statistics each epoch
        self.model.eval()
        print("Epoch {:d}".format(i_epoch))

        sets = {'Train': 0, 'Test': 1}

        # run evaluation on both train and test sets
        for setname, dataset in (('Train', train_set), ('Test', test_set)):

            # get balanced sets
            i_trials_in_batch = get_balanced_batches(len(dataset.X),
                                                     self.rng,
                                                     batch_size=32,
                                                     shuffle=False)

            outputs = []
            net_targets = []

            # for all trials in set
            for i_trials in i_trials_in_batch:

                # adapt datasets
                batch_X = dataset.X[i_trials][:, :, :, None]
                batch_y = dataset.y[i_trials]

                # apply some conversion
                net_in = np_to_var(batch_X)
                net_target = np_to_var(batch_y)

                # convert
                if self.cuda:
                    net_in = net_in.cuda()
                    net_target = net_target.cuda()

                net_target = var_to_np(net_target)
                output = var_to_np(self.model(net_in))
                outputs.append(output)
                net_targets.append(net_target)

            net_targets = np_to_var(np.concatenate(net_targets))
            outputs = np_to_var(np.concatenate(outputs))
            loss = F.nll_loss(outputs, net_targets)

            print("{:6s} Loss: {:.5f}".format(setname, float(var_to_np(loss))))

            self.loss_rec[i_epoch, sets[setname]] = var_to_np(loss)
            predicted_labels = np.argmax(var_to_np(outputs), axis=1)
            accuracy = np.mean(dataset.y == predicted_labels)

            print("{:6s} Accuracy: {:.1f}%".format(setname, accuracy * 100))
            self.accuracy_rec[i_epoch, sets[setname]] = accuracy

        return

    def predict(self, X):
        self.model.eval()

        #i_trials_in_batch = get_balanced_batches(len(X), self.rng, batch_size=32, shuffle=False)

        outputs = []

        for i_trials in i_trials_in_batch:
            batch_X = dataset.X[i_trials][:, :, :, None]

            net_in = np_to_var(batch_X)

            if self.cuda:
                net_in = net_in.cuda()

            output = var_to_np(self.model(net_in))
            outputs.append(output)

        return outputs

    def adjust_learning_rate(self, optimizer, epoch):
        """Sets the learning rate to the initial LR decayed by 10% every 30 epochs"""
        lr = self.lr * (0.1**(epoch // 30))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr