def runModel(mode):
    cudnn.benchmark = True

    start = time.time()

    #mode = str(sys.argv[1])
    #X,y,test_X,test_y = loadSubNormData(mode='all')
    #X,y,test_X,test_y = loadNEDCdata(mode=mode)

    #data = np.load('sessionsData/data%s-sessions.npy'%mode[:3])
    #labels = np.load('sessionsData/labels%s-sessions.npy'%mode[:3])

    data = np.load('data%s.npy' % mode[:3])
    labels = np.load('labels%s.npy' % mode[:3])

    X, y, test_X, test_y = splitDataRandom_Loaded(data, labels, mode)

    print('Mode - %s Total n: %d, Test n: %d' %
          (mode, len(y) + len(test_y), len(test_y)))
    #return 0

    #X = addDataNoise(X,band=[1,4])
    #test_X = addDataNoise(test_X,band=[1,4])

    max_shape = np.max([list(x.shape) for x in X], axis=0)

    assert max_shape[1] == int(config.duration_recording_mins *
                               config.sampling_freq * 60)

    n_classes = 2
    n_recordings = None  # set to an integer, if you want to restrict the set size
    sensor_types = ["EEG"]
    n_chans = 19  #21
    max_recording_mins = 35  # exclude larger recordings from training set
    sec_to_cut = 60  # cut away at start of each recording
    duration_recording_mins = 5  #20  # how many minutes to use per recording
    test_recording_mins = 5  #20
    max_abs_val = 800  # for clipping
    sampling_freq = 100
    divisor = 10  # divide signal by this
    test_on_eval = True  # teston evaluation set or on training set
    # in case of test on eval, n_folds and i_testfold determine
    # validation fold in training set for training until first stop
    n_folds = 10
    i_test_fold = 9
    shuffle = True
    model_name = 'linear'  #'deep'#'shallow' 'linear'
    n_start_chans = 25
    n_chan_factor = 2  # relevant for deep model only
    input_time_length = 6000
    final_conv_length = 1
    model_constraint = 'defaultnorm'
    init_lr = 1e-3
    batch_size = 64
    max_epochs = 35  # until first stop, the continue train on train+valid
    cuda = True  # False

    if model_name == 'shallow':
        model = ShallowFBCSPNet(
            in_chans=n_chans,
            n_classes=n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            input_time_length=input_time_length,
            final_conv_length=final_conv_length).create_network()
    elif model_name == 'deep':
        model = Deep4Net(n_chans,
                         n_classes,
                         n_filters_time=n_start_chans,
                         n_filters_spat=n_start_chans,
                         input_time_length=input_time_length,
                         n_filters_2=int(n_start_chans * n_chan_factor),
                         n_filters_3=int(n_start_chans * (n_chan_factor**2.0)),
                         n_filters_4=int(n_start_chans * (n_chan_factor**3.0)),
                         final_conv_length=final_conv_length,
                         stride_before_pool=True).create_network()
    elif (model_name == 'deep_smac'):
        if model_name == 'deep_smac':
            do_batch_norm = False
        else:
            assert model_name == 'deep_smac_bnorm'
            do_batch_norm = True
        double_time_convs = False
        drop_prob = 0.244445
        filter_length_2 = 12
        filter_length_3 = 14
        filter_length_4 = 12
        filter_time_length = 21
        final_conv_length = 1
        first_nonlin = elu
        first_pool_mode = 'mean'
        first_pool_nonlin = identity
        later_nonlin = elu
        later_pool_mode = 'mean'
        later_pool_nonlin = identity
        n_filters_factor = 1.679066
        n_filters_start = 32
        pool_time_length = 1
        pool_time_stride = 2
        split_first_layer = True
        n_chan_factor = n_filters_factor
        n_start_chans = n_filters_start
        model = Deep4Net(n_chans,
                         n_classes,
                         n_filters_time=n_start_chans,
                         n_filters_spat=n_start_chans,
                         input_time_length=input_time_length,
                         n_filters_2=int(n_start_chans * n_chan_factor),
                         n_filters_3=int(n_start_chans * (n_chan_factor**2.0)),
                         n_filters_4=int(n_start_chans * (n_chan_factor**3.0)),
                         final_conv_length=final_conv_length,
                         batch_norm=do_batch_norm,
                         double_time_convs=double_time_convs,
                         drop_prob=drop_prob,
                         filter_length_2=filter_length_2,
                         filter_length_3=filter_length_3,
                         filter_length_4=filter_length_4,
                         filter_time_length=filter_time_length,
                         first_nonlin=first_nonlin,
                         first_pool_mode=first_pool_mode,
                         first_pool_nonlin=first_pool_nonlin,
                         later_nonlin=later_nonlin,
                         later_pool_mode=later_pool_mode,
                         later_pool_nonlin=later_pool_nonlin,
                         pool_time_length=pool_time_length,
                         pool_time_stride=pool_time_stride,
                         split_first_layer=split_first_layer,
                         stride_before_pool=True).create_network()
    elif model_name == 'shallow_smac':
        conv_nonlin = identity
        do_batch_norm = True
        drop_prob = 0.328794
        filter_time_length = 56
        final_conv_length = 22
        n_filters_spat = 73
        n_filters_time = 24
        pool_mode = 'max'
        pool_nonlin = identity
        pool_time_length = 84
        pool_time_stride = 3
        split_first_layer = True
        model = ShallowFBCSPNet(
            in_chans=n_chans,
            n_classes=n_classes,
            n_filters_time=n_filters_time,
            n_filters_spat=n_filters_spat,
            input_time_length=input_time_length,
            final_conv_length=final_conv_length,
            conv_nonlin=conv_nonlin,
            batch_norm=do_batch_norm,
            drop_prob=drop_prob,
            filter_time_length=filter_time_length,
            pool_mode=pool_mode,
            pool_nonlin=pool_nonlin,
            pool_time_length=pool_time_length,
            pool_time_stride=pool_time_stride,
            split_first_layer=split_first_layer,
        ).create_network()
    elif model_name == 'linear':
        model = nn.Sequential()
        model.add_module("conv_classifier",
                         nn.Conv2d(n_chans, n_classes, (600, 1)))
        model.add_module('softmax', nn.LogSoftmax(dim=1))
        model.add_module('squeeze', Expression(lambda x: x.squeeze(3)))
    else:
        assert False, "unknown model name {:s}".format(model_name)

    to_dense_prediction_model(model)

    if config.cuda:
        model.cuda()
    test_input = np_to_var(
        np.ones((2, config.n_chans, config.input_time_length, 1),
                dtype=np.float32))
    if config.cuda:
        test_input = test_input.cuda()

    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    iterator = CropsFromTrialsIterator(
        batch_size=config.batch_size,
        input_time_length=config.input_time_length,
        n_preds_per_input=n_preds_per_input)

    #model.add_module('softmax', nn.LogSoftmax(dim=1))

    model.eval()

    mode[2] = str(mode[2])
    mode[3] = str(mode[3])
    modelName = '-'.join(mode[:4])

    #params = th.load('sessionsData/%sModel%s-sessions.pt'%(modelName,mode[4]))
    #params = th.load('%sModel%s.pt'%(modelName,mode[4]))
    params = th.load('linear/%sModel%s.pt' % (modelName, mode[4]))

    model.load_state_dict(params)

    if config.test_on_eval:
        #test_X, test_y = test_dataset.load()
        #test_X, test_y = loadNEDCdata(mode='eval')
        max_shape = np.max([list(x.shape) for x in test_X], axis=0)
        assert max_shape[1] == int(config.test_recording_mins *
                                   config.sampling_freq * 60)
    if not config.test_on_eval:
        splitter = TrainValidTestSplitter(config.n_folds,
                                          config.i_test_fold,
                                          shuffle=config.shuffle)
        train_set, valid_set, test_set = splitter.split(X, y)
    else:
        splitter = TrainValidSplitter(config.n_folds,
                                      i_valid_fold=config.i_test_fold,
                                      shuffle=config.shuffle)
        train_set, valid_set = splitter.split(X, y)
        test_set = SignalAndTarget(test_X, test_y)
        del test_X, test_y
    del X, y  # shouldn't be necessary, but just to make sure

    datasets = OrderedDict(
        (('train', train_set), ('valid', valid_set), ('test', test_set)))

    for setname in ('train', 'valid', 'test'):
        #setname = 'test'
        #print("Compute predictions for {:s}...".format(setname))
        dataset = datasets[setname]
        if config.cuda:
            preds_per_batch = [
                var_to_np(model(np_to_var(b[0]).cuda()))
                for b in iterator.get_batches(dataset, shuffle=False)
            ]
        else:
            preds_per_batch = [
                var_to_np(model(np_to_var(b[0])))
                for b in iterator.get_batches(dataset, shuffle=False)
            ]
        preds_per_trial = compute_preds_per_trial(
            preds_per_batch,
            dataset,
            input_time_length=iterator.input_time_length,
            n_stride=iterator.n_preds_per_input)
        mean_preds_per_trial = [
            np.mean(preds, axis=(0, 2)) for preds in preds_per_trial
        ]
        mean_preds_per_trial = np.array(mean_preds_per_trial)

        all_pred_labels = np.argmax(mean_preds_per_trial, axis=1).squeeze()
        all_target_labels = dataset.y
        acc_per_class = []
        for i_class in range(n_classes):
            mask = all_target_labels == i_class
            acc = np.mean(all_pred_labels[mask] == all_target_labels[mask])
            acc_per_class.append(acc)
        misclass = 1 - np.mean(acc_per_class)
        #print('Acc:{}, Class 0:{}, Class 1:{}'.format(np.mean(acc_per_class),acc_per_class[0],acc_per_class[1]))

        if setname == 'test':
            testResult = np.mean(acc_per_class)

    return testResult
