def prepare_iou_targets(self, proposals, box_regression, targets): concat_boxes = torch.cat([a.bbox for a in proposals], dim=0) boxes_per_image = [len(box) for box in proposals] bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS box_coder = BoxCoder(weights=bbox_reg_weights) # [N, 4 * num_classes] pred_boxes = box_coder.decode( box_regression.view(sum(boxes_per_image), -1), concat_boxes ) pred_boxes = pred_boxes.split(boxes_per_image, dim=0) proposals = list(proposals) device = box_regression.device for pred_boxes_per_image, proposals_per_image, targets_per_image in zip(pred_boxes, proposals, targets): # NOTE: Steps here may generate some wrong indices for box regression. # However, these cases would be filtered in the loss by sampled_pos_inds_subset labels_per_image = proposals_per_image.get_field("labels") if self.cls_agnostic_bbox_reg: labels = labels_per_image.new_zeros(labels_per_image.shape) map_inds = 4 * labels[:, None] + torch.tensor([4, 5, 6, 7], device=device) else: map_inds = 4 * labels[:, None] + torch.tensor( [0, 1, 2, 3], device=device) # map_inds = 4 * labels_per_image[:, None] + torch.tensor([0, 1, 2, 3], device=device) pred_boxes_per_image = torch.gather(pred_boxes_per_image, 1, map_inds) if pred_boxes_per_image.shape[0] < 1: matched_ious = proposals_per_image.get_field("ious") proposals_per_image.add_field("iou_pred_targets_final", matched_ious) continue pred_boxlist_per_image = BoxList(pred_boxes_per_image, proposals_per_image.size, mode='xyxy') # [target_num, pred_boxes_num] match_quality_matrix = boxlist_iou(targets_per_image, pred_boxlist_per_image) # [pred_boxes_num] matched_ious, _ = match_quality_matrix.max(dim=0) # matched_ious, matches = match_quality_matrix.max(dim=0) # Assign candidate matches with low quality to negative (unassigned) values sampled_neg_inds_subset = torch.nonzero(labels_per_image == 0).squeeze(1) matched_ious[sampled_neg_inds_subset] = 1 proposals_per_image.add_field("iou_pred_targets_final", matched_ious) return proposals
def test_box_decoder(self): """ Match unit test UtilsBoxesTest.TestBboxTransformRandom in caffe2/operators/generate_proposals_op_util_boxes_test.cc """ box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) bbox = torch.from_numpy( np.array([ 175.62031555, 20.91103172, 253.352005, 155.0145874, 169.24636841, 4.85241556, 228.8605957, 105.02092743, 181.77426147, 199.82876587, 192.88427734, 214.0255127, 174.36262512, 186.75761414, 296.19091797, 231.27906799, 22.73153877, 92.02596283, 135.5695343, 208.80291748, ]).astype(np.float32).reshape(-1, 4)) deltas = torch.from_numpy( np.array([ 0.47861834, 0.13992102, 0.14961673, 0.71495209, 0.29915856, -0.35664671, 0.89018666, 0.70815367, -0.03852064, 0.44466892, 0.49492538, 0.71409376, 0.28052918, 0.02184832, 0.65289006, 1.05060139, -0.38172557, -0.08533806, -0.60335309, 0.79052375, ]).astype(np.float32).reshape(-1, 4)) gt_bbox = (np.array([ 206.949539, -30.715202, 297.387665, 244.448486, 143.871216, -83.342888, 290.502289, 121.053398, 177.430283, 198.666245, 196.295273, 228.703079, 152.251892, 145.431564, 387.215454, 274.594238, 5.062420, 11.040955, 66.328903, 269.686218, ]).astype(np.float32).reshape(-1, 4)) results = box_coder.decode(deltas, bbox) np.testing.assert_allclose(results.detach().numpy(), gt_bbox, atol=1e-4)
class NewROIBoxHead(ROIBoxHead): def __init__(self, cfg, in_channels): super(NewROIBoxHead, self).__init__(cfg, in_channels) self.bbox_dict = dict(bbox=None, target=None) self.box_coder = BoxCoder(weights=cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS) def forward(self, features, proposals, targets=None): if self.training: # Faster R-CNN subsamples during training the proposals with a fixed # positive / negative ratio with torch.no_grad(): proposals = self.loss_evaluator.subsample(proposals, targets) # extract features that will be fed to the final classifier. The # feature_extractor generally corresponds to the pooler + heads x = self.feature_extractor(features, proposals) # final classifier that converts the features into predictions class_logits, box_regression = self.predictor(x) # save bbox result for rois_gan if self.training: result, box_result = self.reduced_bbox_result([box_regression], proposals) self.bbox_dict.update(bbox=result) self.bbox_dict.update(target=box_result) if not self.training: result = self.post_processor((class_logits, box_regression), proposals) return x, result, {} loss_classifier, loss_box_reg = self.loss_evaluator([class_logits], [box_regression]) return ( x, proposals, dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg), ) def reduced_bbox_result(self, box_regression, proposals): box_regression = cat(box_regression, dim=0) device = box_regression.device labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0) regression_targets = cat([ proposal.get_field("regression_targets") for proposal in proposals ], dim=0) sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1) labels_pos = labels[sampled_pos_inds_subset] map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device) image_shapes = [box.size for box in proposals] boxes_per_image = [len(box) for box in proposals] concat_boxes = torch.cat([a.bbox for a in proposals], dim=0) prefix_sum_boxes = [boxes_per_image[0]] for box_per_images in boxes_per_image[1:]: prefix_sum_boxes.append(box_per_images + prefix_sum_boxes[-1]) reduced_boxes_per_image = [0] * len(prefix_sum_boxes) i, j = 0, 0 while i < len(sampled_pos_inds_subset): if sampled_pos_inds_subset[i] < prefix_sum_boxes[j]: reduced_boxes_per_image[j] += 1 i += 1 else: j += 1 proposals = self.box_coder.decode( box_regression[sampled_pos_inds_subset[:, None], map_inds], concat_boxes[sampled_pos_inds_subset]) proposals = proposals.split(reduced_boxes_per_image, dim=0) box_targets = self.box_coder.decode( regression_targets[sampled_pos_inds_subset], concat_boxes[sampled_pos_inds_subset]) box_targets = box_targets.split(reduced_boxes_per_image, dim=0) result = [] for boxes, image_shape in zip(proposals, image_shapes): boxlist = BoxList(boxes, image_shape, mode="xyxy") boxlist = boxlist.clip_to_image(remove_empty=False) result.append(boxlist) box_result = [] for boxes, image_shape in zip(box_targets, image_shapes): boxlist = BoxList(boxes, image_shape, mode="xyxy") boxlist = boxlist.clip_to_image(remove_empty=False) box_result.append(boxlist) return result, box_result
class DuplicationRemovalNetwork(nn.Module): def __init__( self, cfg, is_teacher=False, ): super(DuplicationRemovalNetwork, self).__init__() self.cfg = cfg.clone() # if reg_iou = True, then this network is used to regress # the iou to the GT. if not True, this predict # true-object/duplicate self.reg_iou = self.cfg.MODEL.RELATION_NMS.REG_IOU self.first_n = cfg.MODEL.RELATION_NMS.FIRST_N self.NMS_thread = cfg.MODEL.RELATION_NMS.THREAD self.nms_rank_fc = nn.Linear( cfg.MODEL.RELATION_NMS.ROI_FEAT_DIM, cfg.MODEL.RELATION_NMS.APPEARANCE_FEAT_DIM, bias=True) self.roi_feat_embedding_fc = nn.Linear( cfg.MODEL.RELATION_NMS.ROI_FEAT_DIM, cfg.MODEL.RELATION_NMS.APPEARANCE_FEAT_DIM, bias=True) self.target_thresh = cfg.MODEL.RELATION_NMS.THREAD self.geo_feature_dim = cfg.MODEL.RELATION_NMS.GEO_FEAT_DIM if cfg.MODEL.RELATION_NMS.USE_IOU: self.geo_feature_dim = int(self.geo_feature_dim / 4 * 5) self.relation_module = RelationModule( cfg.MODEL.RELATION_NMS.APPEARANCE_FEAT_DIM, geo_feature_dim=self.geo_feature_dim, fc_dim=(self.geo_feature_dim, 16), group=cfg.MODEL.RELATION_NMS.GROUP, dim=cfg.MODEL.RELATION_NMS.HID_DIM, topk=cfg.MODEL.RELATION_NMS.TOPK, iou_method=cfg.MODEL.RELATION_NMS.IOU_METHOD) self.nms_fg_weight = torch.tensor([1., cfg.MODEL.RELATION_NMS.WEIGHT]) self.mt_fg_weight = torch.tensor([1., 10.]) self.alpha = cfg.MODEL.RELATION_NMS.ALPHA self.gamma = cfg.MODEL.RELATION_NMS.GAMMA self.boxcoder = BoxCoder(weights=(10., 10., 5., 5.)) self.class_agnostic = cfg.MODEL.RELATION_NMS.CLASS_AGNOSTIC self.fg_class = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES - 1 self.classifier = nn.Linear(128, len(self.target_thresh), bias=True) self.relu1 = nn.ReLU(inplace=True) self.fg_thread = cfg.MODEL.RELATION_NMS.FG_THREAD self.detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG self.nms = cfg.MODEL.RELATION_NMS.POS_NMS self.nms_loss_type = cfg.MT.NMS_LOSS_TYPE self.mode = None def set_teacher_mode(self, mode): self.mode = mode def forward(self, x): appearance_feature, proposals, cls_score, box_reg, targets = x self.device = appearance_feature.device appearance_feature = appearance_feature cls_score = cls_score box_reg = box_reg with torch.no_grad(): sorted_boxlists = self.prepare_ranking(cls_score, box_reg, proposals, targets, reg_iou=self.reg_iou) # concate value from different images boxes_per_image = [len(f) for f in proposals] idxs = [f.get_field('sorted_idx') for f in sorted_boxlists] scores = torch.cat([f.get_field('scores') for f in sorted_boxlists]) bboxes = torch.cat( [f.bbox.reshape(-1, self.fg_class, 4) for f in sorted_boxlists]) objectness = torch.cat([ f.get_field('objectness').reshape(-1, self.fg_class) for f in sorted_boxlists ]) all_scores = torch.cat( [f.get_field('all_scores') for f in sorted_boxlists]) # add iou information image_sizes = [f.size for f in sorted_boxlists] sorted_boxes_per_image = [[*f.shape][0] for f in idxs] appearance_feature = self.roi_feat_embedding_fc(appearance_feature) appearance_feature = appearance_feature.split(boxes_per_image, dim=0) sorted_features = [] nms_rank_embedding = [] for id, feature, box_per_image in zip(idxs, appearance_feature, boxes_per_image): feature = feature[id] size = feature.size() if size[0] <= self.first_n: first_n = size[0] else: first_n = self.first_n sorted_features.append(feature) #[rank_dim * batch , feat_dim] nms_rank_embedding.append( extract_rank_embedding( first_n, self.cfg.MODEL.RELATION_NMS.ROI_FEAT_DIM, device=feature.device)) # [first_n * batchsize, num_fg_classes, 128] sorted_features = torch.cat(sorted_features, dim=0) nms_rank_embedding = torch.cat(nms_rank_embedding, dim=0) nms_rank_embedding = self.nms_rank_fc(nms_rank_embedding) sorted_features = sorted_features + nms_rank_embedding[:, None, :] boxes_cls_1 = BoxList(bboxes[:, 0, :], image_sizes[0]) boxes_cls_2 = BoxList(bboxes[:, 1, :], image_sizes[0]) iou_1 = boxlist_iou(boxes_cls_1, boxes_cls_1) iou_2 = boxlist_iou(boxes_cls_2, boxes_cls_2) if self.cfg.MODEL.RELATION_NMS.USE_IOU: iou = [iou_1, iou_2] else: iou = None nms_position_matrix = extract_multi_position_matrix( bboxes, None, self.geo_feature_dim, 1000, clswise=self.cfg.MODEL.RELATION_NMS.CLS_WISE_RELATION, ) nms_attention_1 = self.relation_module(sorted_features, nms_position_matrix, iou) sorted_features = sorted_features + nms_attention_1 sorted_features = self.relu1(sorted_features) # [first_n * num_fg_classes, 128] sorted_features = sorted_features.view( -1, self.cfg.MODEL.RELATION_NMS.APPEARANCE_FEAT_DIM) sorted_features = self.classifier(sorted_features) # logit_reshape, [first_n, num_fg_classes, num_thread] sorted_features = sorted_features.view(-1, self.fg_class, len(self.target_thresh)) if not self.reg_iou: sorted_features = torch.sigmoid(sorted_features) scores = torch.cat([scores[:, :, None]] * len(self.target_thresh), dim=-1) loss_dict = {} if self.training: if self.reg_iou: # when use regression donot do sorted_features = scores * sorted_features reg_label = torch.cat( [f.get_field('labels_iou_reg') for f in sorted_boxlists]) reg_label = reg_label.to(scores.device) reg_label = reg_label.type(torch.cuda.FloatTensor) sorted_features = sorted_features.to(scores.device) sorted_features = sorted_features.type(torch.cuda.FloatTensor) if reg_label.shape is not None: reg_iou_loss = F.mse_loss(reg_label, sorted_features) else: reg_iou_loss = torch.tensor(0.).to(scores.device) loss_dict['nms_loss'] = reg_iou_loss else: sorted_features = scores * sorted_features labels = torch.cat( [f.get_field('labels') for f in sorted_boxlists]) labels = labels.to(scores.device) labels = labels.type(torch.cuda.FloatTensor) # WEIGHTED NMS nms_loss = F.binary_cross_entropy(scores * sorted_features, labels) loss_dict['nms_loss'] = nms_loss return None, loss_dict else: input_scores = scores if self.reg_iou: scores = sorted_features * (scores > self.fg_thread).float() else: scores = sorted_features * scores scores = self.merge_multi_thread_score_test(scores) scores = scores.split(sorted_boxes_per_image, dim=0) bboxes = bboxes.split(sorted_boxes_per_image, dim=0) input_scores = input_scores.split(sorted_boxes_per_image, dim=0) objectness = objectness.split(sorted_boxes_per_image, dim=0) all_scores = all_scores.split(sorted_boxes_per_image, dim=0) result = [] for i_score, score, bbox, obj, image_size, prob_boxhead \ in zip( input_scores, scores, bboxes, objectness, image_sizes, all_scores): result_per_image = [] # for nuclei index = (score[:, 1] >= self.fg_thread).nonzero()[:, 0] # cls_scores = i_score[index, i,0] cls_scores = score[index, 1] cls_scores_all = prob_boxhead[index, 1] cls_boxes = bbox[index, 1, :] cls_obj = obj[index, 1] boxlist_for_class = BoxList(cls_boxes, image_size, mode='xyxy') boxlist_for_class.add_field('scores', cls_scores) boxlist_for_class.add_field('objectness', cls_obj) boxlist_for_class.add_field('all_scores', cls_scores_all) boxlist_for_class = boxlist_nms(boxlist_for_class, 0.5, score_field="scores") num_labels = len(boxlist_for_class) boxlist_for_class.add_field( "labels", torch.full((num_labels, ), 2, dtype=torch.int64).to(self.device)) result_per_image.append(boxlist_for_class) index = (score[:, 0] >= self.fg_thread).nonzero()[:, 0] # cls_scores = i_score[index, i,0] cls_scores = score[index, 0] # pdb.set_trace() cls_scores_all = prob_boxhead[index, 0] cls_boxes = bbox[index, 0, :] cls_obj = obj[index, 0] boxlist_for_class = BoxList(cls_boxes, image_size, mode='xyxy') # Pos greedy NMS if POS_NMS!=-1 # boxlist_for_class.add_field('idx', index) boxlist_for_class.add_field('scores', cls_scores) boxlist_for_class.add_field('objectness', cls_obj) boxlist_for_class.add_field('all_scores', cls_scores_all) # pdb.set_trace() if self.nms: # for nuclei boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms, score_field="scores") # pdb.set_trace() num_labels = len(boxlist_for_class) boxlist_for_class.add_field( "labels", torch.full((num_labels, ), 1, dtype=torch.int64).to(self.device)) result_per_image.append(boxlist_for_class) result_per_image = cat_boxlist(result_per_image) number_of_detections = len(result_per_image) # Limit to max_per_image detections **over all classes** if number_of_detections > self.detections_per_img > 0: cls_scores = result_per_image.get_field("scores") image_thresh, _ = torch.kthvalue( cls_scores.cpu(), number_of_detections - self.detections_per_img + 1) keep = cls_scores >= image_thresh.item() keep = torch.nonzero(keep).squeeze(1) result_per_image = result_per_image[keep] result.append(result_per_image) return result, {} def prepare_reg_label(self, sorted_boxes, sorted_score, targets): ''' :param sorted_boxes: [ first n, fg_cls_num, 4] :param indice: [first n, fg_cls_num] :param sorted_score: [first n, fg_cls_num] :param targets: Boxlist obj :return: label [first n, num_thread * fg_cls_num] ''' TO_REMOVE = 1 labels = targets.get_field('labels') # output = np.zeros((sorted_boxes.shape[0].numpy(),)) # pdb.set_trace() # output_list = [] output_reg_list = [] for i in range(self.fg_class): cls_label_indice = torch.nonzero(labels == (i + 1)) cls_target_bbox = targets.bbox[cls_label_indice[:, 0]] # todo: avoid None gt situation num_valid_gt = len(cls_label_indice) if num_valid_gt == 0: output = np.zeros( ([*sorted_boxes.shape][0], len(self.target_thresh))) # output_reg = output.copy() # output_list.append(output) output_reg_list.append(output) else: output_list_per_class = [] output_reg_list_per_class = [] eye_matrix = np.eye(num_valid_gt) score_per_class = sorted_score[:, i:i + 1].cpu().numpy() boxes = sorted_boxes[:, i, :] boxes = boxes.view(-1, 4) area1 = (boxes[:, 2] - boxes[:, 0] + TO_REMOVE) * (boxes[:, 3] - boxes[:, 1] + TO_REMOVE) area2 = (cls_target_bbox[:, 2] - cls_target_bbox[:, 0] + TO_REMOVE) * (cls_target_bbox[:, 3] - cls_target_bbox[:, 1] + TO_REMOVE) lt = torch.max(boxes[:, None, :2], cls_target_bbox[:, :2]) # [N,M,2] rb = torch.min(boxes[:, None, 2:], cls_target_bbox[:, 2:]) # [N,M,2] wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] # [num_gt, first_n] iou = inter / (area1[:, None] + area2 - inter) iou = iou.cpu().numpy() try: for thresh in self.target_thresh: # pdb.set_trace() output_reg = np.max(iou, 1) # todo: temp comment overlap_mask = (iou > thresh) overlap_iou = iou * overlap_mask valid_bbox_indices = np.where(overlap_mask)[0] overlap_score = np.tile(score_per_class, (1, num_valid_gt)) overlap_score *= overlap_mask max_overlap_indices = np.argmax(iou, axis=1) max_overlap_mask = eye_matrix[max_overlap_indices] overlap_score *= max_overlap_mask overlap_iou = overlap_iou * max_overlap_mask max_score_indices = np.argmax(overlap_score, axis=0) max_overlap_iou = overlap_iou[ max_score_indices, np.arange(overlap_score.shape[1])] # output = np.zeros(([*sorted_boxes.shape][0],)) output_reg = np.zeros(([*sorted_boxes.shape][0], )) output_idx, inter_1, inter_2 = np.intersect1d( max_score_indices, valid_bbox_indices, return_indices=True) # output[output_idx] = 1 output_reg[output_idx] = max_overlap_iou[inter_1] # output_list_per_class.append(output) output_reg_list_per_class.append(output_reg) except: pdb.set_trace() # output_per_class = np.stack(output_list_per_class, axis=-1) output_reg_per_class = np.stack(output_reg_list_per_class, axis=-1) # pdb.set_trace() # output_list.append(output_per_class.view()) output_reg_list.append(output_reg_per_class) # output = np.stack(output_list, axis=1).astype(np.float32, copy=False) output_reg = np.stack(output_reg_list, axis=1).astype(np.float32, copy=False) return output_reg # return (output, output_reg) def prepare_label(self, sorted_boxes, sorted_score, targets): ''' :param sorted_boxes: [ first n, fg_cls_num, 4] :param indice: [first n, fg_cls_num] :param sorted_score: [first n, fg_cls_num] :param targets: Boxlist obj :return: label [first n, num_thread * fg_cls_num] ''' TO_REMOVE = 1 labels = targets.get_field('labels') # output = np.zeros((sorted_boxes.shape[0].numpy(),)) output_list = [] for i in range(self.fg_class): cls_label_indice = torch.nonzero(labels == (i + 1)) cls_target_bbox = targets.bbox[cls_label_indice[:, 0]] # todo: avoid None gt situation num_valid_gt = len(cls_label_indice) if num_valid_gt == 0: output = np.zeros( ([*sorted_boxes.shape][0], len(self.target_thresh))) output_list.append(output) else: output_list_per_class = [] eye_matrix = np.eye(num_valid_gt) score_per_class = sorted_score[:, i:i + 1].cpu().numpy() boxes = sorted_boxes[:, i, :] boxes = boxes.view(-1, 4) area1 = (boxes[:, 2] - boxes[:, 0] + TO_REMOVE) * (boxes[:, 3] - boxes[:, 1] + TO_REMOVE) area2 = (cls_target_bbox[:, 2] - cls_target_bbox[:, 0] + TO_REMOVE) * (cls_target_bbox[:, 3] - cls_target_bbox[:, 1] + TO_REMOVE) lt = torch.max(boxes[:, None, :2], cls_target_bbox[:, :2]) # [N,M,2] rb = torch.min(boxes[:, None, 2:], cls_target_bbox[:, 2:]) # [N,M,2] wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] # [num_gt, first_n] iou = inter / (area1[:, None] + area2 - inter) iou = iou.cpu().numpy() for thresh in self.target_thresh: overlap_mask = (iou > thresh) valid_bbox_indices = np.where(overlap_mask)[0] overlap_score = np.tile(score_per_class, (1, num_valid_gt)) overlap_score *= overlap_mask max_overlap_indices = np.argmax(iou, axis=1) max_overlap_mask = eye_matrix[max_overlap_indices] overlap_score *= max_overlap_mask max_score_indices = np.argmax(overlap_score, axis=0) output = np.zeros(([*sorted_boxes.shape][0], )) output[np.intersect1d(max_score_indices, valid_bbox_indices)] = 1 output_list_per_class.append(output) output_per_class = np.stack(output_list_per_class, axis=-1) output_list.append(output_per_class) output = np.stack(output_list, axis=1).astype(np.float32, copy=False) return output def prepare_ranking(self, cls_score, box_regression, proposals, targets, reg_iou=False): ''' :param score:[num_per_img*batchsize, class] :param proposal: list of boxlist :return: ''' # if is not train, targets is None which should be set into a none list boxes_per_image = [len(box) for box in proposals] concat_boxes = torch.cat([a.bbox for a in proposals], dim=0) image_shapes = [box.size for box in proposals] objectness = [f.get_field('objectness') for f in proposals] proposals = self.boxcoder.decode( box_regression.view(sum(boxes_per_image), -1), concat_boxes) proposals = proposals.split(boxes_per_image, dim=0) cls_score = cls_score.split(boxes_per_image, dim=0) results = [] if self.training: # if idx_t is None: for prob, boxes_per_img, image_shape, target, obj in zip( cls_score, proposals, image_shapes, targets, objectness): boxlist = self.filter_results(boxes_per_img, target, prob, image_shape, self.fg_class + 1, obj, reg_iou) results.append(boxlist) else: # test do not have target for prob, boxes_per_img, image_shape, obj in zip( cls_score, proposals, image_shapes, objectness): boxlist = self.filter_results(boxes_per_img, None, prob, image_shape, self.fg_class + 1, obj, reg_iou=reg_iou) results.append(boxlist) return results def filter_results(self, boxes, targets, scores, image_shape, num_classes, obj, reg_iou=False): """return the sorted boxlist and sorted idx """ # unwrap the boxlist to avoid additional overhead. # if we had multi-class NMS, we could perform this directly on the boxlist # boxes = boxlist.bbox.reshape(-1, num_classes * 4) #[n_roi, 4, cls] # boxes = boxlist.bbox.reshape(-1, 4, num_classes) boxes = boxes.reshape(-1, 4 * num_classes) scores = scores.reshape(-1, num_classes) # pdb.set_trace() if scores.shape[0] == 0: pdb.set_trace() cat_boxes = [] for j in range(1, num_classes): # skip class 0, because it is the background class cls_boxes = boxes[:, j * 4:(j + 1) * 4] cat_boxes.append(cls_boxes) boxes = torch.cat([bbox[:, :, None] for bbox in cat_boxes], dim=2) # scores = torch.cat([s for s in cat_score]) scores = scores[:, 1:] ori_scores = scores num_roi = boxes.shape[0] if num_roi <= self.first_n: first_n = num_roi # pdb.set_trace() else: first_n = self.first_n sorted_scores, indices = torch.topk(scores, first_n, dim=0, largest=True, sorted=True) if obj.shape[0] < first_n: indices = indices[:obj.shape[0]] sorted_scores = sorted_scores[:obj.shape[0]] if indices.shape[1] != 2: pdb.set_trace() cp_s = ori_scores.clone().cpu().numpy() cp_o = obj.clone().cpu().numpy() box = boxes.clone().cpu().numpy() ori_scores = ori_scores[indices] sorted_obj = obj[indices] sorted_boxes = boxes[indices] if sorted_boxes.shape[0] == 0: pdb.set_trace() if self.class_agnostic: # [first_n, num_fg_class, 4] sorted_boxes = torch.squeeze(sorted_boxes, dim=-1) else: try: mask = torch.arange(0, num_classes - 1).to(device=self.device) except: pdb.set_trace() try: mask = mask.view(1, -1, 1, 1).expand(first_n, num_classes - 1, 4, 1) except: pdb.set_trace() sorted_boxes = torch.gather(sorted_boxes, dim=3, index=mask).squeeze(dim=3) if self.training: labels = self.prepare_label(sorted_boxes, sorted_scores, targets) labels_cls = torch.from_numpy(labels).to(sorted_scores.device) if reg_iou: labels_reg = self.prepare_reg_label(sorted_boxes, sorted_scores, targets) labels_reg = torch.from_numpy(labels_reg).to( sorted_scores.device) sorted_boxes = sorted_boxes.view(first_n * (num_classes - 1), -1) sorted_obj = sorted_obj.view(first_n * (num_classes - 1)) boxlist = BoxList( sorted_boxes, image_shape, mode="xyxy", ) boxlist.add_field('sorted_idx', indices) boxlist.add_field('objectness', sorted_obj) boxlist.extra_fields['scores'] = sorted_scores boxlist.extra_fields["all_scores"] = ori_scores # boxlist.extra_fields[""] if self.training: if reg_iou: boxlist.extra_fields['labels_iou_reg'] = labels_reg else: boxlist.extra_fields['labels'] = labels_cls boxlist = boxlist.clip_to_image(remove_empty=False) return boxlist def merge_multi_thread_score_test(self, scores): if self.cfg.MODEL.RELATION_NMS.MERGE_METHOD == -1: scores = torch.mean(scores, -1) elif self.cfg.MODEL.RELATION_NMS.MERGE_METHOD == -2: scores = torch.max(scores, -1) else: idx = self.cfg.MODEL.RELATION_NMS.MERGE_METHOD idx = min(max(idx, 0), len(self.target_thresh)) scores = scores[:, :, idx] return scores
def forward(self, features, proposals, targets=None, proposals_sampled=None): """ Arguments: features (list[Tensor]): feature-maps from possibly several levels proposals (list[BoxList]): proposal boxes targets (list[BoxList], optional): the ground-truth targets. Returns: x (Tensor): the result of the feature extractor proposals (list[BoxList]): during training, the subsampled proposals are returned. During testing, the predicted boxlists are returned losses (dict[Tensor]): During training, returns the losses for the head. During testing, returns an empty dict. """ if self.training: # Faster R-CNN subsamples during training the proposals with a fixed # positive / negative ratio if proposals_sampled is None: with torch.no_grad(): proposals_sampled = self.loss_evaluator.subsample( proposals, targets) proposals = proposals_sampled # extract features that will be fed to the final classifier. The # feature_extractor generally corresponds to the pooler + heads x = self.feature_extractor(features, proposals) # final classifier that converts the features into predictions class_logits, box_regression = self.predictor(x) if not self.training: result = self.post_processor((class_logits, box_regression), proposals) return x, result, {} # TODO: loss is not needed for mean teacher when MT_ON if not self.cfg.MODEL.ROI_BOX_HEAD.FREEZE_WEIGHT: loss_classifier, loss_box_reg = self.loss_evaluator( [class_logits], [box_regression], proposals) if self.cfg.MODEL.ROI_BOX_HEAD.OUTPUT_DECODED_PROPOSAL: bbox_reg_weights = self.cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS box_coder = BoxCoder(weights=bbox_reg_weights) boxes_per_image = [len(box) for box in proposals] concat_boxes = torch.cat([a.bbox for a in proposals], dim=0) decoded_proposals = box_coder.decode( box_regression.view(sum(boxes_per_image), -1), concat_boxes) decoded_proposals = decoded_proposals.split(boxes_per_image, dim=0) # decoded_proposals = self.post_processor((class_logits, box_regression), proposals) # make sure there are valid proposals for i, boxes in enumerate(decoded_proposals): if len(boxes) > 0: proposals[i].bbox = boxes.reshape(-1, 4) loss_dict = dict() if self.cfg.MODEL.MT_ON: loss_dict.update(class_logits=class_logits, box_logits=box_regression) # loss_dict.update(class_logits=x, box_logits=x) # proposals_sampled.add_field('class_logits', class_logits) # proposals_sampled.add_field('box_logits', box_regression) if not self.is_mt and not self.cfg.MODEL.ROI_BOX_HEAD.FREEZE_WEIGHT: loss_dict.update( dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg)) return x, proposals, loss_dict
class MaskRCNNLossComputation(object): def __init__(self, proposal_matcher, discretization_size, use_mil_loss, use_aff, use_box_mask): """ Arguments: proposal_matcher (Matcher) discretization_size (int) """ self.proposal_matcher = proposal_matcher self.discretization_size = discretization_size center_weight = torch.zeros((3, 3)) center_weight[1][1] = 1. aff_weights = [] for i in range(3): for j in range(3): if i == 1 and j == 1: continue weight = torch.zeros((3, 3)) weight[i][j] = 1. aff_weights.append(center_weight - weight) aff_weights = [w.view(1, 1, 3, 3).to("cuda") for w in aff_weights] self.aff_weights = torch.cat(aff_weights, 0) self.box_coder = BoxCoder(weights=(10., 10., 5., 5.)) self.use_mil_loss = use_mil_loss self.use_aff = use_aff if use_box_mask: assert not use_mil_loss self.use_box_mask = use_box_mask def match_targets_to_proposals(self, proposal, target): match_quality_matrix = boxlist_iou(target, proposal) matched_idxs = self.proposal_matcher(match_quality_matrix) # Mask RCNN needs "labels" and "masks "fields for creating the targets target = target.copy_with_fields(["labels", "masks"]) # get the targets corresponding GT for each proposal # NB: need to clamp the indices because we can have a single # GT in the image, and matched_idxs can be -2, which goes # out of bounds matched_targets = target[matched_idxs.clamp(min=0)] matched_targets.add_field("matched_idxs", matched_idxs) return matched_targets def prepare_targets(self, proposals, targets): labels = [] masks = [] for proposals_per_image, targets_per_image in zip(proposals, targets): matched_targets = self.match_targets_to_proposals( proposals_per_image, targets_per_image) matched_idxs = matched_targets.get_field("matched_idxs") labels_per_image = matched_targets.get_field("labels") labels_per_image = labels_per_image.to(dtype=torch.int64) # this can probably be removed, but is left here for clarity # and completeness neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD labels_per_image[neg_inds] = 0 # mask scores are only computed on positive samples positive_inds = torch.nonzero(labels_per_image > 0).squeeze(1) segmentation_masks = matched_targets.get_field("masks") segmentation_masks = segmentation_masks[positive_inds] positive_proposals = proposals_per_image[positive_inds] masks_per_image = project_masks_on_boxes(segmentation_masks, positive_proposals, self.discretization_size) labels.append(labels_per_image) masks.append(masks_per_image) return labels, masks def prepare_targets_cr(self, proposals): # Sample both negative and positive proposals # Only with per col/row labels (without mask) labels = [] for proposals_per_image in proposals: regression_target = proposals_per_image.get_field( "regression_targets") matched_bbox = self.box_coder.decode(regression_target, proposals_per_image.bbox) M = self.discretization_size pos_masks_per_image = torch.ones(len(proposals_per_image), M, M, dtype=torch.float32) not_matched_idx = (regression_target != 0).any(1) # if a box fully matched the proposal, we set all values to 1 # otherwise, we project the box on proposal if not_matched_idx.any(): pos_masks_per_image[not_matched_idx] = project_boxes_on_boxes( matched_bbox[not_matched_idx], proposals_per_image[not_matched_idx], M) pos_masks_per_image = pos_masks_per_image.cuda() pos_labels = torch.cat( [pos_masks_per_image.sum(2), pos_masks_per_image.sum(1)], 1) pos_labels = (pos_labels > 0).float() labels.append(pos_labels) return labels def prepare_targets_boxes(self, proposals): masks = [] for proposals_per_image in proposals: regression_target = proposals_per_image.get_field( "regression_targets") matched_bbox = self.box_coder.decode(regression_target, proposals_per_image.bbox) M = self.discretization_size pos_masks_per_image = torch.ones(len(proposals_per_image), M, M, dtype=torch.float32) not_matched_idx = (regression_target != 0).any(1) # if a box fully matched the proposal, we set all values to 1 # otherwise, we project the box on proposal if not_matched_idx.any(): pos_masks_per_image[not_matched_idx] = project_boxes_on_boxes( matched_bbox[not_matched_idx], proposals_per_image[not_matched_idx], M) masks.append(pos_masks_per_image.cuda()) return masks def __call__(self, proposals, all_mask_logits, targets): """ Arguments: proposals (list[BoxList]) mask_logits (Tensor) targets (list[BoxList]) Return: mask_loss (Tensor): scalar tensor containing the loss """ labels = cat([p.get_field("proto_labels") for p in proposals]) labels = (labels > 0).long() pos_inds = torch.nonzero(labels > 0).squeeze(1) if not self.use_mil_loss: mask_logits = all_mask_logits[0] if self.use_box_mask: mask_targets = self.prepare_targets_boxes(proposals) else: _, mask_targets = self.prepare_targets(proposals, targets) mask_targets = cat(mask_targets, dim=0) labels_pos = labels[pos_inds] if mask_targets.numel() == 0: return mask_logits.sum() * 0 mask_loss = F.binary_cross_entropy_with_logits( mask_logits[pos_inds, labels_pos], mask_targets[pos_inds]) return mask_loss labels_cr = self.prepare_targets_cr(proposals) labels_cr = cat(labels_cr, dim=0) mil_losses = [] for mask_logits in all_mask_logits: mil_score = mask_logits[:, 1] mil_score = torch.cat( [mil_score.max(2)[0], mil_score.max(1)[0]], 1) # torch.mean (in binary_cross_entropy_with_logits) doesn't # accept empty tensors, so handle it separately if mil_score.numel() == 0: mil_losses.append(mask_logits.sum() * 0) mil_loss = F.binary_cross_entropy_with_logits( mil_score[pos_inds], labels_cr[pos_inds]) mil_losses.append(mil_loss) if self.use_aff: mask_logits = all_mask_logits[0] mask_logits_n = mask_logits[:, 1:].sigmoid() aff_maps = F.conv2d(mask_logits_n, self.aff_weights, padding=(1, 1)) affinity_loss = mask_logits_n * (aff_maps**2) affinity_loss = torch.mean(affinity_loss) return 1.2 * sum(mil_losses) / len( mil_losses) + 0.05 * affinity_loss else: return sum(mil_losses) / len(mil_losses)
class PostProcessor(nn.Module): """ From a set of classification scores, box regression and proposals, computes the post-processed boxes, and applies NMS to obtain the final results """ def __init__(self, cfg): super(PostProcessor, self).__init__() bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS self.box_coder = BoxCoder(weights=bbox_reg_weights) self.score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH self.nms_thresh = cfg.MODEL.ROI_HEADS.NMS self.detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG self.cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG self.bbox_aug_enabled = cfg.TEST.BBOX_AUG.ENABLED self.use_nms = not (cfg.MODEL.MASKIOU_ON and "ROI_MASKIOU_HEAD" in cfg.MODEL and cfg.MODEL.ROI_MASKIOU_HEAD.USE_NMS) if self.use_nms: self.nms_func = lambda boxes, scores: _box_nms( boxes, scores, self.nms_thresh) method = cfg.MODEL.ROI_HEADS.SOFT_NMS.METHOD self.use_soft_nms = cfg.MODEL.ROI_HEADS.USE_SOFT_NMS and method in [ 1, 2 ] if self.use_soft_nms: def soft_nms_func(boxes, scores): sigma = cfg.MODEL.ROI_HEADS.SOFT_NMS.SIGMA score_thresh = cfg.MODEL.ROI_HEADS.SOFT_NMS.SCORE_THRESH indices, keep, scores_new = _box_soft_nms( boxes.cpu(), scores.cpu(), nms_thresh=self.nms_thresh, sigma=sigma, score_thresh=score_thresh, method=method) return indices, keep, scores_new self.nms_func = soft_nms_func else: self.detections_per_img = -1 def forward(self, x, boxes): """ Arguments: x (tuple[tensor, tensor]): x contains the class logits and the box_regression from the model. boxes (list[BoxList]): bounding boxes that are used as reference, one for ech image Returns: results (list[BoxList]): one BoxList for each image, containing the extra fields labels and scores """ class_logits, box_regression = x class_prob = F.softmax(class_logits, -1) # TODO think about a representation of batch of boxes image_shapes = [box.size for box in boxes] boxes_per_image = [len(box) for box in boxes] concat_boxes = torch.cat([a.bbox for a in boxes], dim=0) if self.cls_agnostic_bbox_reg: box_regression = box_regression[:, -4:] proposals = self.box_coder.decode( box_regression.view(sum(boxes_per_image), -1), concat_boxes) if self.cls_agnostic_bbox_reg: proposals = proposals.repeat(1, class_prob.shape[1]) num_classes = class_prob.shape[1] proposals = proposals.split(boxes_per_image, dim=0) class_prob = class_prob.split(boxes_per_image, dim=0) results = [] for prob, boxes_per_img, image_shape in zip(class_prob, proposals, image_shapes): boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape) boxlist = boxlist.clip_to_image(remove_empty=False) if not self.bbox_aug_enabled: # If bbox aug is enabled, we will do it later boxlist = self.filter_results(boxlist, num_classes) results.append(boxlist) return results def prepare_boxlist(self, boxes, scores, image_shape): """ Returns BoxList from `boxes` and adds probability scores information as an extra field `boxes` has shape (#detections, 4 * #classes), where each row represents a list of predicted bounding boxes for each of the object classes in the dataset (including the background class). The detections in each row originate from the same object proposal. `scores` has shape (#detection, #classes), where each row represents a list of object detection confidence scores for each of the object classes in the dataset (including the background class). `scores[i, j]`` corresponds to the box at `boxes[i, j * 4:(j + 1) * 4]`. """ boxes = boxes.reshape(-1, 4) scores = scores.reshape(-1) boxlist = BoxList(boxes, image_shape, mode="xyxy") boxlist.add_field("scores", scores) return boxlist def filter_results(self, boxlist, num_classes): """Returns bounding-box detection results by thresholding on scores and applying non-maximum suppression (NMS). """ score_field = "scores" # unwrap the boxlist to avoid additional overhead. # if we had multi-class NMS, we could perform this directly on the boxlist boxes = boxlist.convert("xyxy").bbox.reshape(-1, num_classes * 4) scores = boxlist.get_field(score_field).reshape(-1, num_classes) device = scores.device result = [] # Apply threshold on detection probabilities and apply NMS # Skip j = 0, because it's the background class inds_all = scores > self.score_thresh for j in range(1, num_classes): inds = inds_all[:, j].nonzero().squeeze(1) scores_j = scores[inds, j] boxes_j = boxes[inds, j * 4:(j + 1) * 4] boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") boxlist_for_class.add_field(score_field, scores_j) if self.use_nms: keep = self.nms_func(boxes_j, scores_j) if self.use_soft_nms: indices, keep, scores_j_new = keep boxlist_for_class = boxlist_for_class[indices] boxlist_for_class.add_field(score_field, scores_j_new.to(device=device)) boxlist_for_class = boxlist_for_class[keep] num_labels = len(boxlist_for_class) boxlist_for_class.add_field( "labels", torch.full((num_labels, ), j, dtype=torch.int64, device=device)) result.append(boxlist_for_class) result = cat_boxlist(result) number_of_detections = len(result) # Limit to max_per_image detections **over all classes** if number_of_detections > self.detections_per_img > 0: cls_scores = result.get_field(score_field) image_thresh, _ = torch.kthvalue( cls_scores.cpu(), number_of_detections - self.detections_per_img + 1) keep = cls_scores >= image_thresh.item() keep = torch.nonzero(keep).squeeze(1) result = result[keep] return result
class ROIBoxHead(torch.nn.Module): """ Generic Box Head class. """ def __init__(self, cfg, in_channels): super(ROIBoxHead, self).__init__() self.feature_extractor = make_roi_box_feature_extractor( cfg, in_channels) self.predictor = make_roi_box_predictor( cfg, self.feature_extractor.out_channels) self.post_processor = make_roi_box_post_processor(cfg) self.loss_evaluator = make_roi_box_loss_evaluator(cfg) self.box_coder = BoxCoder(weights=(10., 10., 5., 5.)) self.cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG def add_gt_proposals(self, proposals, targets): """ Arguments: proposals: list[BoxList] targets: list[BoxList] """ # Get the device we're operating on device = proposals[0].bbox.device gt_boxes = [target.copy_with_fields([]) for target in targets] # later cat of bbox requires all fields to be present for all bbox # so we need to add a dummy for objectness that's missing # for gt_box in gt_boxes: # gt_box.add_field("objectness", torch.ones(len(gt_box), device=device)) proposals = [ cat_boxlist((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes) ] return proposals def box_regression_to_proposals(self, box_regression, boxes, is_train=True, class_logits=None): # TODO think about a representation of batch of boxes image_shapes = [box.size for box in boxes] boxes_per_image = [len(box) for box in boxes] concat_boxes = torch.cat([a.bbox for a in boxes], dim=0) # if self.cls_agnostic_bbox_reg: # box_regression = box_regression[:, -4:] device = box_regression.device if is_train: labels = cat([proposal.get_field("labels") for proposal in boxes], dim=0) sampled_pos_inds_subset = torch.nonzero(labels >= 0).squeeze(1) labels_pos = labels[sampled_pos_inds_subset] if self.cls_agnostic_bbox_reg: map_inds = torch.tensor([4, 5, 6, 7], device=device) else: map_inds = 4 * labels_pos[:, None] + torch.tensor( [0, 1, 2, 3], device=device) box_regression = box_regression[sampled_pos_inds_subset[:, None], map_inds] else: if class_logits is not None: if self.cls_agnostic_bbox_reg: bbox_label = torch.ones(len(class_logits), device=device) map_inds = torch.tensor([4, 5, 6, 7], device=device) else: bbox_label = class_logits.argmax(dim=1) map_inds = 4 * bbox_label[:, None] + torch.tensor( [0, 1, 2, 3], device=device) sampled_pos_inds_subset = torch.nonzero( bbox_label >= 0).squeeze(1) box_regression = box_regression[sampled_pos_inds_subset[:, None], map_inds] proposals = self.box_coder.decode( box_regression.view(sum(boxes_per_image), -1), concat_boxes) # if self.cls_agnostic_bbox_reg: # proposals = proposals.repeat(1, class_prob.shape[1]) # 按照图像split proposals = proposals.split(boxes_per_image, dim=0) result = [] for proposal, box in zip(proposals, boxes): im_shape = box.size boxlist = BoxList(proposal, im_shape, mode="xyxy") result.append(boxlist) # if is_train: # result = self.add_gt_proposals(result, boxes) return result # return proposals def forward(self, features, proposals, targets=None): """ Arguments: features (list[Tensor]): feature-maps from possibly several levels proposals (list[BoxList]): proposal boxes targets (list[BoxList], optional): the ground-truth targets. Returns: x (Tensor): the result of the feature extractor proposals (list[BoxList]): during training, the subsampled proposals are returned. During testing, the predicted boxlists are returned losses (dict[Tensor]): During training, returns the losses for the head. During testing, returns an empty dict. """ if self.training: # Faster R-CNN subsamples during training the proposals with a fixed # positive / negative ratio with torch.no_grad(): proposals = self.loss_evaluator.subsample(proposals, targets, stage=1) # extract features that will be fed to the final classifier. The # feature_extractor generally corresponds to the pooler + heads x = self.feature_extractor(features, proposals) # final classifier that converts the features into predictions class_logits, box_regression = self.predictor(x, stage=1) # 2nd stage with torch.no_grad(): proposals_stage2 = self.box_regression_to_proposals( box_regression, proposals) proposals_stage2 = self.add_gt_proposals( proposals_stage2, targets) proposals_stage2 = self.loss_evaluator.subsample( proposals_stage2, targets, stage=2) x = self.feature_extractor(features, proposals_stage2) class_logits_stage2, box_regression_stage2 = self.predictor( x, stage=2) # 3rd stage with torch.no_grad(): proposals_stage3 = self.box_regression_to_proposals( box_regression_stage2, proposals_stage2) proposals_stage3 = self.add_gt_proposals( proposals_stage3, targets) proposals_stage3 = self.loss_evaluator.subsample( proposals_stage3, targets, stage=3) x = self.feature_extractor(features, proposals_stage3) class_logits_stage3, box_regression_stage3 = self.predictor( x, stage=3) else: x = self.feature_extractor(features, proposals) # final classifier that converts the features into predictions class_logits, box_regression = self.predictor(x, stage=1) proposals_stage2 = self.box_regression_to_proposals( box_regression, proposals, False, class_logits) x = self.feature_extractor(features, proposals_stage2) class_logits_stage2, box_regression_stage2 = self.predictor( x, stage=2) proposals_stage3 = self.box_regression_to_proposals( box_regression_stage2, proposals_stage2, False, class_logits_stage2) x = self.feature_extractor(features, proposals_stage3) class_logits_stage3, box_regression_stage3 = self.predictor( x, stage=3) class_logits_stage2, box_regression_stage2 = self.predictor( x, stage=2) class_logits, box_regression = self.predictor(x, stage=1) class_logits_average = (class_logits + class_logits_stage2 + class_logits_stage3) / 3 if not self.training: result = self.post_processor( (class_logits_average, box_regression_stage3), proposals_stage3) return x, result, {} loss_classifier, loss_box_reg = self.loss_evaluator([class_logits], [box_regression], proposals) loss_classifier_stage2, loss_box_reg_stage2 = self.loss_evaluator( [class_logits_stage2], [box_regression_stage2], proposals_stage2) loss_classifier_stage3, loss_box_reg_stage3 = self.loss_evaluator( [class_logits_stage3], [box_regression_stage3], proposals_stage3) return (x, proposals, dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg, loss_classifier_stage2=loss_classifier_stage2, loss_box_reg_stage2=loss_box_reg_stage2, loss_classifier_stage3=loss_classifier_stage3, loss_box_reg_stage3=loss_box_reg_stage3))
class GAnchorGenerator(nn.Module): """ For a set of image sizes and feature maps, computes a set of anchors """ def __init__(self, aspect_ratios=(0.5, 1.0, 2.0), anchor_strides=(4, 8, 16, 32, 64), straddle_thresh=0, octave_base_scale=8, scales_per_octave=3, anchor_weights=(10., 10., 5., 5.), loc_filter_thr=0.01): super(GAnchorGenerator, self).__init__() if len(anchor_strides) == 1: anchor_stride = anchor_strides[0] approx_sizes = [ octave_base_scale * 2**(i * 1.0 / scales_per_octave) * anchor_stride for i in range(scales_per_octave) ] approx_anchors = [ generate_anchors(anchor_stride, approx_sizes, aspect_ratios).float() ] square_anchors = [ generate_anchors(anchor_stride, [octave_base_scale * anchor_stride], [1.0]).float() ] else: approx_anchors = [ generate_anchors(anchor_stride, [ octave_base_scale * 2**(i * 1.0 / scales_per_octave) * anchor_stride for i in range(scales_per_octave) ], aspect_ratios).float() for anchor_stride in anchor_strides ] square_anchors = [ generate_anchors( anchor_stride, [octave_base_scale * anchor_stride], [1.0], ).float() for anchor_stride in anchor_strides ] self.strides = anchor_strides self.approx_anchors = BufferList(approx_anchors) self.square_anchors = BufferList(square_anchors) self.straddle_thresh = straddle_thresh self.anchor_box_coder = BoxCoder(weights=anchor_weights) self.loc_filter_thr = loc_filter_thr def num_approx_anchors_per_location(self): return [len(approx_anchors) for approx_anchors in self.approx_anchors] def grid_square_anchors(self, grid_sizes): anchors = [] for size, stride, base_anchors in zip(grid_sizes, self.strides, self.square_anchors): grid_height, grid_width = size device = base_anchors.device shifts_x = torch.arange(0, grid_width * stride, step=stride, dtype=torch.float32, device=device) shifts_y = torch.arange(0, grid_height * stride, step=stride, dtype=torch.float32, device=device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) anchors.append( (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape( -1, 4)) return anchors def grid_approx_anchors(self, grid_sizes): anchors = [] for size, stride, base_anchors in zip(grid_sizes, self.strides, self.approx_anchors): grid_height, grid_width = size device = base_anchors.device shifts_x = torch.arange(0, grid_width * stride, step=stride, dtype=torch.float32, device=device) shifts_y = torch.arange(0, grid_height * stride, step=stride, dtype=torch.float32, device=device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) anchors.append( (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape( -1, 4)) return anchors def add_visibility_to(self, boxlist): image_width, image_height = boxlist.size anchors = boxlist.bbox if self.straddle_thresh >= 0: inds_inside = ( (anchors[..., 0] >= -self.straddle_thresh) & (anchors[..., 1] >= -self.straddle_thresh) & (anchors[..., 2] < image_width + self.straddle_thresh) & (anchors[..., 3] < image_height + self.straddle_thresh)) else: device = anchors.device inds_inside = torch.ones(anchors.shape[0], dtype=torch.uint8, device=device) boxlist.add_field("visibility", inds_inside) def forward(self, image_list, feature_maps, shapes, locs): grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] num_levels = len(feature_maps) square_anchors_over_all_feature_maps = self.grid_square_anchors( grid_sizes) square_anchors = [] for i, (image_height, image_width) in enumerate(image_list.image_sizes): square_anchors_in_image = [] for square_anchors_per_feature_map in square_anchors_over_all_feature_maps: boxlist = BoxList(square_anchors_per_feature_map, (image_width, image_height), mode="xyxy") self.add_visibility_to(boxlist) square_anchors_in_image.append(boxlist) square_anchors.append(square_anchors_in_image) guided_anchors = [] loc_masks = [] for img_id, (image_height, image_width) in enumerate(image_list.image_sizes): guided_anchors_in_image = [] loc_mask_in_image = [] for i in range(num_levels): squares = square_anchors[img_id][i] shape_pred = shapes[i][img_id] loc_pred = locs[i][img_id] guide_anchors_single, loc_mask_single = self.get_guided_anchors( squares, shape_pred, loc_pred, use_loc_filter=not self.training) guide_anchors_single = BoxList(guide_anchors_single, (image_width, image_height), mode="xyxy") self.add_visibility_to(guide_anchors_single) guided_anchors_in_image.append(guide_anchors_single) loc_mask_in_image.append(loc_mask_single) guided_anchors.append(guided_anchors_in_image) loc_masks.append(loc_mask_in_image) return square_anchors, guided_anchors, loc_masks def get_guided_anchors(self, squares, shape_pred, loc_pred, use_loc_filter=False): loc_pred = loc_pred.sigmoid().detach() if use_loc_filter: loc_mask = loc_pred >= self.loc_filter_thr else: loc_mask = loc_pred >= 0 mask = loc_mask.permute(1, 2, 0).expand(-1, -1, 1) mask = mask.contiguous().view(-1) squares = squares[mask] anchor_deltas = shape_pred.permute(1, 2, 0).contiguous().view( -1, 2).detach()[mask] bbox_deltas = anchor_deltas.new_full(squares.bbox.size(), 0) bbox_deltas[:, 2:] = anchor_deltas guided_anchors = self.anchor_box_coder.decode(bbox_deltas, squares.bbox) return guided_anchors, mask def get_sampled_approxs(self, image_list, feature_maps): grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] approx_anchors_over_all_feature_maps = self.grid_approx_anchors( grid_sizes) approx_anchors = [] for i, (image_height, image_width) in enumerate(image_list.image_sizes): approx_anchors_in_image = [] for approx_anchors_per_feature_map in approx_anchors_over_all_feature_maps: boxlist = BoxList(approx_anchors_per_feature_map, (image_width, image_height), mode="xyxy") self.add_visibility_to(boxlist) approx_anchors_in_image.append(boxlist) approx_anchors.append(approx_anchors_in_image) return approx_anchors