Exemple #1
0
    def prepare_iou_targets(self, proposals, box_regression, targets):
        concat_boxes = torch.cat([a.bbox for a in proposals], dim=0)
        boxes_per_image = [len(box) for box in proposals]
        bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
        box_coder = BoxCoder(weights=bbox_reg_weights)

        # [N, 4 * num_classes]
        pred_boxes = box_coder.decode(
            box_regression.view(sum(boxes_per_image), -1), concat_boxes
        )
        pred_boxes = pred_boxes.split(boxes_per_image, dim=0)

        proposals = list(proposals)
        device = box_regression.device
        for pred_boxes_per_image, proposals_per_image, targets_per_image in zip(pred_boxes, proposals, targets):

            # NOTE: Steps here may generate some wrong indices for box regression.
            # However, these cases would be filtered in the loss by sampled_pos_inds_subset
            labels_per_image = proposals_per_image.get_field("labels")
            if self.cls_agnostic_bbox_reg:
                labels = labels_per_image.new_zeros(labels_per_image.shape)
                map_inds = 4 * labels[:, None] + torch.tensor([4, 5, 6, 7], device=device)
            else:
                map_inds = 4 * labels[:, None] + torch.tensor(
                [0, 1, 2, 3], device=device)
            # map_inds = 4 * labels_per_image[:, None] + torch.tensor([0, 1, 2, 3], device=device)
            pred_boxes_per_image = torch.gather(pred_boxes_per_image, 1, map_inds)
            if pred_boxes_per_image.shape[0] < 1:
                matched_ious = proposals_per_image.get_field("ious")
                proposals_per_image.add_field("iou_pred_targets_final", matched_ious)
                continue

            pred_boxlist_per_image = BoxList(pred_boxes_per_image, proposals_per_image.size, mode='xyxy')
            # [target_num, pred_boxes_num]
            match_quality_matrix = boxlist_iou(targets_per_image, pred_boxlist_per_image)
            # [pred_boxes_num]
            matched_ious, _ = match_quality_matrix.max(dim=0)
            # matched_ious, matches = match_quality_matrix.max(dim=0)
            # Assign candidate matches with low quality to negative (unassigned) values
            sampled_neg_inds_subset = torch.nonzero(labels_per_image == 0).squeeze(1)
            matched_ious[sampled_neg_inds_subset] = 1
            proposals_per_image.add_field("iou_pred_targets_final", matched_ious)

        return proposals
    def test_box_decoder(self):
        """ Match unit test UtilsBoxesTest.TestBboxTransformRandom in
            caffe2/operators/generate_proposals_op_util_boxes_test.cc
        """
        box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
        bbox = torch.from_numpy(
            np.array([
                175.62031555,
                20.91103172,
                253.352005,
                155.0145874,
                169.24636841,
                4.85241556,
                228.8605957,
                105.02092743,
                181.77426147,
                199.82876587,
                192.88427734,
                214.0255127,
                174.36262512,
                186.75761414,
                296.19091797,
                231.27906799,
                22.73153877,
                92.02596283,
                135.5695343,
                208.80291748,
            ]).astype(np.float32).reshape(-1, 4))

        deltas = torch.from_numpy(
            np.array([
                0.47861834,
                0.13992102,
                0.14961673,
                0.71495209,
                0.29915856,
                -0.35664671,
                0.89018666,
                0.70815367,
                -0.03852064,
                0.44466892,
                0.49492538,
                0.71409376,
                0.28052918,
                0.02184832,
                0.65289006,
                1.05060139,
                -0.38172557,
                -0.08533806,
                -0.60335309,
                0.79052375,
            ]).astype(np.float32).reshape(-1, 4))

        gt_bbox = (np.array([
            206.949539,
            -30.715202,
            297.387665,
            244.448486,
            143.871216,
            -83.342888,
            290.502289,
            121.053398,
            177.430283,
            198.666245,
            196.295273,
            228.703079,
            152.251892,
            145.431564,
            387.215454,
            274.594238,
            5.062420,
            11.040955,
            66.328903,
            269.686218,
        ]).astype(np.float32).reshape(-1, 4))

        results = box_coder.decode(deltas, bbox)

        np.testing.assert_allclose(results.detach().numpy(),
                                   gt_bbox,
                                   atol=1e-4)
Exemple #3
0
class NewROIBoxHead(ROIBoxHead):
    def __init__(self, cfg, in_channels):
        super(NewROIBoxHead, self).__init__(cfg, in_channels)
        self.bbox_dict = dict(bbox=None, target=None)

        self.box_coder = BoxCoder(weights=cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS)

    def forward(self, features, proposals, targets=None):
        if self.training:
            # Faster R-CNN subsamples during training the proposals with a fixed
            # positive / negative ratio
            with torch.no_grad():
                proposals = self.loss_evaluator.subsample(proposals, targets)

        # extract features that will be fed to the final classifier. The
        # feature_extractor generally corresponds to the pooler + heads
        x = self.feature_extractor(features, proposals)
        # final classifier that converts the features into predictions
        class_logits, box_regression = self.predictor(x)

        # save bbox result for rois_gan
        if self.training:
            result, box_result = self.reduced_bbox_result([box_regression],
                                                          proposals)
            self.bbox_dict.update(bbox=result)
            self.bbox_dict.update(target=box_result)

        if not self.training:
            result = self.post_processor((class_logits, box_regression),
                                         proposals)
            return x, result, {}

        loss_classifier, loss_box_reg = self.loss_evaluator([class_logits],
                                                            [box_regression])
        return (
            x,
            proposals,
            dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg),
        )

    def reduced_bbox_result(self, box_regression, proposals):

        box_regression = cat(box_regression, dim=0)
        device = box_regression.device

        labels = cat([proposal.get_field("labels") for proposal in proposals],
                     dim=0)
        regression_targets = cat([
            proposal.get_field("regression_targets") for proposal in proposals
        ],
                                 dim=0)

        sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
        labels_pos = labels[sampled_pos_inds_subset]

        map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3],
                                                          device=device)

        image_shapes = [box.size for box in proposals]
        boxes_per_image = [len(box) for box in proposals]
        concat_boxes = torch.cat([a.bbox for a in proposals], dim=0)

        prefix_sum_boxes = [boxes_per_image[0]]
        for box_per_images in boxes_per_image[1:]:
            prefix_sum_boxes.append(box_per_images + prefix_sum_boxes[-1])

        reduced_boxes_per_image = [0] * len(prefix_sum_boxes)
        i, j = 0, 0
        while i < len(sampled_pos_inds_subset):
            if sampled_pos_inds_subset[i] < prefix_sum_boxes[j]:
                reduced_boxes_per_image[j] += 1
                i += 1
            else:
                j += 1

        proposals = self.box_coder.decode(
            box_regression[sampled_pos_inds_subset[:, None], map_inds],
            concat_boxes[sampled_pos_inds_subset])

        proposals = proposals.split(reduced_boxes_per_image, dim=0)

        box_targets = self.box_coder.decode(
            regression_targets[sampled_pos_inds_subset],
            concat_boxes[sampled_pos_inds_subset])

        box_targets = box_targets.split(reduced_boxes_per_image, dim=0)

        result = []
        for boxes, image_shape in zip(proposals, image_shapes):
            boxlist = BoxList(boxes, image_shape, mode="xyxy")
            boxlist = boxlist.clip_to_image(remove_empty=False)
            result.append(boxlist)

        box_result = []
        for boxes, image_shape in zip(box_targets, image_shapes):
            boxlist = BoxList(boxes, image_shape, mode="xyxy")
            boxlist = boxlist.clip_to_image(remove_empty=False)
            box_result.append(boxlist)

        return result, box_result
