def __call__(self, box_cls, box_regression, centerness, targets, anchors): labels, reg_targets = self.prepare_targets(targets, anchors) N = len(labels) box_cls_flatten, box_regression_flatten = concat_box_prediction_layers( box_cls, box_regression) centerness_flatten = [ ct.permute(0, 2, 3, 1).reshape(N, -1, 1) for ct in centerness ] centerness_flatten = torch.cat(centerness_flatten, dim=1).reshape(-1) labels_flatten = torch.cat(labels, dim=0) reg_targets_flatten = torch.cat(reg_targets, dim=0) anchors_flatten = torch.cat([ cat_boxlist(anchors_per_image).bbox for anchors_per_image in anchors ], dim=0) pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1) num_gpus = get_num_gpus() total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel() ])).item() num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0) cls_loss = self.cls_loss_func( box_cls_flatten, labels_flatten.int()) / num_pos_avg_per_gpu if pos_inds.numel() > 0: box_regression_flatten = box_regression_flatten[pos_inds] reg_targets_flatten = reg_targets_flatten[pos_inds] anchors_flatten = anchors_flatten[pos_inds] centerness_flatten = centerness_flatten[pos_inds] centerness_targets = self.compute_centerness_targets( reg_targets_flatten, anchors_flatten) sum_centerness_targets_avg_per_gpu = reduce_sum( centerness_targets.sum()).item() / float(num_gpus) reg_loss = self.GIoULoss( box_regression_flatten, reg_targets_flatten, anchors_flatten, weight=centerness_targets) / sum_centerness_targets_avg_per_gpu centerness_loss = self.centerness_loss_func( centerness_flatten, centerness_targets) / num_pos_avg_per_gpu else: reg_loss = box_regression_flatten.sum() centerness_loss = reg_loss * 0 reg_loss = self.reg_loss_weight * reg_loss if self.sampling_free: return self.guided_loss( [cls_loss, reg_loss, centerness_loss], ["cls_loss", "reg_loss", "centerness_loss"]) else: return dict(cls_loss=cls_loss, reg_loss=reg_loss, centerness_loss=centerness_loss)
def forward(self, anchors, objectness, box_regression, targets=None): """ Arguments: anchors: list[list[BoxList]] objectness: list[tensor] box_regression: list[tensor] Returns: boxlists (list[BoxList]): the post-processed anchors, after applying box decoding and NMS """ sampled_boxes = [] num_levels = len(objectness) anchors = list(zip(*anchors)) for a, o, b in zip(anchors, objectness, box_regression): sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) boxlists = list(zip(*sampled_boxes)) boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] if num_levels > 1: boxlists = self.select_over_all_levels(boxlists) # append ground-truth bboxes to proposals if self.training and targets is not None: boxlists = self.add_gt_proposals(boxlists, targets) return boxlists
def prepare_iou_based_targets(self, targets, anchors): """Compute IoU-based targets""" cls_labels = [] reg_targets = [] matched_idx_all = [] for im_i in range(len(targets)): targets_per_im = targets[im_i] assert targets_per_im.mode == "xyxy" anchors_per_im = cat_boxlist(anchors[im_i]) match_quality_matrix = boxlist_iou(targets_per_im, anchors_per_im) matched_idxs, _ = self.matcher(match_quality_matrix) targets_per_im = targets_per_im.copy_with_fields(['labels']) matched_targets = targets_per_im[matched_idxs.clamp(min=0)] cls_labels_per_im = matched_targets.get_field("labels") cls_labels_per_im = cls_labels_per_im.to(dtype=torch.float32) # Background (negative examples) bg_indices = matched_idxs == Matcher.BELOW_LOW_THRESHOLD cls_labels_per_im[bg_indices] = 0 # discard indices that are between thresholds inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS cls_labels_per_im[inds_to_discard] = -1 matched_gts = matched_targets.bbox matched_idx_all.append(matched_idxs.view(1, -1)) reg_targets_per_im = self.box_coder.encode(matched_gts, anchors_per_im.bbox) cls_labels.append(cls_labels_per_im) reg_targets.append(reg_targets_per_im) return cls_labels, reg_targets, matched_idx_all
def __call__(self, anchors, objectness, box_regression, targets): """ Arguments: anchors (list[BoxList]) objectness (list[Tensor]) box_regression (list[Tensor]) targets (list[BoxList]) Returns: rpn_obj_loss (Tensor) box_loss (Tensor """ anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors] labels, regression_targets = self.prepare_targets(anchors, targets) if not self.sampling_free: sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1) sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1) sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) objectness, box_regression = \ concat_box_prediction_layers(objectness, box_regression) objectness = objectness.squeeze() labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) if self.sampling_free: positive, valid = labels > 0, labels >= 0 rpn_loc_loss = 0.5 * smooth_l1_loss( box_regression[positive], regression_targets[positive], beta=1.0 / 9, size_average=True, ) rpn_obj_loss = self.ce_loss(objectness[valid].view(-1,1), labels[valid].int().view(-1, 1)) / positive.sum() with torch.no_grad(): ratio = rpn_loc_loss / rpn_obj_loss rpn_obj_loss = ratio * rpn_obj_loss else: rpn_loc_loss = smooth_l1_loss( box_regression[sampled_pos_inds], regression_targets[sampled_pos_inds], beta=1.0 / 9, size_average=False, ) / (sampled_inds.numel()) rpn_obj_loss = F.binary_cross_entropy_with_logits( objectness[sampled_inds], labels[sampled_inds] ) return dict(rpn_obj_loss=rpn_obj_loss, rpn_loc_loss=rpn_loc_loss)
def forward(self, box_cls, box_regression, centerness, anchors): sampled_boxes = [] anchors = list(zip(*anchors)) for _, (o, b, c, a) in enumerate( zip(box_cls, box_regression, centerness, anchors)): sampled_boxes.append( self.forward_for_single_feature_map(o, b, c, a)) boxlists = list(zip(*sampled_boxes)) boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] if not (self.bbox_aug_enabled and not self.bbox_aug_vote): boxlists = self.select_over_all_levels(boxlists) return boxlists
def merge_result_from_multi_scales(boxlists, nms_type='nms', vote_thresh=0.65): num_images = len(boxlists) results = [] for i in range(num_images): ssampling_frees = boxlists[i].get_field("ssampling_frees") labels = boxlists[i].get_field("labels") boxes = boxlists[i].bbox boxlist = boxlists[i] result = [] # skip the background for j in range(1, cfg.MODEL.RETINANET.NUM_CLASSES): inds = (labels == j).nonzero().view(-1) ssampling_frees_j = ssampling_frees[inds] boxes_j = boxes[inds, :].view(-1, 4) boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") boxlist_for_class.add_field("ssampling_frees", ssampling_frees_j) boxlist_for_class = boxlist_nms( boxlist_for_class, cfg.MODEL.ATSS.NMS_TH, ssampling_free_field="ssampling_frees", nms_type=nms_type, vote_thresh=vote_thresh) num_labels = len(boxlist_for_class) boxlist_for_class.add_field( "labels", torch.full((num_labels, ), j, dtype=torch.int64, device=ssampling_frees.device)) result.append(boxlist_for_class) result = cat_boxlist(result) number_of_detections = len(result) # Limit to max_per_image detections **over all classes** if number_of_detections > cfg.MODEL.ATSS.PRE_NMS_TOP_N > 0: cls_ssampling_frees = result.get_field("ssampling_frees") image_thresh, _ = torch.kthvalue( cls_ssampling_frees.cpu(), number_of_detections - cfg.MODEL.ATSS.PRE_NMS_TOP_N + 1) keep = cls_ssampling_frees >= image_thresh.item() keep = torch.nonzero(keep).squeeze(1) result = result[keep] results.append(result) return results
def select_over_all_levels(self, boxlists): num_images = len(boxlists) results = [] for i in range(num_images): scores = boxlists[i].get_field("scores") labels = boxlists[i].get_field("labels") boxes = boxlists[i].bbox boxlist = boxlists[i] result = [] # skip the background for j in range(1, self.num_classes): inds = (labels == j).nonzero().view(-1) scores_j = scores[inds] boxes_j = boxes[inds, :].view(-1, 4) boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") boxlist_for_class.add_field("scores", scores_j) boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms_thresh, score_field="scores") num_labels = len(boxlist_for_class) boxlist_for_class.add_field( "labels", torch.full((num_labels, ), j, dtype=torch.int64, device=scores.device)) result.append(boxlist_for_class) result = cat_boxlist(result) number_of_detections = len(result) # Limit to max_per_image detections **over all classes** if number_of_detections > self.fpn_post_nms_top_n > 0: cls_scores = result.get_field("scores") image_thresh, _ = torch.kthvalue( cls_scores.cpu(), number_of_detections - self.fpn_post_nms_top_n + 1) keep = cls_scores >= image_thresh.item() keep = torch.nonzero(keep).squeeze(1) result = result[keep] results.append(result) return results
def filter_results(self, boxlist, num_classes): """Returns bounding-box detection results by thresholding on scores and applying non-maximum suppression (NMS). """ # unwrap the boxlist to avoid additional overhead. # if we had multi-class NMS, we could perform this directly on the boxlist boxes = boxlist.bbox.reshape(-1, num_classes * 4) scores = boxlist.get_field("scores").reshape(-1, num_classes) device = scores.device result = [] # Apply threshold on detection probabilities and apply NMS # Skip j = 0, because it's the background class inds_all = scores > self.score_thresh for j in range(1, num_classes): inds = inds_all[:, j].nonzero().squeeze(1) scores_j = scores[inds, j] boxes_j = boxes[inds, j * 4:(j + 1) * 4] boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") boxlist_for_class.add_field("scores", scores_j) boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms) num_labels = len(boxlist_for_class) boxlist_for_class.add_field( "labels", torch.full((num_labels, ), j, dtype=torch.int64, device=device)) result.append(boxlist_for_class) result = cat_boxlist(result) number_of_detections = len(result) # Limit to max_per_image detections **over all classes** if number_of_detections > self.detections_per_img > 0: cls_scores = result.get_field("scores") image_thresh, _ = torch.kthvalue( cls_scores.cpu(), number_of_detections - self.detections_per_img + 1) keep = cls_scores >= image_thresh.item() keep = torch.nonzero(keep).squeeze(1) result = result[keep] return result
def __call__(self, anchors, box_cls, box_regression, targets): """ Arguments: anchors (list[BoxList]) box_cls (list[Tensor]) box_regression (list[Tensor]) targets (list[BoxList]) Returns: cls_loss (Tensor) retinanet_regression_loss (Tensor) """ anchors = [ cat_boxlist(anchors_per_image) for anchors_per_image in anchors ] labels, regression_targets = self.prepare_targets(anchors, targets) N = len(labels) box_cls, box_regression = concat_box_prediction_layers( box_cls, box_regression) labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) pos_inds = torch.nonzero(labels > 0).squeeze(1) pos_numel = pos_inds.numel() loc_loss = smooth_l1_loss( box_regression[pos_inds], regression_targets[pos_inds], beta=self.bbox_reg_beta, size_average=False, ) / max(1, pos_numel * self.regress_norm) cls_loss = self.box_cls_loss_func(box_cls, labels.int()) / max( 1, pos_numel) if self.sampling_free: return self.guided_loss([cls_loss, loc_loss], ["cls_loss", "loc_loss"]) else: return dict(cls_loss=cls_loss, loc_loss=loc_loss)
def add_gt_proposals(self, proposals, targets): """ Arguments: proposals: list[BoxList] targets: list[BoxList] """ # Get the device we're operating on device = proposals[0].bbox.device gt_boxes = [target.copy_with_fields([]) for target in targets] # later cat of bbox requires all fields to be present for all bbox # so we need to add a dummy for objectness that's missing for gt_box in gt_boxes: gt_box.add_field("objectness", torch.ones(len(gt_box), device=device)) proposals = [ cat_boxlist((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes) ] return proposals
def forward(self, locations, box_cls, box_regression, centerness, image_sizes): """ Arguments: anchors: list[list[BoxList]] box_cls: list[tensor] box_regression: list[tensor] image_sizes: list[(h, w)] Returns: boxlists (list[BoxList]): the post-processed anchors, after applying box decoding and NMS """ sampled_boxes = [] for _, (l, o, b, c) in enumerate( zip(locations, box_cls, box_regression, centerness)): sampled_boxes.append( self.forward_for_single_feature_map(l, o, b, c, image_sizes)) boxlists = list(zip(*sampled_boxes)) boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] if not self.bbox_aug_enabled: boxlists = self.select_over_all_levels(boxlists) return boxlists
def prepare_targets(self, targets, anchors): cls_labels = [] reg_targets = [] for im_i in range(len(targets)): targets_per_im = targets[im_i] assert targets_per_im.mode == "xyxy" bboxes_per_im = targets_per_im.bbox labels_per_im = targets_per_im.get_field("labels") anchors_per_im = cat_boxlist(anchors[im_i]) num_gt = bboxes_per_im.shape[0] if self.positive_type == 'SSC': object_sizes_of_interest = [[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]] area_per_im = targets_per_im.area() expanded_object_sizes_of_interest = [] points = [] for l, anchors_per_level in enumerate(anchors[im_i]): anchors_per_level = anchors_per_level.bbox anchors_cx_per_level = (anchors_per_level[:, 2] + anchors_per_level[:, 0]) / 2.0 anchors_cy_per_level = (anchors_per_level[:, 3] + anchors_per_level[:, 1]) / 2.0 points_per_level = torch.stack( (anchors_cx_per_level, anchors_cy_per_level), dim=1) points.append(points_per_level) object_sizes_of_interest_per_level = \ points_per_level.new_tensor(object_sizes_of_interest[l]) expanded_object_sizes_of_interest.append( object_sizes_of_interest_per_level[None].expand( len(points_per_level), -1)) expanded_object_sizes_of_interest = torch.cat( expanded_object_sizes_of_interest, dim=0) points = torch.cat(points, dim=0) xs, ys = points[:, 0], points[:, 1] l = xs[:, None] - bboxes_per_im[:, 0][None] t = ys[:, None] - bboxes_per_im[:, 1][None] r = bboxes_per_im[:, 2][None] - xs[:, None] b = bboxes_per_im[:, 3][None] - ys[:, None] reg_targets_per_im = torch.stack([l, t, r, b], dim=2) is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0.01 max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0] is_cared_in_the_level = \ (max_reg_targets_per_im >= expanded_object_sizes_of_interest[:, [0]]) & \ (max_reg_targets_per_im <= expanded_object_sizes_of_interest[:, [1]]) locations_to_gt_area = area_per_im[None].repeat(len(points), 1) locations_to_gt_area[is_in_boxes == 0] = INF locations_to_gt_area[is_cared_in_the_level == 0] = INF locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min( dim=1) cls_labels_per_im = labels_per_im[locations_to_gt_inds] cls_labels_per_im[locations_to_min_area == INF] = 0 matched_gts = bboxes_per_im[locations_to_gt_inds] elif self.positive_type == 'ATSS': num_anchors_per_level = [ len(anchors_per_level.bbox) for anchors_per_level in anchors[im_i] ] ious = boxlist_iou(anchors_per_im, targets_per_im) gt_cx = (bboxes_per_im[:, 2] + bboxes_per_im[:, 0]) / 2.0 gt_cy = (bboxes_per_im[:, 3] + bboxes_per_im[:, 1]) / 2.0 gt_points = torch.stack((gt_cx, gt_cy), dim=1) anchors_cx_per_im = (anchors_per_im.bbox[:, 2] + anchors_per_im.bbox[:, 0]) / 2.0 anchors_cy_per_im = (anchors_per_im.bbox[:, 3] + anchors_per_im.bbox[:, 1]) / 2.0 anchor_points = torch.stack( (anchors_cx_per_im, anchors_cy_per_im), dim=1) distances = (anchor_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt() # Selecting candidates based on the center distance between anchor box and object candidate_idxs = [] star_idx = 0 for level, anchors_per_level in enumerate(anchors[im_i]): end_idx = star_idx + num_anchors_per_level[level] distances_per_level = distances[star_idx:end_idx, :] _, topk_idxs_per_level = distances_per_level.topk( self.topk, dim=0, largest=False) candidate_idxs.append(topk_idxs_per_level + star_idx) star_idx = end_idx candidate_idxs = torch.cat(candidate_idxs, dim=0) # Using the sum of mean and standard deviation as the IoU threshold to select final positive samples candidate_ious = ious[candidate_idxs, torch.arange(num_gt)] iou_mean_per_gt = candidate_ious.mean(0) iou_std_per_gt = candidate_ious.std(0) iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt is_pos = candidate_ious >= iou_thresh_per_gt[None, :] # Limiting the final positive samples’ center to object anchor_num = anchors_cx_per_im.shape[0] for ng in range(num_gt): candidate_idxs[:, ng] += ng * anchor_num e_anchors_cx = anchors_cx_per_im.view(1, -1).expand( num_gt, anchor_num).contiguous().view(-1) e_anchors_cy = anchors_cy_per_im.view(1, -1).expand( num_gt, anchor_num).contiguous().view(-1) candidate_idxs = candidate_idxs.view(-1) l = e_anchors_cx[candidate_idxs].view( -1, num_gt) - bboxes_per_im[:, 0] t = e_anchors_cy[candidate_idxs].view( -1, num_gt) - bboxes_per_im[:, 1] r = bboxes_per_im[:, 2] - e_anchors_cx[candidate_idxs].view( -1, num_gt) b = bboxes_per_im[:, 3] - e_anchors_cy[candidate_idxs].view( -1, num_gt) is_in_gts = torch.stack([l, t, r, b], dim=1).min(dim=1)[0] > 0.01 is_pos = is_pos & is_in_gts # if an anchor box is assigned to multiple gts, the one with the highest IoU will be selected. ious_inf = torch.full_like(ious, -INF).t().contiguous().view(-1) index = candidate_idxs.view(-1)[is_pos.view(-1)] ious_inf[index] = ious.t().contiguous().view(-1)[index] ious_inf = ious_inf.view(num_gt, -1).t() anchors_to_gt_values, anchors_to_gt_indexs = ious_inf.max( dim=1) cls_labels_per_im = labels_per_im[anchors_to_gt_indexs] cls_labels_per_im[anchors_to_gt_values == -INF] = 0 matched_gts = bboxes_per_im[anchors_to_gt_indexs] elif self.positive_type == 'IoU': match_quality_matrix = boxlist_iou(targets_per_im, anchors_per_im) matched_idxs = self.matcher(match_quality_matrix) targets_per_im = targets_per_im.copy_with_fields(['labels']) matched_targets = targets_per_im[matched_idxs.clamp(min=0)] cls_labels_per_im = matched_targets.get_field("labels") cls_labels_per_im = cls_labels_per_im.to(dtype=torch.float32) # Background (negative examples) bg_indices = matched_idxs == Matcher.BELOW_LOW_THRESHOLD cls_labels_per_im[bg_indices] = 0 # discard indices that are between thresholds inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS cls_labels_per_im[inds_to_discard] = -1 matched_gts = matched_targets.bbox # Limiting positive samples’ center to object # in order to filter out poor positives and use the centerness branch pos_idxs = torch.nonzero(cls_labels_per_im > 0).squeeze(1) pos_anchors_cx = (anchors_per_im.bbox[pos_idxs, 2] + anchors_per_im.bbox[pos_idxs, 0]) / 2.0 pos_anchors_cy = (anchors_per_im.bbox[pos_idxs, 3] + anchors_per_im.bbox[pos_idxs, 1]) / 2.0 l = pos_anchors_cx - matched_gts[pos_idxs, 0] t = pos_anchors_cy - matched_gts[pos_idxs, 1] r = matched_gts[pos_idxs, 2] - pos_anchors_cx b = matched_gts[pos_idxs, 3] - pos_anchors_cy is_in_gts = torch.stack([l, t, r, b], dim=1).min(dim=1)[0] > 0.01 cls_labels_per_im[pos_idxs[is_in_gts == 0]] = -1 else: raise NotImplementedError reg_targets_per_im = self.box_coder.encode(matched_gts, anchors_per_im.bbox) cls_labels.append(cls_labels_per_im) reg_targets.append(reg_targets_per_im) return cls_labels, reg_targets
def __call__(self, box_cls, box_regression, iou_pred, targets, anchors, locations): # get IoU-based anchor assignment first to compute anchor scores (iou_based_labels, iou_based_reg_targets, matched_idx_all) = self.prepare_iou_based_targets(targets, anchors) matched_idx_all = torch.cat(matched_idx_all, dim=0) N = len(iou_based_labels) iou_based_labels_flatten = torch.cat(iou_based_labels, dim=0).int() iou_based_reg_targets_flatten = torch.cat(iou_based_reg_targets, dim=0) box_cls_flatten, box_regression_flatten = concat_box_prediction_layers( box_cls, box_regression) anchors_flatten = torch.cat([cat_boxlist(anchors_per_image).bbox for anchors_per_image in anchors], dim=0) iou_pred_flatten = [ip.permute(0, 2, 3, 1).reshape(N, -1, 1) for ip in iou_pred] iou_pred_flatten = torch.cat(iou_pred_flatten, dim=1).reshape(-1) pos_inds = torch.nonzero(iou_based_labels_flatten > 0, as_tuple=False).squeeze(1) if pos_inds.numel() > 0: n_loss_per_box = 1 if 'iou' in self.reg_loss_type else 4 # compute anchor scores (losses) for all anchors iou_based_cls_loss = self.cls_loss_func(box_cls_flatten.detach(), iou_based_labels_flatten, sum=False) iou_based_reg_loss = self.compute_reg_loss(iou_based_reg_targets_flatten, box_regression_flatten.detach(), anchors_flatten, iou_based_labels_flatten, weights=None) iou_based_cls_loss *= iou_based_reg_loss.sum() / iou_based_cls_loss.sum() iou_based_reg_loss_full = torch.full((iou_based_cls_loss.shape[0],), fill_value=INF, device=iou_based_cls_loss.device, dtype=iou_based_cls_loss.dtype) iou_based_reg_loss_full[pos_inds] = iou_based_reg_loss.view(-1, n_loss_per_box).mean(1) combined_loss = iou_based_cls_loss.sum(dim=1) + iou_based_reg_loss_full assert not torch.isnan(combined_loss).any() # compute labels and targets using PAA labels, reg_targets = self.compute_paa( targets, anchors, iou_based_labels_flatten.view(N, -1), combined_loss.view(N, -1), matched_idx_all) labels_flatten = torch.cat(labels, dim=0).int() reg_targets_flatten = torch.cat(reg_targets, dim=0) pos_inds = torch.nonzero(labels_flatten > 0, as_tuple=False).squeeze(1) total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel()])).item() num_pos_avg_per_gpu = max(total_num_pos / self.num_gpus, 1.0) box_regression_flatten = box_regression_flatten[pos_inds] reg_targets_flatten = reg_targets_flatten[pos_inds] anchors_flatten = anchors_flatten[pos_inds] # compute iou prediction targets iou_pred_flatten = iou_pred_flatten[pos_inds] gt_boxes = self.box_coder.decode(reg_targets_flatten, anchors_flatten) boxes = self.box_coder.decode(box_regression_flatten, anchors_flatten).detach() ious = self.compute_ious(gt_boxes, boxes) # compute iou losses iou_pred_loss = self.iou_pred_loss_func( iou_pred_flatten, ious) / num_pos_avg_per_gpu * self.iou_loss_weight sum_ious_targets_avg_per_gpu = reduce_sum(ious.sum()).item() / self.num_gpus # set regression loss weights to ious between predicted boxes and GTs reg_loss_weight = ious reg_loss = self.compute_reg_loss(reg_targets_flatten, box_regression_flatten, anchors_flatten, labels_flatten[pos_inds], weights=reg_loss_weight) cls_loss = self.cls_loss_func(box_cls_flatten, labels_flatten.int(), sum=False) else: reg_loss = box_regression_flatten.sum() res = [cls_loss.sum() / num_pos_avg_per_gpu, reg_loss.sum() / sum_ious_targets_avg_per_gpu * self.reg_loss_weight, iou_pred_loss] if self.sampling_free: return self.guided_loss(res, ["cls_loss", "reg_loss", "iou_pred_loss"]) else: return {"cls_loss": res[0], "reg_loss": res[1], "iou_pred_loss": res[2]}
def compute_paa(self, targets, anchors, labels_all, loss_all, matched_idx_all): """ Args: targets (batch_size): list of BoxLists for GT bboxes anchors (batch_size, feature_lvls): anchor boxes per feature level labels_all (batch_size x num_anchors): assigned labels loss_all (batch_size x num_anchors): calculated loss matched_idx_all (batch_size x num_anchors): best-matched GG bbox indexes """ device = loss_all.device cls_labels = [] reg_targets = [] for im_i in range(len(targets)): targets_per_im = targets[im_i] assert targets_per_im.mode == "xyxy" bboxes_per_im = targets_per_im.bbox labels_per_im = targets_per_im.get_field("labels") anchors_per_im = cat_boxlist(anchors[im_i]) labels_all_per_im = labels_all[im_i] loss_all_per_im = loss_all[im_i] matched_idx_all_per_im = matched_idx_all[im_i] assert labels_all_per_im.shape == matched_idx_all_per_im.shape num_anchors_per_level = [len(anchors_per_level.bbox) for anchors_per_level in anchors[im_i]] # select candidates based on IoUs between anchors and GTs candidate_idxs = [] num_gt = bboxes_per_im.shape[0] for gt in range(num_gt): candidate_idxs_per_gt = [] star_idx = 0 for level, anchors_per_level in enumerate(anchors[im_i]): end_idx = star_idx + num_anchors_per_level[level] loss_per_level = loss_all_per_im[star_idx:end_idx] labels_per_level = labels_all_per_im[star_idx:end_idx] matched_idx_per_level = matched_idx_all_per_im[star_idx:end_idx] match_idx = torch.nonzero( (matched_idx_per_level == gt) & (labels_per_level > 0), as_tuple=False )[:, 0] if match_idx.numel() > 0: _, topk_idxs = loss_per_level[match_idx].topk( min(match_idx.numel(), self.topk), largest=False) topk_idxs_per_level_per_gt = match_idx[topk_idxs] candidate_idxs_per_gt.append(topk_idxs_per_level_per_gt + star_idx) star_idx = end_idx if candidate_idxs_per_gt: candidate_idxs.append(torch.cat(candidate_idxs_per_gt)) else: candidate_idxs.append(None) # fit 2-mode GMM per GT box n_labels = anchors_per_im.bbox.shape[0] cls_labels_per_im = torch.zeros(n_labels, dtype=torch.long).to(device) matched_gts = torch.zeros_like(anchors_per_im.bbox) fg_inds = matched_idx_all_per_im >= 0 matched_gts[fg_inds] = bboxes_per_im[matched_idx_all_per_im[fg_inds]] is_grey = None for gt in range(num_gt): if candidate_idxs[gt] is not None: if candidate_idxs[gt].numel() > 1: candidate_loss = loss_all_per_im[candidate_idxs[gt]] candidate_loss, inds = candidate_loss.sort() candidate_loss = candidate_loss.view(-1, 1).cpu().numpy() min_loss, max_loss = candidate_loss.min(), candidate_loss.max() means_init=[[min_loss], [max_loss]] weights_init = [0.5, 0.5] precisions_init=[[[1.0]], [[1.0]]] gmm = skm.GaussianMixture(2, weights_init=weights_init, means_init=means_init, precisions_init=precisions_init) gmm.fit(candidate_loss) components = gmm.predict(candidate_loss) scores = gmm.score_samples(candidate_loss) components = torch.from_numpy(components).to(device) scores = torch.from_numpy(scores).to(device) fgs = components == 0 bgs = components == 1 if torch.nonzero(fgs, as_tuple=False).numel() > 0: # Fig 3. (c) fg_max_score = scores[fgs].max().item() fg_max_idx = torch.nonzero(fgs & (scores == fg_max_score), as_tuple=False).min() is_neg = inds[fgs | bgs] is_pos = inds[:fg_max_idx+1] else: # just treat all samples as positive for high recall. is_pos = inds is_neg = is_grey = None else: is_pos = 0 is_neg = None is_grey = None if is_grey is not None: grey_idx = candidate_idxs[gt][is_grey] cls_labels_per_im[grey_idx] = -1 if is_neg is not None: neg_idx = candidate_idxs[gt][is_neg] cls_labels_per_im[neg_idx] = 0 pos_idx = candidate_idxs[gt][is_pos] cls_labels_per_im[pos_idx] = labels_per_im[gt].view(-1, 1) matched_gts[pos_idx] = bboxes_per_im[gt].view(-1, 4) reg_targets_per_im = self.box_coder.encode(matched_gts, anchors_per_im.bbox) cls_labels.append(cls_labels_per_im) reg_targets.append(reg_targets_per_im) return cls_labels, reg_targets