示例#1
0
 def _forward_single_image(self, left_prediction: BoxList,
                           right_prediction: BoxList) -> DisparityMap:
     left_bbox = left_prediction.bbox
     right_bbox = right_prediction.bbox
     disparity_preds = left_prediction.get_field('disparity')
     mask_preds = left_prediction.get_field('mask').clone()
     # print(disparity_preds.shape)
     assert len(left_bbox) == len(right_bbox) == len(
         disparity_preds
     ), f'{len(left_bbox), len(right_bbox), len(disparity_preds)}'
     num_rois = len(left_bbox)
     if num_rois == 0:
         disparity_full_image = torch.zeros(
             (left_prediction.height, left_prediction.width))
     else:
         disparity_maps = []
         for left_roi, right_roi, disp_roi, mask_pred in zip(
                 left_bbox, right_bbox, disparity_preds, mask_preds):
             x1, y1, x2, y2 = left_roi.tolist()
             x1p, _, x2p, _ = right_roi.tolist()
             x1, y1, x2, y2 = expand_box_to_integer((x1, y1, x2, y2))
             x1p, _, x2p, _ = expand_box_to_integer((x1p, y1, x2p, y2))
             disparity_map_per_roi = torch.zeros(
                 (left_prediction.height, left_prediction.width))
             # mask = mask_pred.squeeze(0)
             # mask = SegmentationMask(BinaryMaskList(mask, size=mask.shape[::-1]), size=mask.shape[::-1],
             #                         mode='mask').crop((x1, y1, x1 + max(x2 - x1, x2p - x1p), y2))
             disp_roi = DisparityMap(disp_roi).resize(
                 (max(x2 - x1, x2p - x1p), y2 - y1)).crop(
                     (0, 0, x2 - x1, y2 - y1)).data
             disp_roi = disp_roi + x1 - x1p
             disparity_map_per_roi[y1:y2, x1:x2] = disp_roi
             disparity_maps.append(disparity_map_per_roi)
         disparity_full_image = torch.stack(disparity_maps).max(dim=0)[0]
     return DisparityMap(disparity_full_image)
示例#2
0
def post_process_and_resize_prediction(left_prediction: BoxList,
                                       right_prediction: BoxList,
                                       dst_size=(1280, 720),
                                       threshold=0.7,
                                       padding=1,
                                       process_disparity=True):
    left_prediction = left_prediction.clone()
    right_prediction = right_prediction.clone()
    if process_disparity and not left_prediction.has_map('disparity'):
        disparity_map_processor = DisparityMapProcessor()
        disparity_pred_full_img = disparity_map_processor(
            left_prediction, right_prediction)
        left_prediction.add_map('disparity', disparity_pred_full_img)
    left_prediction = left_prediction.resize(dst_size)
    right_prediction = right_prediction.resize(dst_size)
    mask_pred = left_prediction.get_field('mask')
    masker = Masker(threshold=threshold, padding=padding)
    mask_pred = masker([mask_pred], [left_prediction])[0].squeeze(1)
    if mask_pred.shape[0] != 0:
        # mask_preds_per_img = mask_pred.sum(dim=0)[0].clamp(max=1)
        mask_preds_per_img = mask_pred
    else:
        mask_preds_per_img = torch.zeros((1, *dst_size[::-1]))
    left_prediction.add_field('mask', mask_preds_per_img)
    return left_prediction, right_prediction
