Пример #1
0
def random_crop(image,
                boxes,
                labels,
                difficulties,
                choices=[0., .1, .3, .5, .7, .9, None]):
    """
    Performs a random crop in the manner stated in the paper. Helps to learn to detect larger and partial objects.

    Note that some objects may be cut out entirely.

    Adapted from https://github.com/amdegroot/ssd.pytorch/blob/master/utils/augmentations.py

    :param image: image, a tensor of dimensions (3, original_h, original_w)
    :param boxes: bounding boxes in boundary coordinates, a tensor of dimensions (n_objects, 4)
    :param labels: labels of objects, a tensor of dimensions (n_objects)
    :param difficulties: difficulties of detection of these objects, a tensor of dimensions (n_objects)
    :return: cropped image, updated bounding box coordinates, updated labels, updated difficulties
    """
    original_h = image.size(0)
    original_w = image.size(1)
    # Keep choosing a minimum overlap until a successful crop is made
    while True:
        # Randomly draw the value for minimum overlap
        min_overlap = random.choice(choices)  # 'None' refers to no cropping

        # If not cropping
        if min_overlap is None:
            return image, boxes, labels, difficulties

        # Try up to 50 times for this choice of minimum overlap
        # This isn't mentioned in the paper, of course, but 50 is chosen in paper authors' original Caffe repo
        max_trials = 50
        for _ in range(max_trials):
            # Crop dimensions must be in [0.3, 1] of original dimensions
            # Note - it's [0.1, 1] in the paper, but actually [0.3, 1] in the authors' repo
            min_scale = 0.3
            scale_h = random.uniform(min_scale, 1)
            scale_w = random.uniform(min_scale, 1)
            new_h = int(scale_h * original_h)
            new_w = int(scale_w * original_w)

            # Aspect ratio has to be in [0.5, 2]
            aspect_ratio = new_h / new_w
            if not 0.5 < aspect_ratio < 2:
                continue

            # Crop coordinates (origin at top-left of image)
            left = random.randint(0, original_w - new_w)
            right = left + new_w
            top = random.randint(0, original_h - new_h)
            bottom = top + new_h
            crop = torch.FloatTensor([left, top, right, bottom])  # (4)

            # Calculate Jaccard overlap between the crop and the bounding boxes
            overlap = box_iou(
                crop.unsqueeze(0), boxes
            )  # (1, n_objects), n_objects is the no. of objects in this image
            overlap = overlap.squeeze(0)  # (n_objects)

            # If not a single bounding box has a Jaccard overlap of greater than the minimum, try again
            if overlap.max().item() < min_overlap:
                continue

            # Crop image
            new_image = image[:, top:bottom, left:right]  # (3, new_h, new_w)

            # new_image = image.crop((left, top, right, bottom))

            # new_image = image[:, top:bottom, left:right]  # (3, new_h, new_w)

            # Find centers of original bounding boxes
            bb_centers = (boxes[:, :2] + boxes[:, 2:]) / 2.  # (n_objects, 2)

            # Find bounding boxes whose centers are in the crop
            centers_in_crop = (bb_centers[:, 0] > left) * (
                bb_centers[:, 0] < right
            ) * (bb_centers[:, 1] > top) * (
                bb_centers[:, 1] < bottom
            )  # (n_objects), a Torch uInt8/Byte tensor, can be used as a boolean index

            # If not a single bounding box has its center in the crop, try again
            if not centers_in_crop.any():
                continue

            # Discard bounding boxes that don't meet this criterion
            new_boxes = boxes[centers_in_crop, :]
            new_labels = labels[centers_in_crop]
            new_difficulties = difficulties[centers_in_crop]

            # Calculate bounding boxes' new coordinates in the crop
            new_boxes[:, :2] = torch.max(new_boxes[:, :2],
                                         crop[:2])  # crop[:2] is [left, top]
            new_boxes[:, :2] -= crop[:2]
            new_boxes[:,
                      2:] = torch.min(new_boxes[:, 2:],
                                      crop[2:])  # crop[2:] is [right, bottom]
            new_boxes[:, 2:] -= crop[:2]

            return new_image, new_boxes, new_labels, new_difficulties
Пример #2
0
 def iou_check(box, expected, tolerance=1e-4):
     out = ops.box_iou(box, box)
     torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
Пример #3
0
 def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
     return 1.0 - box_iou(inputs, target).diagonal()
def calc_detection_voc_prec_rec(pred_bboxes,
                                pred_labels,
                                pred_scores,
                                gt_bboxes,
                                gt_labels,
                                gt_difficulties=None,
                                iou_thresh=0.5):

    pred_bboxes = iter(pred_bboxes)
    pred_labels = iter(pred_labels)
    pred_scores = iter(pred_scores)
    gt_bboxes = iter(gt_bboxes)
    gt_labels = iter(gt_labels)
    if gt_difficulties is None:
        gt_difficulties = itertools.repeat(None)
    else:
        gt_difficulties = iter(gt_difficulties)

    n_pos = defaultdict(int)
    score = defaultdict(list)
    match = defaultdict(list)

    for pred_bbox, pred_label, pred_score, gt_bbox, gt_label, gt_difficult in zip(
            pred_bboxes, pred_labels, pred_scores, gt_bboxes, gt_labels,
            gt_difficulties):
        if gt_difficult is None:
            gt_difficult = np.zeros(gt_bbox.shape[0], dtype=bool)

        for l in np.unique(np.concatenate((pred_label, gt_label)).astype(int)):
            pred_mask_l = pred_label == l
            pred_bbox_l = pred_bbox[pred_mask_l]
            pred_score_l = pred_score[pred_mask_l]
            # sort by score
            order = pred_score_l.argsort()[::-1]
            pred_bbox_l = pred_bbox_l[order]
            pred_score_l = pred_score_l[order]

            gt_mask_l = gt_label == l
            gt_bbox_l = gt_bbox[gt_mask_l]
            gt_difficult_l = gt_difficult[gt_mask_l]

            n_pos[l] += np.logical_not(gt_difficult_l).sum()
            score[l].extend(pred_score_l)

            if len(pred_bbox_l) == 0:
                continue
            if len(gt_bbox_l) == 0:
                match[l].extend((0, ) * pred_bbox_l.shape[0])
                continue

            # VOC evaluation follows integer typed bounding boxes.
            pred_bbox_l = pred_bbox_l.copy()
            pred_bbox_l[:, 2:] += 1
            gt_bbox_l = gt_bbox_l.copy()
            gt_bbox_l[:, 2:] += 1

            iou = box_iou(pred_bbox_l, gt_bbox_l)
            gt_index = iou.argmax(axis=1)
            # set -1 if there is no matching ground truth
            gt_index[iou.max(axis=1) < iou_thresh] = -1
            del iou

            selec = np.zeros(gt_bbox_l.shape[0], dtype=bool)
            for gt_idx in gt_index:
                if gt_idx >= 0:
                    if gt_difficult_l[gt_idx]:
                        match[l].append(-1)
                    else:
                        if not selec[gt_idx]:
                            match[l].append(1)
                        else:
                            match[l].append(0)
                    selec[gt_idx] = True
                else:
                    match[l].append(0)

    for iter_ in (pred_bboxes, pred_labels, pred_scores, gt_bboxes, gt_labels,
                  gt_difficulties):
        if next(iter_, None) is not None:
            raise ValueError('Length of input iterables need to be same.')

    n_fg_class = max(n_pos.keys()) + 1
    prec = [None] * n_fg_class
    rec = [None] * n_fg_class

    for l in n_pos.keys():
        score_l = np.array(score[l])
        match_l = np.array(match[l], dtype=np.int8)

        order = score_l.argsort()[::-1]
        match_l = match_l[order]

        tp = np.cumsum(match_l == 1)
        fp = np.cumsum(match_l == 0)

        # If an element of fp + tp is 0,
        # the corresponding element of prec[l] is nan.
        prec[l] = tp / (fp + tp)
        # If n_pos[l] is 0, rec[l] is None.
        if n_pos[l] > 0:
            rec[l] = tp / n_pos[l]

    return prec, rec
