Example #1
0
def boxlist_nms(boxlist,
                nms_thresh,
                max_proposals=-1,
                score_field="scores",
                nms_type='nms',
                vote_thresh=0.65):
    if nms_thresh <= 0:
        return boxlist
    mode = boxlist.mode
    boxlist = boxlist.convert("xyxy")
    boxes = boxlist.bbox
    score = boxlist.get_field(score_field)
    if nms_type == 'nms':
        keep = _box_nms(boxes, score, nms_thresh)
        if max_proposals > 0:
            keep = keep[:max_proposals]
        boxlist = boxlist[keep]
    else:
        if nms_type == 'vote':
            boxes_vote, scores_vote = bbox_vote(boxes, score, vote_thresh)
        else:
            boxes_vote, scores_vote = soft_bbox_vote(boxes, score, vote_thresh)
        if len(boxes_vote) > 0:
            boxlist.bbox = boxes_vote
            boxlist.extra_fields['scores'] = scores_vote
    return boxlist.convert(mode)
Example #2
0
def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
    """
    Performs non-maximum suppression on a boxlist, with scores specified
    in a boxlist field via score_field.

    Arguments:
        boxlist(BoxList)
        nms_thresh (float)
        max_proposals (int): if > 0, then only the top max_proposals are kept
            after non-maximum suppression
        score_field (str)
    """
    if nms_thresh <= 0:
        return boxlist
    mode = boxlist.mode
    boxlist = boxlist.convert("xyxy")
    boxes = boxlist.bbox
    score = boxlist.get_field(score_field)
    keep = _box_nms(boxes, score, nms_thresh)
    if max_proposals > 0:
        keep = keep[:max_proposals]
    boxlist = boxlist[keep]
    return boxlist.convert(mode)