示例#3
0
    def forward_for_single_feature_map(self, anchors, objectness,
                                       box_regression):
        """
        Arguments:
            anchors: list[BoxList]
            objectness: tensor of size N, A, H, W
            box_regression: tensor of size N, A * 6, H, W
        """
        device = objectness.device
        N, A, H, W = objectness.shape

        # put in the same format as anchors
        objectness = permute_and_flatten(objectness, N, A, 2, H, W)[:, :, 1]
        # objectness = objectness.sigmoid()

        box_regression = permute_and_flatten(box_regression, N, A, 6, H, W)

        num_anchors = A * H * W

        pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
        objectness, topk_idx = objectness.topk(min(pre_nms_top_n,
                                                   objectness.shape[1]),
                                               dim=1,
                                               sorted=True)

        batch_idx = torch.arange(N, device=device)[:, None]
        box_regression = box_regression[batch_idx, topk_idx]

        image_shapes = [box.size for box in anchors]
        concat_anchors = torch.cat([a.bbox for a in anchors], dim=0)
        concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx, topk_idx]

        proposals = self.box_coder.decode(box_regression.view(-1, 6),
                                          concat_anchors.view(-1, 4))

        proposals = proposals.view(N, -1, 6)

        left_result, right_result = [], []
        for proposal, score, im_shape in zip(proposals, objectness,
                                             image_shapes):
            left_boxlist = BoxList(proposal[:, 0:4], im_shape, mode="xyxy")
            right_boxlist = BoxList(proposal[:, [4, 1, 5, 3]],
                                    im_shape,
                                    mode='xyxy')

            left_boxlist.add_field("objectness", score)
            right_boxlist.add_field("objectness", score)
            left_boxlist = left_boxlist.clip_to_image(remove_empty=False)
            right_boxlist = right_boxlist.clip_to_image(remove_empty=False)
            left_boxlist = remove_small_boxes(left_boxlist, self.min_size)
            right_boxlist = remove_small_boxes(right_boxlist, self.min_size)
            left_boxlist, right_boxlist = double_view_boxlist_nms(
                left_boxlist,
                right_boxlist,
                self.nms_thresh,
                max_proposals=self.post_nms_top_n,
                score_field='objectness')
            left_result.append(left_boxlist)
            right_result.append(right_boxlist)
        return left_result, right_result
示例#4
0
 def prepare_boxlist(self, boxes, scores, image_shape):
     """
     Returns BoxList from `boxes` and adds probability scores information
     as an extra field
     `boxes` has shape (#detections, 4 * #classes), where each row represents
     a list of predicted bounding boxes for each of the object classes in the
     dataset (including the background class). The detections in each row
     originate from the same object proposal.
     `scores` has shape (#detection, #classes), where each row represents a list
     of object detection confidence scores for each of the object classes in the
     dataset (including the background class). `scores[i, j]`` corresponds to the
     box at `boxes[i, j * 4:(j + 1) * 4]`.
     """
     boxes = boxes.reshape(-1, 4)
     scores = scores.reshape(-1)
     boxlist = BoxList(boxes, image_shape, mode="xyxy")
     boxlist.add_field("scores", scores)
     return boxlist
示例#5
0
    def __getitem__(self, item):
        img = Image.open(self.image_lists[item]).convert("RGB")

        # dummy target
        w, h = img.size
        target = BoxList([[0, 0, w, h]], img.size, mode="xyxy")

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target
示例#6
0
 def forward(self, image_list, feature_maps):
     feature_shapes = [
         list(feature_map.shape[-2:]) for feature_map in feature_maps
     ]
     anchors_over_all_feature_maps = generate_anchors_all_pyramids(
         self.scales, self.ratios, feature_shapes, self.feature_strides, 1)
     anchors = []
     for i, (image_height,
             image_width) in enumerate(image_list.image_sizes):
         anchors_in_image = []
         for anchors_per_feature_map in anchors_over_all_feature_maps:
             boxlist = BoxList(anchors_per_feature_map,
                               (image_width, image_height),
                               mode="xyxy").to(device='cuda')
             self.add_visibility_to(boxlist)
             anchors_in_image.append(boxlist)
         anchors.append(anchors_in_image)
     return anchors
示例#7
0
 def make_expand_targets_per_image(self, left_targets_per_image: BoxList,
                                   right_targets_per_image: BoxList):
     # todo: do match, expand and put original bbox in extra_fields
     # todo: check necessity of clone
     left_targets_per_image = left_targets_per_image.copy_with_fields([])
     right_targets_per_image = right_targets_per_image.copy_with_fields([])
     left_bbox = left_targets_per_image.bbox
     right_bbox = right_targets_per_image.bbox
     expand_bbox, original_lr_bbox = expand_left_right_box(
         left_bbox, right_bbox)
     expand_target = BoxList(expand_bbox, left_targets_per_image.size)
     expand_target.add_field('original_lr_bbox', original_lr_bbox)
     return expand_target