Пример #5
0
 def iou_check(box, expected, tolerance=1e-4):
     out = ops.box_iou(box, box)
     assert out.size() == expected.size()
     assert ((out - expected).abs().max() < tolerance).item()
Пример #6
0
    def forward(self, features, labels=None, gt_bboxes=None):
        """

        :param features: OrderDict. The shape of each item is (BS, C_i, H_i, W_i)
        :param labels: shape (BS, n_objs)
        :param gt_bboxes: shape (BS, n_objs, 4)
        :return:
        """

        if self.training:
            pre_nms_top_n = self.pre_nms_top_n_in_train
            post_nms_top_n = self.post_nms_top_n_in_train
        else:
            pre_nms_top_n = self.pre_nms_top_n_in_test
            post_nms_top_n = self.post_nms_top_n_in_test

        total_anchors = []
        total_cls_pred = []
        total_reg_pred = []
        total_cls_scores = []
        total_reg_bboxes = []
        for i, feat in enumerate(features.values()):
            x = F.relu(self.conv(feat))
            cls_pred = self.cls(x)  # (BS, num_anchors, H, W)
            reg_pred = self.reg(x)  # (BS, num_anchors*4, H, W)

            BS, num_anchors, H, W = cls_pred.shape
            # (BS, H, W, num_anchors)
            cls_pred = cls_pred.permute(0, 2, 3, 1)
            # (BS, H, W, num_anchors, 4)
            reg_pred = reg_pred.permute(0, 2, 3, 1).reshape(
                (BS, H, W, num_anchors, 4))
            # (H, W, num_anchors, 4)
            anchors = self._buffers["anchor%i" % i]

            # (BS, H, W, num_anchors) -> (BS, H*W*num_anchors)
            cls_pred = cls_pred.reshape((BS, -1))
            # (BS, H, W, num_anchors, 4) -> (BS, H*W*num_anchors, 4)
            reg_pred = reg_pred.reshape((BS, -1, 4))
            # (H, W, num_anchors, 4) -> (H*W*num_anchors, 4)
            anchors = anchors.reshape((-1, 4))

            total_anchors.append(anchors)
            total_cls_pred.append(cls_pred)
            total_reg_pred.append(reg_pred)

            with torch.no_grad():
                # 修正anchors
                reg_bboxes = self.box_coder.decode(anchors, reg_pred.detach())
                reg_bboxes[..., 0].clamp_(0, self.image_size[0])
                reg_bboxes[..., 1].clamp_(0, self.image_size[1])
                reg_bboxes[..., 2].clamp_(0, self.image_size[0])
                reg_bboxes[..., 3].clamp_(0, self.image_size[1])
                # 计算分数
                cls_scores = torch.sigmoid(cls_pred.detach())

                if not self.nms_per_layer:
                    total_cls_scores.append(cls_scores)
                    total_reg_bboxes.append(reg_bboxes)
                else:
                    # NMS per layer
                    BS = cls_scores.shape[0]
                    keep_bboxes = []
                    keep_scores = []
                    for i in range(BS):
                        dtype = reg_bboxes.dtype
                        device = reg_bboxes.device
                        _bboxes = torch.full(
                            (post_nms_top_n // len(features), 4),
                            -1,
                            dtype=dtype,
                            device=device)
                        _scores = torch.full(
                            (post_nms_top_n // len(features), ),
                            -1,
                            dtype=cls_scores.dtype,
                            device=cls_scores.device)
                        pre_nms_top_n_indices = torch.argsort(cls_scores[i],
                                                              descending=True)
                        _num_anchors = pre_nms_top_n_indices.shape[0]
                        _pre_nms_top_n = pre_nms_top_n // len(
                            features) if _num_anchors > pre_nms_top_n // len(
                                features) else _num_anchors
                        pre_nms_top_n_indices = pre_nms_top_n_indices[:
                                                                      _pre_nms_top_n]
                        _reg_bboxes = reg_bboxes[i][pre_nms_top_n_indices]
                        _cls_scores = cls_scores[i][pre_nms_top_n_indices]
                        keep = ops.nms(_reg_bboxes, _cls_scores,
                                       self.nms_thresh)
                        # keep = ops.nms(reg_bboxes[i], cls_scores[i], self.nms_thresh)
                        n_keep = keep.shape[0]
                        n_keep = min(n_keep, post_nms_top_n // len(features))
                        keep = keep[:n_keep]
                        _bboxes[:n_keep] = _reg_bboxes[keep]
                        _scores[:n_keep] = _cls_scores[keep]
                        keep_bboxes.append(_bboxes)
                        keep_scores.append(_scores)

                    total_reg_bboxes.append(torch.stack(keep_bboxes))
                    total_cls_scores.append(torch.stack(keep_scores))

        # (-1, 4)
        anchors = torch.cat(total_anchors, dim=0)
        # (BS, -1)
        cls_pred = torch.cat(total_cls_pred, dim=1)
        # (BS, -1, 4)
        reg_pred = torch.cat(total_reg_pred, dim=1)

        if not self.nms_per_layer:
            # (BS, -1)
            cls_scores = torch.cat(total_cls_scores, dim=1)
            # (BS, -1, 4)
            reg_bboxes = torch.cat(total_reg_bboxes, dim=1)

            # NMS
            BS = cls_pred.shape[0]
            keep_bboxes = []
            for i in range(BS):
                dtype = reg_bboxes.dtype
                device = reg_bboxes.device
                _bboxes = torch.full((post_nms_top_n, 4),
                                     -1,
                                     dtype=dtype,
                                     device=device)
                pre_nms_top_n_indices = torch.argsort(cls_scores[i],
                                                      descending=True)
                _num_anchors = pre_nms_top_n_indices.shape[0]
                _pre_nms_top_n = pre_nms_top_n if _num_anchors > pre_nms_top_n else _num_anchors
                pre_nms_top_n_indices = pre_nms_top_n_indices[:_pre_nms_top_n]
                _reg_bboxes = reg_bboxes[i][pre_nms_top_n_indices]
                keep = ops.nms(_reg_bboxes,
                               cls_scores[i][pre_nms_top_n_indices],
                               self.nms_thresh)
                # keep = ops.nms(reg_bboxes[i], cls_scores[i], self.nms_thresh)
                n_keep = keep.shape[0]
                n_keep = min(n_keep, post_nms_top_n)
                keep = keep[:n_keep]
                _bboxes[:n_keep] = _reg_bboxes[keep]
                keep_bboxes.append(_bboxes)

            bboxes = torch.stack(keep_bboxes)  # (BS, post_nms_top_n, 4)
        else:
            # (BS, post_nms_top_n)
            cls_scores = torch.cat(total_cls_scores, dim=1)
            # (BS, post_nms_top_n)
            bboxes = torch.cat(total_reg_bboxes, dim=1)
            # 根据scores大小对bboxes(rois)进行降序排序
            # 可能的原因:
            # rois的顺序会影响rcnn,比如在rcnn的nms时,rcnn更倾向于选择前面的rois
            # rois的rpn_scores高,rcnn_scores的分数也高
            # 当rois的分数差不多时,位于前面的rois会在nms时抑制后面的rois
            for i in range(cls_scores.shape[0]):
                sorted_indices = torch.argsort(cls_scores[i], descending=True)
                bboxes[i] = bboxes[i][sorted_indices]

        if self.training:
            total_cls_pred = []
            total_reg_pred = []
            total_reg_target = []
            total_fg_bg_mask = []

            all_cls_pred = []
            all_fg_bg_mask = []

            BS = gt_bboxes.shape[0]
            for i in range(BS):
                # 为每个anchor分配label
                areas = ops.boxes.box_area(anchors)
                ious = ops.box_iou(
                    anchors, gt_bboxes[i]
                )  # (num_total_anchors, num_gt_bboxes) (N, M) for short
                # 把nan换成0
                zero_mask = (areas == 0).reshape(-1, 1).expand_as(ious)
                ious[zero_mask] = 0

                if torch.any(torch.isnan(ious)):
                    raise Exception("some elements in ious is nan")

                # the anchor/anchors with the highest Intersection-over-Union (IoU)
                # overlap with a ground-truth box
                iou_max_gt, indices = torch.max(ious, dim=0)
                # 不考虑gt_bboxes中填充的部分
                iou_max_gt = torch.where(labels[i] == -1,
                                         torch.ones_like(iou_max_gt),
                                         iou_max_gt)
                highest_mask = (ious == iou_max_gt)
                fg_mask = torch.any(highest_mask, dim=1)
                # an anchor that has an IoU overlap higher than fg_iou_thresh with any ground-truth box
                iou_max, matched_idx = torch.max(ious, dim=1)
                # 1 for foreground -1 for background 0 for ignore
                fg_bg_mask = torch.zeros_like(iou_max)
                # confirm positive samples
                fg_bg_mask = torch.where(iou_max >= self.fg_iou_thresh,
                                         torch.ones_like(iou_max), fg_bg_mask)
                fg_bg_mask = torch.where(fg_mask, torch.ones_like(iou_max),
                                         fg_bg_mask)
                # confirm negetive samples
                fg_bg_mask = torch.where(iou_max <= self.bg_iou_thresh,
                                         torch.full_like(iou_max, -1),
                                         fg_bg_mask)

                all_cls_pred.append(cls_pred[i].detach())
                all_fg_bg_mask.append(fg_bg_mask.detach())

                # 随机采样
                indices = torch.arange(fg_bg_mask.shape[0],
                                       dtype=torch.int64,
                                       device=fg_bg_mask.device)
                rand_indices = torch.rand_like(fg_bg_mask).argsort()
                fg_bg_mask = fg_bg_mask[rand_indices]  # 打乱顺序,实现“随机”
                indices = indices[rand_indices]

                sorted_indices = fg_bg_mask.argsort(descending=True)
                fg_bg_mask = fg_bg_mask[sorted_indices]
                indices = indices[sorted_indices]
                fg_indices = indices[:self.num_pos]
                fg_mask = fg_bg_mask[:self.num_pos]
                bg_indices = indices[-self.num_neg:]
                bg_mask = fg_bg_mask[-self.num_neg:]

                indices = torch.cat([fg_indices, bg_indices], dim=0)
                fg_bg_mask = torch.cat([fg_mask, bg_mask], dim=0)

                matched_idx = matched_idx[indices]
                _anchors = anchors[indices]

                total_cls_pred.append(cls_pred[i][indices])
                total_reg_pred.append(reg_pred[i][indices])
                total_fg_bg_mask.append(fg_bg_mask)
                total_reg_target.append(
                    self.box_coder.encode(_anchors, gt_bboxes[i][matched_idx]))

                # from lib import debug
                # debug.rpn_pos_bboxes.append(_anchors[fg_bg_mask == 1])
                # print(cls_pred[i][indices][fg_bg_mask == 1].detach().cpu().numpy())

            # (BS, num_samples)
            cls_pred = torch.stack(total_cls_pred)
            # (BS, num_samples, 4)
            reg_pred = torch.stack(total_reg_pred)
            # (BS, num_samples)
            fg_bg_mask = torch.stack(total_fg_bg_mask)
            # (BS, num_samples, 4)
            reg_target = torch.stack(total_reg_target)

            cls_label = torch.where(fg_bg_mask == 1, torch.ones_like(cls_pred),
                                    torch.zeros_like(cls_pred))
            cls_loss = F.binary_cross_entropy_with_logits(
                cls_pred[fg_bg_mask != 0], cls_label[fg_bg_mask != 0])
            if torch.any(torch.isnan(reg_target[fg_bg_mask == 1])):
                raise Exception("some elements in reg_target is nan")
            if torch.any(torch.isnan(reg_pred[fg_bg_mask == 1])):
                raise Exception("some elements in reg_pred is nan")
            if torch.any(fg_bg_mask == 1):
                reg_loss = F.smooth_l1_loss(reg_pred[fg_bg_mask == 1],
                                            reg_target[fg_bg_mask == 1])
            else:  # 没有正样本
                reg_loss = torch.zeros_like(cls_loss)

            cls_pred = cls_pred >= 0.5
            cls_label = cls_label == 1
            acc = torch.mean(
                (cls_label == cls_pred)[fg_bg_mask != 0].to(torch.float))
            num_pos = (fg_bg_mask == 1).sum()
            num_neg = (fg_bg_mask == -1).sum()

            TP = (cls_pred == True)[fg_bg_mask == 1].sum().to(torch.float32)
            FP = (cls_pred == True)[fg_bg_mask == -1].sum().to(torch.float32)
            # TN = (cls_pred == False)[fg_bg_mask == -1].sum()
            FN = (cls_pred == False)[fg_bg_mask == 1].sum().to(torch.float32)

            precision = TP / (TP + FP)
            recall = TP / (TP + FN)

            all_cls_pred = torch.stack(all_cls_pred)
            all_fg_bg_mask = torch.stack(all_fg_bg_mask)
            all_cls_pred = all_cls_pred >= 0
            all_TP = (all_cls_pred == True)[all_fg_bg_mask == 1].sum().to(
                torch.float32)
            all_FP = (all_cls_pred == True)[all_fg_bg_mask == -1].sum().to(
                torch.float32)
            all_FN = (all_cls_pred == False)[all_fg_bg_mask == 1].sum().to(
                torch.float32)
            all_precision = all_TP / (all_TP + all_FP)
            all_recall = all_TP / (all_TP + all_FN)

            if self.logger is not None:
                # print("TP {} FP {} FN {}".format(TP.detach().cpu().item(), FP.detach().cpu().item(), FN.detach().cpu().item()))
                # print("all_TP {} all_FP {} all_FN {}".format(all_TP.detach().cpu().item(), all_FP.detach().cpu().item(), all_FN.detach().cpu().item()))
                # print("precision {} recall {} all_precision {} all_recall {}".format(precision.detach().cpu().item(),
                #                                                                      recall.detach().cpu().item(),
                #                                                                      all_precision.detach().cpu().item(),
                #                                                                      all_recall.detach().cpu().item()))
                self.logger.add_scalar("rpn/TP", TP.detach().cpu().item())
                self.logger.add_scalar("rpn/FP", FP.detach().cpu().item())
                self.logger.add_scalar("rpn/FN", FN.detach().cpu().item())
                self.logger.add_scalar("rpn/all_TP",
                                       all_TP.detach().cpu().item())
                self.logger.add_scalar("rpn/all_FP",
                                       all_FP.detach().cpu().item())
                self.logger.add_scalar("rpn/all_FN",
                                       all_FN.detach().cpu().item())
                self.logger.add_scalar("rpn/acc", acc.detach().cpu().item())
                self.logger.add_scalar("rpn/num_pos",
                                       num_pos.detach().cpu().item())
                self.logger.add_scalar("rpn/num_neg",
                                       num_neg.detach().cpu().item())
                self.logger.add_scalar("rpn/precision",
                                       precision.detach().cpu().item())
                self.logger.add_scalar("rpn/recall",
                                       recall.detach().cpu().item())
                self.logger.add_scalar("rpn/all_precision",
                                       all_precision.detach().cpu().item())
                self.logger.add_scalar("rpn/all_recall",
                                       all_recall.detach().cpu().item())

            return bboxes, cls_loss, reg_loss

        return bboxes, None, None
Пример #7
0
 def _evaluate_iou(self, preds, targets):
     """Evaluate intersection over union (IOU) for target from dataset and output prediction from model."""
     # no box detected, 0 IOU
     if preds["boxes"].shape[0] == 0:
         return torch.tensor(0.0, device=preds["boxes"].device)
     return box_iou(preds["boxes"], targets["boxes"]).diag().mean()
Пример #8
0
    def loss(self,
             outputs: tuple,
             gt_bboxes: list,
             gt_labels: list,
             iou_thresh: float = 0.5) -> dict:
        """ 損失関数

        Args:
            outputs (tuple): (予測オフセット, 予測存在率,  予測信頼度)
                            * 予測オフセット : (B, P, 4) (coord fmt: [Δx, Δy, Δw, Δh]) (P: PBoxの数. P = 10647 の想定.)
                            * 予測存在率     : (B, P)
                            * 予測信頼度     : (B, P, num_classes)
            gt_bboxes (list): 正解BBOX座標 [(G1, 4), (G2, 4), ...] (coord fmt: [x, y, w, h])
            gt_labels (list): 正解ラベル [(G1,), (G2,)]
            iou_thresh (float): Potitive / Negative を判定する際の iou の閾値

        Returns:
            dict: {
                loss: xxx,
                loss_loc: xxx,
                loss_obj: xxx,
                loss_conf: xxx
            }
        """
        out_locs, out_objs, out_confs = outputs
        device = out_locs.device

        # [Step 1]
        #   target を作成する
        #   - Pred を GT に対応させる
        #     - Grid 内に (x, y) が含まれ、BBox との IoU が最大となる Prior Box -> その BBox に割り当てる
        #   - 最大 IoU が 0.5 以上かつ GT に対応しない場合、 Label を -1 に設定する (ignore 対象とする)
        #   - 最大 IoU が 0.5 未満の場合、Label を 0 に設定する

        B, P, C = out_confs.size()
        target_locs = torch.zeros(B, P, 4, device=device)
        target_labels = torch.zeros(B, P, dtype=torch.long, device=device)

        pboxes, grid_length = self.pboxes.to(device).split(4, dim=1)
        for i in range(B):
            bboxes = gt_bboxes[i].to(device)
            labels = gt_labels[i].to(device)

            is_in_grid = (pboxes[:, [0]] <= bboxes[:, 0]) * (bboxes[:, 0] < pboxes[:, [0]] + grid_length) * \
                (pboxes[:, [1]] <= bboxes[:, 1]) * (bboxes[:, 1] < pboxes[:, [1]] + grid_length)
            bboxes_xyxy = box_convert(bboxes, in_fmt='xywh', out_fmt='xyxy')
            pboxes_xyxy = box_convert(pboxes, in_fmt='xywh', out_fmt='xyxy')
            ious = box_iou(pboxes_xyxy, bboxes_xyxy)
            best_ious, best_pbox_ids = (ious * is_in_grid).max(dim=0)
            max_ious, matched_bbox_ids = ious.max(dim=1)

            # 各 BBox に対し最大 IoU を取る Prior Box を選ぶ -> その BBox に割り当てる
            for j in range(len(best_pbox_ids)):
                matched_bbox_ids[best_pbox_ids][j] = j
            max_ious[best_pbox_ids] = 1.

            bboxes = bboxes[matched_bbox_ids]
            locs = self._calc_delta(bboxes, pboxes, grid_length)
            labels = labels[matched_bbox_ids]
            labels[max_ious < 1.] = -1  # void クラス
            labels[max_ious.less(
                iou_thresh)] = 0  # 0 が背景クラス. Positive Class は 1 ~

            target_locs[i] = locs
            target_labels[i] = labels

        # [Step 2]
        #   pos_mask, neg_mask を作成する
        #   - pos_mask: Label が > 0 のもの
        #   - neg_mask: label が = 0 のもの
        pos_mask = target_labels > 0
        neg_mask = target_labels == 0

        N = pos_mask.sum()
        # [Step 2]
        #   Positive に対して、 Localization Loss を計算する
        loss_loc = (F.binary_cross_entropy_with_logits(
            out_locs[pos_mask][..., :2],
            target_locs[pos_mask][..., :2],
            reduction='sum') + F.mse_loss(out_locs[pos_mask][..., 2:],
                                          target_locs[pos_mask][..., 2:],
                                          reduction='sum')) / N

        # [Step 3]
        #   Positive に対して、Confidence Loss を計算する
        loss_conf = F.binary_cross_entropy_with_logits(
            out_confs[pos_mask],
            F.one_hot(target_labels[pos_mask] - 1,
                      num_classes=self.nc).float(),
            reduction='sum') / N

        # [Step 4]
        #   Positive & Negative に対して、 Objectness Loss を計算する
        loss_obj = F.binary_cross_entropy_with_logits(
            out_objs[pos_mask + neg_mask],
            pos_mask[pos_mask + neg_mask].float(),
            reduction='sum') / N

        # [Step 5]
        #   損失の和を計算する
        loss = loss_loc + loss_obj + loss_conf

        return {
            'loss': loss,
            'loss_loc': loss_loc,
            'loss_conf': loss_conf,
            'loss_obj': loss_obj
        }
Пример #9
0
    def set_ignoring(self, noobj_mask, inference, targets, head_anchors,
                     head_size):
        """
        Args:
            head_anchors: anchors of this head
        """
        batch_size = len(targets)
        head_h, head_w = head_size
        # cx, cy, w, h
        x = (1 + self.EGS_factor) * torch.sigmoid(
            inference[..., 0]) - 0.5 * self.EGS_factor
        y = (1 + self.EGS_factor) * torch.sigmoid(
            inference[..., 1]) - 0.5 * self.EGS_factor
        w = inference[..., 2]
        h = inference[..., 3]

        # set device
        FloatTensor = torch.cuda.FloatTensor if self.device == 'cuda' else torch.FloatTensor
        # generate coordinate grids
        grid_x = torch.linspace(0, head_w - 1, head_w)
        grid_x = grid_x.repeat(head_h,
                               1).repeat(batch_size * self.n_head_anchors, 1,
                                         1)
        grid_x = grid_x.view(x.shape).type(FloatTensor)

        grid_y = torch.linspace(0, head_h - 1, head_h)
        grid_y = grid_y.repeat(head_w,
                               1).t().repeat(batch_size * self.n_head_anchors,
                                             1, 1)
        grid_y = grid_y.view(y.shape).type(FloatTensor)

        # generate anchors for coordinate grids
        anchor_w = FloatTensor(head_anchors)[:, 0].unsqueeze(1)
        anchor_w = anchor_w.repeat(batch_size,
                                   1).repeat(1, 1,
                                             head_h * head_w).view(w.shape)

        anchor_h = FloatTensor(head_anchors)[:, 1].unsqueeze(1)
        anchor_h = anchor_h.repeat(batch_size,
                                   1).repeat(1, 1,
                                             head_h * head_w).view(h.shape)

        # calculate bboxes
        infer_boxes = FloatTensor(inference[..., :4].shape)
        infer_boxes[..., 0] = x.data + grid_x
        infer_boxes[..., 1] = y.data + grid_y
        infer_boxes[..., 2] = torch.exp(w.data) * anchor_w
        infer_boxes[..., 3] = torch.exp(h.data) * anchor_h

        for bs, tar in enumerate(targets):
            if len(tar) == 0:
                continue
            # (num_anchors, 4)
            ignored_boxes = infer_boxes[bs].view(-1, 4)

            # groundtruth on this head size
            gt_x = tar[:, 0:1] * head_w
            gt_y = tar[:, 1:2] * head_h
            gt_w = tar[:, 2:3] * head_w
            gt_h = tar[:, 3:4] * head_h
            gt_box = FloatTensor(torch.cat([gt_x, gt_y, gt_w, gt_h], -1))

            # calculate IoU
            iou_metrix = ops.box_iou(gt_box, ignored_boxes)
            # get nearest groundtruth for each anchor
            max_iou, _ = torch.max(iou_metrix, dim=0)
            max_iou = max_iou.view(infer_boxes[bs].shape[:3])
            # turn off noobj_mask according to threshold
            noobj_mask[bs][max_iou > self.ignoring_threshold] = 0
        return noobj_mask, infer_boxes
Пример #10
0
def calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels,
                  true_difficulties):
    """
    Calculate the Mean Average Precision (mAP) of detected objects.

    See https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173 for an explanation

    :param det_boxes: list of tensors, one tensor for each image containing detected objects' bounding boxes
    :param det_labels: list of tensors, one tensor for each image containing detected objects' labels
    :param det_scores: list of tensors, one tensor for each image containing detected objects' labels' scores
    :param true_boxes: list of tensors, one tensor for each image containing actual objects' bounding boxes
    :param true_labels: list of tensors, one tensor for each image containing actual objects' labels
    :param true_difficulties: list of tensors, one tensor for each image containing actual objects' difficulty (0 or 1)
    :return: list of average precisions for all classes, mean average precision (mAP)
    """
    assert len(det_boxes) == len(
        det_labels
    ) == len(det_scores) == len(true_boxes) == len(true_labels) == len(
        true_difficulties
    )  # these are all lists of tensors of the same length, i.e. number of images
    n_classes = len(label_map)

    # Store all (true) objects in a single continuous tensor while keeping track of the image it is from
    true_images = list()
    for i in range(len(true_labels)):
        true_images.extend([i] * true_labels[i].size(0))
    device = det_boxes[0].device
    true_images = torch.LongTensor(true_images).to(
        device
    )  # (n_objects), n_objects is the total no. of objects across all images
    true_boxes = torch.cat(true_boxes, dim=0)  # (n_objects, 4)
    true_labels = torch.cat(true_labels, dim=0)  # (n_objects)
    true_difficulties = torch.cat(true_difficulties, dim=0)  # (n_objects)

    assert true_images.size(0) == true_boxes.size(0) == true_labels.size(0)

    # Store all detections in a single continuous tensor while keeping track of the image it is from
    det_images = list()
    for i in range(len(det_labels)):
        det_images.extend([i] * det_labels[i].size(0))
    det_images = torch.LongTensor(det_images).to(device)  # (n_detections)
    det_boxes = torch.cat(det_boxes, dim=0)  # (n_detections, 4)
    det_labels = torch.cat(det_labels, dim=0)  # (n_detections)
    det_scores = torch.cat(det_scores, dim=0)  # (n_detections)

    assert det_images.size(0) == det_boxes.size(0) == det_labels.size(
        0) == det_scores.size(0)

    # Calculate APs for each class (except background)
    average_precisions = torch.zeros((n_classes - 1),
                                     dtype=torch.float)  # (n_classes - 1)
    for c in range(1, n_classes):
        # Extract only objects with this class
        true_class_images = true_images[true_labels == c]  # (n_class_objects)
        true_class_boxes = true_boxes[true_labels == c]  # (n_class_objects, 4)
        true_class_difficulties = true_difficulties[true_labels ==
                                                    c]  # (n_class_objects)
        n_easy_class_objects = (
            ~true_class_difficulties).sum().item()  # ignore difficult objects

        # Keep track of which true objects with this class have already been 'detected'
        # So far, none
        true_class_boxes_detected = torch.zeros(
            (true_class_difficulties.size(0)),
            dtype=torch.uint8).to(device)  # (n_class_objects)

        # Extract only detections with this class
        det_class_images = det_images[det_labels == c]  # (n_class_detections)
        det_class_boxes = det_boxes[det_labels == c]  # (n_class_detections, 4)
        det_class_scores = det_scores[det_labels == c]  # (n_class_detections)
        n_class_detections = det_class_boxes.size(0)
        if n_class_detections == 0:
            continue

        # Sort detections in decreasing order of confidence/scores
        det_class_scores, sort_ind = torch.sort(
            det_class_scores, dim=0, descending=True)  # (n_class_detections)
        det_class_images = det_class_images[sort_ind]  # (n_class_detections)
        det_class_boxes = det_class_boxes[sort_ind]  # (n_class_detections, 4)

        # In the order of decreasing scores, check if true or false positive
        true_positives = torch.zeros(
            (n_class_detections),
            dtype=torch.float).to(device)  # (n_class_detections)
        false_positives = torch.zeros(
            (n_class_detections),
            dtype=torch.float).to(device)  # (n_class_detections)
        for d in range(n_class_detections):
            this_detection_box = det_class_boxes[d].unsqueeze(0)  # (1, 4)
            this_image = det_class_images[d]  # (), scalar

            # Find objects in the same image with this class, their difficulties, and whether they have been detected before
            object_boxes = true_class_boxes[
                true_class_images == this_image]  # (n_class_objects_in_img)
            object_difficulties = true_class_difficulties[
                true_class_images == this_image]  # (n_class_objects_in_img)
            # If no such object in this image, then the detection is a false positive
            if object_boxes.size(0) == 0:
                false_positives[d] = 1
                continue

            # Find maximum overlap of this detection with objects in this image of this class
            overlaps = box_iou(this_detection_box,
                               object_boxes)  # (1, n_class_objects_in_img)
            max_overlap, ind = torch.max(overlaps.squeeze(0),
                                         dim=0)  # (), () - scalars

            # 'ind' is the index of the object in these image-level tensors 'object_boxes', 'object_difficulties'
            # In the original class-level tensors 'true_class_boxes', etc., 'ind' corresponds to object with index...
            original_ind = torch.LongTensor(range(true_class_boxes.size(0)))[
                true_class_images == this_image][ind]
            # We need 'original_ind' to update 'true_class_boxes_detected'

            # If the maximum overlap is greater than the threshold of 0.5, it's a match
            if max_overlap.item() > 0.5:
                # If the object it matched with is 'difficult', ignore it
                if not object_difficulties[ind]:
                    # If this object has already not been detected, it's a true positive
                    if true_class_boxes_detected[original_ind] == 0:
                        true_positives[d] = 1
                        true_class_boxes_detected[
                            original_ind] = 1  # this object has now been detected/accounted for
                    # Otherwise, it's a false positive (since this object is already accounted for)
                    else:
                        false_positives[d] = 1
            # Otherwise, the detection occurs in a different location than the actual object, and is a false positive
            else:
                false_positives[d] = 1

        # Compute cumulative precision and recall at each detection in the order of decreasing scores
        cumul_true_positives = torch.cumsum(true_positives,
                                            dim=0)  # (n_class_detections)
        cumul_false_positives = torch.cumsum(false_positives,
                                             dim=0)  # (n_class_detections)
        cumul_precision = cumul_true_positives / (
            cumul_true_positives + cumul_false_positives + 1e-10
        )  # (n_class_detections)
        cumul_recall = cumul_true_positives / n_easy_class_objects  # (n_class_detections)

        # Find the mean of the maximum of the precisions corresponding to recalls above the threshold 't'
        recall_thresholds = torch.arange(start=0, end=1.1,
                                         step=.1).tolist()  # (11)
        precisions = torch.zeros((len(recall_thresholds)),
                                 dtype=torch.float).to(device)  # (11)
        for i, t in enumerate(recall_thresholds):
            recalls_above_t = cumul_recall >= t
            if recalls_above_t.any():
                precisions[i] = cumul_precision[recalls_above_t].max()
            else:
                precisions[i] = 0.
        average_precisions[c -
                           1] = precisions.mean()  # c is in [1, n_classes - 1]

    # Calculate Mean Average Precision (mAP)
    mean_average_precision = average_precisions.mean().item()

    # Keep class-wise average precisions in a dictionary
    average_precisions = {
        rev_label_map[c + 1]: v
        for c, v in enumerate(average_precisions.tolist())
    }

    return average_precisions, mean_average_precision
