def iou_loss(preds: Tensor, target: Tensor) -> Tensor: """Calculates the intersection over union loss. Args: preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` Example: >>> import torch >>> from pl_bolts.losses.object_detection import iou_loss >>> preds = torch.tensor([[100, 100, 200, 200]]) >>> target = torch.tensor([[150, 150, 250, 250]]) >>> iou_loss(preds, target) tensor([[0.8571]]) Returns: IoU loss """ loss = 1 - iou(preds, target) return loss
def test_iou_multi(preds, target, expected_iou): torch.testing.assert_allclose(iou(preds, target), expected_iou)
def test_iou_no_overlap(preds, target, expected_iou): torch.testing.assert_allclose(iou(preds, target), expected_iou)