示例#8
0
    def select_over_all_levels(self, boxlists):
        num_images = len(boxlists)
        results = []
        for i in range(num_images):
            scores = boxlists[i].get_field("scores")
            labels = boxlists[i].get_field("labels")
            boxes = boxlists[i].bbox
            boxlist = boxlists[i]
            result = []
            # skip the background
            for j in range(1, self.num_classes):
                inds = (labels == j).nonzero().view(-1)

                scores_j = scores[inds]
                boxes_j = boxes[inds, :].view(-1, 4)
                boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
                boxlist_for_class.add_field("scores", scores_j)
                boxlist_for_class = boxlist_nms(boxlist_for_class,
                                                self.nms_thresh,
                                                score_field="scores")
                num_labels = len(boxlist_for_class)
                boxlist_for_class.add_field(
                    "labels",
                    torch.full((num_labels, ),
                               j,
                               dtype=torch.int64,
                               device=scores.device))
                result.append(boxlist_for_class)

            result = cat_boxlist(result)
            number_of_detections = len(result)

            # Limit to max_per_image detections **over all classes**
            if number_of_detections > self.fpn_post_nms_top_n > 0:
                cls_scores = result.get_field("scores")
                image_thresh, _ = torch.kthvalue(
                    cls_scores.cpu(),
                    number_of_detections - self.fpn_post_nms_top_n + 1)
                keep = cls_scores >= image_thresh.item()
                keep = torch.nonzero(keep).squeeze(1)
                result = result[keep]
            results.append(result)
        return results
示例#9
0
    def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4:(j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms)
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels",
                torch.full((num_labels, ), j, dtype=torch.int64,
                           device=device))
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(),
                number_of_detections - self.detections_per_img + 1)
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result
示例#10
0
 def get_ground_truth(self, index):
     img_id = self.ids[index]
     if not is_testing_split(self.split):
         left_annotation = self.annotations['left'][int(img_id)]
         right_annotation = self.annotations['right'][int(img_id)]
         info = self.get_img_info(index)
         height, width = info['height'], info['width']
         # left target
         left_target = BoxList(left_annotation["boxes"], (width, height),
                               mode="xyxy")
         left_target.add_field("labels", left_annotation["labels"])
         # left_target.add_field("alphas", left_annotation['alphas'])
         boxes_3d = Box3DList(left_annotation["boxes_3d"], (width, height),
                              mode='ry_lhwxyz')
         left_target.add_field("box3d", boxes_3d)
         left_target.add_map('disparity', self.get_disparity(index))
         left_target.add_field('masks', self.get_mask(index))
         left_target.add_field(
             'truncation', torch.tensor(self.truncations_list[int(img_id)]))
         left_target.add_field(
             'occlusion', torch.tensor(self.occlusions_list[int(img_id)]))
         left_target.add_field(
             'image_size',
             torch.tensor([[width, height]]).repeat(len(left_target), 1))
         left_target.add_field(
             'calib', Calib(self.get_calibration(index), (width, height)))
         left_target.add_field(
             'index',
             torch.full((len(left_target), 1), index, dtype=torch.long))
         left_target.add_field(
             'imgid',
             torch.full((len(left_target), 1),
                        int(img_id),
                        dtype=torch.long))
         left_target = left_target.clip_to_image(remove_empty=True)
         # right target
         right_target = BoxList(right_annotation["boxes"], (width, height),
                                mode="xyxy")
         right_target.add_field("labels", right_annotation["labels"])
         right_target = right_target.clip_to_image(remove_empty=True)
         target = {'left': left_target, 'right': right_target}
         return target
     else:
         fakebox = torch.tensor([[0, 0, 0, 0]])
         info = self.get_img_info(index)
         height, width = info['height'], info['width']
         # left target
         left_target = BoxList(fakebox, (width, height), mode="xyxy")
         left_target.add_field(
             'image_size',
             torch.tensor([[width, height]]).repeat(len(left_target), 1))
         left_target.add_field(
             'calib', Calib(self.get_calibration(index), (width, height)))
         left_target.add_field(
             'index',
             torch.full((len(left_target), 1), index, dtype=torch.long))
         left_target.add_field(
             'imgid',
             torch.full((len(left_target), 1),
                        int(img_id),
                        dtype=torch.long))
         # right target
         right_target = BoxList(fakebox, (width, height), mode="xyxy")
         target = {'left': left_target, 'right': right_target}
         return target