Пример #11
0
 def _iou_boxes(self, other: 'BoundingBoxes') -> Tensor:
     sz = other[0].sz
     assert len(self) == len(other)
     a = self.to_tensor(sz).cpu().unsqueeze(1)
     b = other.to_tensor(sz).cpu().unsqueeze(1)
     return torch.cat([box_iou(i, j) for i, j in zip(a, b)]).squeeze()
Пример #12
0
 def _iou_box(self, other: 'BoundingBox') -> Tensor:
     a = self.to_tensor(other.sz)
     b = other.x[None].to(a.device)
     return box_iou(a, b).squeeze(-1)
Пример #13
0
 def iou(self, other: Union['BoundingBox', 'BoundingBoxes']) -> Rank0Tensor:
     if isinstance(other, BoundingBoxes): return other.iou(self)
     a = self.x[None]
     b = other.get_resized_x(self.sz)[None].to(a.device)
     return box_iou(a, b).item()
Пример #14
0
    def forward(self, images, labels=None, gt_bboxes=None):
        """

        :param images: shape (BS, C, H, W)
        :param labels: shape (BS, n_objs)
        :param gt_bboxes: shape (BS, n_objs, 4)
        :return:
        """

        feats = self.backbone(images)
        # rois shape (BS, num_rois, 4)
        rois, rpn_cls_loss, rpn_reg_loss = self.rpn(feats, labels, gt_bboxes)

        # rois[..., 0].clamp_(0, self.image_size[0])
        # rois[..., 1].clamp_(0, self.image_size[1])
        # rois[..., 2].clamp_(0, self.image_size[0])
        # rois[..., 3].clamp_(0, self.image_size[1])

        # from lib import debug
        # debug.rois.append(rois)

        if self.training:
            # 把gt bboxes加入到rois中
            rois = torch.cat([rois, gt_bboxes], dim=1)

        # rois 添加batch_id维
        BS, num_rois, _ = rois.shape
        batch_id = torch.stack(
            [torch.full_like(rois[i, :, :1], i) for i in range(BS)], dim=0)
        # (BS, num_rois, 5)
        rois = torch.cat([batch_id, rois], dim=2)
        # (BS*num_rois, 5)
        rois = rois.reshape((-1, 5))

        # roi pooling in each feature map
        if self.roi_pooling == "roi_align":
            roi_pooling = ops.roi_align
        elif self.roi_pooling == "roi_pool":
            roi_pooling = ops.roi_pool
        else:
            raise Exception("{} is not support".format(self.roi_pooling))

        if len(feats) == 1:
            _, feat = feats.popitem()
            roi_feats = roi_pooling(
                feat, rois,
                (self.roi_pooling_output_size, self.roi_pooling_output_size),
                1 / self.strides[0])
        else:
            feat_levels = np.log2(self.strides).astype(np.int64)
            feat_names = [n for n in feats.keys()]
            assert len(feat_levels) == len(feat_names)

            w = rois[:, 3] - rois[:, 1]
            h = rois[:, 4] - rois[:, 2]
            roi_levels = torch.floor(
                4 + torch.log2(torch.sqrt(w * h) / 224 + 1e-6))

            _f = feats[feat_names[0]]
            C = _f.shape[1]
            device = _f.device
            dtype = _f.dtype
            roi_feats = torch.zeros(
                (BS * num_rois, C, self.roi_pooling_output_size,
                 self.roi_pooling_output_size),
                dtype=dtype,
                device=device)

            for i, (feat_level,
                    feat_name) in enumerate(zip(feat_levels, feat_names)):
                mask_in_level = roi_levels == feat_level
                _roi_feats = roi_pooling(feats[feat_name], rois[mask_in_level],
                                         (self.roi_pooling_output_size,
                                          self.roi_pooling_output_size),
                                         1 / self.strides[i])
                roi_feats[mask_in_level] = _roi_feats

        # roi_feats shape (BS*num_rois, C, self.roi_pooling_output_size, self.roi_pooling_output_size)

        # roi head
        # (BS*num_rois, num_vector)
        box_feats = self.roi_head(roi_feats)
        # (BS*num_rois, num_classes+1)
        cls_pred = self.cls(box_feats)
        # (BS*num_rois, num_classes*4)
        reg_pred = self.reg(box_feats)
        # (BS, num_rois, num_classes+1)
        cls_pred = cls_pred.reshape((BS, num_rois, -1))
        # (BS, num_rois, num_classes, 4)
        reg_pred = reg_pred.reshape((BS, num_rois, -1, 4))

        if self.training:
            # (BS, num_rois, 5)
            rois = rois.reshape((BS, num_rois, 5))
            rois = rois[:, :, 1:]

            total_cls_pred = []
            total_reg_pred = []
            total_fg_bg_mask = []
            total_labels = []
            total_reg_target = []

            for i in range(BS):
                # 为每个roi分配label
                areas = ops.boxes.box_area(rois[i])
                ignore_mask = areas == 0
                # (num_rois, num_gt_bboxes)
                ious = ops.box_iou(rois[i], gt_bboxes[i])
                # rois中有box面积为0,比如(-1,-1,-1,-1),导致ious中出现nan
                # 把nan换成0
                zero_mask = (areas == 0).reshape(-1, 1).expand_as(ious)
                ious[zero_mask] = 0

                if torch.any(torch.isnan(ious)):
                    raise Exception("some elements in ious is nan")

                #############################################################
                # 统计rois是否覆盖所有gt,gt的召回率
                num_gt = labels.shape[1]
                _ious_withou_gt = ious[:-num_gt]  # 去掉rois中的gt
                _ious_max_withou_gt, _ = torch.max(_ious_withou_gt, dim=0)
                gt_recall = (_ious_max_withou_gt >= 0.5)[labels[i] != -1].to(
                    torch.float32).mean()
                if self.logger is not None:
                    self.logger.add_scalar("rcnn/gt_recall_0.5",
                                           gt_recall.detach().cpu().item())
                gt_recall = (_ious_max_withou_gt >= 0.7)[labels[i] != -1].to(
                    torch.float32).mean()
                if self.logger is not None:
                    self.logger.add_scalar("rcnn/gt_recall_0.7",
                                           gt_recall.detach().cpu().item())
                #############################################################

                # the roi/rois with the highest Intersection-over-Union (IoU)
                # overlap with a ground-truth box
                iou_max_gt, _ = torch.max(ious, dim=0)

                # 不考虑gt_bboxes中填充的部分
                iou_max_gt = torch.where(labels[i] == -1,
                                         torch.ones_like(iou_max_gt),
                                         iou_max_gt)
                highest_mask = (ious == iou_max_gt)
                fg_mask = torch.any(highest_mask, dim=1)
                # a roi that has an IoU overlap higher than fg_iou_thresh with any ground-truth box
                iou_max, matched_idx = torch.max(ious, dim=1)
                # 1 for foreground -1 for background 0 for ignore
                fg_bg_mask = torch.zeros_like(iou_max)
                # confirm positive samples
                fg_bg_mask = torch.where(iou_max >= self.fg_iou_thresh,
                                         torch.ones_like(iou_max), fg_bg_mask)
                fg_bg_mask = torch.where(fg_mask, torch.ones_like(iou_max),
                                         fg_bg_mask)
                # confirm negetive samples
                fg_bg_mask = torch.where(iou_max <= self.bg_iou_thresh,
                                         torch.full_like(iou_max, -1),
                                         fg_bg_mask)
                # ignore samples
                fg_bg_mask = torch.where(ignore_mask,
                                         torch.zeros_like(iou_max), fg_bg_mask)

                # 随机采样
                indices = torch.arange(fg_bg_mask.shape[0],
                                       dtype=torch.int64,
                                       device=fg_bg_mask.device)
                rand_indices = torch.rand_like(fg_bg_mask).argsort()
                fg_bg_mask = fg_bg_mask[rand_indices]  # 打乱顺序,实现“随机”
                indices = indices[rand_indices]

                sorted_indices = fg_bg_mask.argsort(descending=True)
                fg_bg_mask = fg_bg_mask[sorted_indices]
                indices = indices[sorted_indices]
                fg_indices = indices[:self.num_pos]
                fg_mask = fg_bg_mask[:self.num_pos]
                bg_indices = indices[-self.num_neg:]
                bg_mask = fg_bg_mask[-self.num_neg:]

                indices = torch.cat([fg_indices, bg_indices], dim=0)
                fg_bg_mask = torch.cat([fg_mask, bg_mask], dim=0)
                matched_idx = matched_idx[indices]
                # (num_samples)
                # label 暂时不考虑background
                label = labels[i][matched_idx]
                # 把标签-1变成0,F.one_hot不支持负数
                _label = label.clone()
                _label[_label == -1] = 0
                # (num_samples, num_classes)
                label_mask = F.one_hot(_label, self.num_classes)

                # (num_samples*num_classes,)
                label_mask = label_mask.reshape(-1).to(torch.bool)
                _reg_pred = reg_pred[i][indices].reshape(-1, 4)
                # (num_samples, 4)
                _reg_pred = _reg_pred[label_mask]
                _rois = rois[i][indices]
                total_cls_pred.append(cls_pred[i][indices])
                total_reg_pred.append(_reg_pred)
                total_fg_bg_mask.append(fg_bg_mask)
                total_labels.append(label)
                total_reg_target.append(
                    self.box_coder.encode(_rois, gt_bboxes[i][matched_idx]))

                # from lib import debug
                # debug.rcnn_pos_bboxes.append(self.box_coder.decode(_rois, _reg_pred)[fg_bg_mask == 1])

            # (BS, num_samples, num_classes+1)
            cls_pred = torch.stack(total_cls_pred)
            # (BS, num_samples, 4)
            reg_pred = torch.stack(total_reg_pred)
            # (BS, num_samples)
            fg_bg_mask = torch.stack(total_fg_bg_mask)
            # (BS, num_samples)
            labels = torch.stack(total_labels)
            # (BS, num_samples, 4)
            reg_target = torch.stack(total_reg_target)
            if torch.any(torch.isnan(reg_target[fg_bg_mask == 1])):
                raise Exception("some elements in reg_target is nan")
            if torch.any(torch.isinf(reg_target[fg_bg_mask == 1])):
                raise Exception("some elements in reg_target is inf")
            assert torch.any(fg_bg_mask == 1)  # 把gt加入到rois中,不可能没有正样本
            rcnn_reg_loss = F.smooth_l1_loss(reg_pred[fg_bg_mask == 1],
                                             reg_target[fg_bg_mask == 1])
            # if torch.any(fg_bg_mask == 1):
            #     rcnn_reg_loss = F.smooth_l1_loss(reg_pred[fg_bg_mask == 1], reg_target[fg_bg_mask == 1])
            # else:  # 没有正样本
            #     rcnn_reg_loss = torch.zeros([], dtype=reg_pred.dtype, device=reg_pred.device)

            cls_label = labels + 1  # 所有类别id+1,为background空出0
            cls_label = cls_label.reshape((-1, ))
            cls_pred = cls_pred.reshape((-1, self.num_classes + 1))
            fg_bg_mask = fg_bg_mask.reshape(-1, )
            # 设置background的label为0
            cls_label = torch.where(fg_bg_mask == -1,
                                    torch.zeros_like(cls_label), cls_label)
            rcnn_cls_loss = F.cross_entropy(cls_pred[fg_bg_mask != 0],
                                            cls_label[fg_bg_mask != 0])

            cls_pred = torch.argmax(cls_pred, dim=1)
            acc = torch.mean(
                (cls_pred == cls_label)[fg_bg_mask != 0].to(torch.float))
            num_pos = (fg_bg_mask == 1).sum()
            num_neg = (fg_bg_mask == -1).sum()
            if self.logger is not None:
                self.logger.add_scalar("rcnn/acc", acc.detach().cpu().item())
                self.logger.add_scalar("rcnn/num_pos",
                                       num_pos.detach().cpu().item())
                self.logger.add_scalar("rcnn/num_neg",
                                       num_neg.detach().cpu().item())

            return rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss

        cls_scores = F.softmax(cls_pred, dim=2)
        # (BS, num_rois, num_classes)
        cls_scores = cls_scores[:, :, 1:]

        # from lib import debug
        # debug.rois_scores.append(cls_scores)

        # rois: (BS*num_rois, 5)
        # reg_pred: (BS, num_rois, num_classes, 4)
        # _reg_pred: (num_classes, BS*num_rois, 4)
        _reg_pred = reg_pred.permute(
            (2, 0, 1, 3)).reshape(self.num_classes, BS * num_rois, 4)
        # (num_classes, BS*num_rois, 4)
        reg_bboxes = self.box_coder.decode(rois[:, 1:], _reg_pred)
        reg_bboxes[..., 0].clamp_(0, self.image_size[0])
        reg_bboxes[..., 1].clamp_(0, self.image_size[1])
        reg_bboxes[..., 2].clamp_(0, self.image_size[0])
        reg_bboxes[..., 3].clamp_(0, self.image_size[1])
        # (BS, num_rois, num_classes, 4)
        reg_bboxes = reg_bboxes.permute((1, 0, 2)).reshape(
            (BS, num_rois, self.num_classes, 4))

        # (num_rois, num_classes)
        classes_id = torch.cat(
            [
                # (num_rois, 1)
                torch.full_like(cls_scores[0, :, :1], i)
                for i in range(self.num_classes)
            ],
            dim=1)
        # (num_rois*num_classes)
        classes_id = classes_id.reshape((-1, ))
        # (BS, num_rois*num_classes)
        cls_scores = cls_scores.reshape((BS, -1))
        # (BS, num_rois*num_classes, 4)
        reg_bboxes = reg_bboxes.reshape((BS, -1, 4))

        scores = []
        bboxes = []
        labels = []
        for i in range(BS):
            _scores = torch.full((self.max_objs_per_image, ),
                                 -1,
                                 dtype=cls_scores.dtype,
                                 device=cls_scores.device)
            _labels = torch.full((self.max_objs_per_image, ),
                                 -1,
                                 dtype=classes_id.dtype,
                                 device=classes_id.device)
            _bboxes = torch.full((self.max_objs_per_image, 4),
                                 -1,
                                 dtype=reg_bboxes.dtype,
                                 device=reg_bboxes.device)
            keep_mask = cls_scores[i] >= self.obj_thresh
            _reg_bboxes = reg_bboxes[i][keep_mask]
            _cls_scores = cls_scores[i][keep_mask]
            _classes_id = classes_id[keep_mask]
            keep = ops.boxes.batched_nms(_reg_bboxes, _cls_scores, _classes_id,
                                         self.nms_thresh)
            n_keep = keep.shape[0]
            n_keep = min(n_keep, self.max_objs_per_image)
            keep = keep[:n_keep]
            _scores[:n_keep] = _cls_scores[keep]
            _labels[:n_keep] = _classes_id[keep]
            _bboxes[:n_keep] = _reg_bboxes[keep]

            scores.append(_scores)
            labels.append(_labels)
            bboxes.append(_bboxes)

        scores = torch.stack(scores)  # (BS, max_objs)
        labels = torch.stack(labels)  # (BS, max_objs)
        bboxes = torch.stack(bboxes)  # (BS, max_objs, 4)

        return scores, labels, bboxes
