Exemplo n.º 1
0
def trainingNetwork(images_folder_train, labels_folder_train,
                    images_folder_val, labels_folder_val, dictionary,
                    target_classes, num_classes, save_network_as,
                    classifier_name, epochs, batch_sz, batch_mult,
                    learning_rate, L2_penalty, validation_frequency,
                    flagShuffle, experiment_name, progress):

    ##### DATA #####

    # setup the training dataset
    datasetTrain = CoralsDataset(images_folder_train, labels_folder_train,
                                 dictionary, target_classes, num_classes)

    print("Dataset setup..", end='')
    datasetTrain.computeAverage()
    datasetTrain.computeWeights()
    target_classes = datasetTrain.dict_target
    print("done.")

    datasetTrain.enableAugumentation()

    datasetVal = CoralsDataset(images_folder_val, labels_folder_val,
                               dictionary, target_classes, num_classes)
    datasetVal.dataset_average = datasetTrain.dataset_average
    datasetVal.weights = datasetTrain.weights

    #AUGUMENTATION IS NOT APPLIED ON THE VALIDATION SET
    datasetVal.disableAugumentation()

    # setup the data loader
    dataloaderTrain = DataLoader(datasetTrain,
                                 batch_size=batch_sz,
                                 shuffle=flagShuffle,
                                 num_workers=0,
                                 drop_last=True,
                                 pin_memory=True)

    validation_batch_size = 4
    dataloaderVal = DataLoader(datasetVal,
                               batch_size=validation_batch_size,
                               shuffle=False,
                               num_workers=0,
                               drop_last=True,
                               pin_memory=True)

    training_images_number = len(datasetTrain.images_names)
    validation_images_number = len(datasetVal.images_names)

    ###### SETUP THE NETWORK #####
    net = DeepLab(backbone='resnet',
                  output_stride=16,
                  num_classes=datasetTrain.num_classes)
    models_dir = "models/"
    network_name = os.path.join(models_dir, "deeplab-resnet.pth.tar")
    state = torch.load(network_name)
    # RE-INIZIALIZE THE CLASSIFICATION LAYER WITH THE RIGHT NUMBER OF CLASSES, DON'T LOAD WEIGHTS OF THE CLASSIFICATION LAYER
    new_dictionary = state['state_dict']
    del new_dictionary['decoder.last_conv.8.weight']
    del new_dictionary['decoder.last_conv.8.bias']
    net.load_state_dict(state['state_dict'], strict=False)
    print("NETWORK USED: DEEPLAB V3+")

    # LOSS

    weights = datasetTrain.weights
    class_weights = torch.FloatTensor(weights).cuda()
    lossfn = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

    # OPTIMIZER
    # optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.0002, momentum=0.9)
    optimizer = optim.Adam(net.parameters(),
                           lr=learning_rate,
                           weight_decay=L2_penalty)

    USE_CUDA = torch.cuda.is_available()

    if USE_CUDA:
        device = torch.device("cuda")
        net.to(device)

    ##### TRAINING LOOP #####

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=2,
                                                     verbose=True)

    best_accuracy = 0.0
    best_jaccard_score = 0.0

    print("Training Network")
    for epoch in range(epochs):  # loop over the dataset multiple times

        txt = "Epoch " + str(epoch + 1) + "/" + str(epochs)
        progress.setMessage(txt)
        progress.setProgress((100.0 * epoch) / epochs)
        QApplication.processEvents()

        net.train()
        optimizer.zero_grad()
        running_loss = 0.0
        for i, minibatch in enumerate(dataloaderTrain):
            # get the inputs
            images_batch = minibatch['image']
            labels_batch = minibatch['labels']

            if USE_CUDA:
                images_batch = images_batch.to(device)
                labels_batch = labels_batch.to(device)

            # forward+loss+backward
            outputs = net(images_batch)
            loss = lossfn(outputs, labels_batch)
            loss.backward()

            # TO AVOID MEMORY TRUBLE UPDATE WEIGHTS EVERY BATCH SIZE X BATCH MULT
            if (i + 1) % batch_mult == 0:
                optimizer.step()
                optimizer.zero_grad()

            print(epoch, i, loss.item())
            running_loss += loss.item()

        print("Epoch: %d , Running loss = %f" % (epoch, running_loss))

        ### VALIDATION ###
        if epoch > 0 and (epoch + 1) % validation_frequency == 0:

            print("RUNNING VALIDATION.. ", end='')

            # datasetVal.weights are the same of datasetTrain
            metrics_val, mean_loss_val = evaluateNetwork(
                dataloaderVal,
                datasetVal.weights,
                datasetVal.num_classes,
                net,
                flagTrainingDataset=False)
            accuracy = metrics_val['Accuracy']
            jaccard_score = metrics_val['JaccardScore']

            scheduler.step(mean_loss_val)

            metrics_train, mean_loss_train = evaluateNetwork(
                dataloaderTrain,
                datasetTrain.weights,
                datasetTrain.num_classes,
                net,
                flagTrainingDataset=True)
            accuracy_training = metrics_train['Accuracy']
            jaccard_training = metrics_train['JaccardScore']

            if jaccard_score > best_jaccard_score:

                best_accuracy = accuracy
                best_jaccard_score = jaccard_score
                torch.save(net.state_dict(), save_network_as)
                # performance of the best accuracy network on the validation dataset
                metrics_filename = save_network_as[:len(save_network_as) -
                                                   4] + "-val-metrics.txt"
                saveMetrics(metrics_val, metrics_filename)
                metrics_filename = save_network_as[:len(save_network_as) -
                                                   4] + "-train-metrics.txt"
                saveMetrics(metrics_train, metrics_filename)

            print("-> CURRENT BEST ACCURACY ", best_accuracy)

    print("***** TRAINING FINISHED *****")

    return datasetTrain