示例#11
0
    def __call__(self, output_dict, proposals):
        """

        :param output_dict: dict_keys(['rcnn_cls', 'rcnn_reg'])
        :param proposals: dict_keys(['rpn_cls', 'rpn_reg', 'backbone_xyz',
                                     'backbone_features', 'rois', 'roi_scores_raw',
                                     'seg_result', 'rpn_xyz', 'rpn_features',
                                     'seg_mask', 'roi_boxes3d', 'pts_depth'])

        :return:
        """
        size = (1280, 720)  # todo:replace
        # batch_size = len(targets)
        roi_boxes3d = proposals['roi_boxes3d']  # (B, M, 7)
        batch_size = roi_boxes3d.shape[0]
        rcnn_cls = output_dict['rcnn_cls'].view(
            batch_size, -1, output_dict['rcnn_cls'].shape[1])
        rcnn_reg = output_dict['rcnn_reg'].view(
            batch_size, -1, output_dict['rcnn_reg'].shape[1])  # (B, M, C)

        # bounding box regression
        anchor_size = self.MEAN_SIZE
        if self.cfg.RCNN.SIZE_RES_ON_ROI:
            assert False

        pred_boxes3d = decode_bbox_target(
            roi_boxes3d.view(-1, 7),
            rcnn_reg.view(-1, rcnn_reg.shape[-1]),
            anchor_size=anchor_size,
            loc_scope=self.cfg.RCNN.LOC_SCOPE,
            loc_bin_size=self.cfg.RCNN.LOC_BIN_SIZE,
            num_head_bin=self.cfg.RCNN.NUM_HEAD_BIN,
            get_xz_fine=True,
            get_y_by_bin=self.cfg.RCNN.LOC_Y_BY_BIN,
            loc_y_scope=self.cfg.RCNN.LOC_Y_SCOPE,
            loc_y_bin_size=self.cfg.RCNN.LOC_Y_BIN_SIZE,
            get_ry_fine=True).view(batch_size, -1, 7)

        # scoring
        if rcnn_cls.shape[2] == 1:
            raw_scores = rcnn_cls  # (B, M, 1)

            norm_scores = torch.sigmoid(raw_scores)
            pred_classes = (norm_scores >
                            self.cfg.RCNN.SCORE_THRESH).long().squeeze(-1)
        else:
            pred_classes = torch.argmax(rcnn_cls, dim=1).view(-1)
            cls_norm_scores = F.softmax(rcnn_cls, dim=1)
            raw_scores = rcnn_cls[:, pred_classes]
            norm_scores = cls_norm_scores[:, pred_classes]

        inds = norm_scores > self.cfg.RCNN.SCORE_THRESH
        # inds = norm_scores > 0.05
        results = []
        for k in range(batch_size):
            cur_inds = inds[k].view(-1)
            if cur_inds.sum() == 0:
                # print('Low scores, random result.')
                use_rpn_proposals = True
                if not use_rpn_proposals:
                    bbox_3d = Box3DList(
                        torch.rand(1, 7).float(), size,
                        'xyzhwl_ry').convert("ry_lhwxyz")
                    bbox = torch.Tensor([0, 0, 0, 0]).repeat(1, 1).cuda()
                    bbox = BoxList(bbox, size, mode="xyxy")
                    bbox.add_field("box3d", bbox_3d)
                    # bbox.add_field("box3d_score", torch.Tensor(1).zero_())
                    bbox.add_field("box3d_score", torch.zeros(1) * (-10))
                    bbox.add_field("labels", torch.ones(1).cuda())
                    bbox.add_field("iou_score", torch.Tensor(1).zero_())
                    bbox.add_field("random", torch.ones((len(bbox))).long())
                    if self.cfg.RPN.EARLY_INTEGRATE:
                        bbox.add_field('det_id', torch.Tensor(1).zero_())
                        bbox.add_field('box3d_backend', bbox_3d)
                        bbox.add_field('box3d_backend_ids',
                                       torch.Tensor(1).zero_())
                        bbox.add_field('box3d_backend_keep',
                                       torch.Tensor(1).zero_())
                    results.append(bbox)
                    continue
                else:
                    # print('use_rpn_proposals')
                    proposal_score = proposals['roi_scores_raw'][k]
                    select_idx = proposal_score.argmax()
                    b3d = roi_boxes3d[k][select_idx]
                    bbox_3d = Box3DList(b3d, size,
                                        'xyzhwl_ry').convert("ry_lhwxyz")
                    bbox = torch.Tensor([0, 0, size[0],
                                         size[1]]).repeat(b3d.shape[0],
                                                          1).cuda()
                    bbox = BoxList(bbox, size, mode="xyxy")
                    bbox.add_field("box3d", bbox_3d)
                    # bbox.add_field("box3d_score", torch.Tensor([proposal_score[select_idx]]))
                    bbox.add_field("box3d_score", torch.zeros(1))
                    bbox.add_field("labels", 1)
                    bbox.add_field("random", torch.ones((len(bbox))).long())
                    # bbox.add_field('iou_score', scores_selected)
                    results.append(bbox)
                    continue

            pred_boxes3d_selected = pred_boxes3d[k, cur_inds]
            raw_scores_selected = raw_scores[k, cur_inds]
            norm_scores_selected = norm_scores[k, cur_inds]
            pred_classes_selected = pred_classes[k, cur_inds]
            # NMS thresh
            # rotated nms
            boxes_bev_selected = boxes3d_to_bev_torch(pred_boxes3d_selected)
            keep_idx = nms_gpu(boxes_bev_selected, raw_scores_selected,
                               self.cfg.RCNN.NMS_THRESH).view(-1)
            pred_boxes3d_selected = pred_boxes3d_selected[keep_idx]
            scores_selected = raw_scores_selected[keep_idx].squeeze(1)
            pred_classes_selected = pred_classes_selected[keep_idx]

            bbox_3d = Box3DList(pred_boxes3d_selected, size,
                                'xyzhwl_ry').convert("ry_lhwxyz")
            bbox = torch.Tensor([0, 0, size[0], size[1]
                                 ]).repeat(pred_boxes3d_selected.shape[0],
                                           1).cuda()
            bbox = BoxList(bbox, size, mode="xyxy")
            bbox.add_field("box3d", bbox_3d)
            bbox.add_field("box3d_score", scores_selected)
            bbox.add_field("labels", pred_classes_selected)
            bbox.add_field('iou_score', scores_selected)
            bbox.add_field('random', torch.zeros((len(bbox))).long())
            results.append(bbox)
        return results