Пример #15
0
def _get_graph_centers(boxes, cls_prob, im_labels):
    """Get graph centers."""

    num_images, num_classes = im_labels.shape
    assert num_images == 1, 'batch size shoud be equal to 1'
    dev = cls_prob.device
    gt_boxes = torch.zeros((0, 4), dtype=boxes.dtype, device=dev)
    gt_classes = torch.zeros((0, 1), dtype=torch.long, device=dev)
    gt_scores = torch.zeros((0, 1), dtype=cls_prob.dtype, device=dev)
    for i in im_labels.nonzero()[:, 1]:
        cls_prob_tmp = cls_prob[:, i]
        idxs = (cls_prob_tmp >= 0).nonzero()[:, 0]
        idxs_tmp = _get_top_ranking_propoals(cls_prob_tmp[idxs].reshape(-1, 1))
        idxs = idxs[idxs_tmp]
        boxes_tmp = boxes[idxs, :]
        cls_prob_tmp = cls_prob_tmp[idxs]

        graph = (ops.box_iou(boxes_tmp, boxes_tmp) > 0.4).float()

        keep_idxs = []
        gt_scores_tmp = []
        count = cls_prob_tmp.size(0)
        while True:
            order = graph.sum(dim=1).argsort(descending=True)
            tmp = order[0]
            keep_idxs.append(tmp)
            inds = (graph[tmp, :] > 0).nonzero()[:, 0]
            gt_scores_tmp.append(cls_prob_tmp[inds].max())

            graph[:, inds] = 0
            graph[inds, :] = 0
            count = count - len(inds)
            if count <= 5:
                break

        gt_boxes_tmp = boxes_tmp[keep_idxs, :].view(-1, 4).to(dev)
        gt_scores_tmp = torch.tensor(gt_scores_tmp, device=dev)

        keep_idxs_new = torch.from_numpy(
            (gt_scores_tmp.argsort().to('cpu').numpy()[-1:(
                -1 - min(len(gt_scores_tmp), 5)):-1]).copy()).to(dev)

        gt_boxes = torch.cat((gt_boxes, gt_boxes_tmp[keep_idxs_new, :]))
        gt_scores = torch.cat(
            (gt_scores, gt_scores_tmp[keep_idxs_new].reshape(-1, 1)))
        gt_classes = torch.cat((gt_classes, (i + 1) * torch.ones(
            (len(keep_idxs_new), 1), dtype=torch.long, device=dev)))

        # If a proposal is chosen as a cluster center,
        # we simply delete a proposal from the candidata proposal pool,
        # because we found that the results of different strategies are similar and this strategy is more efficient
        another_tmp = idxs.to('cpu')[torch.tensor(keep_idxs)][keep_idxs_new.to(
            'cpu')].numpy()
        cls_prob = torch.from_numpy(
            np.delete(cls_prob.to('cpu').numpy(), another_tmp, axis=0)).to(dev)
        boxes = torch.from_numpy(
            np.delete(boxes.to('cpu').numpy(), another_tmp, axis=0)).to(dev)

    proposals = {
        'gt_boxes': gt_boxes.to(dev),
        'gt_classes': gt_classes.to(dev),
        'gt_scores': gt_scores.to(dev)
    }

    return proposals