def TrainNetwork(num_epochs, split_size, batch_size, load_size, num_boxes, num_classes, train_img_files_path, train_target_files_path, category_list, model, device, optimizer, load_model_file, lambda_coord, lambda_noobj): """ Starts the training process of the model. Parameters: num_epochs (int): Amount of epochs for training the model. split_size (int): Size of the grid which is applied to the images. batch_size (int): Batch size. load_size (int): Amount of batches which are loaded in one function call. num_boxes (int): Amount of boxes which are being predicted per grid cell. num_classes (int): Amount of classes which are being predicted. train_img_files_path (str): System path to the image folder containing the train images. train_target_files_path (str): System path to the target folder containing the json file with the ground-truth labels. category_list (list): A list containing all ground-truth classes. model (): The YOLOv1-model. device (): The device used for training. optimizer (): Algorithm for updating the model weights. load_model_file (str): Name of the file used to store/load train checkpoints. lambda_coord (float): Hyperparameter for the loss regarding the bounding box coordinates. lambda_noobj (float): Hyperparameter for the loss in case there is no object in that cell. """ model.train() # Initialize the DataLoader for the train dataset data = DataLoader(train_img_files_path, train_target_files_path, category_list, split_size, batch_size, load_size) loss_log = {} # Used for tracking the loss torch.save(loss_log, "loss_log.pt") # Initialize the log file for epoch in range(num_epochs): epoch_losses = [] # Stores the loss progress print("DATA IS BEING LOADED FOR A NEW EPOCH") print("") data.LoadFiles() # Resets the DataLoader for a new epoch while len(data.img_files) > 0: all_batch_losses = 0. # Used to track the training loss print("LOADING NEW BATCHES") print("Remaining files:" + str(len(data.img_files))) print("") data.LoadData() # Loads new batches for batch_idx, (img_data, target_data) in enumerate(data.data): img_data = img_data.to(device) target_data = target_data.to(device) optimizer.zero_grad() predictions = model(img_data) yolo_loss = YOLO_Loss(predictions, target_data, split_size, num_boxes, num_classes, lambda_coord, lambda_noobj) yolo_loss.loss() loss = yolo_loss.final_loss all_batch_losses += loss.item() loss.backward() optimizer.step() print( 'Train Epoch: {} of {} [Batch: {}/{} ({:.0f}%)] Loss: {:.6f}' .format(epoch + 1, num_epochs, batch_idx + 1, len(data.data), (batch_idx + 1) / len(data.data) * 100., loss)) print('') epoch_losses.append(all_batch_losses / len(data.data)) print("Loss progress so far:", epoch_losses) print("") loss_log = torch.load('loss_log.pt') mean_loss = sum(epoch_losses) / len(epoch_losses) loss_log['Epoch: ' + str(epoch + 1)] = mean_loss torch.save(loss_log, 'loss_log.pt') print( f"Mean loss for this epoch was {sum(epoch_losses)/len(epoch_losses)}" ) print("") checkpoint = { "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), } save_checkpoint(checkpoint, filename=load_model_file) time.sleep(10)
def validate(test_img_files_path, test_target_files_path, category_list, split_size, batch_size, load_size, model, cell_dim, num_boxes, num_classes, device, iou_threshold_nms, iou_threshold_map, threshold, use_nms): """ Uses the test dataset to validate the performance of the model. Calculates the mean Average Precision (mAP) for object detection. Parameters: test_img_files_path (str): System path to the image directory containing the test dataset images. test_target_files_path (str): System path to the json file containing the ground-truth labels for the test dataset. category_list (list): A list containing all classes which should be detected. split_size (int): Size of the grid which is applied to the image. batch_size (int): Batch size. load_size (int): Amount of batches which are loaded in one function call. model (): The YOLOv1-model. cell_dim (int): The dimension of a single cell. num_boxes (int): Amount of bounding boxes which are being predicted by the model. num_classes (int): Amount of classes which are being predicted. device (): Device which is used for training and testing the model. iou_threshold_nms (float): Threshold for the IoU between the predicted boxes and the ground-truth boxes for non maximum suppression. iou_threshold_map (float): Threshold for the IoU between the predicted boxes and the ground-truth boxes for mean average precision. threshold (float): Threshold for the confidence score of predicted bounding boxes. use_nms (bool): Specifies if non max suppression should be applied to the bounding box predictions. """ model.eval() print("DATA IS BEING LOADED FOR VALIDATION") print("") # Initialize the DataLoader for the test dataset data = DataLoader(test_img_files_path, test_target_files_path, category_list, split_size, batch_size, load_size) data.LoadFiles() # Here will all predicted and ground-truth bounding boxes for the whole test # dataset be stored. These two lists will be finally used for evaluation. # Every element of the list will have the following form: # [image index, class prediction, confidence score, x1, y1, x2, y2] # Every element of the list represents a single bounding box. all_pred_boxes = [] all_target_boxes = [] train_idx = 0 # Tracks the sample index for each image in the test dataset # This while loop is used to fill the two lists all_pred_boxes and all_target_boxes. while len(data.img_files) > 0: print("LOADING NEW VALIDATION BATCHES") print("Remaining validation files:" + str(len(data.img_files))) print("") data.LoadData() for batch_idx, (img_data, target_data) in enumerate(data.data): img_data = img_data.to(device) target_data = target_data.to(device) with torch.no_grad(): predictions = model(img_data) print('Extracting bounding boxes') print('Batch: {}/{} ({:.0f}%)'.format(batch_idx + 1, len( data.data), (batch_idx + 1) / len(data.data) * 100.)) print('') pred_boxes = extract_boxes(predictions, num_classes, num_boxes, cell_dim, threshold) target_boxes = extract_boxes(target_data, num_classes, 1, cell_dim, threshold) for sample_idx in range(len(pred_boxes)): if use_nms: # Applies non max suppression to the bounding box predictions nms_boxes = non_max_suppression(pred_boxes[sample_idx], iou_threshold_nms) else: # Use the same list without changing anything nms_boxes = pred_boxes[sample_idx] for nms_box in nms_boxes: all_pred_boxes.append([train_idx] + nms_box) for box in target_boxes[sample_idx]: all_target_boxes.append([train_idx] + box) train_idx += 1 pred = { 0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0 } for prediction in all_pred_boxes: cls_idx = prediction[1] pred[cls_idx] += 1 print(pred) pred = { 0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0 } for prediction in all_target_boxes: cls_idx = prediction[1] pred[cls_idx] += 1 print(pred) """