Exemplo n.º 2
0
def trainingNetwork(images_folder_train, labels_folder_train, images_folder_val, labels_folder_val,
                    dictionary, target_classes, output_classes, save_network_as, classifier_name,
                    epochs, batch_sz, batch_mult, learning_rate, L2_penalty, validation_frequency, loss_to_use,
                    epochs_switch, epochs_transition, tversky_alpha, tversky_gamma, optimiz,
                    flag_shuffle, flag_training_accuracy, progress):

    ##### DATA #####

    # setup the training dataset
    datasetTrain = CoralsDataset(images_folder_train, labels_folder_train, dictionary, target_classes)

    print("Dataset setup..", end='')
    datasetTrain.computeAverage()
    datasetTrain.computeWeights()
    print(datasetTrain.dict_target)
    print(datasetTrain.weights)
    freq = 1.0 / datasetTrain.weights
    print(freq)
    print("done.")

    save_classifier_as = save_network_as.replace(".net", ".json")

    datasetTrain.enableAugumentation()

    datasetVal = CoralsDataset(images_folder_val, labels_folder_val, dictionary, target_classes)
    datasetVal.dataset_average = datasetTrain.dataset_average
    datasetVal.weights = datasetTrain.weights

    #AUGUMENTATION IS NOT APPLIED ON THE VALIDATION SET
    datasetVal.disableAugumentation()

    # setup the data loader
    dataloaderTrain = DataLoader(datasetTrain, batch_size=batch_sz, shuffle=flag_shuffle, num_workers=0, drop_last=True,
                                 pin_memory=True)

    validation_batch_size = 4
    dataloaderVal = DataLoader(datasetVal, batch_size=validation_batch_size, shuffle=False, num_workers=0, drop_last=True,
                                 pin_memory=True)

    training_images_number = len(datasetTrain.images_names)
    validation_images_number = len(datasetVal.images_names)

    print("NETWORK USED: DEEPLAB V3+")

    if os.path.exists(save_network_as):
        net = DeepLab(backbone='resnet', output_stride=16, num_classes=output_classes)
        net.load_state_dict(torch.load(save_network_as))
        print("Checkpoint loaded.")
    else:
        ###### SETUP THE NETWORK #####
        net = DeepLab(backbone='resnet', output_stride=16, num_classes=output_classes)
        state = torch.load("models/deeplab-resnet.pth.tar")
        # RE-INIZIALIZE THE CLASSIFICATION LAYER WITH THE RIGHT NUMBER OF CLASSES, DON'T LOAD WEIGHTS OF THE CLASSIFICATION LAYER
        new_dictionary = state['state_dict']
        del new_dictionary['decoder.last_conv.8.weight']
        del new_dictionary['decoder.last_conv.8.bias']
        net.load_state_dict(state['state_dict'], strict=False)

    # OPTIMIZER
    if optimiz == "SGD":
        optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=L2_penalty, momentum=0.9)
    elif optimiz == "ADAM":
        optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=L2_penalty)

    USE_CUDA = torch.cuda.is_available()

    if USE_CUDA:
        device = torch.device("cuda")
        net.to(device)

    ##### TRAINING LOOP #####

    reduce_lr_patience = 2
    if loss_to_use == "DICE+BOUNDARY":
        reduce_lr_patience = 200
        print("patience increased !")

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=reduce_lr_patience, verbose=True)

    best_accuracy = 0.0
    best_jaccard_score = 0.0

    # Crossentropy loss
    weights = datasetTrain.weights
    class_weights = torch.FloatTensor(weights).cuda()
    CEloss = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

    # weights for GENERALIZED DICE LOSS (GDL)
    freq = 1.0 / datasetTrain.weights[1:]
    w = 1.0 / (freq * freq)
    w = w / w.sum() + 0.00001
    w_for_GDL = torch.from_numpy(w)
    w_for_GDL = w_for_GDL.to(device)

    # Focal Tversky loss
    focal_tversky_gamma = torch.tensor(tversky_gamma)
    focal_tversky_gamma = focal_tversky_gamma.to(device)

    tversky_loss_alpha = torch.tensor(tversky_alpha)
    tversky_loss_beta = torch.tensor(1.0 - tversky_alpha)
    tversky_loss_alpha = tversky_loss_alpha.to(device)
    tversky_loss_beta = tversky_loss_beta.to(device)



    print("Training Network")
    num_iter = 0
    total_iter = epochs * int(len(datasetTrain) / dataloaderTrain.batch_size)
    for epoch in range(epochs):

        net.train()
        optimizer.zero_grad()

        loss_values = []
        for i, minibatch in enumerate(dataloaderTrain):

            txt = "Training - Iterations " + str(num_iter + 1) + "/" + str(total_iter)
            progress.setMessage(txt)
            progress.setProgress((100.0 * num_iter) / total_iter)
            QApplication.processEvents()
            num_iter += 1

            # get the inputs
            images_batch = minibatch['image']
            labels_batch = minibatch['labels']

            if USE_CUDA:
                images_batch = images_batch.to(device)
                labels_batch = labels_batch.to(device)

            # forward+loss+backward
            outputs = net(images_batch)

            loss = computeLoss(loss_to_use, CEloss, w_for_GDL, tversky_loss_alpha, tversky_loss_beta, focal_tversky_gamma,
                               epoch, epochs_switch, epochs_transition, labels_batch, outputs)

            loss.backward()

            # TO AVOID MEMORY TROUBLE UPDATE WEIGHTS EVERY BATCH SIZE x BATCH MULT
            if (i+1)% batch_mult == 0:
                optimizer.step()
                optimizer.zero_grad()

            print(epoch, i, loss.item())
            loss_values.append(loss.item())

        mean_loss_train = sum(loss_values) / len(loss_values)
        print("Epoch: %d , Mean loss = %f" % (epoch, mean_loss_train))

        ### VALIDATION ###
        if epoch > 0 and (epoch+1) % validation_frequency == 0:

            print("RUNNING VALIDATION.. ", end='')

            metrics_val, mean_loss_val = evaluateNetwork(datasetVal, dataloaderVal, loss_to_use, CEloss, w_for_GDL,
                                                         tversky_loss_alpha, tversky_loss_beta, focal_tversky_gamma,
                                                         epoch, epochs_switch, epochs_transition,
                                                         output_classes, net, flag_compute_mIoU=False)
            accuracy = metrics_val['Accuracy']
            jaccard_score = metrics_val['JaccardScore']

            scheduler.step(mean_loss_val)

            accuracy_training = 0.0
            jaccard_training = 0.0

            if flag_training_accuracy is True:
                metrics_train, mean_loss_train = evaluateNetwork(datasetTrain, dataloaderTrain, loss_to_use, CEloss, w_for_GDL,
                                                                 tversky_loss_alpha, tversky_loss_beta, focal_tversky_gamma,
                                                                 epoch, epochs_switch, epochs_transition,
                                                                 output_classes, net, flag_compute_mIoU=False)
                accuracy_training = metrics_train['Accuracy']
                jaccard_training = metrics_train['JaccardScore']

            #if jaccard_score > best_jaccard_score:
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_jaccard_score = jaccard_score
                torch.save(net.state_dict(), save_network_as)
                # performance of the best accuracy network on the validation dataset
                metrics_filename = save_network_as[:len(save_network_as) - 4] + "-val-metrics.txt"
                saveMetrics(metrics_val, metrics_filename)


            print("-> CURRENT BEST ACCURACY ", best_accuracy)


    # main loop ended
    torch.cuda.empty_cache()
    del net
    net = None

    print("***** TRAINING FINISHED *****")
    print("BEST ACCURACY REACHED ON THE VALIDATION SET: %.3f " % best_accuracy)

    return datasetTrain