示例#12
0
    def forward(self,
                anchors,
                objectness,
                box_regression,
                left_targets=None,
                right_targets=None):
        device = objectness[0].device
        scores = []
        for i, score in enumerate(objectness):
            scores.append(
                score.permute(0, 2, 3,
                              1).contiguous().view(score.shape[0], -1, 2))
        scores = torch.cat(scores, 1)[:, :, 1]
        bbox_regs = []
        for i, bbox_reg in enumerate(box_regression):
            bbox_regs.append(
                bbox_reg.permute(0, 2, 3,
                                 1).contiguous().view(bbox_reg.shape[0], -1,
                                                      6))
        bbox_regs = torch.cat(bbox_regs, 1)
        anchors = list(zip(*anchors))
        combined_anchors = []
        batch_size = len(anchors[0])
        for i in range(batch_size):
            combined_anchors.append(
                cat_boxlist(
                    [anchors[level][i] for level in range(len(anchors))]))
        num_anchors = len(combined_anchors[0])
        # pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
        # scores, topk_idx = scores.topk(min(pre_nms_top_n, scores.shape[1]), dim=1, sorted=True)

        # batch_idx = torch.arange(bsz, device=device)[:, None]
        # bbox_regs = bbox_regs[batch_idx, topk_idx]

        image_shapes = [box.size for box in combined_anchors]
        # concat_anchors = torch.cat([a.bbox for a in combined_anchors], dim=0)
        # concat_anchors = concat_anchors.reshape(bsz, -1, 4)[batch_idx, topk_idx]

        proposals = self.box_coder.decode(
            bbox_regs.view(-1, 6),
            torch.cat([a.bbox.view(-1, 4)
                       for a in combined_anchors]).to(device))

        proposals = proposals.view(batch_size, -1, 6)
        proposals_left = proposals[:, :, 0:4]
        proposals_right = proposals[:, :, [4, 1, 5, 3]]
        proposals_left = clip_boxes(proposals_left, image_shapes, batch_size)
        proposals_right = clip_boxes(proposals_right, image_shapes, batch_size)
        scores_keep = scores
        proposals_keep_left = proposals_left
        proposals_keep_right = proposals_right

        _, order = torch.sort(scores_keep, 1, True)

        left_result, right_result = [], []
        for i in range(batch_size):
            # # 3. remove predicted boxes with either height or width < threshold
            # # (NOTE: convert min_size to input image scale stored in im_info[2])
            proposals_single_left = proposals_keep_left[i]
            proposals_single_right = proposals_keep_right[i]
            scores_single = scores_keep[i]

            # # 4. sort all (proposal, score) pairs by score from highest to lowest
            # # 5. take top pre_nms_topN (e.g. 6000)
            order_single = order[i]

            if self.pre_nms_top_n > 0 and self.pre_nms_top_n < scores_keep.numel(
            ):
                order_single = order_single[:self.pre_nms_top_n]

            proposals_single_left = proposals_single_left[order_single, :]
            proposals_single_right = proposals_single_right[order_single, :]
            scores_single = scores_single[order_single].view(-1, 1)

            # 6. apply nms (e.g. threshold = 0.7)
            # 7. take after_nms_topN (e.g. 300)
            # 8. return the top proposals (-> RoIs top)
            left_boxlist = BoxList(proposals_single_left,
                                   image_shapes[i],
                                   mode="xyxy")
            right_boxlist = BoxList(proposals_single_right,
                                    image_shapes[i],
                                    mode='xyxy')

            left_boxlist.add_field("objectness", scores_single.squeeze(1))
            right_boxlist.add_field("objectness", scores_single.squeeze(1))
            left_boxlist = left_boxlist.clip_to_image(remove_empty=False)
            right_boxlist = right_boxlist.clip_to_image(remove_empty=False)
            left_boxlist = remove_small_boxes(left_boxlist, self.min_size)
            right_boxlist = remove_small_boxes(right_boxlist, self.min_size)
            left_boxlist, right_boxlist = double_view_boxlist_nms(
                left_boxlist,
                right_boxlist,
                self.nms_thresh,
                max_proposals=self.post_nms_top_n,
                score_field='objectness')
            left_result.append(left_boxlist)
            right_result.append(right_boxlist)
        return left_result, right_result
