コード例 #1
0
def test_to_xywh():
    bbox = BBox(xmin=5, xmax=6, ymin=10, ymax=30)
    x, y, w, h = bbox.xywh()
    assert x == 5
    assert w == 1
    assert y == 10
    assert h == 20
コード例 #2
0
def test_per_class_map():
    annotations = [
        AnnotatedBBox.ground_truth("a", BBox(0, 5, 0, 5)),
        AnnotatedBBox.prediction("a", BBox(0, 5, 0, 5), 0.9),
        AnnotatedBBox.ground_truth("b", BBox(0, 5, 0, 5)),
        AnnotatedBBox.prediction("b", BBox(5, 6, 5, 6), 0.9),
    ]
    metrics = get_metrics([annotations], iou_threshold=0.9)
    assert metrics.per_class["a"].AP == 1
    assert metrics.per_class["b"].AP == 0
    assert metrics.mAP == 0.5
コード例 #3
0
def test_metrics_two_predictions_one_gt_2():
    width, height = 10, 10
    bbox = BBox(0, 2, 0, 2).normalize(width, height)
    annotations = [
        AnnotatedBBox.ground_truth("a", bbox),
        AnnotatedBBox.prediction("a", bbox, 0.9),
        AnnotatedBBox.prediction("a", bbox.move(0.5, 0.5), 0.5),
    ]
    metrics = get_metrics([annotations])
    assert metrics.per_class["a"].total_FP == 1
    assert metrics.per_class["a"].total_TP == 1
    assert metrics.mAP == 1
コード例 #4
0
def test_iou():
    bb_a = BBox(0, 2, 0, 2)
    bb_b = BBox(3, 5, 0, 1)
    bb_c = BBox(0, 1, 3, 4)
    assert iou(bb_a, bb_b) == 0
    assert iou(bb_a, bb_c) == 0
    assert iou(bb_b, bb_c) == 0
    assert iou(bb_a, bb_a) == 1

    bb_d = BBox(1, 3, 1, 3)
    assert iou(bb_a, bb_d) == get_intersection_area(bb_a, bb_d) / get_union_area(
        bb_a, bb_d
    )
コード例 #5
0
def test_iou_threshold():
    bbox = BBox(0, 5, 0, 5)
    annotations = [
        AnnotatedBBox.ground_truth("a", bbox),
        AnnotatedBBox.prediction("a", bbox.move(2.5, 0), 0.9),
    ]
    metrics = get_metrics([annotations], iou_threshold=0.9)
    assert metrics.per_class["a"].total_FP == 1
    assert metrics.per_class["a"].total_TP == 0
    assert metrics.mAP == 0

    metrics = get_metrics([annotations], iou_threshold=0.2)
    assert metrics.per_class["a"].total_FP == 0
    assert metrics.per_class["a"].total_TP == 1
    assert metrics.mAP == 1
コード例 #6
0
def test_get_intersection_area():
    bb_a = BBox(0, 2, 0, 2)
    bb_b = BBox(3, 5, 0, 1)
    bb_c = BBox(0, 1, 3, 4)
    assert get_intersection_area(bb_a, bb_b) == 0
    assert get_intersection_area(bb_a, bb_c) == 0
    assert get_intersection_area(bb_b, bb_c) == 0

    bb_d = BBox(1, 3, 1, 3)
    assert get_intersection_area(bb_a, bb_d) == 1
    assert get_intersection_area(bb_d, bb_a) == 1
    assert get_intersection_area(bb_a, bb_a) == 4

    bb_e = BBox(0, 5, 0, 5)
    assert get_intersection_area(bb_e, bb_a) == 4
    assert get_intersection_area(bb_e, bb_b) == 2
コード例 #7
0
def test_get_union_area():
    bb_a = BBox(0, 2, 0, 2)
    bb_b = BBox(3, 5, 0, 1)
    bb_c = BBox(0, 1, 3, 4)
    assert get_union_area(bb_a, bb_b) == bb_a.area + bb_b.area
    assert get_union_area(bb_a, bb_c) == bb_a.area + bb_c.area
    assert get_union_area(bb_b, bb_c) == bb_b.area + bb_c.area
    assert get_union_area(bb_a, bb_a) == bb_a.area

    bb_d = BBox(1, 3, 1, 3)
    assert get_union_area(bb_a, bb_d) == bb_a.area + bb_d.area - get_intersection_area(
        bb_a, bb_d
    )

    bb_e = BBox(0, 5, 0, 5)
    assert get_union_area(bb_e, bb_b) == bb_e.area
コード例 #8
0
def test_metrics_perfect_prediction():
    bb = BBox(0, 5, 0, 5).normalize(10, 10)
    annotations = [
        AnnotatedBBox.ground_truth("a", bb),
        AnnotatedBBox.prediction("a", bb, 0.9),
    ]
    metrics = get_metrics([annotations])
    assert metrics.mAP == 1.0
コード例 #9
0
def test_metrics_missing_gt():
    bb = BBox(0, 5, 0, 5).normalize(10, 10)
    annotations = [
        AnnotatedBBox.prediction("a", bb, 0.9),
        AnnotatedBBox.prediction("b", bb, 0.9),
        AnnotatedBBox.prediction("b", bb, 0.8),
    ]
    metrics = get_metrics([annotations])
    assert metrics.per_class["a"].total_FP == 1
    assert metrics.per_class["b"].total_FP == 2
    assert metrics.mAP == 0.0
コード例 #10
0
def test_metrics_multiple_images_perfect_prediction():
    width, height = 10, 10
    bbox = BBox(0, 5, 0, 5).normalize(width, height)
    image_a = [
        AnnotatedBBox.ground_truth("a", bbox),
        AnnotatedBBox.prediction("a", bbox, 0.9),
    ]
    image_b = [
        AnnotatedBBox.ground_truth("a", bbox),
        AnnotatedBBox.prediction("a", bbox, 0.9),
    ]
    metrics = get_metrics([image_a, image_b])
    assert metrics.mAP == 1.0
コード例 #11
0
def test_metrics_do_not_contain_numpy_type():
    annotations = [
        AnnotatedBBox.ground_truth("a", BBox(0, 5, 0, 5)),
        AnnotatedBBox.prediction("a", BBox(0, 5, 0, 5), 0.9),
        AnnotatedBBox.ground_truth("b", BBox(0, 5, 0, 5)),
        AnnotatedBBox.prediction("b", BBox(5, 6, 5, 6), 0.9),
    ]
    metrics = get_metrics([annotations], iou_threshold=0.9)
    assert not isinstance(metrics.mAP, np.floating)
    for value in metrics.per_class.values():
        for item in value.precision:
            assert not isinstance(item, np.floating)
        for item in value.recall:
            assert not isinstance(item, np.floating)
        for item in value.interpolated_precision:
            assert not isinstance(item, np.floating)
        for item in value.interpolated_recall:
            assert not isinstance(item, np.floating)
        assert not isinstance(value.AP, np.floating)
        assert not isinstance(value.total_GT, np.integer)
        assert not isinstance(value.total_TP, np.integer)
        assert not isinstance(value.total_FP, np.integer)
コード例 #12
0
def test_bb_intersect():
    bb_a = BBox(0, 2, 0, 2)
    bb_b = BBox(3, 5, 0, 1)
    bb_c = BBox(0, 1, 3, 4)
    assert not boxes_intersect(bb_a, bb_b)
    assert not boxes_intersect(bb_a, bb_c)
    assert not boxes_intersect(bb_b, bb_c)

    bb_d = BBox(2, 4, 2, 4)
    assert not boxes_intersect(bb_d, bb_a)

    bb_e = BBox(1, 4, 1, 4)
    assert boxes_intersect(bb_e, bb_a)
    assert not boxes_intersect(bb_e, bb_b)
    assert not boxes_intersect(bb_e, bb_c)
    assert boxes_intersect(bb_e, bb_d)

    bb_f = BBox(0, 5, 0, 5)
    assert boxes_intersect(bb_f, bb_a)
    assert boxes_intersect(bb_f, bb_b)
    assert boxes_intersect(bb_f, bb_c)
    assert boxes_intersect(bb_f, bb_d)
    assert boxes_intersect(bb_f, bb_e)
コード例 #13
0
def test_normalize_bbox():
    bbox = BBox(xmin=0, xmax=50, ymin=50, ymax=75)
    bbox = bbox.normalize(100, 100)
    assert bbox.as_tuple() == (0, 0.5, 0.5, 0.75)
コード例 #14
0
def test_metrics_missing_prediction():
    bb = BBox(0, 5, 0, 5).normalize(10, 10)
    annotations = [AnnotatedBBox.ground_truth("a", bb)]
    metrics = get_metrics([annotations])
    assert metrics.mAP == 0.0
コード例 #15
0
def test_invalid_bbox():
    with pytest.raises(Exception):
        BBox(xmin=0, xmax=0, ymin=0.5, ymax=0.6)

    with pytest.raises(Exception):
        BBox(xmin=1, xmax=0, ymin=0.3, ymax=0.5)
コード例 #16
0
def test_from_xywh():
    bbox = BBox.from_xywh(5, 6, 10, 30)
    assert bbox.xmin == 5
    assert bbox.xmax == 15
    assert bbox.ymin == 6
    assert bbox.ymax == 36
コード例 #17
0
def test_class_constructor():
    bbox = BBox.from_xywh(1, 1, 5, 5)
    assert isinstance(bbox, BBox)

    bbox = NormalizedBBox.from_xywh(0.1, 0.1, 0.5, 0.5)
    assert isinstance(bbox, NormalizedBBox)
コード例 #18
0
def test_rescale_at_center():
    bbox = BBox.from_xywh(100, 200, 50, 100)
    assert bbox.rescale_at_center(2).xywh() == (75, 150, 100, 200)
コード例 #19
0
def test_denormalized_as_normalized():
    bbox = BBox.from_xywh(5, 10, 20, 50)
    normalized = bbox.as_normalized(100, 100)
    assert isinstance(normalized, NormalizedBBox)
    assert normalized == bbox.normalize(100, 100)
コード例 #20
0
def test_denormalized_as_denormalized():
    bbox = BBox.from_xywh(5, 10, 20, 50)
    assert bbox.as_denormalized(100, 100) is bbox
コード例 #21
0
def test_clip_bbox():
    bbox = BBox(xmin=-10.3, xmax=50.0, ymin=-20.3, ymax=75.9)
    bbox = bbox.clip(20, 30, 0, 0)
    assert bbox.as_tuple() == (0, 20, 0, 30)
    bbox = BBox(xmin=0, xmax=20, ymin=0, ymax=30)
    bbox = bbox.clip(20, 30, 0, 0)
    assert bbox.as_tuple() == (0, 20, 0, 30)
    bbox = BBox(xmin=0, xmax=20, ymin=0, ymax=30)
    bbox = bbox.clip(21, 31, 0, 0)
    assert bbox.as_tuple() == (0, 20, 0, 30)