Exemplo n.º 1
0
def TrainNetwork(num_epochs, split_size, batch_size, load_size, num_boxes,
                 num_classes, train_img_files_path, train_target_files_path,
                 category_list, model, device, optimizer, load_model_file,
                 lambda_coord, lambda_noobj):
    """
    Starts the training process of the model.
    
    Parameters:
        num_epochs (int): Amount of epochs for training the model.
        split_size (int): Size of the grid which is applied to the images.
        batch_size (int): Batch size.
        load_size (int): Amount of batches which are loaded in one function call.
        num_boxes (int): Amount of boxes which are being predicted per grid cell.
        num_classes (int): Amount of classes which are being predicted.        
        train_img_files_path (str): System path to the image folder containing 
        the train images.
        train_target_files_path (str): System path to the target folder containing 
        the json file with the ground-truth labels.
        category_list (list): A list containing all ground-truth classes.
        model (): The YOLOv1-model. 
        device (): The device used for training.
        optimizer (): Algorithm for updating the model weights.
        load_model_file (str): Name of the file used to store/load train checkpoints.
        lambda_coord (float): Hyperparameter for the loss regarding the bounding 
        box coordinates.
        lambda_noobj (float): Hyperparameter for the loss in case there is no 
        object in that cell.
    """

    model.train()

    # Initialize the DataLoader for the train dataset
    data = DataLoader(train_img_files_path, train_target_files_path,
                      category_list, split_size, batch_size, load_size)

    loss_log = {}  # Used for tracking the loss
    torch.save(loss_log, "loss_log.pt")  # Initialize the log file

    for epoch in range(num_epochs):
        epoch_losses = []  # Stores the loss progress

        print("DATA IS BEING LOADED FOR A NEW EPOCH")
        print("")
        data.LoadFiles()  # Resets the DataLoader for a new epoch

        while len(data.img_files) > 0:
            all_batch_losses = 0.  # Used to track the training loss

            print("LOADING NEW BATCHES")
            print("Remaining files:" + str(len(data.img_files)))
            print("")
            data.LoadData()  # Loads new batches

            for batch_idx, (img_data, target_data) in enumerate(data.data):
                img_data = img_data.to(device)
                target_data = target_data.to(device)

                optimizer.zero_grad()

                predictions = model(img_data)

                yolo_loss = YOLO_Loss(predictions, target_data, split_size,
                                      num_boxes, num_classes, lambda_coord,
                                      lambda_noobj)
                yolo_loss.loss()
                loss = yolo_loss.final_loss
                all_batch_losses += loss.item()

                loss.backward()
                optimizer.step()

                print(
                    'Train Epoch: {} of {} [Batch: {}/{} ({:.0f}%)] Loss: {:.6f}'
                    .format(epoch + 1, num_epochs, batch_idx + 1,
                            len(data.data),
                            (batch_idx + 1) / len(data.data) * 100., loss))
                print('')

            epoch_losses.append(all_batch_losses / len(data.data))
            print("Loss progress so far:", epoch_losses)
            print("")

        loss_log = torch.load('loss_log.pt')
        mean_loss = sum(epoch_losses) / len(epoch_losses)
        loss_log['Epoch: ' + str(epoch + 1)] = mean_loss
        torch.save(loss_log, 'loss_log.pt')
        print(
            f"Mean loss for this epoch was {sum(epoch_losses)/len(epoch_losses)}"
        )
        print("")

        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint, filename=load_model_file)

        time.sleep(10)
