Пример #1
0
        def _generate_rel(n):
            # Generate n candidate boxes in the normalized 0-1 domain
            cand_boxes = kwimage.Boxes.random(num=n,
                                              scale=1.0,
                                              format='tlbr',
                                              anchors=anchors,
                                              anchor_std=0,
                                              rng=rng)

            chosen_gids = np.array(sorted(rng.choice(all_gids, size=n)))
            gid_to_boxes = kwarray.group_items(cand_boxes, chosen_gids, axis=0)

            neg_gids = []
            neg_boxes = []
            for gid, img_boxes in gid_to_boxes.items():
                qtree = self.qtrees[gid]
                # scale from normalized coordinates to image coordinates
                img_boxes = img_boxes.scale((qtree.width, qtree.height))
                for box in img_boxes:
                    # isect_aids, overlaps = self.ious(gid, box)
                    isect_aids, overlaps = self.iooas(gid, box)
                    if len(overlaps) == 0 or overlaps.max() < thresh:
                        neg_gids.append(gid)
                        neg_boxes.append(box.data)
            return neg_gids, neg_boxes
Пример #2
0
    def __init__(self,
                 index_to_label,
                 batch_size=1,
                 num_batches='auto',
                 quantile=0.5,
                 shuffle=False,
                 rng=None):
        import kwarray

        rng = kwarray.ensure_rng(rng, api='python')
        label_to_indices = kwarray.group_items(np.arange(len(index_to_label)),
                                               index_to_label)

        label_to_freq = ub.map_vals(len, label_to_indices)

        label_to_subsampler = {
            label: RingSampler(indices, shuffle=shuffle, rng=rng)
            for label, indices in label_to_indices.items()
        }

        self.label_to_freq = label_to_freq
        self.index_to_label = index_to_label
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.rng = rng
        self.label_to_indices = label_to_indices
        self.label_to_subsampler = label_to_subsampler

        if num_batches == 'auto':
            self.num_batches = self._auto_num_batches(quantile)
        else:
            self.num_batches = num_batches

        self.labels = list(self.label_to_indices.keys())
Пример #3
0
def _demodata_refine_boxes(n_roi, n_img, rng=0):
    """
    Create random test data for the
    ``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` method
    """
    import numpy as np
    from mmdet.core.bbox.demodata import random_boxes
    from mmdet.core.bbox.demodata import ensure_rng

    try:
        import kwarray
    except ImportError:
        import pytest

        pytest.skip("kwarray is required for this test")
    scale = 512
    rng = ensure_rng(rng)
    img_metas = [{"img_shape": (scale, scale)} for _ in range(n_img)]
    # Create rois in the expected format
    roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
    if n_img == 0:
        assert n_roi == 0, "cannot have any rois if there are no images"
        img_ids = torch.empty((0, ), dtype=torch.long)
        roi_boxes = torch.empty((0, 4), dtype=torch.float32)
    else:
        img_ids = rng.randint(0, n_img, (n_roi, ))
        img_ids = torch.from_numpy(img_ids)
    rois = torch.cat([img_ids[:, None].float(), roi_boxes], dim=1)
    # Create other args
    labels = rng.randint(0, 2, (n_roi, ))
    labels = torch.from_numpy(labels).long()
    bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
    # For each image, pretend random positive boxes are gts
    is_label_pos = (labels.numpy() > 0).astype(np.int)
    lbl_per_img = kwarray.group_items(is_label_pos, img_ids.numpy())
    pos_per_img = [sum(lbl_per_img.get(gid, [])) for gid in range(n_img)]
    # randomly generate with numpy then sort with torch
    _pos_is_gts = [
        rng.randint(0, 2, (npos, )).astype(np.uint8) for npos in pos_per_img
    ]
    pos_is_gts = [
        torch.from_numpy(p).sort(descending=True)[0] for p in _pos_is_gts
    ]
    return rois, labels, bbox_preds, pos_is_gts, img_metas