Пример #1
0
 def validation_step(self, batch, batch_idx):
     images, targets = batch
     # Retinanet takes only images for eval() mode
     outs = self.model(images)
     iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
     giou = torch.stack([_evaluate_giou(t, o) for t, o in zip(targets, outs)]).mean()
     return {"val_iou": iou, "val_giou": giou}
Пример #2
0
def val_step(model, val_loader, device, num_batches=None,
             log_interval: int = 100):

    """
    Performs one step of validation. Calculates loss, forward pass and returns metrics.
    Args:
        model : PyTorch FasterRCNN Model.
        val_loader : Validation loader.
        device : "cuda" or "cpu"
        num_batches : (optional) Integer To limit validation to certain number of batches.
        log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
    """

    model = model.to(device)
    start_val_step = time.time()
    last_idx = len(val_loader) - 1
    batch_time_m = utils.AverageMeter()
    cnt = 0
    model.eval()
    batch_start = time.time()
    metrics = OrderedDict()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            last_batch = batch_idx == last_idx
            images = list(image.to(device) for image in inputs)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            out = model(images)
            iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, out)]).mean()
            giou = torch.stack([_evaluate_giou(t, o) for t, o in zip(targets, out)]).mean()

            cnt += 1
            batch_time_m.update(time.time() - batch_start)
            batch_start = time.time()

            if last_batch or batch_idx % log_interval == 0:  # If we reach the log intervel
                print("Batch Validation Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  ".format(
                      batch_time=batch_time_m,))

            if num_batches is not None:
                if cnt >= num_batches:
                    avg_iou = torch.stack([iou]).mean()
                    avg_giou = torch.stack([giou]).mean()
                    metrics["iou"] = avg_iou
                    metrics["giou"] = avg_giou
                    print(f"Done till {num_batches} Validation batches")
                    end_val_step = time.time()
                    print(f"Time taken for validation step = {end_val_step - start_val_step} sec")
                    return metrics

    avg_iou = torch.stack([iou]).mean()
    avg_giou = torch.stack([giou]).mean()
    metrics["iou"] = avg_iou
    metrics["giou"] = avg_giou

    end_val_step = time.time()
    print(f"Time taken for validation step = {end_val_step - start_val_step} sec")
    return metrics