Exemplo n.º 2
0
def validate(test_img_files_path, test_target_files_path, category_list,
             split_size, batch_size, load_size, model, cell_dim, num_boxes,
             num_classes, device, iou_threshold_nms, iou_threshold_map,
             threshold, use_nms):
    """
    Uses the test dataset to validate the performance of the model. Calculates 
    the mean Average Precision (mAP) for object detection.
    
    Parameters:
        test_img_files_path (str): System path to the image directory containing 
        the test dataset images.
        test_target_files_path (str): System path to the json file containing the 
        ground-truth labels for the test dataset.
        category_list (list): A list containing all classes which should be detected.
        split_size (int): Size of the grid which is applied to the image.
        batch_size (int): Batch size.
        load_size (int): Amount of batches which are loaded in one function call.
        model (): The YOLOv1-model.
        cell_dim (int): The dimension of a single cell.
        num_boxes (int): Amount of bounding boxes which are being predicted by 
        the model.
        num_classes (int): Amount of classes which are being predicted.
        device (): Device which is used for training and testing the model.
        iou_threshold_nms (float): Threshold for the IoU between the predicted boxes 
        and the ground-truth boxes for non maximum suppression.
        iou_threshold_map (float): Threshold for the IoU between the predicted boxes 
        and the ground-truth boxes for mean average precision.
        threshold (float): Threshold for the confidence score of predicted 
        bounding boxes.
        use_nms (bool): Specifies if non max suppression should be applied to the
        bounding box predictions.
    """

    model.eval()

    print("DATA IS BEING LOADED FOR VALIDATION")
    print("")
    # Initialize the DataLoader for the test dataset
    data = DataLoader(test_img_files_path, test_target_files_path,
                      category_list, split_size, batch_size, load_size)
    data.LoadFiles()

    # Here will all predicted and ground-truth bounding boxes for the whole test
    # dataset be stored. These two lists will be finally used for evaluation.
    # Every element of the list will have the following form:
    # [image index, class prediction, confidence score, x1, y1, x2, y2]
    # Every element of the list represents a single bounding box.
    all_pred_boxes = []
    all_target_boxes = []

    train_idx = 0  # Tracks the sample index for each image in the test dataset

    # This while loop is used to fill the two lists all_pred_boxes and all_target_boxes.
    while len(data.img_files) > 0:
        print("LOADING NEW VALIDATION BATCHES")
        print("Remaining validation files:" + str(len(data.img_files)))
        print("")
        data.LoadData()

        for batch_idx, (img_data, target_data) in enumerate(data.data):
            img_data = img_data.to(device)
            target_data = target_data.to(device)

            with torch.no_grad():
                predictions = model(img_data)

            print('Extracting bounding boxes')
            print('Batch: {}/{} ({:.0f}%)'.format(batch_idx + 1, len(
                data.data), (batch_idx + 1) / len(data.data) * 100.))
            print('')
            pred_boxes = extract_boxes(predictions, num_classes, num_boxes,
                                       cell_dim, threshold)
            target_boxes = extract_boxes(target_data, num_classes, 1, cell_dim,
                                         threshold)

            for sample_idx in range(len(pred_boxes)):
                if use_nms:
                    # Applies non max suppression to the bounding box predictions
                    nms_boxes = non_max_suppression(pred_boxes[sample_idx],
                                                    iou_threshold_nms)
                else:
                    # Use the same list without changing anything
                    nms_boxes = pred_boxes[sample_idx]

                for nms_box in nms_boxes:
                    all_pred_boxes.append([train_idx] + nms_box)

                for box in target_boxes[sample_idx]:
                    all_target_boxes.append([train_idx] + box)

                train_idx += 1

    pred = {
        0: 0,
        1: 0,
        2: 0,
        3: 0,
        4: 0,
        5: 0,
        6: 0,
        7: 0,
        8: 0,
        9: 0,
        10: 0,
        11: 0,
        12: 0,
        13: 0
    }
    for prediction in all_pred_boxes:
        cls_idx = prediction[1]
        pred[cls_idx] += 1
    print(pred)
    pred = {
        0: 0,
        1: 0,
        2: 0,
        3: 0,
        4: 0,
        5: 0,
        6: 0,
        7: 0,
        8: 0,
        9: 0,
        10: 0,
        11: 0,
        12: 0,
        13: 0
    }
    for prediction in all_target_boxes:
        cls_idx = prediction[1]
        pred[cls_idx] += 1
    print(pred)
    """