예제 #1
0
def evaluate(model,
             dataset,
             summary_writer,
             images_results_dir,
             iou_thres,
             conf_thres,
             nms_thres,
             log_every=None,
             limit=None,
             plot=False,
             save=False):
    if save:
        assert dir_exists_and_is_empty(
            images_results_dir
        ), f'{images_results_dir} is not empty or does not exist.'

    logger.info(
        f'Start evaluating model with iou_thres: {iou_thres}, conf_thres: {conf_thres} and nms_thres: {nms_thres}'
    )

    metrics = Metrics()

    model.to(DEVICE)
    model.eval()
    with torch.no_grad():
        data_loader = DataLoader(dataset,
                                 batch_size=dataset.batch_size,
                                 shuffle=True,
                                 collate_fn=dataset.collate_fn)
        class_names = model.class_names

        total = limit if limit is not None else len(data_loader)
        for batch_i, (images, ground_truth_boxes,
                      image_paths) in tqdm(enumerate(data_loader),
                                           total=total):
            if len(images) != dataset.batch_size:
                logger.warning(
                    f"Skipping batch {batch_i} because it does not have correct size ({dataset.batch_size})"
                )
                continue

            images = images.to(DEVICE)

            coordinates, class_scores, confidence = model(images)

            class_scores = torch.sigmoid(class_scores)

            prediction = torch.cat(
                (coordinates, confidence.unsqueeze(-1), class_scores), -1)

            detections = non_max_suppression(prediction=prediction,
                                             conf_thres=conf_thres,
                                             nms_thres=nms_thres)

            if plot:
                plot_batch(detections, ground_truth_boxes, images, class_names)

            if save:
                save_batch(image_paths, images_results_dir, detections,
                           ground_truth_boxes, images, class_names)

            ground_truth_map_objects = list(
                GroundTruth.from_ground_truths(image_paths,
                                               ground_truth_boxes))
            detection_map_objects = list(
                Detection.from_detections(image_paths, detections))

            metrics.add_detections_for_batch(detection_map_objects,
                                             ground_truth_map_objects,
                                             iou_thres=iou_thres)

            if limit is not None and batch_i >= limit:
                logger.info(f"Stop evaluation here after {batch_i} batches")
                break

            if batch_i != 0 and log_every is not None and batch_i % log_every == 0:
                log_average_precision_for_classes(metrics, class_names,
                                                  summary_writer, batch_i)

        log_average_precision_for_classes(metrics, class_names, summary_writer,
                                          total)
예제 #2
0
def train(model,
          dataset,
          model_dir,
          summary_writer,
          epochs,
          lr,
          conf_thres,
          nms_thres,
          iou_thres,
          lambda_coord=5,
          lambda_no_obj=0.5,
          gradient_accumulations=2,
          clip_gradients=False,
          limit=None,
          debug=False,
          print_every=10,
          save_every=None,
          log_to_neptune=False):
    if log_to_neptune:
        env_path = Path(os.environ['HOME'], 'workspace/setup-box/neptune.env')
        load_dotenv(dotenv_path=env_path)

        neptune.init('petersiemen/sandbox',
                     api_token=os.getenv("NEPTUNE_API_TOKEN"))

    total = limit if limit is not None else len(dataset)

    logger.info(
        f'Start training on {total} images. Using lr: {lr}, '
        f'lambda_coord: {lambda_coord}, lambda_no_obj: {lambda_no_obj}, '
        f'conf_thres: {conf_thres}, nms_thres:{nms_thres}, iou_thres: {iou_thres}, '
        f'gradient_accumulations: {gradient_accumulations}, '
        f'clip_gradients: {clip_gradients}, lambda_no_obj: {lambda_no_obj}')
    metrics = Metrics()

    model.to(DEVICE)
    model.train()

    optimizer = torch.optim.Adam(model.get_trainable_parameters(), lr=lr)
    grid_sizes = model.grid_sizes

    data_loader = DataLoader(dataset,
                             batch_size=dataset.batch_size,
                             shuffle=True,
                             collate_fn=dataset.collate_fn)
    class_names = model.class_names

    for epoch in range(1, epochs + 1):
        for batch_i, (images, ground_truth_boxes,
                      image_paths) in tqdm(enumerate(data_loader),
                                           total=total):
            if len(images) != dataset.batch_size:
                logger.warning(
                    f"Skipping batch {batch_i} because it does not have correct size ({dataset.batch_size})"
                )
                continue

            images = images.to(DEVICE)

            coordinates, class_scores, confidence = model(images)

            obj_mask, noobj_mask, cls_mask, target_coordinates, target_confidence, target_class_scores = build_targets(
                coordinates, class_scores, ground_truth_boxes, grid_sizes)
            yolo_loss = YoloLoss(coordinates,
                                 confidence,
                                 class_scores,
                                 obj_mask,
                                 noobj_mask,
                                 cls_mask,
                                 target_coordinates,
                                 target_confidence,
                                 target_class_scores,
                                 lambda_coord=lambda_coord,
                                 lambda_no_obj=lambda_no_obj)

            class_scores = torch.sigmoid(class_scores)
            prediction = torch.cat(
                (coordinates, confidence.unsqueeze(-1), class_scores), -1)

            detections = non_max_suppression(prediction=prediction,
                                             conf_thres=conf_thres,
                                             nms_thres=nms_thres)

            ground_truth_map_objects = list(
                GroundTruth.from_ground_truths(image_paths,
                                               ground_truth_boxes))
            detection_map_objects = list(
                Detection.from_detections(image_paths, detections))

            metrics.add_detections_for_batch(detection_map_objects,
                                             ground_truth_map_objects,
                                             iou_thres=iou_thres)

            if debug:
                plot_batch(detections, ground_truth_boxes, images, class_names)

            loss = yolo_loss.get()
            # backward pass to calculate the weight gradients
            loss.backward()

            if clip_gradients:
                logger.debug("Clipping gradients with max_norm = 1")
                clip_grad_norm_(model.parameters(), max_norm=1)

            if batch_i % print_every == 0:  # print every print_every +1  batches
                yolo_loss.capture(summary_writer, batch_i, during='train')
                #plot_weights_and_gradients(model, summary_writer, epoch * batch_i)
                log_performance(epoch, epochs, batch_i, total, yolo_loss,
                                metrics, class_names, summary_writer,
                                log_to_neptune)

            # Accumulates gradient before each step
            if batch_i % gradient_accumulations == 0:
                logger.debug(
                    f"Updating weights for batch {batch_i} (gradient_accumulations :{gradient_accumulations})"
                )
                # update the weights
                optimizer.step()
                # zero the parameter (weight) gradients
                optimizer.zero_grad()

            del images
            del ground_truth_boxes

            if limit is not None and batch_i + 1 >= limit:
                logger.info(
                    'Stop here after training {} batches (limit: {})'.format(
                        batch_i, limit))
                log_performance(epoch, epochs, batch_i, total, yolo_loss,
                                metrics, class_names, summary_writer,
                                log_to_neptune)
                save_model(model_dir, model, epoch, batch_i)
                return

            if save_every is not None and batch_i % save_every == 0:
                save_model(model_dir, model, epoch, batch_i)

        # save model after every epoch
        save_model(model_dir, model, epoch, None)