Ejemplo n.º 1
0
    def forward(self, box_preds, gt_boxes):
        """
        Parameters
        ----------
        box_preds: Predicted bounding boxes. (batch, xx, 4).
        gt_boxes: Ground-truth bounding boxes.

        Returns
        -------
        (tuple of) tensor.
            objectness: 0 for negative, 1 for positive, -1 for ignore. (batch, xx, 1).
            center_targets: regression target for center x and y. (batch, xx, 2).
            scale_targets: regression target for scale x and y. (batch, xx, 2).
            weights: element-wise gradient weights for center_targets and scale_targets.
            class_targets: a one-hot vector for classification. (batch, xx, 80).
        """
        with torch.no_grad():
            objness_t = torch.zeros_like(
                torch.unsqueeze(box_preds[:, :, 0], -1))
            center_t = torch.zeros_like(box_preds[:, :, 0:2])
            scale_t = torch.zeros_like(box_preds[:, :, 0:2])
            weight_t = torch.zeros_like(box_preds[:, :, 0:2])
            class_t = torch.ones_like(objness_t.repeat(1, 1,
                                                       self._num_class)) * -1
            ious_max = []
            for box_preds_per_img, gt_boxes_per_img in zip(
                    box_preds, gt_boxes):
                ious = bbox_overlaps(box_preds_per_img, gt_boxes_per_img)
                ious_max.append(torch.max(
                    ious, dim=-1, keepdim=True)[0])  # (h*w*num_anchors, 1)
            ious_max = torch.stack(ious_max, dim=0)
            # use -1 for ignored.
            objness_t = (ious_max > self._ignore_iou_thresh).to(
                torch.float32) * -1
            return objness_t, center_t, scale_t, weight_t, class_t
