Exemplo n.º 1
0
    def _assign_label_to_anchor(self, valid_anchor, gt_bbox, neg_thr, pos_thr, min_pos_thr):
        num_anchor = valid_anchor.shape[0]
        cls_label = np.full(shape=(num_anchor,), fill_value=-1, dtype=np.float32)

        if len(gt_bbox) > 0:
            # num_anchor x num_gt
            overlaps = bbox_overlaps_cython(valid_anchor.astype(np.float32, copy=False), gt_bbox.astype(np.float32, copy=False))
            max_overlaps = overlaps.max(axis=1)
            argmax_overlaps = overlaps.argmax(axis=1)
            gt_max_overlaps = overlaps.max(axis=0)

            # TODO: speed up this
            # TODO: fix potentially assigning wrong anchors as positive
            # A correct implementation is given as
            # gt_argmax_overlaps = np.where((overlaps.transpose() == gt_max_overlaps[:, None]) &
            #                               (overlaps.transpose() >= min_pos_thr))[1]
            gt_argmax_overlaps = np.where((overlaps == gt_max_overlaps) &
                                          (overlaps >= min_pos_thr))[0]
            # anchor class
            cls_label[max_overlaps < neg_thr] = 0
            # fg label: for each gt, anchor with highest overlap
            cls_label[gt_argmax_overlaps] = 1
            # fg label: above threshold IoU
            cls_label[max_overlaps >= pos_thr] = 1
        else:
            cls_label[:] = 0
            argmax_overlaps = np.zeros(shape=(num_anchor, ))

        return cls_label, argmax_overlaps
Exemplo n.º 2
0
    def _assign_label_to_anchor_group(self, valid_anchor, gt_bbox, gt_class,
                                      neg_thr, pos_thr, min_pos_thr):
        num_anchor = valid_anchor.shape[0]
        cls_label = np.full(shape=(num_anchor, ),
                            fill_value=-1,
                            dtype=np.float32)

        if len(gt_bbox) > 0:
            # num_anchor x num_gt
            overlaps = bbox_overlaps_cython(
                valid_anchor.astype(np.float32, copy=False),
                gt_bbox.astype(np.float32, copy=False))
            max_overlaps = overlaps.max(axis=1)
            argmax_overlaps = overlaps.argmax(axis=1)
            gt_max_overlaps = overlaps.max(axis=0)
            mask_label_map = np.tile(
                gt_class.reshape(-1).astype(np.float32),
                num_anchor).reshape(num_anchor, -1)
            max_overlaps_label = mask_label_map[np.arange(num_anchor),
                                                argmax_overlaps]
            gt_argmax_overlaps = np.where((overlaps == gt_max_overlaps)
                                          & (overlaps >= min_pos_thr))
            # anchor class
            cls_label[max_overlaps < neg_thr] = 0
            # fg label: for each gt, anchor with highest overlap
            cls_label[
                gt_argmax_overlaps[0]] = mask_label_map[gt_argmax_overlaps]
            # fg label: above threshold IoU
            fg_label_idxs = np.where(max_overlaps >= pos_thr)[0]
            cls_label[fg_label_idxs] = max_overlaps_label[fg_label_idxs]
        else:
            cls_label[:] = 0
            argmax_overlaps = np.zeros(shape=(num_anchor, ))

        return cls_label, argmax_overlaps
Exemplo n.º 3
0
    def _filter_anchor_by_scale_range(self, cls_label, valid_anchor, gt_bbox,
                                      valid_range, invalid_anchor_threshd):
        if len(gt_bbox) == 0:
            return
        gt_bbox_sizes = (gt_bbox[:, 2] - gt_bbox[:, 0] +
                         1.0) * (gt_bbox[:, 3] - gt_bbox[:, 1] + 1.0)
        invalid_gt_bbox_inds = np.where((gt_bbox_sizes < valid_range[0]**2) | (
            gt_bbox_sizes > valid_range[1]**2))[0]
        invalid_gt_bbox = gt_bbox[invalid_gt_bbox_inds]
        if len(invalid_gt_bbox) > 0:
            invalid_overlaps = bbox_overlaps_cython(
                valid_anchor.astype(np.float32, copy=False),
                invalid_gt_bbox.astype(np.float32, copy=False))
            invalid_argmax_overlaps = invalid_overlaps.argmax(axis=1)
            invalid_max_overlaps = invalid_overlaps[
                np.arange(len(valid_anchor)), invalid_argmax_overlaps]

            # ignore anchors overlapped with invalid gt boxes
            disable_inds = np.where(
                (invalid_max_overlaps > invalid_anchor_threshd))[0]
            cls_label[disable_inds] = -1