示例#1
0
class EEGNet_classifier():
    def __init__(self):
        self.net = EEGNet().cuda()
        self.criterion = nn.BCELoss()
        self.optimizer = optim.Adam(self.net.parameters(), lr=0.001)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   step_size=20,
                                                   gamma=0.8)

    def fit(self, X, y, quiet=True):
        for epoch in range(200):
            running_loss = 0
            _skf = StratifiedKFold(n_splits=10, shuffle=True)
            for more, less in _skf.split(X, y):
                # Get data and labels
                inputs = torch.from_numpy(X[more])
                labels = torch.FloatTensor(np.array([y[more]]).T * 1.0)

                # Make sure positive and negative samples have the same number
                inputs = inputs[:torch.sum(labels).type(torch.int16) * 2]
                labels = labels[:torch.sum(labels).type(torch.int16) * 2]

                # wrap them in Variable
                inputs = Variable(inputs.cuda())
                labels = Variable(labels.cuda())

                # zero the parameter gradients
                self.optimizer.zero_grad()

                # forward + backward + optimize
                outputs = self.net(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()

                self.optimizer.step()

                running_loss += loss.data

            self.scheduler.step()

            if not quiet:
                # print(f'Epoch {epoch}: {running_loss}')
                if epoch % 10 == 0:
                    print(f'Epoch {epoch}: {running_loss}')

        print(f'Epoch {epoch}: {running_loss}')

    def predict(self, X):
        inputs = torch.from_numpy(X)
        inputs = Variable(inputs.cuda())
        return self.net(inputs)
def train_subject_specific(subject,
                           epochs=500,
                           batch_size=32,
                           lr=0.001,
                           silent=False,
                           plot=True,
                           **kwargs):
    """
    Trains a subject specific model for the given subject

    Parameters:
     - subject:    Integer in the Range 1 <= subject <= 9
     - epochs:     Number of epochs to train
     - batch_size: Batch Size
     - lr:         Learning Rate
     - silent:     bool, if True, hide all output including the progress bar
     - plot:       bool, if True, generate plots
     - kwargs:     Remaining arguments passed to the EEGnet model

    Returns: (model, metrics)
     - model:   t.nn.Module, trained model
     - metrics: t.tensor, size=[1, 4], accuracy, precision, recall, f1
    """
    # load the data
    train_samples, train_labels = get_data(subject, training=True)
    test_samples, test_labels = get_data(subject, training=False)
    train_loader = as_data_loader(train_samples,
                                  train_labels,
                                  batch_size=batch_size)
    # test_loader = as_data_loader(test_samples, test_labels, batch_size=test_labels.shape[0])
    test_loader = as_data_loader(test_samples,
                                 test_labels,
                                 batch_size=batch_size)

    # prepare the model
    model = EEGNet(T=train_samples.shape[2], **kwargs)
    model.initialize_params()
    if t.cuda.is_available():
        model = model.cuda()

    # prepare loss function and optimizer
    loss_function = t.nn.CrossEntropyLoss()
    optimizer = t.optim.Adam(model.parameters(), lr=lr, eps=1e-7)
    scheduler = None

    # print the training setup
    print_summary(model, optimizer, loss_function, scheduler)

    # prepare progress bar
    with tqdm(desc=f"Subject {subject}",
              total=epochs,
              leave=False,
              disable=silent,
              unit='epoch',
              ascii=True) as pbar:

        # Early stopping is not allowed in this mode, because the testing data cannot be used for
        # training!
        model, metrics, _, history = _train_net(subject,
                                                model,
                                                train_loader,
                                                test_loader,
                                                loss_function,
                                                optimizer,
                                                scheduler=scheduler,
                                                epochs=epochs,
                                                early_stopping=False,
                                                plot=plot,
                                                pbar=pbar)

    if not silent:
        print(f"Subject {subject}: accuracy = {metrics[0, 0]}")
    return model, metrics, history