Ejemplo n.º 2
0
    def forward_single_image(self, gt_boxes, gt_labels, img_metas,
                             shift_anchor_boxes, shape_like, num_anchors,
                             anchors, pad_shape, all_featmaps, num_offsets):
        # shape_like: (h3*w3+h2*w2+h1*w1, 9 anchors, 2).
        center_targets = torch.zeros(shape_like).cuda()
        scale_targets = torch.zeros_like(center_targets)
        weights = torch.zeros_like(center_targets)
        objectness = torch.zeros_like(weights.split(1, dim=-1)[0])
        class_targets = torch.ones_like(objectness).repeat(
            1, 1, self._num_class) * -1

        gtx, gty, gtw, gth = point_to_center(gt_boxes,
                                             split=True,
                                             keep_axis=True)
        shift_gt_boxes = torch.cat(
            (-0.5 * gtw, -0.5 * gth, 0.5 * gtw, 0.5 * gth), dim=-1)
        # ious between zero-center anchors(9) and zero-center gt boxes(gt num).
        ious = bbox_overlaps(shift_anchor_boxes, shift_gt_boxes)
        # assume the center of gt and anchor is aligned and find the best matched anchor scale.
        matches = ious.argmax(dim=0).to(torch.int32)  # (num_gt,)
        valid_gts = (gt_boxes >= 0).prod(dim=-1)  # (num_gt,)
        pad_height, pad_width = pad_shape
        for m in range(matches.shape[0]):
            # for each gt in a single image.
            if valid_gts[m] < 1:
                break
            match = matches[m]  # matched anchor idx, note that 0 <= match < 9.
            nlayer = np.nonzero(num_anchors > match)[0][0]
            height = all_featmaps[nlayer].shape[2]
            width = all_featmaps[nlayer].shape[3]
            mgtx, mgty, mgtw, mgth = (gtx[m, 0], gty[m, 0], gtw[m, 0], gth[m,
                                                                           0])
            # compute the location of the gt top-left centers on the feature map level.
            loc_x = (mgtx / pad_width * width).to(torch.int32)
            loc_y = (mgty / pad_height * height).to(torch.int32)
            # write back to targets
            index = num_offsets[nlayer] + loc_y * width + loc_x
            center_targets[index, match,
                           0] = mgtx / pad_width * width - loc_x  # tx
            center_targets[index, match,
                           1] = mgty / pad_height * height - loc_y  # ty
            scale_targets[index, match,
                          0] = torch.log(max(mgtw, 1) / anchors[match, 0])
            scale_targets[index, match,
                          1] = torch.log(max(mgth, 1) / anchors[match, 1])
            weights[index,
                    match, :] = 2.0 - mgtw * mgth / pad_width / pad_height
            first_n = img_metas.get('mixup_params',
                                    dict()).get('first_n_labels', len(matches))
            lambd = img_metas.get('mixup_params', dict()).get('lambd', 1.)
            if m < first_n:
                objectness[index, match, 0] = lambd
            else:
                objectness[index, match, 0] = 1. - lambd
            class_targets[index, match, :] = 0
            class_targets[index, match, int(gt_labels[m]) - 1] = 1
        return objectness, center_targets, scale_targets, weights, class_targets
    def loss_single(self, cls_score, bbox_pred, iou_pred, anchors, labels, label_weights,
                    bbox_targets, bbox_weights, num_total_samples):
        """Compute loss of a single scale level.

        Args:
            cls_score (Tensor): Box scores for each scale level
                Has shape (N, num_anchors * num_classes, H, W).
            bbox_pred (Tensor): Box energies / deltas for each scale
                level with shape (N, num_anchors * 4, H, W).
            anchors (Tensor): Box reference for each scale level with shape
                (N, num_total_anchors, 4).
            labels (Tensor): Labels of each anchors with shape
                (N, num_total_anchors).
            label_weights (Tensor): Label weights of each anchor with shape
                (N, num_total_anchors)
            bbox_targets (Tensor): BBox regression targets of each anchor wight
                shape (N, num_total_anchors, 4).
            bbox_weights (Tensor): BBox regression loss weights of each anchor
                with shape (N, num_total_anchors, 4).
            num_total_samples (int): If sampling, num total samples equal to
                the number of total anchors; Otherwise, it is the number of
                positive anchors.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        # classification loss
        anchors = anchors.reshape(-1, 4)
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)
        cls_score = cls_score.permute(0, 2, 3,
                                      1).reshape(-1, self.cls_out_channels)
        loss_cls = self.loss_cls(
            cls_score, labels, label_weights, avg_factor=num_total_samples)
        # regression loss
        bbox_targets = bbox_targets.reshape(-1, 4)
        bbox_weights = bbox_weights.reshape(-1, 4)
        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)

        iou_targets = label_weights.new_zeros(labels.shape)
        iou_weights = label_weights.new_zeros(labels.shape)
        iou_weights[(bbox_weights.sum(axis=1) > 0).nonzero()] = 1.
        iou_pred = iou_pred.permute(0, 2, 3, 1).reshape(-1)

        bg_class_ind = self.num_classes
        pos_inds = ((labels >= 0) &
                    (labels < bg_class_ind)).nonzero().squeeze(1)

        if self.reg_decoded_bbox:
            anchors = anchors.reshape(-1, 4)
            bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
        loss_bbox = self.loss_bbox(
            bbox_pred,
            bbox_targets,
            bbox_weights,
            avg_factor=num_total_samples)

        if len(pos_inds) > 0:
            # dx, dy, dw, dh
            pos_bbox_targets = bbox_targets[pos_inds]
            # tx, ty, tw, th
            pos_bbox_pred = bbox_pred[pos_inds]
            # x1, y1, x2, y2
            pos_anchors = anchors[pos_inds]

            if self.reg_decoded_bbox:
                pos_decode_bbox_pred = pos_bbox_pred
                gt_bboxes = pos_bbox_targets
            else:
                # x1, y1, x2 ,y2
                pos_decode_bbox_pred = self.bbox_coder.decode(
                    pos_anchors, pos_bbox_pred)

                gt_bboxes = self.bbox_coder.decode(pos_anchors,
                                                   pos_bbox_targets)

            if self.detach:
                pos_decode_bbox_pred = pos_decode_bbox_pred.detach()

            iou_targets[pos_inds] = bbox_overlaps(
                pos_decode_bbox_pred, gt_bboxes, is_aligned=True)

        loss_iou = self.loss_iou(
            iou_pred, iou_targets, iou_weights, avg_factor=num_total_samples)
        return loss_cls, loss_bbox, loss_iou
Ejemplo n.º 4
0
    def nms_resampling_discrete(self, proposals, scores, ids, gt_bboxes,
                                gt_labels, a_r, a_c, a_f):
        # proposal is considered as background when its iou with gt < 0.3
        select_thresh = 0.3
        out = []

        # rare, common, frequent = self.get_category_frequency(gt_labels.device)
        frequent = torch.tensor([0, 3], device=gt_labels.device)
        common = torch.tensor([1, 4, 9], device=gt_labels.device)
        rare = torch.tensor([2, 5, 6, 7, 8, 10], device=gt_labels.device)

        rare_gtbox = torch.zeros((2000, 4), device=gt_labels.device)
        rare_gtbox_idx = 0
        common_gtbox = torch.zeros((2000, 4), device=gt_labels.device)
        common_gtbox_idx = 0
        frequent_gtbox = torch.zeros((2000, 4), device=gt_labels.device)
        frequent_gtbox_idx = 0
        for gt_bbox, gt_label in zip(gt_bboxes, gt_labels):
            if gt_label in rare:
                rare_gtbox[rare_gtbox_idx, ...] = gt_bbox
                rare_gtbox_idx += 1
            elif gt_label in common:
                common_gtbox[common_gtbox_idx, ...] = gt_bbox
                common_gtbox_idx += 1
            else:
                frequent_gtbox[frequent_gtbox_idx, ...] = gt_bbox
                frequent_gtbox_idx += 1
        rare_gtbox = rare_gtbox[:rare_gtbox_idx, ...]
        common_gtbox = common_gtbox[:common_gtbox_idx, ...]

        frequent_proposals, _ = batched_nms(
            proposals, scores, ids, dict(type='nms', iou_threshold=a_f))
        if len(rare_gtbox) > 0:
            rare_proposals, _ = batched_nms(
                proposals, scores, ids, dict(type='nms', iou_threshold=a_r))
            rare_overlaps = bbox_overlaps(rare_gtbox, rare_proposals[:, :4])
            rare_max_overlaps, rare_argmax_overlaps = rare_overlaps.max(dim=0)
            rare_pos_inds = rare_max_overlaps >= select_thresh
            rare_proposals = rare_proposals[rare_pos_inds, :]
            out.append(rare_proposals)

            frequent_rare_overlaps = bbox_overlaps(rare_gtbox,
                                                   frequent_proposals[:, :4])
            frequent_rare_max_overlaps, frequent_rare_argmax_overlaps = frequent_rare_overlaps.max(
                dim=0)
            valid_inds = frequent_rare_max_overlaps < select_thresh
            frequent_proposals = frequent_proposals[valid_inds, :]
        if len(common_gtbox) > 0:
            # keep = self.nms_py(proposals, scores, a_c)
            common_proposals, _ = batched_nms(
                proposals, scores, ids, dict(type='nms', iou_threshold=a_c))
            common_overlaps = bbox_overlaps(common_gtbox,
                                            common_proposals[:, :4])
            common_max_overlaps, common_argmax_overlaps = common_overlaps.max(
                dim=0)
            common_pos_inds = common_max_overlaps >= select_thresh
            common_proposals = common_proposals[common_pos_inds, :]
            out.append(common_proposals)

            frequent_common_overlaps = bbox_overlaps(common_gtbox,
                                                     frequent_proposals[:, :4])
            frequent_common_max_overlaps, frequent_common_argmax_overlaps = frequent_common_overlaps.max(
                dim=0)
            valid_inds = frequent_common_max_overlaps < select_thresh
            frequent_proposals = frequent_proposals[valid_inds, :]
        out.append(frequent_proposals)
        if len(out) > 1:
            out_proposals = torch.cat(out, 0)
        else:
            out_proposals = frequent_proposals

        return out_proposals