Exemple #4
0
class DuplicationRemovalNetwork(nn.Module):
    def __init__(
        self,
        cfg,
        is_teacher=False,
    ):
        super(DuplicationRemovalNetwork, self).__init__()
        self.cfg = cfg.clone()
        # if reg_iou = True, then this network is used to regress
        # the iou to the GT. if not True, this predict
        # true-object/duplicate
        self.reg_iou = self.cfg.MODEL.RELATION_NMS.REG_IOU
        self.first_n = cfg.MODEL.RELATION_NMS.FIRST_N
        self.NMS_thread = cfg.MODEL.RELATION_NMS.THREAD
        self.nms_rank_fc = nn.Linear(
            cfg.MODEL.RELATION_NMS.ROI_FEAT_DIM,
            cfg.MODEL.RELATION_NMS.APPEARANCE_FEAT_DIM,
            bias=True)
        self.roi_feat_embedding_fc = nn.Linear(
            cfg.MODEL.RELATION_NMS.ROI_FEAT_DIM,
            cfg.MODEL.RELATION_NMS.APPEARANCE_FEAT_DIM,
            bias=True)
        self.target_thresh = cfg.MODEL.RELATION_NMS.THREAD
        self.geo_feature_dim = cfg.MODEL.RELATION_NMS.GEO_FEAT_DIM

        if cfg.MODEL.RELATION_NMS.USE_IOU:
            self.geo_feature_dim = int(self.geo_feature_dim / 4 * 5)
        self.relation_module = RelationModule(
            cfg.MODEL.RELATION_NMS.APPEARANCE_FEAT_DIM,
            geo_feature_dim=self.geo_feature_dim,
            fc_dim=(self.geo_feature_dim, 16),
            group=cfg.MODEL.RELATION_NMS.GROUP,
            dim=cfg.MODEL.RELATION_NMS.HID_DIM,
            topk=cfg.MODEL.RELATION_NMS.TOPK,
            iou_method=cfg.MODEL.RELATION_NMS.IOU_METHOD)

        self.nms_fg_weight = torch.tensor([1., cfg.MODEL.RELATION_NMS.WEIGHT])
        self.mt_fg_weight = torch.tensor([1., 10.])
        self.alpha = cfg.MODEL.RELATION_NMS.ALPHA
        self.gamma = cfg.MODEL.RELATION_NMS.GAMMA
        self.boxcoder = BoxCoder(weights=(10., 10., 5., 5.))
        self.class_agnostic = cfg.MODEL.RELATION_NMS.CLASS_AGNOSTIC
        self.fg_class = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES - 1
        self.classifier = nn.Linear(128, len(self.target_thresh), bias=True)
        self.relu1 = nn.ReLU(inplace=True)
        self.fg_thread = cfg.MODEL.RELATION_NMS.FG_THREAD
        self.detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG
        self.nms = cfg.MODEL.RELATION_NMS.POS_NMS
        self.nms_loss_type = cfg.MT.NMS_LOSS_TYPE
        self.mode = None

    def set_teacher_mode(self, mode):
        self.mode = mode

    def forward(self, x):
        appearance_feature, proposals, cls_score, box_reg, targets = x
        self.device = appearance_feature.device
        appearance_feature = appearance_feature
        cls_score = cls_score
        box_reg = box_reg

        with torch.no_grad():
            sorted_boxlists = self.prepare_ranking(cls_score,
                                                   box_reg,
                                                   proposals,
                                                   targets,
                                                   reg_iou=self.reg_iou)
        # concate value from different images
        boxes_per_image = [len(f) for f in proposals]
        idxs = [f.get_field('sorted_idx') for f in sorted_boxlists]
        scores = torch.cat([f.get_field('scores') for f in sorted_boxlists])
        bboxes = torch.cat(
            [f.bbox.reshape(-1, self.fg_class, 4) for f in sorted_boxlists])
        objectness = torch.cat([
            f.get_field('objectness').reshape(-1, self.fg_class)
            for f in sorted_boxlists
        ])
        all_scores = torch.cat(
            [f.get_field('all_scores') for f in sorted_boxlists])

        # add iou information
        image_sizes = [f.size for f in sorted_boxlists]
        sorted_boxes_per_image = [[*f.shape][0] for f in idxs]
        appearance_feature = self.roi_feat_embedding_fc(appearance_feature)
        appearance_feature = appearance_feature.split(boxes_per_image, dim=0)
        sorted_features = []
        nms_rank_embedding = []
        for id, feature, box_per_image in zip(idxs, appearance_feature,
                                              boxes_per_image):
            feature = feature[id]
            size = feature.size()
            if size[0] <= self.first_n:
                first_n = size[0]
            else:
                first_n = self.first_n
            sorted_features.append(feature)
            #[rank_dim * batch , feat_dim]
            nms_rank_embedding.append(
                extract_rank_embedding(
                    first_n,
                    self.cfg.MODEL.RELATION_NMS.ROI_FEAT_DIM,
                    device=feature.device))
        #  [first_n * batchsize, num_fg_classes, 128]
        sorted_features = torch.cat(sorted_features, dim=0)
        nms_rank_embedding = torch.cat(nms_rank_embedding, dim=0)
        nms_rank_embedding = self.nms_rank_fc(nms_rank_embedding)
        sorted_features = sorted_features + nms_rank_embedding[:, None, :]

        boxes_cls_1 = BoxList(bboxes[:, 0, :], image_sizes[0])
        boxes_cls_2 = BoxList(bboxes[:, 1, :], image_sizes[0])
        iou_1 = boxlist_iou(boxes_cls_1, boxes_cls_1)
        iou_2 = boxlist_iou(boxes_cls_2, boxes_cls_2)
        if self.cfg.MODEL.RELATION_NMS.USE_IOU:
            iou = [iou_1, iou_2]
        else:
            iou = None
        nms_position_matrix = extract_multi_position_matrix(
            bboxes,
            None,
            self.geo_feature_dim,
            1000,
            clswise=self.cfg.MODEL.RELATION_NMS.CLS_WISE_RELATION,
        )
        nms_attention_1 = self.relation_module(sorted_features,
                                               nms_position_matrix, iou)
        sorted_features = sorted_features + nms_attention_1
        sorted_features = self.relu1(sorted_features)
        # [first_n * num_fg_classes, 128]
        sorted_features = sorted_features.view(
            -1, self.cfg.MODEL.RELATION_NMS.APPEARANCE_FEAT_DIM)
        sorted_features = self.classifier(sorted_features)
        # logit_reshape, [first_n, num_fg_classes, num_thread]
        sorted_features = sorted_features.view(-1, self.fg_class,
                                               len(self.target_thresh))
        if not self.reg_iou:
            sorted_features = torch.sigmoid(sorted_features)
        scores = torch.cat([scores[:, :, None]] * len(self.target_thresh),
                           dim=-1)
        loss_dict = {}
        if self.training:
            if self.reg_iou:
                # when use regression donot do sorted_features = scores * sorted_features
                reg_label = torch.cat(
                    [f.get_field('labels_iou_reg') for f in sorted_boxlists])
                reg_label = reg_label.to(scores.device)
                reg_label = reg_label.type(torch.cuda.FloatTensor)
                sorted_features = sorted_features.to(scores.device)
                sorted_features = sorted_features.type(torch.cuda.FloatTensor)
                if reg_label.shape is not None:
                    reg_iou_loss = F.mse_loss(reg_label, sorted_features)
                else:
                    reg_iou_loss = torch.tensor(0.).to(scores.device)
                loss_dict['nms_loss'] = reg_iou_loss
            else:
                sorted_features = scores * sorted_features
                labels = torch.cat(
                    [f.get_field('labels') for f in sorted_boxlists])

                labels = labels.to(scores.device)
                labels = labels.type(torch.cuda.FloatTensor)

                # WEIGHTED NMS
                nms_loss = F.binary_cross_entropy(scores * sorted_features,
                                                  labels)
                loss_dict['nms_loss'] = nms_loss
            return None, loss_dict
        else:
            input_scores = scores
            if self.reg_iou:
                scores = sorted_features * (scores > self.fg_thread).float()
            else:
                scores = sorted_features * scores
            scores = self.merge_multi_thread_score_test(scores)
            scores = scores.split(sorted_boxes_per_image, dim=0)
            bboxes = bboxes.split(sorted_boxes_per_image, dim=0)
            input_scores = input_scores.split(sorted_boxes_per_image, dim=0)
            objectness = objectness.split(sorted_boxes_per_image, dim=0)
            all_scores = all_scores.split(sorted_boxes_per_image, dim=0)
            result = []
            for i_score, score, bbox, obj, image_size, prob_boxhead \
                    in zip(
                                                    input_scores,
                                                        scores,
                                                        bboxes,
                                                    objectness,
                                                    image_sizes, all_scores):
                result_per_image = []
                # for nuclei
                index = (score[:, 1] >= self.fg_thread).nonzero()[:, 0]
                # cls_scores = i_score[index, i,0]
                cls_scores = score[index, 1]
                cls_scores_all = prob_boxhead[index, 1]
                cls_boxes = bbox[index, 1, :]
                cls_obj = obj[index, 1]

                boxlist_for_class = BoxList(cls_boxes, image_size, mode='xyxy')

                boxlist_for_class.add_field('scores', cls_scores)
                boxlist_for_class.add_field('objectness', cls_obj)
                boxlist_for_class.add_field('all_scores', cls_scores_all)
                boxlist_for_class = boxlist_nms(boxlist_for_class,
                                                0.5,
                                                score_field="scores")
                num_labels = len(boxlist_for_class)
                boxlist_for_class.add_field(
                    "labels",
                    torch.full((num_labels, ), 2,
                               dtype=torch.int64).to(self.device))
                result_per_image.append(boxlist_for_class)
                index = (score[:, 0] >= self.fg_thread).nonzero()[:, 0]
                # cls_scores = i_score[index, i,0]
                cls_scores = score[index, 0]
                # pdb.set_trace()

                cls_scores_all = prob_boxhead[index, 0]
                cls_boxes = bbox[index, 0, :]
                cls_obj = obj[index, 0]

                boxlist_for_class = BoxList(cls_boxes, image_size, mode='xyxy')
                # Pos greedy NMS if POS_NMS!=-1
                # boxlist_for_class.add_field('idx', index)
                boxlist_for_class.add_field('scores', cls_scores)
                boxlist_for_class.add_field('objectness', cls_obj)
                boxlist_for_class.add_field('all_scores', cls_scores_all)
                # pdb.set_trace()
                if self.nms:
                    # for nuclei
                    boxlist_for_class = boxlist_nms(boxlist_for_class,
                                                    self.nms,
                                                    score_field="scores")
                # pdb.set_trace()
                num_labels = len(boxlist_for_class)
                boxlist_for_class.add_field(
                    "labels",
                    torch.full((num_labels, ), 1,
                               dtype=torch.int64).to(self.device))
                result_per_image.append(boxlist_for_class)
                result_per_image = cat_boxlist(result_per_image)
                number_of_detections = len(result_per_image)

                # Limit to max_per_image detections **over all classes**
                if number_of_detections > self.detections_per_img > 0:
                    cls_scores = result_per_image.get_field("scores")
                    image_thresh, _ = torch.kthvalue(
                        cls_scores.cpu(),
                        number_of_detections - self.detections_per_img + 1)
                    keep = cls_scores >= image_thresh.item()
                    keep = torch.nonzero(keep).squeeze(1)
                    result_per_image = result_per_image[keep]
                result.append(result_per_image)

            return result, {}

    def prepare_reg_label(self, sorted_boxes, sorted_score, targets):
        '''

        :param sorted_boxes: [ first n, fg_cls_num, 4]
        :param indice: [first n, fg_cls_num]
        :param sorted_score: [first n, fg_cls_num]
        :param targets: Boxlist obj
        :return: label [first n, num_thread * fg_cls_num]
        '''
        TO_REMOVE = 1
        labels = targets.get_field('labels')

        # output = np.zeros((sorted_boxes.shape[0].numpy(),))
        # pdb.set_trace()
        # output_list = []
        output_reg_list = []
        for i in range(self.fg_class):
            cls_label_indice = torch.nonzero(labels == (i + 1))
            cls_target_bbox = targets.bbox[cls_label_indice[:, 0]]

            # todo: avoid None gt situation
            num_valid_gt = len(cls_label_indice)

            if num_valid_gt == 0:

                output = np.zeros(
                    ([*sorted_boxes.shape][0], len(self.target_thresh)))
                # output_reg = output.copy()
                # output_list.append(output)
                output_reg_list.append(output)
            else:
                output_list_per_class = []
                output_reg_list_per_class = []
                eye_matrix = np.eye(num_valid_gt)
                score_per_class = sorted_score[:, i:i + 1].cpu().numpy()
                boxes = sorted_boxes[:, i, :]
                boxes = boxes.view(-1, 4)
                area1 = (boxes[:, 2] - boxes[:, 0] +
                         TO_REMOVE) * (boxes[:, 3] - boxes[:, 1] + TO_REMOVE)
                area2 = (cls_target_bbox[:, 2] - cls_target_bbox[:, 0] +
                         TO_REMOVE) * (cls_target_bbox[:, 3] -
                                       cls_target_bbox[:, 1] + TO_REMOVE)
                lt = torch.max(boxes[:, None, :2],
                               cls_target_bbox[:, :2])  # [N,M,2]
                rb = torch.min(boxes[:, None, 2:],
                               cls_target_bbox[:, 2:])  # [N,M,2]
                wh = (rb - lt + TO_REMOVE).clamp(min=0)  # [N,M,2]
                inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
                # [num_gt, first_n]
                iou = inter / (area1[:, None] + area2 - inter)
                iou = iou.cpu().numpy()
                try:
                    for thresh in self.target_thresh:
                        # pdb.set_trace()
                        output_reg = np.max(iou, 1)
                        # todo: temp comment
                        overlap_mask = (iou > thresh)
                        overlap_iou = iou * overlap_mask
                        valid_bbox_indices = np.where(overlap_mask)[0]
                        overlap_score = np.tile(score_per_class,
                                                (1, num_valid_gt))
                        overlap_score *= overlap_mask
                        max_overlap_indices = np.argmax(iou, axis=1)
                        max_overlap_mask = eye_matrix[max_overlap_indices]
                        overlap_score *= max_overlap_mask
                        overlap_iou = overlap_iou * max_overlap_mask
                        max_score_indices = np.argmax(overlap_score, axis=0)
                        max_overlap_iou = overlap_iou[
                            max_score_indices,
                            np.arange(overlap_score.shape[1])]
                        # output = np.zeros(([*sorted_boxes.shape][0],))
                        output_reg = np.zeros(([*sorted_boxes.shape][0], ))
                        output_idx, inter_1, inter_2 = np.intersect1d(
                            max_score_indices,
                            valid_bbox_indices,
                            return_indices=True)
                        # output[output_idx] = 1
                        output_reg[output_idx] = max_overlap_iou[inter_1]
                        # output_list_per_class.append(output)
                        output_reg_list_per_class.append(output_reg)
                except:
                    pdb.set_trace()
                # output_per_class = np.stack(output_list_per_class, axis=-1)
                output_reg_per_class = np.stack(output_reg_list_per_class,
                                                axis=-1)
                # pdb.set_trace()
                # output_list.append(output_per_class.view())
                output_reg_list.append(output_reg_per_class)

        # output =  np.stack(output_list, axis=1).astype(np.float32, copy=False)
        output_reg = np.stack(output_reg_list, axis=1).astype(np.float32,
                                                              copy=False)
        return output_reg
        # return (output, output_reg)

    def prepare_label(self, sorted_boxes, sorted_score, targets):
        '''

        :param sorted_boxes: [ first n, fg_cls_num, 4]
        :param indice: [first n, fg_cls_num]
        :param sorted_score: [first n, fg_cls_num]
        :param targets: Boxlist obj
        :return: label [first n, num_thread * fg_cls_num]
        '''
        TO_REMOVE = 1
        labels = targets.get_field('labels')

        # output = np.zeros((sorted_boxes.shape[0].numpy(),))

        output_list = []
        for i in range(self.fg_class):
            cls_label_indice = torch.nonzero(labels == (i + 1))
            cls_target_bbox = targets.bbox[cls_label_indice[:, 0]]

            # todo: avoid None gt situation
            num_valid_gt = len(cls_label_indice)

            if num_valid_gt == 0:

                output = np.zeros(
                    ([*sorted_boxes.shape][0], len(self.target_thresh)))
                output_list.append(output)
            else:
                output_list_per_class = []
                eye_matrix = np.eye(num_valid_gt)
                score_per_class = sorted_score[:, i:i + 1].cpu().numpy()
                boxes = sorted_boxes[:, i, :]
                boxes = boxes.view(-1, 4)
                area1 = (boxes[:, 2] - boxes[:, 0] +
                         TO_REMOVE) * (boxes[:, 3] - boxes[:, 1] + TO_REMOVE)
                area2 = (cls_target_bbox[:, 2] - cls_target_bbox[:, 0] +
                         TO_REMOVE) * (cls_target_bbox[:, 3] -
                                       cls_target_bbox[:, 1] + TO_REMOVE)
                lt = torch.max(boxes[:, None, :2],
                               cls_target_bbox[:, :2])  # [N,M,2]
                rb = torch.min(boxes[:, None, 2:],
                               cls_target_bbox[:, 2:])  # [N,M,2]
                wh = (rb - lt + TO_REMOVE).clamp(min=0)  # [N,M,2]
                inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
                # [num_gt, first_n]
                iou = inter / (area1[:, None] + area2 - inter)
                iou = iou.cpu().numpy()

                for thresh in self.target_thresh:
                    overlap_mask = (iou > thresh)
                    valid_bbox_indices = np.where(overlap_mask)[0]
                    overlap_score = np.tile(score_per_class, (1, num_valid_gt))
                    overlap_score *= overlap_mask
                    max_overlap_indices = np.argmax(iou, axis=1)
                    max_overlap_mask = eye_matrix[max_overlap_indices]
                    overlap_score *= max_overlap_mask
                    max_score_indices = np.argmax(overlap_score, axis=0)
                    output = np.zeros(([*sorted_boxes.shape][0], ))
                    output[np.intersect1d(max_score_indices,
                                          valid_bbox_indices)] = 1
                    output_list_per_class.append(output)
                output_per_class = np.stack(output_list_per_class, axis=-1)
                output_list.append(output_per_class)
        output = np.stack(output_list, axis=1).astype(np.float32, copy=False)
        return output

    def prepare_ranking(self,
                        cls_score,
                        box_regression,
                        proposals,
                        targets,
                        reg_iou=False):
        '''
        :param score:[num_per_img*batchsize, class]
        :param proposal: list of boxlist
        :return:
        '''
        # if is not train, targets is None which should be set into a none list

        boxes_per_image = [len(box) for box in proposals]
        concat_boxes = torch.cat([a.bbox for a in proposals], dim=0)
        image_shapes = [box.size for box in proposals]
        objectness = [f.get_field('objectness') for f in proposals]
        proposals = self.boxcoder.decode(
            box_regression.view(sum(boxes_per_image), -1), concat_boxes)
        proposals = proposals.split(boxes_per_image, dim=0)
        cls_score = cls_score.split(boxes_per_image, dim=0)
        results = []
        if self.training:
            # if idx_t is None:
            for prob, boxes_per_img, image_shape, target, obj in zip(
                    cls_score, proposals, image_shapes, targets, objectness):

                boxlist = self.filter_results(boxes_per_img, target, prob,
                                              image_shape, self.fg_class + 1,
                                              obj, reg_iou)

                results.append(boxlist)
        else:
            # test do not have target
            for prob, boxes_per_img, image_shape, obj in zip(
                    cls_score, proposals, image_shapes, objectness):
                boxlist = self.filter_results(boxes_per_img,
                                              None,
                                              prob,
                                              image_shape,
                                              self.fg_class + 1,
                                              obj,
                                              reg_iou=reg_iou)
                results.append(boxlist)

        return results

    def filter_results(self,
                       boxes,
                       targets,
                       scores,
                       image_shape,
                       num_classes,
                       obj,
                       reg_iou=False):
        """return the sorted boxlist and sorted idx
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        # boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        #[n_roi, 4, cls]
        # boxes = boxlist.bbox.reshape(-1, 4, num_classes)

        boxes = boxes.reshape(-1, 4 * num_classes)
        scores = scores.reshape(-1, num_classes)
        # pdb.set_trace()
        if scores.shape[0] == 0:
            pdb.set_trace()
        cat_boxes = []
        for j in range(1, num_classes):
            # skip class 0, because it is the background class
            cls_boxes = boxes[:, j * 4:(j + 1) * 4]
            cat_boxes.append(cls_boxes)
        boxes = torch.cat([bbox[:, :, None] for bbox in cat_boxes], dim=2)
        # scores =  torch.cat([s for s in cat_score])
        scores = scores[:, 1:]
        ori_scores = scores
        num_roi = boxes.shape[0]
        if num_roi <= self.first_n:
            first_n = num_roi
            # pdb.set_trace()
        else:
            first_n = self.first_n

        sorted_scores, indices = torch.topk(scores,
                                            first_n,
                                            dim=0,
                                            largest=True,
                                            sorted=True)

        if obj.shape[0] < first_n:
            indices = indices[:obj.shape[0]]
            sorted_scores = sorted_scores[:obj.shape[0]]
        if indices.shape[1] != 2:
            pdb.set_trace()
        cp_s = ori_scores.clone().cpu().numpy()
        cp_o = obj.clone().cpu().numpy()
        box = boxes.clone().cpu().numpy()
        ori_scores = ori_scores[indices]
        sorted_obj = obj[indices]
        sorted_boxes = boxes[indices]

        if sorted_boxes.shape[0] == 0:
            pdb.set_trace()

        if self.class_agnostic:
            # [first_n, num_fg_class, 4]
            sorted_boxes = torch.squeeze(sorted_boxes, dim=-1)
        else:
            try:
                mask = torch.arange(0, num_classes - 1).to(device=self.device)
            except:
                pdb.set_trace()
            try:
                mask = mask.view(1, -1, 1, 1).expand(first_n, num_classes - 1,
                                                     4, 1)
            except:
                pdb.set_trace()
            sorted_boxes = torch.gather(sorted_boxes, dim=3,
                                        index=mask).squeeze(dim=3)
        if self.training:
            labels = self.prepare_label(sorted_boxes, sorted_scores, targets)
            labels_cls = torch.from_numpy(labels).to(sorted_scores.device)
            if reg_iou:
                labels_reg = self.prepare_reg_label(sorted_boxes,
                                                    sorted_scores, targets)
                labels_reg = torch.from_numpy(labels_reg).to(
                    sorted_scores.device)
        sorted_boxes = sorted_boxes.view(first_n * (num_classes - 1), -1)
        sorted_obj = sorted_obj.view(first_n * (num_classes - 1))
        boxlist = BoxList(
            sorted_boxes,
            image_shape,
            mode="xyxy",
        )
        boxlist.add_field('sorted_idx', indices)
        boxlist.add_field('objectness', sorted_obj)
        boxlist.extra_fields['scores'] = sorted_scores
        boxlist.extra_fields["all_scores"] = ori_scores
        # boxlist.extra_fields[""]
        if self.training:
            if reg_iou:
                boxlist.extra_fields['labels_iou_reg'] = labels_reg
            else:
                boxlist.extra_fields['labels'] = labels_cls
        boxlist = boxlist.clip_to_image(remove_empty=False)
        return boxlist

    def merge_multi_thread_score_test(self, scores):
        if self.cfg.MODEL.RELATION_NMS.MERGE_METHOD == -1:
            scores = torch.mean(scores, -1)
        elif self.cfg.MODEL.RELATION_NMS.MERGE_METHOD == -2:
            scores = torch.max(scores, -1)
        else:
            idx = self.cfg.MODEL.RELATION_NMS.MERGE_METHOD
            idx = min(max(idx, 0), len(self.target_thresh))
            scores = scores[:, :, idx]
        return scores
    def forward(self,
                features,
                proposals,
                targets=None,
                proposals_sampled=None):
        """
        Arguments:
            features (list[Tensor]): feature-maps from possibly several levels
            proposals (list[BoxList]): proposal boxes
            targets (list[BoxList], optional): the ground-truth targets.

        Returns:
            x (Tensor): the result of the feature extractor
            proposals (list[BoxList]): during training, the subsampled proposals
                are returned. During testing, the predicted boxlists are returned
            losses (dict[Tensor]): During training, returns the losses for the
                head. During testing, returns an empty dict.
        """

        if self.training:
            # Faster R-CNN subsamples during training the proposals with a fixed
            # positive / negative ratio
            if proposals_sampled is None:
                with torch.no_grad():
                    proposals_sampled = self.loss_evaluator.subsample(
                        proposals, targets)
            proposals = proposals_sampled

        # extract features that will be fed to the final classifier. The
        # feature_extractor generally corresponds to the pooler + heads
        x = self.feature_extractor(features, proposals)
        # final classifier that converts the features into predictions
        class_logits, box_regression = self.predictor(x)

        if not self.training:
            result = self.post_processor((class_logits, box_regression),
                                         proposals)
            return x, result, {}

        # TODO: loss is not needed for mean teacher when MT_ON
        if not self.cfg.MODEL.ROI_BOX_HEAD.FREEZE_WEIGHT:
            loss_classifier, loss_box_reg = self.loss_evaluator(
                [class_logits], [box_regression], proposals)

        if self.cfg.MODEL.ROI_BOX_HEAD.OUTPUT_DECODED_PROPOSAL:
            bbox_reg_weights = self.cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
            box_coder = BoxCoder(weights=bbox_reg_weights)
            boxes_per_image = [len(box) for box in proposals]
            concat_boxes = torch.cat([a.bbox for a in proposals], dim=0)
            decoded_proposals = box_coder.decode(
                box_regression.view(sum(boxes_per_image), -1), concat_boxes)
            decoded_proposals = decoded_proposals.split(boxes_per_image, dim=0)
            # decoded_proposals = self.post_processor((class_logits, box_regression), proposals)
            # make sure there are valid proposals
            for i, boxes in enumerate(decoded_proposals):
                if len(boxes) > 0:
                    proposals[i].bbox = boxes.reshape(-1, 4)

        loss_dict = dict()

        if self.cfg.MODEL.MT_ON:
            loss_dict.update(class_logits=class_logits,
                             box_logits=box_regression)
            # loss_dict.update(class_logits=x, box_logits=x)
            # proposals_sampled.add_field('class_logits', class_logits)
            # proposals_sampled.add_field('box_logits', box_regression)

        if not self.is_mt and not self.cfg.MODEL.ROI_BOX_HEAD.FREEZE_WEIGHT:
            loss_dict.update(
                dict(loss_classifier=loss_classifier,
                     loss_box_reg=loss_box_reg))

        return x, proposals, loss_dict
Exemple #6
0
class MaskRCNNLossComputation(object):
    def __init__(self, proposal_matcher, discretization_size, use_mil_loss,
                 use_aff, use_box_mask):
        """
        Arguments:
            proposal_matcher (Matcher)
            discretization_size (int)
        """
        self.proposal_matcher = proposal_matcher
        self.discretization_size = discretization_size
        center_weight = torch.zeros((3, 3))
        center_weight[1][1] = 1.
        aff_weights = []
        for i in range(3):
            for j in range(3):
                if i == 1 and j == 1:
                    continue
                weight = torch.zeros((3, 3))
                weight[i][j] = 1.
                aff_weights.append(center_weight - weight)
        aff_weights = [w.view(1, 1, 3, 3).to("cuda") for w in aff_weights]
        self.aff_weights = torch.cat(aff_weights, 0)
        self.box_coder = BoxCoder(weights=(10., 10., 5., 5.))
        self.use_mil_loss = use_mil_loss
        self.use_aff = use_aff
        if use_box_mask:
            assert not use_mil_loss
        self.use_box_mask = use_box_mask

    def match_targets_to_proposals(self, proposal, target):
        match_quality_matrix = boxlist_iou(target, proposal)
        matched_idxs = self.proposal_matcher(match_quality_matrix)
        # Mask RCNN needs "labels" and "masks "fields for creating the targets
        target = target.copy_with_fields(["labels", "masks"])
        # get the targets corresponding GT for each proposal
        # NB: need to clamp the indices because we can have a single
        # GT in the image, and matched_idxs can be -2, which goes
        # out of bounds
        matched_targets = target[matched_idxs.clamp(min=0)]
        matched_targets.add_field("matched_idxs", matched_idxs)
        return matched_targets

    def prepare_targets(self, proposals, targets):
        labels = []
        masks = []
        for proposals_per_image, targets_per_image in zip(proposals, targets):
            matched_targets = self.match_targets_to_proposals(
                proposals_per_image, targets_per_image)
            matched_idxs = matched_targets.get_field("matched_idxs")

            labels_per_image = matched_targets.get_field("labels")
            labels_per_image = labels_per_image.to(dtype=torch.int64)

            # this can probably be removed, but is left here for clarity
            # and completeness
            neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
            labels_per_image[neg_inds] = 0

            # mask scores are only computed on positive samples
            positive_inds = torch.nonzero(labels_per_image > 0).squeeze(1)

            segmentation_masks = matched_targets.get_field("masks")
            segmentation_masks = segmentation_masks[positive_inds]

            positive_proposals = proposals_per_image[positive_inds]

            masks_per_image = project_masks_on_boxes(segmentation_masks,
                                                     positive_proposals,
                                                     self.discretization_size)

            labels.append(labels_per_image)
            masks.append(masks_per_image)

        return labels, masks

    def prepare_targets_cr(self, proposals):
        # Sample both negative and positive proposals
        # Only with per col/row labels (without mask)
        labels = []
        for proposals_per_image in proposals:
            regression_target = proposals_per_image.get_field(
                "regression_targets")
            matched_bbox = self.box_coder.decode(regression_target,
                                                 proposals_per_image.bbox)
            M = self.discretization_size
            pos_masks_per_image = torch.ones(len(proposals_per_image),
                                             M,
                                             M,
                                             dtype=torch.float32)
            not_matched_idx = (regression_target != 0).any(1)
            # if a box fully matched the proposal, we set all values to 1
            # otherwise, we project the box on proposal
            if not_matched_idx.any():
                pos_masks_per_image[not_matched_idx] = project_boxes_on_boxes(
                    matched_bbox[not_matched_idx],
                    proposals_per_image[not_matched_idx], M)

            pos_masks_per_image = pos_masks_per_image.cuda()
            pos_labels = torch.cat(
                [pos_masks_per_image.sum(2),
                 pos_masks_per_image.sum(1)], 1)
            pos_labels = (pos_labels > 0).float()

            labels.append(pos_labels)

        return labels

    def prepare_targets_boxes(self, proposals):
        masks = []
        for proposals_per_image in proposals:
            regression_target = proposals_per_image.get_field(
                "regression_targets")
            matched_bbox = self.box_coder.decode(regression_target,
                                                 proposals_per_image.bbox)
            M = self.discretization_size
            pos_masks_per_image = torch.ones(len(proposals_per_image),
                                             M,
                                             M,
                                             dtype=torch.float32)
            not_matched_idx = (regression_target != 0).any(1)
            # if a box fully matched the proposal, we set all values to 1
            # otherwise, we project the box on proposal
            if not_matched_idx.any():
                pos_masks_per_image[not_matched_idx] = project_boxes_on_boxes(
                    matched_bbox[not_matched_idx],
                    proposals_per_image[not_matched_idx], M)

            masks.append(pos_masks_per_image.cuda())
        return masks

    def __call__(self, proposals, all_mask_logits, targets):
        """
        Arguments:
            proposals (list[BoxList])
            mask_logits (Tensor)
            targets (list[BoxList])

        Return:
            mask_loss (Tensor): scalar tensor containing the loss
        """
        labels = cat([p.get_field("proto_labels") for p in proposals])
        labels = (labels > 0).long()
        pos_inds = torch.nonzero(labels > 0).squeeze(1)
        if not self.use_mil_loss:
            mask_logits = all_mask_logits[0]
            if self.use_box_mask:
                mask_targets = self.prepare_targets_boxes(proposals)
            else:
                _, mask_targets = self.prepare_targets(proposals, targets)
            mask_targets = cat(mask_targets, dim=0)
            labels_pos = labels[pos_inds]
            if mask_targets.numel() == 0:
                return mask_logits.sum() * 0
            mask_loss = F.binary_cross_entropy_with_logits(
                mask_logits[pos_inds, labels_pos], mask_targets[pos_inds])
            return mask_loss

        labels_cr = self.prepare_targets_cr(proposals)
        labels_cr = cat(labels_cr, dim=0)
        mil_losses = []
        for mask_logits in all_mask_logits:

            mil_score = mask_logits[:, 1]
            mil_score = torch.cat(
                [mil_score.max(2)[0], mil_score.max(1)[0]], 1)
            # torch.mean (in binary_cross_entropy_with_logits) doesn't
            # accept empty tensors, so handle it separately
            if mil_score.numel() == 0:
                mil_losses.append(mask_logits.sum() * 0)

            mil_loss = F.binary_cross_entropy_with_logits(
                mil_score[pos_inds], labels_cr[pos_inds])
            mil_losses.append(mil_loss)

        if self.use_aff:
            mask_logits = all_mask_logits[0]
            mask_logits_n = mask_logits[:, 1:].sigmoid()
            aff_maps = F.conv2d(mask_logits_n,
                                self.aff_weights,
                                padding=(1, 1))
            affinity_loss = mask_logits_n * (aff_maps**2)
            affinity_loss = torch.mean(affinity_loss)
            return 1.2 * sum(mil_losses) / len(
                mil_losses) + 0.05 * affinity_loss
        else:
            return sum(mil_losses) / len(mil_losses)
Exemple #7
0
class PostProcessor(nn.Module):
    """
    From a set of classification scores, box regression and proposals,
    computes the post-processed boxes, and applies NMS to obtain the
    final results
    """
    def __init__(self, cfg):
        super(PostProcessor, self).__init__()

        bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
        self.box_coder = BoxCoder(weights=bbox_reg_weights)

        self.score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH
        self.nms_thresh = cfg.MODEL.ROI_HEADS.NMS
        self.detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG
        self.cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG
        self.bbox_aug_enabled = cfg.TEST.BBOX_AUG.ENABLED

        self.use_nms = not (cfg.MODEL.MASKIOU_ON
                            and "ROI_MASKIOU_HEAD" in cfg.MODEL
                            and cfg.MODEL.ROI_MASKIOU_HEAD.USE_NMS)

        if self.use_nms:
            self.nms_func = lambda boxes, scores: _box_nms(
                boxes, scores, self.nms_thresh)
            method = cfg.MODEL.ROI_HEADS.SOFT_NMS.METHOD
            self.use_soft_nms = cfg.MODEL.ROI_HEADS.USE_SOFT_NMS and method in [
                1, 2
            ]
            if self.use_soft_nms:

                def soft_nms_func(boxes, scores):
                    sigma = cfg.MODEL.ROI_HEADS.SOFT_NMS.SIGMA
                    score_thresh = cfg.MODEL.ROI_HEADS.SOFT_NMS.SCORE_THRESH
                    indices, keep, scores_new = _box_soft_nms(
                        boxes.cpu(),
                        scores.cpu(),
                        nms_thresh=self.nms_thresh,
                        sigma=sigma,
                        score_thresh=score_thresh,
                        method=method)
                    return indices, keep, scores_new

                self.nms_func = soft_nms_func
        else:
            self.detections_per_img = -1

    def forward(self, x, boxes):
        """
        Arguments:
            x (tuple[tensor, tensor]): x contains the class logits
                and the box_regression from the model.
            boxes (list[BoxList]): bounding boxes that are used as
                reference, one for ech image

        Returns:
            results (list[BoxList]): one BoxList for each image, containing
                the extra fields labels and scores
        """
        class_logits, box_regression = x
        class_prob = F.softmax(class_logits, -1)

        # TODO think about a representation of batch of boxes
        image_shapes = [box.size for box in boxes]
        boxes_per_image = [len(box) for box in boxes]
        concat_boxes = torch.cat([a.bbox for a in boxes], dim=0)

        if self.cls_agnostic_bbox_reg:
            box_regression = box_regression[:, -4:]
        proposals = self.box_coder.decode(
            box_regression.view(sum(boxes_per_image), -1), concat_boxes)
        if self.cls_agnostic_bbox_reg:
            proposals = proposals.repeat(1, class_prob.shape[1])

        num_classes = class_prob.shape[1]

        proposals = proposals.split(boxes_per_image, dim=0)
        class_prob = class_prob.split(boxes_per_image, dim=0)

        results = []
        for prob, boxes_per_img, image_shape in zip(class_prob, proposals,
                                                    image_shapes):
            boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
            boxlist = boxlist.clip_to_image(remove_empty=False)
            if not self.bbox_aug_enabled:  # If bbox aug is enabled, we will do it later
                boxlist = self.filter_results(boxlist, num_classes)
            results.append(boxlist)
        return results

    def prepare_boxlist(self, boxes, scores, image_shape):
        """
        Returns BoxList from `boxes` and adds probability scores information
        as an extra field
        `boxes` has shape (#detections, 4 * #classes), where each row represents
        a list of predicted bounding boxes for each of the object classes in the
        dataset (including the background class). The detections in each row
        originate from the same object proposal.
        `scores` has shape (#detection, #classes), where each row represents a list
        of object detection confidence scores for each of the object classes in the
        dataset (including the background class). `scores[i, j]`` corresponds to the
        box at `boxes[i, j * 4:(j + 1) * 4]`.
        """
        boxes = boxes.reshape(-1, 4)
        scores = scores.reshape(-1)
        boxlist = BoxList(boxes, image_shape, mode="xyxy")
        boxlist.add_field("scores", scores)
        return boxlist

    def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        score_field = "scores"
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.convert("xyxy").bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field(score_field).reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4:(j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field(score_field, scores_j)

            if self.use_nms:
                keep = self.nms_func(boxes_j, scores_j)
                if self.use_soft_nms:
                    indices, keep, scores_j_new = keep
                    boxlist_for_class = boxlist_for_class[indices]
                    boxlist_for_class.add_field(score_field,
                                                scores_j_new.to(device=device))

                boxlist_for_class = boxlist_for_class[keep]
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels",
                torch.full((num_labels, ), j, dtype=torch.int64,
                           device=device))
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field(score_field)
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(),
                number_of_detections - self.detections_per_img + 1)
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result
Exemple #8
0
class ROIBoxHead(torch.nn.Module):
    """
    Generic Box Head class.
    """
    def __init__(self, cfg, in_channels):
        super(ROIBoxHead, self).__init__()
        self.feature_extractor = make_roi_box_feature_extractor(
            cfg, in_channels)
        self.predictor = make_roi_box_predictor(
            cfg, self.feature_extractor.out_channels)
        self.post_processor = make_roi_box_post_processor(cfg)
        self.loss_evaluator = make_roi_box_loss_evaluator(cfg)
        self.box_coder = BoxCoder(weights=(10., 10., 5., 5.))
        self.cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG

    def add_gt_proposals(self, proposals, targets):
        """
        Arguments:
            proposals: list[BoxList]
            targets: list[BoxList]
        """
        # Get the device we're operating on
        device = proposals[0].bbox.device

        gt_boxes = [target.copy_with_fields([]) for target in targets]

        # later cat of bbox requires all fields to be present for all bbox
        # so we need to add a dummy for objectness that's missing
        # for gt_box in gt_boxes:
        #     gt_box.add_field("objectness", torch.ones(len(gt_box), device=device))

        proposals = [
            cat_boxlist((proposal, gt_box))
            for proposal, gt_box in zip(proposals, gt_boxes)
        ]

        return proposals

    def box_regression_to_proposals(self,
                                    box_regression,
                                    boxes,
                                    is_train=True,
                                    class_logits=None):
        # TODO think about a representation of batch of boxes
        image_shapes = [box.size for box in boxes]
        boxes_per_image = [len(box) for box in boxes]
        concat_boxes = torch.cat([a.bbox for a in boxes], dim=0)

        # if self.cls_agnostic_bbox_reg:
        #     box_regression = box_regression[:, -4:]
        device = box_regression.device
        if is_train:
            labels = cat([proposal.get_field("labels") for proposal in boxes],
                         dim=0)
            sampled_pos_inds_subset = torch.nonzero(labels >= 0).squeeze(1)
            labels_pos = labels[sampled_pos_inds_subset]
            if self.cls_agnostic_bbox_reg:
                map_inds = torch.tensor([4, 5, 6, 7], device=device)
            else:
                map_inds = 4 * labels_pos[:, None] + torch.tensor(
                    [0, 1, 2, 3], device=device)

            box_regression = box_regression[sampled_pos_inds_subset[:, None],
                                            map_inds]
        else:
            if class_logits is not None:
                if self.cls_agnostic_bbox_reg:
                    bbox_label = torch.ones(len(class_logits), device=device)
                    map_inds = torch.tensor([4, 5, 6, 7], device=device)
                else:
                    bbox_label = class_logits.argmax(dim=1)
                    map_inds = 4 * bbox_label[:, None] + torch.tensor(
                        [0, 1, 2, 3], device=device)
                sampled_pos_inds_subset = torch.nonzero(
                    bbox_label >= 0).squeeze(1)
                box_regression = box_regression[sampled_pos_inds_subset[:,
                                                                        None],
                                                map_inds]

        proposals = self.box_coder.decode(
            box_regression.view(sum(boxes_per_image), -1), concat_boxes)
        # if self.cls_agnostic_bbox_reg:
        #     proposals = proposals.repeat(1, class_prob.shape[1])

        # 按照图像split
        proposals = proposals.split(boxes_per_image, dim=0)
        result = []
        for proposal, box in zip(proposals, boxes):
            im_shape = box.size
            boxlist = BoxList(proposal, im_shape, mode="xyxy")
            result.append(boxlist)
        # if is_train:
        #     result = self.add_gt_proposals(result, boxes)
        return result
        # return proposals

    def forward(self, features, proposals, targets=None):
        """
        Arguments:
            features (list[Tensor]): feature-maps from possibly several levels
            proposals (list[BoxList]): proposal boxes
            targets (list[BoxList], optional): the ground-truth targets.

        Returns:
            x (Tensor): the result of the feature extractor
            proposals (list[BoxList]): during training, the subsampled proposals
                are returned. During testing, the predicted boxlists are returned
            losses (dict[Tensor]): During training, returns the losses for the
                head. During testing, returns an empty dict.
        """

        if self.training:
            # Faster R-CNN subsamples during training the proposals with a fixed
            # positive / negative ratio
            with torch.no_grad():
                proposals = self.loss_evaluator.subsample(proposals,
                                                          targets,
                                                          stage=1)

            # extract features that will be fed to the final classifier. The
            # feature_extractor generally corresponds to the pooler + heads
            x = self.feature_extractor(features, proposals)
            # final classifier that converts the features into predictions
            class_logits, box_regression = self.predictor(x, stage=1)

            # 2nd stage
            with torch.no_grad():
                proposals_stage2 = self.box_regression_to_proposals(
                    box_regression, proposals)
                proposals_stage2 = self.add_gt_proposals(
                    proposals_stage2, targets)
                proposals_stage2 = self.loss_evaluator.subsample(
                    proposals_stage2, targets, stage=2)

            x = self.feature_extractor(features, proposals_stage2)
            class_logits_stage2, box_regression_stage2 = self.predictor(
                x, stage=2)

            # 3rd stage
            with torch.no_grad():
                proposals_stage3 = self.box_regression_to_proposals(
                    box_regression_stage2, proposals_stage2)
                proposals_stage3 = self.add_gt_proposals(
                    proposals_stage3, targets)
                proposals_stage3 = self.loss_evaluator.subsample(
                    proposals_stage3, targets, stage=3)

            x = self.feature_extractor(features, proposals_stage3)
            class_logits_stage3, box_regression_stage3 = self.predictor(
                x, stage=3)
        else:
            x = self.feature_extractor(features, proposals)
            # final classifier that converts the features into predictions
            class_logits, box_regression = self.predictor(x, stage=1)
            proposals_stage2 = self.box_regression_to_proposals(
                box_regression, proposals, False, class_logits)
            x = self.feature_extractor(features, proposals_stage2)
            class_logits_stage2, box_regression_stage2 = self.predictor(
                x, stage=2)
            proposals_stage3 = self.box_regression_to_proposals(
                box_regression_stage2, proposals_stage2, False,
                class_logits_stage2)
            x = self.feature_extractor(features, proposals_stage3)
            class_logits_stage3, box_regression_stage3 = self.predictor(
                x, stage=3)
            class_logits_stage2, box_regression_stage2 = self.predictor(
                x, stage=2)
            class_logits, box_regression = self.predictor(x, stage=1)

            class_logits_average = (class_logits + class_logits_stage2 +
                                    class_logits_stage3) / 3

        if not self.training:
            result = self.post_processor(
                (class_logits_average, box_regression_stage3),
                proposals_stage3)
            return x, result, {}

        loss_classifier, loss_box_reg = self.loss_evaluator([class_logits],
                                                            [box_regression],
                                                            proposals)
        loss_classifier_stage2, loss_box_reg_stage2 = self.loss_evaluator(
            [class_logits_stage2], [box_regression_stage2], proposals_stage2)
        loss_classifier_stage3, loss_box_reg_stage3 = self.loss_evaluator(
            [class_logits_stage3], [box_regression_stage3], proposals_stage3)
        return (x, proposals,
                dict(loss_classifier=loss_classifier,
                     loss_box_reg=loss_box_reg,
                     loss_classifier_stage2=loss_classifier_stage2,
                     loss_box_reg_stage2=loss_box_reg_stage2,
                     loss_classifier_stage3=loss_classifier_stage3,
                     loss_box_reg_stage3=loss_box_reg_stage3))
class GAnchorGenerator(nn.Module):
    """
    For a set of image sizes and feature maps, computes a set
    of anchors
    """
    def __init__(self,
                 aspect_ratios=(0.5, 1.0, 2.0),
                 anchor_strides=(4, 8, 16, 32, 64),
                 straddle_thresh=0,
                 octave_base_scale=8,
                 scales_per_octave=3,
                 anchor_weights=(10., 10., 5., 5.),
                 loc_filter_thr=0.01):
        super(GAnchorGenerator, self).__init__()

        if len(anchor_strides) == 1:
            anchor_stride = anchor_strides[0]
            approx_sizes = [
                octave_base_scale * 2**(i * 1.0 / scales_per_octave) *
                anchor_stride for i in range(scales_per_octave)
            ]
            approx_anchors = [
                generate_anchors(anchor_stride, approx_sizes,
                                 aspect_ratios).float()
            ]
            square_anchors = [
                generate_anchors(anchor_stride,
                                 [octave_base_scale * anchor_stride],
                                 [1.0]).float()
            ]
        else:
            approx_anchors = [
                generate_anchors(anchor_stride, [
                    octave_base_scale * 2**(i * 1.0 / scales_per_octave) *
                    anchor_stride for i in range(scales_per_octave)
                ], aspect_ratios).float() for anchor_stride in anchor_strides
            ]
            square_anchors = [
                generate_anchors(
                    anchor_stride,
                    [octave_base_scale * anchor_stride],
                    [1.0],
                ).float() for anchor_stride in anchor_strides
            ]

        self.strides = anchor_strides
        self.approx_anchors = BufferList(approx_anchors)
        self.square_anchors = BufferList(square_anchors)
        self.straddle_thresh = straddle_thresh
        self.anchor_box_coder = BoxCoder(weights=anchor_weights)
        self.loc_filter_thr = loc_filter_thr

    def num_approx_anchors_per_location(self):
        return [len(approx_anchors) for approx_anchors in self.approx_anchors]

    def grid_square_anchors(self, grid_sizes):
        anchors = []
        for size, stride, base_anchors in zip(grid_sizes, self.strides,
                                              self.square_anchors):
            grid_height, grid_width = size
            device = base_anchors.device
            shifts_x = torch.arange(0,
                                    grid_width * stride,
                                    step=stride,
                                    dtype=torch.float32,
                                    device=device)
            shifts_y = torch.arange(0,
                                    grid_height * stride,
                                    step=stride,
                                    dtype=torch.float32,
                                    device=device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            shift_x = shift_x.reshape(-1)
            shift_y = shift_y.reshape(-1)
            shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

            anchors.append(
                (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(
                    -1, 4))

        return anchors

    def grid_approx_anchors(self, grid_sizes):
        anchors = []
        for size, stride, base_anchors in zip(grid_sizes, self.strides,
                                              self.approx_anchors):
            grid_height, grid_width = size
            device = base_anchors.device
            shifts_x = torch.arange(0,
                                    grid_width * stride,
                                    step=stride,
                                    dtype=torch.float32,
                                    device=device)
            shifts_y = torch.arange(0,
                                    grid_height * stride,
                                    step=stride,
                                    dtype=torch.float32,
                                    device=device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            shift_x = shift_x.reshape(-1)
            shift_y = shift_y.reshape(-1)
            shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

            anchors.append(
                (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(
                    -1, 4))

        return anchors

    def add_visibility_to(self, boxlist):
        image_width, image_height = boxlist.size
        anchors = boxlist.bbox
        if self.straddle_thresh >= 0:
            inds_inside = (
                (anchors[..., 0] >= -self.straddle_thresh)
                & (anchors[..., 1] >= -self.straddle_thresh)
                & (anchors[..., 2] < image_width + self.straddle_thresh)
                & (anchors[..., 3] < image_height + self.straddle_thresh))
        else:
            device = anchors.device
            inds_inside = torch.ones(anchors.shape[0],
                                     dtype=torch.uint8,
                                     device=device)
        boxlist.add_field("visibility", inds_inside)

    def forward(self, image_list, feature_maps, shapes, locs):
        grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]

        num_levels = len(feature_maps)

        square_anchors_over_all_feature_maps = self.grid_square_anchors(
            grid_sizes)

        square_anchors = []
        for i, (image_height,
                image_width) in enumerate(image_list.image_sizes):
            square_anchors_in_image = []
            for square_anchors_per_feature_map in square_anchors_over_all_feature_maps:
                boxlist = BoxList(square_anchors_per_feature_map,
                                  (image_width, image_height),
                                  mode="xyxy")
                self.add_visibility_to(boxlist)
                square_anchors_in_image.append(boxlist)
            square_anchors.append(square_anchors_in_image)

        guided_anchors = []
        loc_masks = []
        for img_id, (image_height,
                     image_width) in enumerate(image_list.image_sizes):
            guided_anchors_in_image = []
            loc_mask_in_image = []
            for i in range(num_levels):
                squares = square_anchors[img_id][i]
                shape_pred = shapes[i][img_id]
                loc_pred = locs[i][img_id]
                guide_anchors_single, loc_mask_single = self.get_guided_anchors(
                    squares,
                    shape_pred,
                    loc_pred,
                    use_loc_filter=not self.training)
                guide_anchors_single = BoxList(guide_anchors_single,
                                               (image_width, image_height),
                                               mode="xyxy")
                self.add_visibility_to(guide_anchors_single)
                guided_anchors_in_image.append(guide_anchors_single)
                loc_mask_in_image.append(loc_mask_single)
            guided_anchors.append(guided_anchors_in_image)
            loc_masks.append(loc_mask_in_image)

        return square_anchors, guided_anchors, loc_masks

    def get_guided_anchors(self,
                           squares,
                           shape_pred,
                           loc_pred,
                           use_loc_filter=False):
        loc_pred = loc_pred.sigmoid().detach()
        if use_loc_filter:
            loc_mask = loc_pred >= self.loc_filter_thr
        else:
            loc_mask = loc_pred >= 0
        mask = loc_mask.permute(1, 2, 0).expand(-1, -1, 1)
        mask = mask.contiguous().view(-1)

        squares = squares[mask]
        anchor_deltas = shape_pred.permute(1, 2, 0).contiguous().view(
            -1, 2).detach()[mask]
        bbox_deltas = anchor_deltas.new_full(squares.bbox.size(), 0)
        bbox_deltas[:, 2:] = anchor_deltas
        guided_anchors = self.anchor_box_coder.decode(bbox_deltas,
                                                      squares.bbox)
        return guided_anchors, mask

    def get_sampled_approxs(self, image_list, feature_maps):
        grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]

        approx_anchors_over_all_feature_maps = self.grid_approx_anchors(
            grid_sizes)
        approx_anchors = []
        for i, (image_height,
                image_width) in enumerate(image_list.image_sizes):
            approx_anchors_in_image = []
            for approx_anchors_per_feature_map in approx_anchors_over_all_feature_maps:
                boxlist = BoxList(approx_anchors_per_feature_map,
                                  (image_width, image_height),
                                  mode="xyxy")
                self.add_visibility_to(boxlist)
                approx_anchors_in_image.append(boxlist)
            approx_anchors.append(approx_anchors_in_image)

        return approx_anchors