示例#13
0
    def forward_for_single_feature_map(self, anchors, box_cls, box_regression):
        """
        Arguments:
            anchors: list[BoxList]
            box_cls: tensor of size N, A * C, H, W
            box_regression: tensor of size N, A * 4, H, W
        """
        device = box_cls.device
        N, _, H, W = box_cls.shape
        A = box_regression.size(1) // 4
        C = box_cls.size(1) // A

        # put in the same format as anchors
        box_cls = permute_and_flatten(box_cls, N, A, C, H, W)
        box_cls = box_cls.sigmoid()

        box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
        box_regression = box_regression.reshape(N, -1, 4)

        num_anchors = A * H * W

        candidate_inds = box_cls > self.pre_nms_thresh

        pre_nms_top_n = candidate_inds.view(N, -1).sum(1)
        pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)

        results = []
        for per_box_cls, per_box_regression, per_pre_nms_top_n, \
        per_candidate_inds, per_anchors in zip(
            box_cls,
            box_regression,
            pre_nms_top_n,
            candidate_inds,
            anchors):

            # Sort and select TopN
            # TODO most of this can be made out of the loop for
            # all images.
            # TODO:Yang: Not easy to do. Because the numbers of detections are
            # different in each image. Therefore, this part needs to be done
            # per image.
            per_box_cls = per_box_cls[per_candidate_inds]

            per_box_cls, top_k_indices = \
                    per_box_cls.topk(per_pre_nms_top_n, sorted=False)

            per_candidate_nonzeros = \
                    per_candidate_inds.nonzero()[top_k_indices, :]

            per_box_loc = per_candidate_nonzeros[:, 0]
            per_class = per_candidate_nonzeros[:, 1]
            per_class += 1

            detections = self.box_coder.decode(
                per_box_regression[per_box_loc, :].view(-1, 4),
                per_anchors.bbox[per_box_loc, :].view(-1, 4))

            boxlist = BoxList(detections, per_anchors.size, mode="xyxy")
            boxlist.add_field("labels", per_class)
            boxlist.add_field("scores", per_box_cls)
            boxlist = boxlist.clip_to_image(remove_empty=False)
            boxlist = remove_small_boxes(boxlist, self.min_size)
            results.append(boxlist)

        return results