def testNetwork(images_folder, labels_folder, dictionary, target_classes, dataset_train, network_filename, output_folder): """ Load a network and test it on the test dataset. :param network_filename: Full name of the network to load (PATH+name) """ # TEST DATASET datasetTest = CoralsDataset(images_folder, labels_folder, dictionary, target_classes) datasetTest.disableAugumentation() datasetTest.num_classes = dataset_train.num_classes datasetTest.weights = dataset_train.weights datasetTest.dataset_average = dataset_train.dataset_average datasetTest.dict_target = dataset_train.dict_target output_classes = dataset_train.num_classes batchSize = 4 dataloaderTest = DataLoader(datasetTest, batch_size=batchSize, shuffle=False, num_workers=0, drop_last=True, pin_memory=True) # DEEPLAB V3+ net = DeepLab(backbone='resnet', output_stride=16, num_classes=output_classes) net.load_state_dict(torch.load(network_filename)) print("Weights loaded.") metrics_test, loss = evaluateNetwork(datasetTest, dataloaderTest, "NONE", None, [0.0], 0.0, 0.0, 0.0, 0, 0, 0, output_classes, net, True, output_folder) metrics_filename = network_filename[:len(network_filename) - 4] + "-test-metrics.txt" saveMetrics(metrics_test, metrics_filename) print("***** TEST FINISHED *****") return metrics_test
def _load_classifier(self, modelName): models_dir = "models/" network_name = os.path.join(models_dir, modelName) classifier_pocillopora = DeepLab(backbone='resnet', output_stride=16, num_classes=self.nclasses) classifier_pocillopora.load_state_dict(torch.load(network_name)) classifier_pocillopora.eval() return classifier_pocillopora
def trainingNetwork(images_folder_train, labels_folder_train, images_folder_val, labels_folder_val, dictionary, target_classes, num_classes, save_network_as, classifier_name, epochs, batch_sz, batch_mult, learning_rate, L2_penalty, validation_frequency, flagShuffle, experiment_name, progress): ##### DATA ##### # setup the training dataset datasetTrain = CoralsDataset(images_folder_train, labels_folder_train, dictionary, target_classes, num_classes) print("Dataset setup..", end='') datasetTrain.computeAverage() datasetTrain.computeWeights() target_classes = datasetTrain.dict_target print("done.") datasetTrain.enableAugumentation() datasetVal = CoralsDataset(images_folder_val, labels_folder_val, dictionary, target_classes, num_classes) datasetVal.dataset_average = datasetTrain.dataset_average datasetVal.weights = datasetTrain.weights #AUGUMENTATION IS NOT APPLIED ON THE VALIDATION SET datasetVal.disableAugumentation() # setup the data loader dataloaderTrain = DataLoader(datasetTrain, batch_size=batch_sz, shuffle=flagShuffle, num_workers=0, drop_last=True, pin_memory=True) validation_batch_size = 4 dataloaderVal = DataLoader(datasetVal, batch_size=validation_batch_size, shuffle=False, num_workers=0, drop_last=True, pin_memory=True) training_images_number = len(datasetTrain.images_names) validation_images_number = len(datasetVal.images_names) ###### SETUP THE NETWORK ##### net = DeepLab(backbone='resnet', output_stride=16, num_classes=datasetTrain.num_classes) models_dir = "models/" network_name = os.path.join(models_dir, "deeplab-resnet.pth.tar") state = torch.load(network_name) # RE-INIZIALIZE THE CLASSIFICATION LAYER WITH THE RIGHT NUMBER OF CLASSES, DON'T LOAD WEIGHTS OF THE CLASSIFICATION LAYER new_dictionary = state['state_dict'] del new_dictionary['decoder.last_conv.8.weight'] del new_dictionary['decoder.last_conv.8.bias'] net.load_state_dict(state['state_dict'], strict=False) print("NETWORK USED: DEEPLAB V3+") # LOSS weights = datasetTrain.weights class_weights = torch.FloatTensor(weights).cuda() lossfn = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1) # OPTIMIZER # optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.0002, momentum=0.9) optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=L2_penalty) USE_CUDA = torch.cuda.is_available() if USE_CUDA: device = torch.device("cuda") net.to(device) ##### TRAINING LOOP ##### scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, verbose=True) best_accuracy = 0.0 best_jaccard_score = 0.0 print("Training Network") for epoch in range(epochs): # loop over the dataset multiple times txt = "Epoch " + str(epoch + 1) + "/" + str(epochs) progress.setMessage(txt) progress.setProgress((100.0 * epoch) / epochs) QApplication.processEvents() net.train() optimizer.zero_grad() running_loss = 0.0 for i, minibatch in enumerate(dataloaderTrain): # get the inputs images_batch = minibatch['image'] labels_batch = minibatch['labels'] if USE_CUDA: images_batch = images_batch.to(device) labels_batch = labels_batch.to(device) # forward+loss+backward outputs = net(images_batch) loss = lossfn(outputs, labels_batch) loss.backward() # TO AVOID MEMORY TRUBLE UPDATE WEIGHTS EVERY BATCH SIZE X BATCH MULT if (i + 1) % batch_mult == 0: optimizer.step() optimizer.zero_grad() print(epoch, i, loss.item()) running_loss += loss.item() print("Epoch: %d , Running loss = %f" % (epoch, running_loss)) ### VALIDATION ### if epoch > 0 and (epoch + 1) % validation_frequency == 0: print("RUNNING VALIDATION.. ", end='') # datasetVal.weights are the same of datasetTrain metrics_val, mean_loss_val = evaluateNetwork( dataloaderVal, datasetVal.weights, datasetVal.num_classes, net, flagTrainingDataset=False) accuracy = metrics_val['Accuracy'] jaccard_score = metrics_val['JaccardScore'] scheduler.step(mean_loss_val) metrics_train, mean_loss_train = evaluateNetwork( dataloaderTrain, datasetTrain.weights, datasetTrain.num_classes, net, flagTrainingDataset=True) accuracy_training = metrics_train['Accuracy'] jaccard_training = metrics_train['JaccardScore'] if jaccard_score > best_jaccard_score: best_accuracy = accuracy best_jaccard_score = jaccard_score torch.save(net.state_dict(), save_network_as) # performance of the best accuracy network on the validation dataset metrics_filename = save_network_as[:len(save_network_as) - 4] + "-val-metrics.txt" saveMetrics(metrics_val, metrics_filename) metrics_filename = save_network_as[:len(save_network_as) - 4] + "-train-metrics.txt" saveMetrics(metrics_train, metrics_filename) print("-> CURRENT BEST ACCURACY ", best_accuracy) print("***** TRAINING FINISHED *****") return datasetTrain
def trainingNetwork(images_folder_train, labels_folder_train, images_folder_val, labels_folder_val, dictionary, target_classes, output_classes, save_network_as, classifier_name, epochs, batch_sz, batch_mult, learning_rate, L2_penalty, validation_frequency, loss_to_use, epochs_switch, epochs_transition, tversky_alpha, tversky_gamma, optimiz, flag_shuffle, flag_training_accuracy, progress): ##### DATA ##### # setup the training dataset datasetTrain = CoralsDataset(images_folder_train, labels_folder_train, dictionary, target_classes) print("Dataset setup..", end='') datasetTrain.computeAverage() datasetTrain.computeWeights() print(datasetTrain.dict_target) print(datasetTrain.weights) freq = 1.0 / datasetTrain.weights print(freq) print("done.") save_classifier_as = save_network_as.replace(".net", ".json") datasetTrain.enableAugumentation() datasetVal = CoralsDataset(images_folder_val, labels_folder_val, dictionary, target_classes) datasetVal.dataset_average = datasetTrain.dataset_average datasetVal.weights = datasetTrain.weights #AUGUMENTATION IS NOT APPLIED ON THE VALIDATION SET datasetVal.disableAugumentation() # setup the data loader dataloaderTrain = DataLoader(datasetTrain, batch_size=batch_sz, shuffle=flag_shuffle, num_workers=0, drop_last=True, pin_memory=True) validation_batch_size = 4 dataloaderVal = DataLoader(datasetVal, batch_size=validation_batch_size, shuffle=False, num_workers=0, drop_last=True, pin_memory=True) training_images_number = len(datasetTrain.images_names) validation_images_number = len(datasetVal.images_names) print("NETWORK USED: DEEPLAB V3+") if os.path.exists(save_network_as): net = DeepLab(backbone='resnet', output_stride=16, num_classes=output_classes) net.load_state_dict(torch.load(save_network_as)) print("Checkpoint loaded.") else: ###### SETUP THE NETWORK ##### net = DeepLab(backbone='resnet', output_stride=16, num_classes=output_classes) state = torch.load("models/deeplab-resnet.pth.tar") # RE-INIZIALIZE THE CLASSIFICATION LAYER WITH THE RIGHT NUMBER OF CLASSES, DON'T LOAD WEIGHTS OF THE CLASSIFICATION LAYER new_dictionary = state['state_dict'] del new_dictionary['decoder.last_conv.8.weight'] del new_dictionary['decoder.last_conv.8.bias'] net.load_state_dict(state['state_dict'], strict=False) # OPTIMIZER if optimiz == "SGD": optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=L2_penalty, momentum=0.9) elif optimiz == "ADAM": optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=L2_penalty) USE_CUDA = torch.cuda.is_available() if USE_CUDA: device = torch.device("cuda") net.to(device) ##### TRAINING LOOP ##### reduce_lr_patience = 2 if loss_to_use == "DICE+BOUNDARY": reduce_lr_patience = 200 print("patience increased !") scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=reduce_lr_patience, verbose=True) best_accuracy = 0.0 best_jaccard_score = 0.0 # Crossentropy loss weights = datasetTrain.weights class_weights = torch.FloatTensor(weights).cuda() CEloss = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1) # weights for GENERALIZED DICE LOSS (GDL) freq = 1.0 / datasetTrain.weights[1:] w = 1.0 / (freq * freq) w = w / w.sum() + 0.00001 w_for_GDL = torch.from_numpy(w) w_for_GDL = w_for_GDL.to(device) # Focal Tversky loss focal_tversky_gamma = torch.tensor(tversky_gamma) focal_tversky_gamma = focal_tversky_gamma.to(device) tversky_loss_alpha = torch.tensor(tversky_alpha) tversky_loss_beta = torch.tensor(1.0 - tversky_alpha) tversky_loss_alpha = tversky_loss_alpha.to(device) tversky_loss_beta = tversky_loss_beta.to(device) print("Training Network") num_iter = 0 total_iter = epochs * int(len(datasetTrain) / dataloaderTrain.batch_size) for epoch in range(epochs): net.train() optimizer.zero_grad() loss_values = [] for i, minibatch in enumerate(dataloaderTrain): txt = "Training - Iterations " + str(num_iter + 1) + "/" + str(total_iter) progress.setMessage(txt) progress.setProgress((100.0 * num_iter) / total_iter) QApplication.processEvents() num_iter += 1 # get the inputs images_batch = minibatch['image'] labels_batch = minibatch['labels'] if USE_CUDA: images_batch = images_batch.to(device) labels_batch = labels_batch.to(device) # forward+loss+backward outputs = net(images_batch) loss = computeLoss(loss_to_use, CEloss, w_for_GDL, tversky_loss_alpha, tversky_loss_beta, focal_tversky_gamma, epoch, epochs_switch, epochs_transition, labels_batch, outputs) loss.backward() # TO AVOID MEMORY TROUBLE UPDATE WEIGHTS EVERY BATCH SIZE x BATCH MULT if (i+1)% batch_mult == 0: optimizer.step() optimizer.zero_grad() print(epoch, i, loss.item()) loss_values.append(loss.item()) mean_loss_train = sum(loss_values) / len(loss_values) print("Epoch: %d , Mean loss = %f" % (epoch, mean_loss_train)) ### VALIDATION ### if epoch > 0 and (epoch+1) % validation_frequency == 0: print("RUNNING VALIDATION.. ", end='') metrics_val, mean_loss_val = evaluateNetwork(datasetVal, dataloaderVal, loss_to_use, CEloss, w_for_GDL, tversky_loss_alpha, tversky_loss_beta, focal_tversky_gamma, epoch, epochs_switch, epochs_transition, output_classes, net, flag_compute_mIoU=False) accuracy = metrics_val['Accuracy'] jaccard_score = metrics_val['JaccardScore'] scheduler.step(mean_loss_val) accuracy_training = 0.0 jaccard_training = 0.0 if flag_training_accuracy is True: metrics_train, mean_loss_train = evaluateNetwork(datasetTrain, dataloaderTrain, loss_to_use, CEloss, w_for_GDL, tversky_loss_alpha, tversky_loss_beta, focal_tversky_gamma, epoch, epochs_switch, epochs_transition, output_classes, net, flag_compute_mIoU=False) accuracy_training = metrics_train['Accuracy'] jaccard_training = metrics_train['JaccardScore'] #if jaccard_score > best_jaccard_score: if accuracy > best_accuracy: best_accuracy = accuracy best_jaccard_score = jaccard_score torch.save(net.state_dict(), save_network_as) # performance of the best accuracy network on the validation dataset metrics_filename = save_network_as[:len(save_network_as) - 4] + "-val-metrics.txt" saveMetrics(metrics_val, metrics_filename) print("-> CURRENT BEST ACCURACY ", best_accuracy) # main loop ended torch.cuda.empty_cache() del net net = None print("***** TRAINING FINISHED *****") print("BEST ACCURACY REACHED ON THE VALIDATION SET: %.3f " % best_accuracy) return datasetTrain
import time import torch from models.deeplab import DeepLab if __name__ == '__main__': checkpoint = 'BEST_checkpoint.tar' print('loading {}...'.format(checkpoint)) start = time.time() checkpoint = torch.load(checkpoint) print('elapsed {} sec'.format(time.time() - start)) model = checkpoint['model'].module print(type(model)) filename = 'deep_mobile_matting.pt' print('saving {}...'.format(filename)) start = time.time() torch.save(model.state_dict(), filename) print('elapsed {} sec'.format(time.time() - start)) print('loading {}...'.format(filename)) start = time.time() model = DeepLab(backbone='mobilenet', output_stride=16, num_classes=1) model.load_state_dict(torch.load(filename)) print('elapsed {} sec'.format(time.time() - start)) # scripted_model_file = 'deep_mobile_matting_scripted.pt' # torch.jit.save(torch.jit.script(model), scripted_model_file)