示例#1
0
def train_model(clf: CnnClassifier, results_file: TextIO) -> TrainedModel:
    '''
    Trains a CNN classifier.
    Computes the accuracy on the validation set.
    Returns both the trained classifier and accuracy.
    '''

    results_file.write('epoch,train_loss,train_loss_sd,val_accuracy\n')
    n_epochs = 100
    for i in range(n_epochs):
        print(f"epoch {i+1}")
        # reset list of per-epoch training loss
        loss_train = []
        for batch in train_batches:
            # train classifier
            loss_train.append(clf.train(batch.data, batch.label))
        # output as requested
        loss_train = np.array(loss_train)
        print(f" train loss: {np.mean(loss_train):5.3f} +- "
              f"{np.std(loss_train):5.3f}")

        # calculate validation accuracy
        accuracy = Accuracy()
        for batch in val_batches:
            # predict and update accuracy
            prediction = clf.predict(batch.data)
            accuracy.update(prediction, batch.label)
        print(f" val acc: {accuracy}")
        results_file.write(f"{i},{np.mean(loss_train)},")
        results_file.write(f"{np.std(loss_train)},{accuracy.accuracy()}\n")

    return TrainedModel(clf, accuracy)
示例#2
0
epochs_since_best_accuracy = 0
for epoch in range(0, EPOCHS):
    print("Epoche: ", epoch + 1)

    for batch in batchGenerator_training:
        loss = clf.train(batch.data, batch.label)
        loss_list.append(loss)

    loss = np.array(loss_list)
    loss_mean = np.mean(loss)
    loss_deviation = np.std(loss)
    print("Train loss: ", loss_mean, "-+", loss_deviation)

    accuracy.reset()
    for batch in batchGenerator_validation:
        predictions = clf.predict(batch.data)
        accuracy.update(predictions.cpu().detach().numpy(), batch.label)

    print("Val " + str(accuracy))

    if EARLY_STOPPING:
        epochs_since_best_accuracy += 1
        if best_accuracy < accuracy.accuracy():
            best_accuracy = accuracy.accuracy()
            torch.save(net.state_dict(), "best_model.pth")
            epochs_since_best_accuracy = 0
        if epochs_since_best_accuracy > EARLY_STOPPING_NUM_OF_EPOCHS:
            print(
                str(EARLY_STOPPING_NUM_OF_EPOCHS) +
                " epochs passed without improvement in validation accuracy. Stopping the training now."
                + "best validation accuracy: " + str(best_accuracy))
示例#3
0
def train(lr, wd, operation):
    print("Training a network with:")
    print("Weight Decay = {}".format(wd))
    print("Augmentation = {}".format(operation))
    print("Learning Rate = {}".format(lr))

    device = torch.device("cuda" if CUDA else "cpu")

    img_shape = train_data.image_shape()
    num_classes = train_data.num_classes()

    net = Net(img_shape, num_classes).to(device)
    clf = CnnClassifier(net, (0, *img_shape), num_classes, lr, wd)

    op = operations[operation]
    train_batches = BatchGenerator(train_data, 128, False, op)
    val_batches = BatchGenerator(val_data, 128, False, op)

    not_improved_since = 0
    best_accuracy = 0
    best_loss = 0
    stop_epoch = 0

    for epoch in range(NR_EPOCHS):
        print("Epoch {}/{}".format(epoch, NR_EPOCHS), end="\r")
        losses = []
        for batch in train_batches:
            loss = clf.train(batch.data, batch.label)
            losses.append(loss)
        losses = np.array(losses)
        mean = round(np.mean(losses), 3)
        std = round(np.std(losses), 3)

        accuracy = Accuracy()
        for batch in val_batches:
            predictions = clf.predict(batch.data)
            accuracy.update(predictions, batch.label)
        acc = round(accuracy.accuracy(), 3)
        # Early stopping
        if acc > best_accuracy:
            stop_epoch = epoch
            not_improved_since = 0
            best_accuracy = acc
            best_loss = mean
        else:
            not_improved_since += 1
        if not_improved_since > EARLY_STOPPING: # if not improved since 5 epochs stop training
            break
    print()
    print("Best val accuracy after epoch {}".format(stop_epoch + 1))
    print("Validation Accuracy: {}".format(best_accuracy))
    print("Train Loss: {}".format(best_loss))

    with open(RESULTS_FILE, "a") as file:
        file.write("Trained a network with:\n")
        file.write("Weight Decay = {}\n".format(wd))
        file.write("Augmentation = {}\n".format(operation))
        file.write("Learning Rate = {}\n".format(lr))
        file.write("---\n")
        file.write("Best val accuracy after epoch {}\n".format(stop_epoch + 1))
        file.write("Validation Accuracy: {}\n".format(best_accuracy))
        file.write("Train Loss: {}\n".format(best_loss))
        file.write("\n#################################\n")