class ShallowFBCSPNet_SpecializedTrainer(BaseEstimator, ClassifierMixin):

    model = None

    def __init__(self, network=None, filename=None):
        self.cuda = True
        if network is not None:
            self._decorateNetwork(network)
        elif filename is not None:
            self._loadFromFile(filename)
        else:
            print("unsupported option")
            sys.exit(-1)

        # set default parameters
        self.configure()

    def configure(self,
                  nb_epoch=160,
                  initial_lr=0.00006,
                  trainTestRatio=(7 / 8)):
        self.nb_epoch = nb_epoch
        self.lr = initial_lr
        self.trainTestRatio = trainTestRatio

    def _decorateNetwork(self, network):

        self.model = network  # TODO make a deep copy

        # deactivate training for all layers
        #for param in network.conv_classifier.parameters():
        #    param.requires_grad = False

        # replace last layer with a brand new one (for which training is true by default)
        self.model.conv_classifier = nn.Conv2d(5, 2, (116, 1),
                                               bias=True).cuda()

        # save/load only the model parameters(prefered solution) TODO: ask yannick
        torch.save(self.model.state_dict(), "myModel.pth")

        return

    def _loadFromFile(self, filename):

        # TODO: integrate this in saved file parameters somehow
        #n_filters_time=10
        #filter_time_length=75
        #n_filters_spat=5
        #pool_time_length=60
        #pool_time_stride=30
        #in_chans = 15
        #input_time_length = 3584

        # final_conv_length = auto ensures we only get a single output in the time dimension
        self.model = ShallowFBCSPNet(
            in_chans=15,
            n_classes=2,
            input_time_length=3584,
            n_filters_time=10,
            filter_time_length=75,
            n_filters_spat=5,
            pool_time_length=60,
            pool_time_stride=30,
            final_conv_length='auto').create_network()

        # setup model for cuda
        if self.cuda:
            print("That's the new one")
            self.model.cuda()

        # load the saved network (makes it possible to run bottom form same starting point
        self.model.load_state_dict(torch.load("myModel.pth"))
        return

    """
        Fit the network
        Params:
            X, data array in the format (...)
            y, labels
        ref: http://danielhnyk.cz/creating-your-own-estimator-scikit-learn/
    """

    def fit(self, X, y):

        self.nb_epoch = 160

        # prepare an optimizer
        self.optimizer = optim.Adam(self.model.conv_classifier.parameters(),
                                    lr=self.lr)

        # define a number of train/test trials
        nb_train_trials = int(np.floor(self.trainTestRatio * X.shape[0]))

        # split the dataset
        train_set = SignalAndTarget(X[:nb_train_trials], y=y[:nb_train_trials])
        test_set = SignalAndTarget(X[nb_train_trials:], y=y[nb_train_trials:])

        # random generator
        self.rng = RandomState(None)

        # array that tracks results
        self.loss_rec = np.zeros((self.nb_epoch, 2))
        self.accuracy_rec = np.zeros((self.nb_epoch, 2))

        # run all epoch
        for i_epoch in range(self.nb_epoch):

            self._batchTrain(i_epoch, train_set)
            self._evalTraining(i_epoch, train_set, test_set)

        return self

    """
        Training iteration, train the network on the train_set
        Params:
            i_epoch, current epoch iteration
            train_set, training set
    """

    def _batchTrain(self, i_epoch, train_set):

        # get a set of balanced batches
        i_trials_in_batch = get_balanced_batches(len(train_set.X),
                                                 self.rng,
                                                 shuffle=True,
                                                 batch_size=32)

        self.adjust_learning_rate(self.optimizer, i_epoch)

        # Set model to training mode
        self.model.train()

        # go through all batches
        for i_trials in i_trials_in_batch:

            # Have to add empty fourth dimension to X
            batch_X = train_set.X[i_trials][:, :, :, None]
            batch_y = train_set.y[i_trials]

            net_in = np_to_var(batch_X)
            net_target = np_to_var(batch_y)

            # if cuda, copy to cuda memory
            if self.cuda:
                net_in = net_in.cuda()
                net_target = net_target.cuda()

            # Remove gradients of last backward pass from all parameters
            self.optimizer.zero_grad()
            # Compute outputs of the network
            outputs = self.model(net_in)
            # Compute the loss
            loss = F.nll_loss(outputs, net_target)
            # Do the backpropagation
            loss.backward()
            # Update parameters with the optimizer
            self.optimizer.step()

        return

    """
        Evaluation iteration, computes the performance the network
        Params:
            i_epoch, current epoch iteration
            train_set, training set
    """

    def _evalTraining(self, i_epoch, train_set, test_set):

        # Print some statistics each epoch
        self.model.eval()
        print("Epoch {:d}".format(i_epoch))

        sets = {'Train': 0, 'Test': 1}

        # run evaluation on both train and test sets
        for setname, dataset in (('Train', train_set), ('Test', test_set)):

            # get balanced sets
            i_trials_in_batch = get_balanced_batches(len(dataset.X),
                                                     self.rng,
                                                     batch_size=32,
                                                     shuffle=False)

            outputs = []
            net_targets = []

            # for all trials in set
            for i_trials in i_trials_in_batch:

                # adapt datasets
                batch_X = dataset.X[i_trials][:, :, :, None]
                batch_y = dataset.y[i_trials]

                # apply some conversion
                net_in = np_to_var(batch_X)
                net_target = np_to_var(batch_y)

                # convert
                if self.cuda:
                    net_in = net_in.cuda()
                    net_target = net_target.cuda()

                net_target = var_to_np(net_target)
                output = var_to_np(self.model(net_in))
                outputs.append(output)
                net_targets.append(net_target)

            net_targets = np_to_var(np.concatenate(net_targets))
            outputs = np_to_var(np.concatenate(outputs))
            loss = F.nll_loss(outputs, net_targets)

            print("{:6s} Loss: {:.5f}".format(setname, float(var_to_np(loss))))

            self.loss_rec[i_epoch, sets[setname]] = var_to_np(loss)
            predicted_labels = np.argmax(var_to_np(outputs), axis=1)
            accuracy = np.mean(dataset.y == predicted_labels)

            print("{:6s} Accuracy: {:.1f}%".format(setname, accuracy * 100))
            self.accuracy_rec[i_epoch, sets[setname]] = accuracy

        return

    def predict(self, X):
        self.model.eval()

        #i_trials_in_batch = get_balanced_batches(len(X), self.rng, batch_size=32, shuffle=False)

        outputs = []

        for i_trials in i_trials_in_batch:
            batch_X = dataset.X[i_trials][:, :, :, None]

            net_in = np_to_var(batch_X)

            if self.cuda:
                net_in = net_in.cuda()

            output = var_to_np(self.model(net_in))
            outputs.append(output)

        return outputs

    def adjust_learning_rate(self, optimizer, epoch):
        """Sets the learning rate to the initial LR decayed by 10% every 30 epochs"""
        lr = self.lr * (0.1**(epoch // 30))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr