Ejemplo n.º 1
0
    def test_single__bad_index_raises_error(self):
        cls_id = 12
        index = 5
        assert self.preds[0][0][0].shape[0] < index

        with pytest.raises(IndexError):
            Predict.single(cls_id, self.preds, index=index)
Ejemplo n.º 2
0
    def test_single__explicit_choose_item_in_batch(self):
        cls_id = 12

        ret = Predict.single(cls_id, self.preds)
        ret2 = Predict.single(cls_id, self.preds, index=1)

        if ret and ret2:
            assert ret[0].shape != ret2[0].shape, \
            "shouldn't be the same because different training example items"
Ejemplo n.º 3
0
    def test_single__calls_single_nms(self, mock_single_nms):
        cls_id = 6
        conf_thresh = 0.5

        Predict.single(cls_id, self.preds, conf_thresh=conf_thresh)

        assert mock_single_nms.called
        assert mock_single_nms.call_args[0][0] == cls_id
        assert len(mock_single_nms.call_args[0][1][0]) == 4
        assert mock_single_nms.call_args[0][-1] == conf_thresh
Ejemplo n.º 4
0
    def test_all(self):
        single_preds = None
        for c in range(cfg.NUM_CLASSES):
            single_preds = Predict.single(c, self.preds)
            if single_preds:
                break

        ret_bbs, ret_scores, ret_cls_ids = Predict.all(self.preds)

        single_bbs, single_scores, single_cls_ids = single_preds
        # more then 1 cls of obj detected
        assert ret_bbs.shape[0] > single_bbs.shape[0]
        assert len(ret_bbs.shape) == 2
        assert ret_bbs.shape[1] == 4
        assert ret_bbs.shape[0] == ret_scores.shape[0]
        assert ret_bbs.shape[0] == ret_cls_ids.shape[0]
Ejemplo n.º 5
0
def plot_nms_single_preds(model_output,
                          idx,
                          dataset,
                          preds,
                          cls_id,
                          limit=5,
                          ax=None):
    "Plots NMS predictions for a single object class"
    cat_names = dataset.categories()
    image_ids, ims, gt_bbs, gt_cats = model_output
    chw_im = ims[idx]
    nms_preds = Predict.single(cls_id, preds, index=idx)
    # guard agains no NMS preds being found based upon the confidence threshold
    if nms_preds:
        boxes, scores, cls_ids = nms_preds
        targets = (boxes[:limit], cls_ids[:limit])
    else:
        targets = ([], [])
    plot_single_predictions(chw_im,
                            gt_bbs,
                            gt_cats,
                            idx,
                            cat_names,
                            targets=targets,
                            ax=ax)
Ejemplo n.º 6
0
def plot_nms_preds(model_output, idx, dataset, preds, limit=5, ax=None):
    cat_names = dataset.categories()
    image_ids, ims, gt_bbs, gt_cats = model_output
    chw_im = ims[idx]
    boxes, scores, cls_ids = Predict.all(preds, index=idx)
    plot_single_predictions(chw_im,
                            gt_bbs,
                            gt_cats,
                            idx,
                            cat_names,
                            targets=(boxes[:limit], cls_ids[:limit]),
                            ax=ax)
Ejemplo n.º 7
0
    def test_single__returns_correct_shapes(self):
        cls_id = 12

        ret = Predict.single(cls_id, self.preds)

        # same shapes from single_nms, need `if` because sometimes
        # this is none, but if it's detected something, I want check shapes
        if ret:
            ret_bbs, ret_scores, ret_cls_ids = ret
            assert len(ret_bbs.shape) == 2
            assert len(ret_scores.shape) == 1
            assert ret_bbs.shape[0] == ret_scores.shape[0]
            assert ret_cls_ids.shape[0] == ret_scores.shape[0]
Ejemplo n.º 8
0
    def test_nms__doesnt_suppress_if_under_overlap_thresh(self):
        # item 0 in boxes is the lowest score, with a .8 overlap
        # but overlap_thres is .1 so it isn't suppresed
        a = [0., 0., 5., 10.]
        b = [0., 0., 4., 10.]
        c = [5., 5., 10., 10.]
        overlap = 0.8
        assert Bboxer.single_bb_iou(a, b) == overlap
        assert Bboxer.single_bb_iou(a, c) == 0.
        boxes = torch.tensor([a, b, c])
        scores = torch.tensor([.1, .2, .3])
        assert boxes.shape[0] == scores.shape[0]

        # overlap suppresses
        ret_keep, ret_count = Predict.nms(boxes, scores, overlap=0.79)
        self.assert_arr_equals(ret_keep, [2, 1, 0])
        assert ret_count == 2

        # overlap allows
        ret_keep, ret_count = Predict.nms(boxes, scores, overlap=overlap)
        self.assert_arr_equals(ret_keep, [2, 1, 0])
        assert ret_count == 3
Ejemplo n.º 9
0
    def test_single_nms(self):
        cls_id = 0
        bbs, cats = self.preds
        item_cats = cats[0]
        item_bbs = bbs[0]

        ret_bbs, ret_scores, ret_cls_ids = Predict.single_nms(
            cls_id, item_bbs, item_cats)

        assert len(ret_bbs.shape) == 2
        assert len(ret_scores.shape) == 1
        assert ret_bbs.shape[0] == ret_scores.shape[0]
        assert ret_bbs.shape[1] == 4, "4 points p/ bb"
        assert ret_scores.sum().item(
        ) != 0, "pred confidence should always be greater than 0"
        assert ret_cls_ids.shape[0] == ret_scores.shape[0]
Ejemplo n.º 10
0
    def test_nms__suppresses_overlapping_bb_with_lower_score(self):
        # item 0 in boxes is the lowest score, with a .8 overlap
        # with item 1 so it get suppressed
        a = [0., 0., 5., 10.]
        b = [0., 0., 4., 10.]
        c = [5., 5., 10., 10.]
        assert Bboxer.single_bb_iou(a, b) == .8
        assert Bboxer.single_bb_iou(a, c) == 0.
        boxes = torch.tensor([a, b, c])
        scores = torch.tensor([.1, .2, .3])
        assert boxes.shape[0] == scores.shape[0]

        ret_keep, ret_count = Predict.nms(boxes, scores)

        self.assert_arr_equals(ret_keep, [2, 1, 0])
        assert ret_count == 2
Ejemplo n.º 11
0
    def test_single_nms__both_returned_bc_pass_conf_thresh_and_no_overlap(
            self):
        cls_id = 0
        a = [0., 0., 5., 10.]
        b = [5., 5., 10., 10.]
        assert Bboxer.single_bb_iou(a, b) == 0
        item_bbs = torch.tensor([a, b])
        assert list(item_bbs.shape) == [2, 4]
        cats_a = torch.cat((torch.tensor([.5]), torch.zeros(19)))
        cats_b = torch.cat((torch.tensor([.6]), torch.zeros(19)))
        item_cats = torch.stack((cats_a, cats_b))
        assert list(item_cats.shape) == [2, 20]

        ret_bbs, ret_scores, ret_cls_ids = Predict.single_nms(
            cls_id, item_bbs, item_cats)

        self.assert_arr_equals(ret_bbs, [b, a])
        self.assert_arr_equals(ret_scores, [.6, .5])
        self.assert_arr_equals(ret_cls_ids, [0, 0])
Ejemplo n.º 12
0
    def test_single_nms__filters_out_by_cls_id(self):
        cls_id = 0
        a = [0., 0., 5., 10.]
        b = [5., 5., 10., 10.]
        assert Bboxer.single_bb_iou(a, b) == 0
        item_bbs = torch.tensor([a, b])
        assert list(item_bbs.shape) == [2, 4]
        # the ordering of the one-hot category max'es dictates which
        # cls_id, so cats_a is cls 1, cats_b is cls 0
        cats_a = torch.cat((torch.tensor([0, .5]), torch.zeros(18)))
        cats_b = torch.cat((torch.tensor([.6, 0]), torch.zeros(18)))
        item_cats = torch.stack((cats_a, cats_b))
        assert list(item_cats.shape) == [2, 20]

        ret_bbs, ret_scores, ret_cls_ids = Predict.single_nms(
            cls_id, item_bbs, item_cats)

        self.assert_arr_equals(ret_bbs, [b])
        self.assert_arr_equals(ret_scores, [.6])
        self.assert_arr_equals(ret_cls_ids, [0])
Ejemplo n.º 13
0
    def test_single_nms__can_change_conf_thresh(self):
        conf_thresh = 0.55
        cls_id = 0
        a = [0., 0., 5., 10.]
        b = [5., 5., 10., 10.]
        assert Bboxer.single_bb_iou(a, b) == 0
        item_bbs = torch.tensor([a, b])
        assert list(item_bbs.shape) == [2, 4]
        cats_a = torch.cat((torch.tensor([.5]), torch.zeros(19)))
        assert (cats_a > conf_thresh).sum().item() == 0
        cats_b = torch.cat((torch.tensor([.6]), torch.zeros(19)))
        assert (cats_b > conf_thresh).sum().item() == 1
        item_cats = torch.stack((cats_a, cats_b))
        assert list(item_cats.shape) == [2, 20]

        ret_bbs, ret_scores, ret_cls_ids = Predict.single_nms(
            cls_id, item_bbs, item_cats, conf_thresh)

        self.assert_arr_equals(ret_bbs, [b])
        self.assert_arr_equals(ret_scores, [.6])
        self.assert_arr_equals(ret_cls_ids, [0])
Ejemplo n.º 14
0
    def test_single_nms__one_returned_bc_pass_conf_thresh_too_low(self):
        cls_id = 0
        a = [0., 0., 5., 10.]
        b = [5., 5., 10., 10.]
        assert Bboxer.single_bb_iou(a, b) == 0
        item_bbs = torch.tensor([a, b])
        assert list(item_bbs.shape) == [2, 4]
        low_conf_thresh = .09
        assert low_conf_thresh < cfg.NMS_CONF_THRESH
        cats_a = torch.cat((torch.tensor([low_conf_thresh]), torch.zeros(19)))
        assert (cats_a > cfg.NMS_CONF_THRESH).sum().item() == 0
        cats_b = torch.cat((torch.tensor([.6]), torch.zeros(19)))
        item_cats = torch.stack((cats_a, cats_b))
        assert list(item_cats.shape) == [2, 20]

        ret_bbs, ret_scores, ret_cls_ids = Predict.single_nms(
            cls_id, item_bbs, item_cats)

        self.assert_arr_equals(ret_bbs, [b])
        self.assert_arr_equals(ret_scores, [.6])
        self.assert_arr_equals(ret_cls_ids, [0])
Ejemplo n.º 15
0
 def test_single__returns_empty_for_bg(self):
     cls_id = 20
     assert not Predict.single(cls_id, self.preds)
Ejemplo n.º 16
0
    def test_all__bg_is_not_predicted(self):
        ret_bbs, ret_scores, ret_cls_ids = Predict.all(self.preds)

        assert (ret_cls_ids == 20).sum().item() == 0