def forward(self, box_preds, gt_boxes): """ Parameters ---------- box_preds: Predicted bounding boxes. (batch, xx, 4). gt_boxes: Ground-truth bounding boxes. Returns ------- (tuple of) tensor. objectness: 0 for negative, 1 for positive, -1 for ignore. (batch, xx, 1). center_targets: regression target for center x and y. (batch, xx, 2). scale_targets: regression target for scale x and y. (batch, xx, 2). weights: element-wise gradient weights for center_targets and scale_targets. class_targets: a one-hot vector for classification. (batch, xx, 80). """ with torch.no_grad(): objness_t = torch.zeros_like( torch.unsqueeze(box_preds[:, :, 0], -1)) center_t = torch.zeros_like(box_preds[:, :, 0:2]) scale_t = torch.zeros_like(box_preds[:, :, 0:2]) weight_t = torch.zeros_like(box_preds[:, :, 0:2]) class_t = torch.ones_like(objness_t.repeat(1, 1, self._num_class)) * -1 ious_max = [] for box_preds_per_img, gt_boxes_per_img in zip( box_preds, gt_boxes): ious = bbox_overlaps(box_preds_per_img, gt_boxes_per_img) ious_max.append(torch.max( ious, dim=-1, keepdim=True)[0]) # (h*w*num_anchors, 1) ious_max = torch.stack(ious_max, dim=0) # use -1 for ignored. objness_t = (ious_max > self._ignore_iou_thresh).to( torch.float32) * -1 return objness_t, center_t, scale_t, weight_t, class_t
def forward_single_image(self, gt_boxes, gt_labels, img_metas, shift_anchor_boxes, shape_like, num_anchors, anchors, pad_shape, all_featmaps, num_offsets): # shape_like: (h3*w3+h2*w2+h1*w1, 9 anchors, 2). center_targets = torch.zeros(shape_like).cuda() scale_targets = torch.zeros_like(center_targets) weights = torch.zeros_like(center_targets) objectness = torch.zeros_like(weights.split(1, dim=-1)[0]) class_targets = torch.ones_like(objectness).repeat( 1, 1, self._num_class) * -1 gtx, gty, gtw, gth = point_to_center(gt_boxes, split=True, keep_axis=True) shift_gt_boxes = torch.cat( (-0.5 * gtw, -0.5 * gth, 0.5 * gtw, 0.5 * gth), dim=-1) # ious between zero-center anchors(9) and zero-center gt boxes(gt num). ious = bbox_overlaps(shift_anchor_boxes, shift_gt_boxes) # assume the center of gt and anchor is aligned and find the best matched anchor scale. matches = ious.argmax(dim=0).to(torch.int32) # (num_gt,) valid_gts = (gt_boxes >= 0).prod(dim=-1) # (num_gt,) pad_height, pad_width = pad_shape for m in range(matches.shape[0]): # for each gt in a single image. if valid_gts[m] < 1: break match = matches[m] # matched anchor idx, note that 0 <= match < 9. nlayer = np.nonzero(num_anchors > match)[0][0] height = all_featmaps[nlayer].shape[2] width = all_featmaps[nlayer].shape[3] mgtx, mgty, mgtw, mgth = (gtx[m, 0], gty[m, 0], gtw[m, 0], gth[m, 0]) # compute the location of the gt top-left centers on the feature map level. loc_x = (mgtx / pad_width * width).to(torch.int32) loc_y = (mgty / pad_height * height).to(torch.int32) # write back to targets index = num_offsets[nlayer] + loc_y * width + loc_x center_targets[index, match, 0] = mgtx / pad_width * width - loc_x # tx center_targets[index, match, 1] = mgty / pad_height * height - loc_y # ty scale_targets[index, match, 0] = torch.log(max(mgtw, 1) / anchors[match, 0]) scale_targets[index, match, 1] = torch.log(max(mgth, 1) / anchors[match, 1]) weights[index, match, :] = 2.0 - mgtw * mgth / pad_width / pad_height first_n = img_metas.get('mixup_params', dict()).get('first_n_labels', len(matches)) lambd = img_metas.get('mixup_params', dict()).get('lambd', 1.) if m < first_n: objectness[index, match, 0] = lambd else: objectness[index, match, 0] = 1. - lambd class_targets[index, match, :] = 0 class_targets[index, match, int(gt_labels[m]) - 1] = 1 return objectness, center_targets, scale_targets, weights, class_targets
def loss_single(self, cls_score, bbox_pred, iou_pred, anchors, labels, label_weights, bbox_targets, bbox_weights, num_total_samples): """Compute loss of a single scale level. Args: cls_score (Tensor): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W). bbox_pred (Tensor): Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W). anchors (Tensor): Box reference for each scale level with shape (N, num_total_anchors, 4). labels (Tensor): Labels of each anchors with shape (N, num_total_anchors). label_weights (Tensor): Label weights of each anchor with shape (N, num_total_anchors) bbox_targets (Tensor): BBox regression targets of each anchor wight shape (N, num_total_anchors, 4). bbox_weights (Tensor): BBox regression loss weights of each anchor with shape (N, num_total_anchors, 4). num_total_samples (int): If sampling, num total samples equal to the number of total anchors; Otherwise, it is the number of positive anchors. Returns: dict[str, Tensor]: A dictionary of loss components. """ # classification loss anchors = anchors.reshape(-1, 4) labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) loss_cls = self.loss_cls( cls_score, labels, label_weights, avg_factor=num_total_samples) # regression loss bbox_targets = bbox_targets.reshape(-1, 4) bbox_weights = bbox_weights.reshape(-1, 4) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) iou_targets = label_weights.new_zeros(labels.shape) iou_weights = label_weights.new_zeros(labels.shape) iou_weights[(bbox_weights.sum(axis=1) > 0).nonzero()] = 1. iou_pred = iou_pred.permute(0, 2, 3, 1).reshape(-1) bg_class_ind = self.num_classes pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) if self.reg_decoded_bbox: anchors = anchors.reshape(-1, 4) bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) loss_bbox = self.loss_bbox( bbox_pred, bbox_targets, bbox_weights, avg_factor=num_total_samples) if len(pos_inds) > 0: # dx, dy, dw, dh pos_bbox_targets = bbox_targets[pos_inds] # tx, ty, tw, th pos_bbox_pred = bbox_pred[pos_inds] # x1, y1, x2, y2 pos_anchors = anchors[pos_inds] if self.reg_decoded_bbox: pos_decode_bbox_pred = pos_bbox_pred gt_bboxes = pos_bbox_targets else: # x1, y1, x2 ,y2 pos_decode_bbox_pred = self.bbox_coder.decode( pos_anchors, pos_bbox_pred) gt_bboxes = self.bbox_coder.decode(pos_anchors, pos_bbox_targets) if self.detach: pos_decode_bbox_pred = pos_decode_bbox_pred.detach() iou_targets[pos_inds] = bbox_overlaps( pos_decode_bbox_pred, gt_bboxes, is_aligned=True) loss_iou = self.loss_iou( iou_pred, iou_targets, iou_weights, avg_factor=num_total_samples) return loss_cls, loss_bbox, loss_iou
def nms_resampling_discrete(self, proposals, scores, ids, gt_bboxes, gt_labels, a_r, a_c, a_f): # proposal is considered as background when its iou with gt < 0.3 select_thresh = 0.3 out = [] # rare, common, frequent = self.get_category_frequency(gt_labels.device) frequent = torch.tensor([0, 3], device=gt_labels.device) common = torch.tensor([1, 4, 9], device=gt_labels.device) rare = torch.tensor([2, 5, 6, 7, 8, 10], device=gt_labels.device) rare_gtbox = torch.zeros((2000, 4), device=gt_labels.device) rare_gtbox_idx = 0 common_gtbox = torch.zeros((2000, 4), device=gt_labels.device) common_gtbox_idx = 0 frequent_gtbox = torch.zeros((2000, 4), device=gt_labels.device) frequent_gtbox_idx = 0 for gt_bbox, gt_label in zip(gt_bboxes, gt_labels): if gt_label in rare: rare_gtbox[rare_gtbox_idx, ...] = gt_bbox rare_gtbox_idx += 1 elif gt_label in common: common_gtbox[common_gtbox_idx, ...] = gt_bbox common_gtbox_idx += 1 else: frequent_gtbox[frequent_gtbox_idx, ...] = gt_bbox frequent_gtbox_idx += 1 rare_gtbox = rare_gtbox[:rare_gtbox_idx, ...] common_gtbox = common_gtbox[:common_gtbox_idx, ...] frequent_proposals, _ = batched_nms( proposals, scores, ids, dict(type='nms', iou_threshold=a_f)) if len(rare_gtbox) > 0: rare_proposals, _ = batched_nms( proposals, scores, ids, dict(type='nms', iou_threshold=a_r)) rare_overlaps = bbox_overlaps(rare_gtbox, rare_proposals[:, :4]) rare_max_overlaps, rare_argmax_overlaps = rare_overlaps.max(dim=0) rare_pos_inds = rare_max_overlaps >= select_thresh rare_proposals = rare_proposals[rare_pos_inds, :] out.append(rare_proposals) frequent_rare_overlaps = bbox_overlaps(rare_gtbox, frequent_proposals[:, :4]) frequent_rare_max_overlaps, frequent_rare_argmax_overlaps = frequent_rare_overlaps.max( dim=0) valid_inds = frequent_rare_max_overlaps < select_thresh frequent_proposals = frequent_proposals[valid_inds, :] if len(common_gtbox) > 0: # keep = self.nms_py(proposals, scores, a_c) common_proposals, _ = batched_nms( proposals, scores, ids, dict(type='nms', iou_threshold=a_c)) common_overlaps = bbox_overlaps(common_gtbox, common_proposals[:, :4]) common_max_overlaps, common_argmax_overlaps = common_overlaps.max( dim=0) common_pos_inds = common_max_overlaps >= select_thresh common_proposals = common_proposals[common_pos_inds, :] out.append(common_proposals) frequent_common_overlaps = bbox_overlaps(common_gtbox, frequent_proposals[:, :4]) frequent_common_max_overlaps, frequent_common_argmax_overlaps = frequent_common_overlaps.max( dim=0) valid_inds = frequent_common_max_overlaps < select_thresh frequent_proposals = frequent_proposals[valid_inds, :] out.append(frequent_proposals) if len(out) > 1: out_proposals = torch.cat(out, 0) else: out_proposals = frequent_proposals return out_proposals