예제 #1
0
    def _cal_anchor_target(self, label, valid_anchor, gt_bbox, anchor_label):
        num_anchor = valid_anchor.shape[0]
        reg_target = np.zeros(shape=(num_anchor, 4), dtype=np.float32)
        reg_weight = np.zeros(shape=(num_anchor, 4), dtype=np.float32)
        fg_index = np.where(label == 1)[0]
        if len(fg_index) > 0:
            reg_target[fg_index] = bbox_transform(valid_anchor[fg_index], gt_bbox[anchor_label[fg_index], :4])
            reg_weight[fg_index, :] = 1.0

        return reg_target, reg_weight
예제 #2
0
파일: input.py 프로젝트: zymale/simpledet
    def _assign_label_to_anchor(self, valid_anchor, gt_bbox, neg_thr, pos_thr,
                                min_pos_thr):
        num_anchor = valid_anchor.shape[0]
        cls_label = np.full(shape=(num_anchor, ),
                            fill_value=-1,
                            dtype=np.float32)
        reg_target = np.zeros(shape=(num_anchor, 4), dtype=np.float32)
        reg_weight = np.zeros(shape=(num_anchor, 4), dtype=np.float32)

        if len(gt_bbox) > 0:
            # num_anchor x num_gt
            overlaps = bbox_overlaps_cython(
                valid_anchor.astype(np.float32, copy=False),
                gt_bbox.astype(np.float32, copy=False))
            max_overlaps = overlaps.max(axis=1)
            argmax_overlaps = overlaps.argmax(axis=1)
            gt_max_overlaps = overlaps.max(axis=0)
            # TODO: speed up this
            # TODO: fix potentially assigning wrong anchors as positive
            # A correct implementation is given as
            # gt_argmax_overlaps = np.where((overlaps.transpose() == gt_max_overlaps[:, None]) &
            #                               (overlaps.transpose() >= min_pos_thr))[1]
            gt_argmax_overlaps = np.where((overlaps == gt_max_overlaps)
                                          & (overlaps >= min_pos_thr))
            # anchor class
            cls_label[max_overlaps < neg_thr] = 0
            # fg label: for each gt, anchor with highest overlap
            cls_label[gt_argmax_overlaps[0]] = gt_bbox[gt_argmax_overlaps[1],
                                                       4]
            # fg label: above threshold IoU
            cls_label[max_overlaps >= pos_thr] = gt_bbox[argmax_overlaps[
                max_overlaps >= pos_thr], 4]

            # anchor regression
            reg_target[:] = bbox_transform(valid_anchor,
                                           gt_bbox[argmax_overlaps, :4])
            reg_weight[cls_label >= 1, :] = 1.0
        else:
            cls_label[:] = 0

        return cls_label, reg_target, reg_weight