Example #3
0
    def __call__(self, locations, box_cls, box_regression, centerness,
                 cof_preds, feat_mask, targets):
        """
        Arguments:
            locations (list[BoxList])
            box_cls (list[Tensor])
            box_regression (list[Tensor])
            centerness (list[Tensor])
            targets (list[BoxList])

        Returns:
            cls_loss (Tensor)
            reg_loss (Tensor)
            centerness_loss (Tensor)
        """
        N = box_cls[0].size(0)
        num_classes = box_cls[0].size(1)
        labels, reg_targets, labels_list, bbox_gt_list, gt_inds = self.prepare_targets(
            locations, targets)

        ######decode box########
        sampled_boxes = []
        for _, (l, b, s) in enumerate(
                zip(locations, box_regression, self.fpn_strides)):
            sampled_boxes.append(self.decode_for_single_feature_map(l, b, s))

        flatten_sampled_boxes = [
            torch.cat([
                labels_level_img.reshape(-1, 4)
                for labels_level_img in sampled_boxes_per_img
            ]) for sampled_boxes_per_img in zip(*sampled_boxes)
        ]

        box_cls_flatten = []
        box_regression_flatten = []
        centerness_flatten = []
        labels_flatten = []
        reg_targets_flatten = []
        for l in range(len(labels)):
            box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(
                -1, num_classes))
            box_regression_flatten.append(box_regression[l].permute(
                0, 2, 3, 1).reshape(-1, 4))
            labels_flatten.append(labels[l].reshape(-1))
            reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
            centerness_flatten.append(centerness[l].reshape(-1))

        box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
        box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
        centerness_flatten = torch.cat(centerness_flatten, dim=0)
        labels_flatten = torch.cat(labels_flatten, dim=0)
        reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)

        pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1)

        box_regression_flatten = box_regression_flatten[pos_inds]
        reg_targets_flatten = reg_targets_flatten[pos_inds]
        centerness_flatten = centerness_flatten[pos_inds]

        num_gpus = get_num_gpus()
        # sync num_pos from all gpus
        total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel()
                                                        ])).item()
        num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0)

        cls_loss = self.cls_loss_func(
            box_cls_flatten, labels_flatten.int()) / num_pos_avg_per_gpu

        if pos_inds.numel() > 0:
            centerness_targets = self.compute_centerness_targets(
                reg_targets_flatten)

            # average sum_centerness_targets from all gpus,
            # which is used to normalize centerness-weighed reg loss
            sum_centerness_targets_avg_per_gpu = \
                reduce_sum(centerness_targets.sum()).item() / float(num_gpus)

            reg_loss = self.box_reg_loss_func(
                box_regression_flatten, reg_targets_flatten,
                centerness_targets) / sum_centerness_targets_avg_per_gpu
            centerness_loss = self.centerness_loss_func(
                centerness_flatten, centerness_targets) / num_pos_avg_per_gpu
        else:
            reg_loss = box_regression_flatten.sum()
            reduce_sum(centerness_flatten.new_tensor([0.0]))
            centerness_loss = centerness_flatten.sum()

        ##########mask loss#################
        num_imgs = len(flatten_sampled_boxes)
        flatten_cls_scores1 = []
        for l in range(len(labels)):
            flatten_cls_scores1.append(box_cls[l].permute(0, 2, 3, 1).reshape(
                num_imgs, -1, num_classes))

        flatten_cls_scores1 = torch.cat(flatten_cls_scores1, dim=1)

        flatten_cof_preds = [
            cof_pred.permute(0, 2, 3, 1).reshape(len(labels_list), -1, 32 * 4)
            for cof_pred in cof_preds
        ]
        flatten_cof_preds = torch.cat(flatten_cof_preds, dim=1)

        loss_mask = 0
        for i in range(num_imgs):
            labels = torch.cat(
                [labels_level.flatten() for labels_level in labels_list[i]])
            # bbox_gt = torch.cat([gt_level.reshape(-1,4) for gt_level in bbox_gt_list[i]])
            bbox_dt = flatten_sampled_boxes[i] / 2
            bbox_dt = bbox_dt.detach()
            pos_inds = labels > 0

            cof_pred = flatten_cof_preds[i][pos_inds]
            img_mask = feat_mask[i]
            mask_h = feat_mask[i].shape[1]
            mask_w = feat_mask[i].shape[2]
            idx_gt = gt_inds[i]
            bbox_dt = bbox_dt[pos_inds, :4]
            gt_masks = targets[i].get_field("masks").get_mask_tensor().to(
                dtype=torch.float32, device=feat_mask.device)
            gt_masks = gt_masks.reshape(-1, gt_masks.shape[-2],
                                        gt_masks.shape[-1])

            area = (bbox_dt[:, 2] - bbox_dt[:, 0]) * (bbox_dt[:, 3] -
                                                      bbox_dt[:, 1])
            bbox_dt = bbox_dt[area > 1.0, :]
            idx_gt = idx_gt[area > 1.0]
            cof_pred = cof_pred[area > 1.0]
            if bbox_dt.shape[0] == 0:
                continue

            bbox_gt = targets[i].bbox
            cls_score = flatten_cls_scores1[i, pos_inds, labels[pos_inds] -
                                            1].sigmoid().detach()
            cls_score = cls_score[area > 1.0]
            ious = bbox_overlaps(bbox_gt[idx_gt] / 2, bbox_dt, is_aligned=True)
            weighting = cls_score * ious
            weighting = weighting / torch.sum(weighting) * len(weighting)
            keep = _box_nms(bbox_dt, cls_score, 0.9)
            bbox_dt = bbox_dt[keep]
            weighting = weighting[keep]
            idx_gt = idx_gt[keep]
            cof_pred = cof_pred[keep]

            gt_mask = F.interpolate(gt_masks.unsqueeze(0),
                                    scale_factor=0.5,
                                    mode='bilinear',
                                    align_corners=False).squeeze(0)

            shape = np.minimum(feat_mask[i].shape, gt_mask.shape)
            gt_mask_new = gt_mask.new_zeros(gt_mask.shape[0], mask_h, mask_w)
            gt_mask_new[:gt_mask.shape[0], :shape[1], :
                        shape[2]] = gt_mask[:gt_mask.
                                            shape[0], :shape[1], :shape[2]]
            gt_mask_new = gt_mask_new.gt(0.5).float()

            gt_mask_new = torch.index_select(gt_mask_new, 0,
                                             idx_gt).permute(1, 2,
                                                             0).contiguous()

            #######spp###########################
            img_mask1 = img_mask.permute(1, 2, 0)
            pos_masks00 = torch.sigmoid(img_mask1 @ cof_pred[:, 0:32].t())
            pos_masks01 = torch.sigmoid(img_mask1 @ cof_pred[:, 32:64].t())
            pos_masks10 = torch.sigmoid(img_mask1 @ cof_pred[:, 64:96].t())
            pos_masks11 = torch.sigmoid(img_mask1 @ cof_pred[:, 96:128].t())
            pred_masks = torch.stack(
                [pos_masks00, pos_masks01, pos_masks10, pos_masks11], dim=0)
            pred_masks = self.crop_cuda(pred_masks, bbox_dt)
            gt_mask_crop = self.crop_gt_cuda(gt_mask_new, bbox_dt)

            pre_loss = F.binary_cross_entropy(pred_masks,
                                              gt_mask_crop,
                                              reduction='none')

            pos_get_csize = center_size(bbox_dt)
            gt_box_width = pos_get_csize[:, 2]
            gt_box_height = pos_get_csize[:, 3]
            pre_loss = pre_loss.sum(dim=(
                0, 1)) / gt_box_width / gt_box_height / pos_get_csize.shape[0]
            loss_mask += torch.sum(pre_loss * weighting.detach())

        loss_mask = loss_mask / num_imgs
        if loss_mask > 1.0:
            loss_mask = loss_mask * 0.5

        return cls_loss, reg_loss, centerness_loss, loss_mask