def random_crop(image, boxes, labels, difficulties, choices=[0., .1, .3, .5, .7, .9, None]): """ Performs a random crop in the manner stated in the paper. Helps to learn to detect larger and partial objects. Note that some objects may be cut out entirely. Adapted from https://github.com/amdegroot/ssd.pytorch/blob/master/utils/augmentations.py :param image: image, a tensor of dimensions (3, original_h, original_w) :param boxes: bounding boxes in boundary coordinates, a tensor of dimensions (n_objects, 4) :param labels: labels of objects, a tensor of dimensions (n_objects) :param difficulties: difficulties of detection of these objects, a tensor of dimensions (n_objects) :return: cropped image, updated bounding box coordinates, updated labels, updated difficulties """ original_h = image.size(0) original_w = image.size(1) # Keep choosing a minimum overlap until a successful crop is made while True: # Randomly draw the value for minimum overlap min_overlap = random.choice(choices) # 'None' refers to no cropping # If not cropping if min_overlap is None: return image, boxes, labels, difficulties # Try up to 50 times for this choice of minimum overlap # This isn't mentioned in the paper, of course, but 50 is chosen in paper authors' original Caffe repo max_trials = 50 for _ in range(max_trials): # Crop dimensions must be in [0.3, 1] of original dimensions # Note - it's [0.1, 1] in the paper, but actually [0.3, 1] in the authors' repo min_scale = 0.3 scale_h = random.uniform(min_scale, 1) scale_w = random.uniform(min_scale, 1) new_h = int(scale_h * original_h) new_w = int(scale_w * original_w) # Aspect ratio has to be in [0.5, 2] aspect_ratio = new_h / new_w if not 0.5 < aspect_ratio < 2: continue # Crop coordinates (origin at top-left of image) left = random.randint(0, original_w - new_w) right = left + new_w top = random.randint(0, original_h - new_h) bottom = top + new_h crop = torch.FloatTensor([left, top, right, bottom]) # (4) # Calculate Jaccard overlap between the crop and the bounding boxes overlap = box_iou( crop.unsqueeze(0), boxes ) # (1, n_objects), n_objects is the no. of objects in this image overlap = overlap.squeeze(0) # (n_objects) # If not a single bounding box has a Jaccard overlap of greater than the minimum, try again if overlap.max().item() < min_overlap: continue # Crop image new_image = image[:, top:bottom, left:right] # (3, new_h, new_w) # new_image = image.crop((left, top, right, bottom)) # new_image = image[:, top:bottom, left:right] # (3, new_h, new_w) # Find centers of original bounding boxes bb_centers = (boxes[:, :2] + boxes[:, 2:]) / 2. # (n_objects, 2) # Find bounding boxes whose centers are in the crop centers_in_crop = (bb_centers[:, 0] > left) * ( bb_centers[:, 0] < right ) * (bb_centers[:, 1] > top) * ( bb_centers[:, 1] < bottom ) # (n_objects), a Torch uInt8/Byte tensor, can be used as a boolean index # If not a single bounding box has its center in the crop, try again if not centers_in_crop.any(): continue # Discard bounding boxes that don't meet this criterion new_boxes = boxes[centers_in_crop, :] new_labels = labels[centers_in_crop] new_difficulties = difficulties[centers_in_crop] # Calculate bounding boxes' new coordinates in the crop new_boxes[:, :2] = torch.max(new_boxes[:, :2], crop[:2]) # crop[:2] is [left, top] new_boxes[:, :2] -= crop[:2] new_boxes[:, 2:] = torch.min(new_boxes[:, 2:], crop[2:]) # crop[2:] is [right, bottom] new_boxes[:, 2:] -= crop[:2] return new_image, new_boxes, new_labels, new_difficulties
def iou_check(box, expected, tolerance=1e-4): out = ops.box_iou(box, box) torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - box_iou(inputs, target).diagonal()
def calc_detection_voc_prec_rec(pred_bboxes, pred_labels, pred_scores, gt_bboxes, gt_labels, gt_difficulties=None, iou_thresh=0.5): pred_bboxes = iter(pred_bboxes) pred_labels = iter(pred_labels) pred_scores = iter(pred_scores) gt_bboxes = iter(gt_bboxes) gt_labels = iter(gt_labels) if gt_difficulties is None: gt_difficulties = itertools.repeat(None) else: gt_difficulties = iter(gt_difficulties) n_pos = defaultdict(int) score = defaultdict(list) match = defaultdict(list) for pred_bbox, pred_label, pred_score, gt_bbox, gt_label, gt_difficult in zip( pred_bboxes, pred_labels, pred_scores, gt_bboxes, gt_labels, gt_difficulties): if gt_difficult is None: gt_difficult = np.zeros(gt_bbox.shape[0], dtype=bool) for l in np.unique(np.concatenate((pred_label, gt_label)).astype(int)): pred_mask_l = pred_label == l pred_bbox_l = pred_bbox[pred_mask_l] pred_score_l = pred_score[pred_mask_l] # sort by score order = pred_score_l.argsort()[::-1] pred_bbox_l = pred_bbox_l[order] pred_score_l = pred_score_l[order] gt_mask_l = gt_label == l gt_bbox_l = gt_bbox[gt_mask_l] gt_difficult_l = gt_difficult[gt_mask_l] n_pos[l] += np.logical_not(gt_difficult_l).sum() score[l].extend(pred_score_l) if len(pred_bbox_l) == 0: continue if len(gt_bbox_l) == 0: match[l].extend((0, ) * pred_bbox_l.shape[0]) continue # VOC evaluation follows integer typed bounding boxes. pred_bbox_l = pred_bbox_l.copy() pred_bbox_l[:, 2:] += 1 gt_bbox_l = gt_bbox_l.copy() gt_bbox_l[:, 2:] += 1 iou = box_iou(pred_bbox_l, gt_bbox_l) gt_index = iou.argmax(axis=1) # set -1 if there is no matching ground truth gt_index[iou.max(axis=1) < iou_thresh] = -1 del iou selec = np.zeros(gt_bbox_l.shape[0], dtype=bool) for gt_idx in gt_index: if gt_idx >= 0: if gt_difficult_l[gt_idx]: match[l].append(-1) else: if not selec[gt_idx]: match[l].append(1) else: match[l].append(0) selec[gt_idx] = True else: match[l].append(0) for iter_ in (pred_bboxes, pred_labels, pred_scores, gt_bboxes, gt_labels, gt_difficulties): if next(iter_, None) is not None: raise ValueError('Length of input iterables need to be same.') n_fg_class = max(n_pos.keys()) + 1 prec = [None] * n_fg_class rec = [None] * n_fg_class for l in n_pos.keys(): score_l = np.array(score[l]) match_l = np.array(match[l], dtype=np.int8) order = score_l.argsort()[::-1] match_l = match_l[order] tp = np.cumsum(match_l == 1) fp = np.cumsum(match_l == 0) # If an element of fp + tp is 0, # the corresponding element of prec[l] is nan. prec[l] = tp / (fp + tp) # If n_pos[l] is 0, rec[l] is None. if n_pos[l] > 0: rec[l] = tp / n_pos[l] return prec, rec
def iou_check(box, expected, tolerance=1e-4): out = ops.box_iou(box, box) assert out.size() == expected.size() assert ((out - expected).abs().max() < tolerance).item()
def forward(self, features, labels=None, gt_bboxes=None): """ :param features: OrderDict. The shape of each item is (BS, C_i, H_i, W_i) :param labels: shape (BS, n_objs) :param gt_bboxes: shape (BS, n_objs, 4) :return: """ if self.training: pre_nms_top_n = self.pre_nms_top_n_in_train post_nms_top_n = self.post_nms_top_n_in_train else: pre_nms_top_n = self.pre_nms_top_n_in_test post_nms_top_n = self.post_nms_top_n_in_test total_anchors = [] total_cls_pred = [] total_reg_pred = [] total_cls_scores = [] total_reg_bboxes = [] for i, feat in enumerate(features.values()): x = F.relu(self.conv(feat)) cls_pred = self.cls(x) # (BS, num_anchors, H, W) reg_pred = self.reg(x) # (BS, num_anchors*4, H, W) BS, num_anchors, H, W = cls_pred.shape # (BS, H, W, num_anchors) cls_pred = cls_pred.permute(0, 2, 3, 1) # (BS, H, W, num_anchors, 4) reg_pred = reg_pred.permute(0, 2, 3, 1).reshape( (BS, H, W, num_anchors, 4)) # (H, W, num_anchors, 4) anchors = self._buffers["anchor%i" % i] # (BS, H, W, num_anchors) -> (BS, H*W*num_anchors) cls_pred = cls_pred.reshape((BS, -1)) # (BS, H, W, num_anchors, 4) -> (BS, H*W*num_anchors, 4) reg_pred = reg_pred.reshape((BS, -1, 4)) # (H, W, num_anchors, 4) -> (H*W*num_anchors, 4) anchors = anchors.reshape((-1, 4)) total_anchors.append(anchors) total_cls_pred.append(cls_pred) total_reg_pred.append(reg_pred) with torch.no_grad(): # 修正anchors reg_bboxes = self.box_coder.decode(anchors, reg_pred.detach()) reg_bboxes[..., 0].clamp_(0, self.image_size[0]) reg_bboxes[..., 1].clamp_(0, self.image_size[1]) reg_bboxes[..., 2].clamp_(0, self.image_size[0]) reg_bboxes[..., 3].clamp_(0, self.image_size[1]) # 计算分数 cls_scores = torch.sigmoid(cls_pred.detach()) if not self.nms_per_layer: total_cls_scores.append(cls_scores) total_reg_bboxes.append(reg_bboxes) else: # NMS per layer BS = cls_scores.shape[0] keep_bboxes = [] keep_scores = [] for i in range(BS): dtype = reg_bboxes.dtype device = reg_bboxes.device _bboxes = torch.full( (post_nms_top_n // len(features), 4), -1, dtype=dtype, device=device) _scores = torch.full( (post_nms_top_n // len(features), ), -1, dtype=cls_scores.dtype, device=cls_scores.device) pre_nms_top_n_indices = torch.argsort(cls_scores[i], descending=True) _num_anchors = pre_nms_top_n_indices.shape[0] _pre_nms_top_n = pre_nms_top_n // len( features) if _num_anchors > pre_nms_top_n // len( features) else _num_anchors pre_nms_top_n_indices = pre_nms_top_n_indices[: _pre_nms_top_n] _reg_bboxes = reg_bboxes[i][pre_nms_top_n_indices] _cls_scores = cls_scores[i][pre_nms_top_n_indices] keep = ops.nms(_reg_bboxes, _cls_scores, self.nms_thresh) # keep = ops.nms(reg_bboxes[i], cls_scores[i], self.nms_thresh) n_keep = keep.shape[0] n_keep = min(n_keep, post_nms_top_n // len(features)) keep = keep[:n_keep] _bboxes[:n_keep] = _reg_bboxes[keep] _scores[:n_keep] = _cls_scores[keep] keep_bboxes.append(_bboxes) keep_scores.append(_scores) total_reg_bboxes.append(torch.stack(keep_bboxes)) total_cls_scores.append(torch.stack(keep_scores)) # (-1, 4) anchors = torch.cat(total_anchors, dim=0) # (BS, -1) cls_pred = torch.cat(total_cls_pred, dim=1) # (BS, -1, 4) reg_pred = torch.cat(total_reg_pred, dim=1) if not self.nms_per_layer: # (BS, -1) cls_scores = torch.cat(total_cls_scores, dim=1) # (BS, -1, 4) reg_bboxes = torch.cat(total_reg_bboxes, dim=1) # NMS BS = cls_pred.shape[0] keep_bboxes = [] for i in range(BS): dtype = reg_bboxes.dtype device = reg_bboxes.device _bboxes = torch.full((post_nms_top_n, 4), -1, dtype=dtype, device=device) pre_nms_top_n_indices = torch.argsort(cls_scores[i], descending=True) _num_anchors = pre_nms_top_n_indices.shape[0] _pre_nms_top_n = pre_nms_top_n if _num_anchors > pre_nms_top_n else _num_anchors pre_nms_top_n_indices = pre_nms_top_n_indices[:_pre_nms_top_n] _reg_bboxes = reg_bboxes[i][pre_nms_top_n_indices] keep = ops.nms(_reg_bboxes, cls_scores[i][pre_nms_top_n_indices], self.nms_thresh) # keep = ops.nms(reg_bboxes[i], cls_scores[i], self.nms_thresh) n_keep = keep.shape[0] n_keep = min(n_keep, post_nms_top_n) keep = keep[:n_keep] _bboxes[:n_keep] = _reg_bboxes[keep] keep_bboxes.append(_bboxes) bboxes = torch.stack(keep_bboxes) # (BS, post_nms_top_n, 4) else: # (BS, post_nms_top_n) cls_scores = torch.cat(total_cls_scores, dim=1) # (BS, post_nms_top_n) bboxes = torch.cat(total_reg_bboxes, dim=1) # 根据scores大小对bboxes(rois)进行降序排序 # 可能的原因: # rois的顺序会影响rcnn,比如在rcnn的nms时,rcnn更倾向于选择前面的rois # rois的rpn_scores高,rcnn_scores的分数也高 # 当rois的分数差不多时,位于前面的rois会在nms时抑制后面的rois for i in range(cls_scores.shape[0]): sorted_indices = torch.argsort(cls_scores[i], descending=True) bboxes[i] = bboxes[i][sorted_indices] if self.training: total_cls_pred = [] total_reg_pred = [] total_reg_target = [] total_fg_bg_mask = [] all_cls_pred = [] all_fg_bg_mask = [] BS = gt_bboxes.shape[0] for i in range(BS): # 为每个anchor分配label areas = ops.boxes.box_area(anchors) ious = ops.box_iou( anchors, gt_bboxes[i] ) # (num_total_anchors, num_gt_bboxes) (N, M) for short # 把nan换成0 zero_mask = (areas == 0).reshape(-1, 1).expand_as(ious) ious[zero_mask] = 0 if torch.any(torch.isnan(ious)): raise Exception("some elements in ious is nan") # the anchor/anchors with the highest Intersection-over-Union (IoU) # overlap with a ground-truth box iou_max_gt, indices = torch.max(ious, dim=0) # 不考虑gt_bboxes中填充的部分 iou_max_gt = torch.where(labels[i] == -1, torch.ones_like(iou_max_gt), iou_max_gt) highest_mask = (ious == iou_max_gt) fg_mask = torch.any(highest_mask, dim=1) # an anchor that has an IoU overlap higher than fg_iou_thresh with any ground-truth box iou_max, matched_idx = torch.max(ious, dim=1) # 1 for foreground -1 for background 0 for ignore fg_bg_mask = torch.zeros_like(iou_max) # confirm positive samples fg_bg_mask = torch.where(iou_max >= self.fg_iou_thresh, torch.ones_like(iou_max), fg_bg_mask) fg_bg_mask = torch.where(fg_mask, torch.ones_like(iou_max), fg_bg_mask) # confirm negetive samples fg_bg_mask = torch.where(iou_max <= self.bg_iou_thresh, torch.full_like(iou_max, -1), fg_bg_mask) all_cls_pred.append(cls_pred[i].detach()) all_fg_bg_mask.append(fg_bg_mask.detach()) # 随机采样 indices = torch.arange(fg_bg_mask.shape[0], dtype=torch.int64, device=fg_bg_mask.device) rand_indices = torch.rand_like(fg_bg_mask).argsort() fg_bg_mask = fg_bg_mask[rand_indices] # 打乱顺序,实现“随机” indices = indices[rand_indices] sorted_indices = fg_bg_mask.argsort(descending=True) fg_bg_mask = fg_bg_mask[sorted_indices] indices = indices[sorted_indices] fg_indices = indices[:self.num_pos] fg_mask = fg_bg_mask[:self.num_pos] bg_indices = indices[-self.num_neg:] bg_mask = fg_bg_mask[-self.num_neg:] indices = torch.cat([fg_indices, bg_indices], dim=0) fg_bg_mask = torch.cat([fg_mask, bg_mask], dim=0) matched_idx = matched_idx[indices] _anchors = anchors[indices] total_cls_pred.append(cls_pred[i][indices]) total_reg_pred.append(reg_pred[i][indices]) total_fg_bg_mask.append(fg_bg_mask) total_reg_target.append( self.box_coder.encode(_anchors, gt_bboxes[i][matched_idx])) # from lib import debug # debug.rpn_pos_bboxes.append(_anchors[fg_bg_mask == 1]) # print(cls_pred[i][indices][fg_bg_mask == 1].detach().cpu().numpy()) # (BS, num_samples) cls_pred = torch.stack(total_cls_pred) # (BS, num_samples, 4) reg_pred = torch.stack(total_reg_pred) # (BS, num_samples) fg_bg_mask = torch.stack(total_fg_bg_mask) # (BS, num_samples, 4) reg_target = torch.stack(total_reg_target) cls_label = torch.where(fg_bg_mask == 1, torch.ones_like(cls_pred), torch.zeros_like(cls_pred)) cls_loss = F.binary_cross_entropy_with_logits( cls_pred[fg_bg_mask != 0], cls_label[fg_bg_mask != 0]) if torch.any(torch.isnan(reg_target[fg_bg_mask == 1])): raise Exception("some elements in reg_target is nan") if torch.any(torch.isnan(reg_pred[fg_bg_mask == 1])): raise Exception("some elements in reg_pred is nan") if torch.any(fg_bg_mask == 1): reg_loss = F.smooth_l1_loss(reg_pred[fg_bg_mask == 1], reg_target[fg_bg_mask == 1]) else: # 没有正样本 reg_loss = torch.zeros_like(cls_loss) cls_pred = cls_pred >= 0.5 cls_label = cls_label == 1 acc = torch.mean( (cls_label == cls_pred)[fg_bg_mask != 0].to(torch.float)) num_pos = (fg_bg_mask == 1).sum() num_neg = (fg_bg_mask == -1).sum() TP = (cls_pred == True)[fg_bg_mask == 1].sum().to(torch.float32) FP = (cls_pred == True)[fg_bg_mask == -1].sum().to(torch.float32) # TN = (cls_pred == False)[fg_bg_mask == -1].sum() FN = (cls_pred == False)[fg_bg_mask == 1].sum().to(torch.float32) precision = TP / (TP + FP) recall = TP / (TP + FN) all_cls_pred = torch.stack(all_cls_pred) all_fg_bg_mask = torch.stack(all_fg_bg_mask) all_cls_pred = all_cls_pred >= 0 all_TP = (all_cls_pred == True)[all_fg_bg_mask == 1].sum().to( torch.float32) all_FP = (all_cls_pred == True)[all_fg_bg_mask == -1].sum().to( torch.float32) all_FN = (all_cls_pred == False)[all_fg_bg_mask == 1].sum().to( torch.float32) all_precision = all_TP / (all_TP + all_FP) all_recall = all_TP / (all_TP + all_FN) if self.logger is not None: # print("TP {} FP {} FN {}".format(TP.detach().cpu().item(), FP.detach().cpu().item(), FN.detach().cpu().item())) # print("all_TP {} all_FP {} all_FN {}".format(all_TP.detach().cpu().item(), all_FP.detach().cpu().item(), all_FN.detach().cpu().item())) # print("precision {} recall {} all_precision {} all_recall {}".format(precision.detach().cpu().item(), # recall.detach().cpu().item(), # all_precision.detach().cpu().item(), # all_recall.detach().cpu().item())) self.logger.add_scalar("rpn/TP", TP.detach().cpu().item()) self.logger.add_scalar("rpn/FP", FP.detach().cpu().item()) self.logger.add_scalar("rpn/FN", FN.detach().cpu().item()) self.logger.add_scalar("rpn/all_TP", all_TP.detach().cpu().item()) self.logger.add_scalar("rpn/all_FP", all_FP.detach().cpu().item()) self.logger.add_scalar("rpn/all_FN", all_FN.detach().cpu().item()) self.logger.add_scalar("rpn/acc", acc.detach().cpu().item()) self.logger.add_scalar("rpn/num_pos", num_pos.detach().cpu().item()) self.logger.add_scalar("rpn/num_neg", num_neg.detach().cpu().item()) self.logger.add_scalar("rpn/precision", precision.detach().cpu().item()) self.logger.add_scalar("rpn/recall", recall.detach().cpu().item()) self.logger.add_scalar("rpn/all_precision", all_precision.detach().cpu().item()) self.logger.add_scalar("rpn/all_recall", all_recall.detach().cpu().item()) return bboxes, cls_loss, reg_loss return bboxes, None, None
def _evaluate_iou(self, preds, targets): """Evaluate intersection over union (IOU) for target from dataset and output prediction from model.""" # no box detected, 0 IOU if preds["boxes"].shape[0] == 0: return torch.tensor(0.0, device=preds["boxes"].device) return box_iou(preds["boxes"], targets["boxes"]).diag().mean()
def loss(self, outputs: tuple, gt_bboxes: list, gt_labels: list, iou_thresh: float = 0.5) -> dict: """ 損失関数 Args: outputs (tuple): (予測オフセット, 予測存在率, 予測信頼度) * 予測オフセット : (B, P, 4) (coord fmt: [Δx, Δy, Δw, Δh]) (P: PBoxの数. P = 10647 の想定.) * 予測存在率 : (B, P) * 予測信頼度 : (B, P, num_classes) gt_bboxes (list): 正解BBOX座標 [(G1, 4), (G2, 4), ...] (coord fmt: [x, y, w, h]) gt_labels (list): 正解ラベル [(G1,), (G2,)] iou_thresh (float): Potitive / Negative を判定する際の iou の閾値 Returns: dict: { loss: xxx, loss_loc: xxx, loss_obj: xxx, loss_conf: xxx } """ out_locs, out_objs, out_confs = outputs device = out_locs.device # [Step 1] # target を作成する # - Pred を GT に対応させる # - Grid 内に (x, y) が含まれ、BBox との IoU が最大となる Prior Box -> その BBox に割り当てる # - 最大 IoU が 0.5 以上かつ GT に対応しない場合、 Label を -1 に設定する (ignore 対象とする) # - 最大 IoU が 0.5 未満の場合、Label を 0 に設定する B, P, C = out_confs.size() target_locs = torch.zeros(B, P, 4, device=device) target_labels = torch.zeros(B, P, dtype=torch.long, device=device) pboxes, grid_length = self.pboxes.to(device).split(4, dim=1) for i in range(B): bboxes = gt_bboxes[i].to(device) labels = gt_labels[i].to(device) is_in_grid = (pboxes[:, [0]] <= bboxes[:, 0]) * (bboxes[:, 0] < pboxes[:, [0]] + grid_length) * \ (pboxes[:, [1]] <= bboxes[:, 1]) * (bboxes[:, 1] < pboxes[:, [1]] + grid_length) bboxes_xyxy = box_convert(bboxes, in_fmt='xywh', out_fmt='xyxy') pboxes_xyxy = box_convert(pboxes, in_fmt='xywh', out_fmt='xyxy') ious = box_iou(pboxes_xyxy, bboxes_xyxy) best_ious, best_pbox_ids = (ious * is_in_grid).max(dim=0) max_ious, matched_bbox_ids = ious.max(dim=1) # 各 BBox に対し最大 IoU を取る Prior Box を選ぶ -> その BBox に割り当てる for j in range(len(best_pbox_ids)): matched_bbox_ids[best_pbox_ids][j] = j max_ious[best_pbox_ids] = 1. bboxes = bboxes[matched_bbox_ids] locs = self._calc_delta(bboxes, pboxes, grid_length) labels = labels[matched_bbox_ids] labels[max_ious < 1.] = -1 # void クラス labels[max_ious.less( iou_thresh)] = 0 # 0 が背景クラス. Positive Class は 1 ~ target_locs[i] = locs target_labels[i] = labels # [Step 2] # pos_mask, neg_mask を作成する # - pos_mask: Label が > 0 のもの # - neg_mask: label が = 0 のもの pos_mask = target_labels > 0 neg_mask = target_labels == 0 N = pos_mask.sum() # [Step 2] # Positive に対して、 Localization Loss を計算する loss_loc = (F.binary_cross_entropy_with_logits( out_locs[pos_mask][..., :2], target_locs[pos_mask][..., :2], reduction='sum') + F.mse_loss(out_locs[pos_mask][..., 2:], target_locs[pos_mask][..., 2:], reduction='sum')) / N # [Step 3] # Positive に対して、Confidence Loss を計算する loss_conf = F.binary_cross_entropy_with_logits( out_confs[pos_mask], F.one_hot(target_labels[pos_mask] - 1, num_classes=self.nc).float(), reduction='sum') / N # [Step 4] # Positive & Negative に対して、 Objectness Loss を計算する loss_obj = F.binary_cross_entropy_with_logits( out_objs[pos_mask + neg_mask], pos_mask[pos_mask + neg_mask].float(), reduction='sum') / N # [Step 5] # 損失の和を計算する loss = loss_loc + loss_obj + loss_conf return { 'loss': loss, 'loss_loc': loss_loc, 'loss_conf': loss_conf, 'loss_obj': loss_obj }
def set_ignoring(self, noobj_mask, inference, targets, head_anchors, head_size): """ Args: head_anchors: anchors of this head """ batch_size = len(targets) head_h, head_w = head_size # cx, cy, w, h x = (1 + self.EGS_factor) * torch.sigmoid( inference[..., 0]) - 0.5 * self.EGS_factor y = (1 + self.EGS_factor) * torch.sigmoid( inference[..., 1]) - 0.5 * self.EGS_factor w = inference[..., 2] h = inference[..., 3] # set device FloatTensor = torch.cuda.FloatTensor if self.device == 'cuda' else torch.FloatTensor # generate coordinate grids grid_x = torch.linspace(0, head_w - 1, head_w) grid_x = grid_x.repeat(head_h, 1).repeat(batch_size * self.n_head_anchors, 1, 1) grid_x = grid_x.view(x.shape).type(FloatTensor) grid_y = torch.linspace(0, head_h - 1, head_h) grid_y = grid_y.repeat(head_w, 1).t().repeat(batch_size * self.n_head_anchors, 1, 1) grid_y = grid_y.view(y.shape).type(FloatTensor) # generate anchors for coordinate grids anchor_w = FloatTensor(head_anchors)[:, 0].unsqueeze(1) anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, head_h * head_w).view(w.shape) anchor_h = FloatTensor(head_anchors)[:, 1].unsqueeze(1) anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, head_h * head_w).view(h.shape) # calculate bboxes infer_boxes = FloatTensor(inference[..., :4].shape) infer_boxes[..., 0] = x.data + grid_x infer_boxes[..., 1] = y.data + grid_y infer_boxes[..., 2] = torch.exp(w.data) * anchor_w infer_boxes[..., 3] = torch.exp(h.data) * anchor_h for bs, tar in enumerate(targets): if len(tar) == 0: continue # (num_anchors, 4) ignored_boxes = infer_boxes[bs].view(-1, 4) # groundtruth on this head size gt_x = tar[:, 0:1] * head_w gt_y = tar[:, 1:2] * head_h gt_w = tar[:, 2:3] * head_w gt_h = tar[:, 3:4] * head_h gt_box = FloatTensor(torch.cat([gt_x, gt_y, gt_w, gt_h], -1)) # calculate IoU iou_metrix = ops.box_iou(gt_box, ignored_boxes) # get nearest groundtruth for each anchor max_iou, _ = torch.max(iou_metrix, dim=0) max_iou = max_iou.view(infer_boxes[bs].shape[:3]) # turn off noobj_mask according to threshold noobj_mask[bs][max_iou > self.ignoring_threshold] = 0 return noobj_mask, infer_boxes
def calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties): """ Calculate the Mean Average Precision (mAP) of detected objects. See https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173 for an explanation :param det_boxes: list of tensors, one tensor for each image containing detected objects' bounding boxes :param det_labels: list of tensors, one tensor for each image containing detected objects' labels :param det_scores: list of tensors, one tensor for each image containing detected objects' labels' scores :param true_boxes: list of tensors, one tensor for each image containing actual objects' bounding boxes :param true_labels: list of tensors, one tensor for each image containing actual objects' labels :param true_difficulties: list of tensors, one tensor for each image containing actual objects' difficulty (0 or 1) :return: list of average precisions for all classes, mean average precision (mAP) """ assert len(det_boxes) == len( det_labels ) == len(det_scores) == len(true_boxes) == len(true_labels) == len( true_difficulties ) # these are all lists of tensors of the same length, i.e. number of images n_classes = len(label_map) # Store all (true) objects in a single continuous tensor while keeping track of the image it is from true_images = list() for i in range(len(true_labels)): true_images.extend([i] * true_labels[i].size(0)) device = det_boxes[0].device true_images = torch.LongTensor(true_images).to( device ) # (n_objects), n_objects is the total no. of objects across all images true_boxes = torch.cat(true_boxes, dim=0) # (n_objects, 4) true_labels = torch.cat(true_labels, dim=0) # (n_objects) true_difficulties = torch.cat(true_difficulties, dim=0) # (n_objects) assert true_images.size(0) == true_boxes.size(0) == true_labels.size(0) # Store all detections in a single continuous tensor while keeping track of the image it is from det_images = list() for i in range(len(det_labels)): det_images.extend([i] * det_labels[i].size(0)) det_images = torch.LongTensor(det_images).to(device) # (n_detections) det_boxes = torch.cat(det_boxes, dim=0) # (n_detections, 4) det_labels = torch.cat(det_labels, dim=0) # (n_detections) det_scores = torch.cat(det_scores, dim=0) # (n_detections) assert det_images.size(0) == det_boxes.size(0) == det_labels.size( 0) == det_scores.size(0) # Calculate APs for each class (except background) average_precisions = torch.zeros((n_classes - 1), dtype=torch.float) # (n_classes - 1) for c in range(1, n_classes): # Extract only objects with this class true_class_images = true_images[true_labels == c] # (n_class_objects) true_class_boxes = true_boxes[true_labels == c] # (n_class_objects, 4) true_class_difficulties = true_difficulties[true_labels == c] # (n_class_objects) n_easy_class_objects = ( ~true_class_difficulties).sum().item() # ignore difficult objects # Keep track of which true objects with this class have already been 'detected' # So far, none true_class_boxes_detected = torch.zeros( (true_class_difficulties.size(0)), dtype=torch.uint8).to(device) # (n_class_objects) # Extract only detections with this class det_class_images = det_images[det_labels == c] # (n_class_detections) det_class_boxes = det_boxes[det_labels == c] # (n_class_detections, 4) det_class_scores = det_scores[det_labels == c] # (n_class_detections) n_class_detections = det_class_boxes.size(0) if n_class_detections == 0: continue # Sort detections in decreasing order of confidence/scores det_class_scores, sort_ind = torch.sort( det_class_scores, dim=0, descending=True) # (n_class_detections) det_class_images = det_class_images[sort_ind] # (n_class_detections) det_class_boxes = det_class_boxes[sort_ind] # (n_class_detections, 4) # In the order of decreasing scores, check if true or false positive true_positives = torch.zeros( (n_class_detections), dtype=torch.float).to(device) # (n_class_detections) false_positives = torch.zeros( (n_class_detections), dtype=torch.float).to(device) # (n_class_detections) for d in range(n_class_detections): this_detection_box = det_class_boxes[d].unsqueeze(0) # (1, 4) this_image = det_class_images[d] # (), scalar # Find objects in the same image with this class, their difficulties, and whether they have been detected before object_boxes = true_class_boxes[ true_class_images == this_image] # (n_class_objects_in_img) object_difficulties = true_class_difficulties[ true_class_images == this_image] # (n_class_objects_in_img) # If no such object in this image, then the detection is a false positive if object_boxes.size(0) == 0: false_positives[d] = 1 continue # Find maximum overlap of this detection with objects in this image of this class overlaps = box_iou(this_detection_box, object_boxes) # (1, n_class_objects_in_img) max_overlap, ind = torch.max(overlaps.squeeze(0), dim=0) # (), () - scalars # 'ind' is the index of the object in these image-level tensors 'object_boxes', 'object_difficulties' # In the original class-level tensors 'true_class_boxes', etc., 'ind' corresponds to object with index... original_ind = torch.LongTensor(range(true_class_boxes.size(0)))[ true_class_images == this_image][ind] # We need 'original_ind' to update 'true_class_boxes_detected' # If the maximum overlap is greater than the threshold of 0.5, it's a match if max_overlap.item() > 0.5: # If the object it matched with is 'difficult', ignore it if not object_difficulties[ind]: # If this object has already not been detected, it's a true positive if true_class_boxes_detected[original_ind] == 0: true_positives[d] = 1 true_class_boxes_detected[ original_ind] = 1 # this object has now been detected/accounted for # Otherwise, it's a false positive (since this object is already accounted for) else: false_positives[d] = 1 # Otherwise, the detection occurs in a different location than the actual object, and is a false positive else: false_positives[d] = 1 # Compute cumulative precision and recall at each detection in the order of decreasing scores cumul_true_positives = torch.cumsum(true_positives, dim=0) # (n_class_detections) cumul_false_positives = torch.cumsum(false_positives, dim=0) # (n_class_detections) cumul_precision = cumul_true_positives / ( cumul_true_positives + cumul_false_positives + 1e-10 ) # (n_class_detections) cumul_recall = cumul_true_positives / n_easy_class_objects # (n_class_detections) # Find the mean of the maximum of the precisions corresponding to recalls above the threshold 't' recall_thresholds = torch.arange(start=0, end=1.1, step=.1).tolist() # (11) precisions = torch.zeros((len(recall_thresholds)), dtype=torch.float).to(device) # (11) for i, t in enumerate(recall_thresholds): recalls_above_t = cumul_recall >= t if recalls_above_t.any(): precisions[i] = cumul_precision[recalls_above_t].max() else: precisions[i] = 0. average_precisions[c - 1] = precisions.mean() # c is in [1, n_classes - 1] # Calculate Mean Average Precision (mAP) mean_average_precision = average_precisions.mean().item() # Keep class-wise average precisions in a dictionary average_precisions = { rev_label_map[c + 1]: v for c, v in enumerate(average_precisions.tolist()) } return average_precisions, mean_average_precision
def _iou_boxes(self, other: 'BoundingBoxes') -> Tensor: sz = other[0].sz assert len(self) == len(other) a = self.to_tensor(sz).cpu().unsqueeze(1) b = other.to_tensor(sz).cpu().unsqueeze(1) return torch.cat([box_iou(i, j) for i, j in zip(a, b)]).squeeze()
def _iou_box(self, other: 'BoundingBox') -> Tensor: a = self.to_tensor(other.sz) b = other.x[None].to(a.device) return box_iou(a, b).squeeze(-1)
def iou(self, other: Union['BoundingBox', 'BoundingBoxes']) -> Rank0Tensor: if isinstance(other, BoundingBoxes): return other.iou(self) a = self.x[None] b = other.get_resized_x(self.sz)[None].to(a.device) return box_iou(a, b).item()
def forward(self, images, labels=None, gt_bboxes=None): """ :param images: shape (BS, C, H, W) :param labels: shape (BS, n_objs) :param gt_bboxes: shape (BS, n_objs, 4) :return: """ feats = self.backbone(images) # rois shape (BS, num_rois, 4) rois, rpn_cls_loss, rpn_reg_loss = self.rpn(feats, labels, gt_bboxes) # rois[..., 0].clamp_(0, self.image_size[0]) # rois[..., 1].clamp_(0, self.image_size[1]) # rois[..., 2].clamp_(0, self.image_size[0]) # rois[..., 3].clamp_(0, self.image_size[1]) # from lib import debug # debug.rois.append(rois) if self.training: # 把gt bboxes加入到rois中 rois = torch.cat([rois, gt_bboxes], dim=1) # rois 添加batch_id维 BS, num_rois, _ = rois.shape batch_id = torch.stack( [torch.full_like(rois[i, :, :1], i) for i in range(BS)], dim=0) # (BS, num_rois, 5) rois = torch.cat([batch_id, rois], dim=2) # (BS*num_rois, 5) rois = rois.reshape((-1, 5)) # roi pooling in each feature map if self.roi_pooling == "roi_align": roi_pooling = ops.roi_align elif self.roi_pooling == "roi_pool": roi_pooling = ops.roi_pool else: raise Exception("{} is not support".format(self.roi_pooling)) if len(feats) == 1: _, feat = feats.popitem() roi_feats = roi_pooling( feat, rois, (self.roi_pooling_output_size, self.roi_pooling_output_size), 1 / self.strides[0]) else: feat_levels = np.log2(self.strides).astype(np.int64) feat_names = [n for n in feats.keys()] assert len(feat_levels) == len(feat_names) w = rois[:, 3] - rois[:, 1] h = rois[:, 4] - rois[:, 2] roi_levels = torch.floor( 4 + torch.log2(torch.sqrt(w * h) / 224 + 1e-6)) _f = feats[feat_names[0]] C = _f.shape[1] device = _f.device dtype = _f.dtype roi_feats = torch.zeros( (BS * num_rois, C, self.roi_pooling_output_size, self.roi_pooling_output_size), dtype=dtype, device=device) for i, (feat_level, feat_name) in enumerate(zip(feat_levels, feat_names)): mask_in_level = roi_levels == feat_level _roi_feats = roi_pooling(feats[feat_name], rois[mask_in_level], (self.roi_pooling_output_size, self.roi_pooling_output_size), 1 / self.strides[i]) roi_feats[mask_in_level] = _roi_feats # roi_feats shape (BS*num_rois, C, self.roi_pooling_output_size, self.roi_pooling_output_size) # roi head # (BS*num_rois, num_vector) box_feats = self.roi_head(roi_feats) # (BS*num_rois, num_classes+1) cls_pred = self.cls(box_feats) # (BS*num_rois, num_classes*4) reg_pred = self.reg(box_feats) # (BS, num_rois, num_classes+1) cls_pred = cls_pred.reshape((BS, num_rois, -1)) # (BS, num_rois, num_classes, 4) reg_pred = reg_pred.reshape((BS, num_rois, -1, 4)) if self.training: # (BS, num_rois, 5) rois = rois.reshape((BS, num_rois, 5)) rois = rois[:, :, 1:] total_cls_pred = [] total_reg_pred = [] total_fg_bg_mask = [] total_labels = [] total_reg_target = [] for i in range(BS): # 为每个roi分配label areas = ops.boxes.box_area(rois[i]) ignore_mask = areas == 0 # (num_rois, num_gt_bboxes) ious = ops.box_iou(rois[i], gt_bboxes[i]) # rois中有box面积为0,比如(-1,-1,-1,-1),导致ious中出现nan # 把nan换成0 zero_mask = (areas == 0).reshape(-1, 1).expand_as(ious) ious[zero_mask] = 0 if torch.any(torch.isnan(ious)): raise Exception("some elements in ious is nan") ############################################################# # 统计rois是否覆盖所有gt,gt的召回率 num_gt = labels.shape[1] _ious_withou_gt = ious[:-num_gt] # 去掉rois中的gt _ious_max_withou_gt, _ = torch.max(_ious_withou_gt, dim=0) gt_recall = (_ious_max_withou_gt >= 0.5)[labels[i] != -1].to( torch.float32).mean() if self.logger is not None: self.logger.add_scalar("rcnn/gt_recall_0.5", gt_recall.detach().cpu().item()) gt_recall = (_ious_max_withou_gt >= 0.7)[labels[i] != -1].to( torch.float32).mean() if self.logger is not None: self.logger.add_scalar("rcnn/gt_recall_0.7", gt_recall.detach().cpu().item()) ############################################################# # the roi/rois with the highest Intersection-over-Union (IoU) # overlap with a ground-truth box iou_max_gt, _ = torch.max(ious, dim=0) # 不考虑gt_bboxes中填充的部分 iou_max_gt = torch.where(labels[i] == -1, torch.ones_like(iou_max_gt), iou_max_gt) highest_mask = (ious == iou_max_gt) fg_mask = torch.any(highest_mask, dim=1) # a roi that has an IoU overlap higher than fg_iou_thresh with any ground-truth box iou_max, matched_idx = torch.max(ious, dim=1) # 1 for foreground -1 for background 0 for ignore fg_bg_mask = torch.zeros_like(iou_max) # confirm positive samples fg_bg_mask = torch.where(iou_max >= self.fg_iou_thresh, torch.ones_like(iou_max), fg_bg_mask) fg_bg_mask = torch.where(fg_mask, torch.ones_like(iou_max), fg_bg_mask) # confirm negetive samples fg_bg_mask = torch.where(iou_max <= self.bg_iou_thresh, torch.full_like(iou_max, -1), fg_bg_mask) # ignore samples fg_bg_mask = torch.where(ignore_mask, torch.zeros_like(iou_max), fg_bg_mask) # 随机采样 indices = torch.arange(fg_bg_mask.shape[0], dtype=torch.int64, device=fg_bg_mask.device) rand_indices = torch.rand_like(fg_bg_mask).argsort() fg_bg_mask = fg_bg_mask[rand_indices] # 打乱顺序,实现“随机” indices = indices[rand_indices] sorted_indices = fg_bg_mask.argsort(descending=True) fg_bg_mask = fg_bg_mask[sorted_indices] indices = indices[sorted_indices] fg_indices = indices[:self.num_pos] fg_mask = fg_bg_mask[:self.num_pos] bg_indices = indices[-self.num_neg:] bg_mask = fg_bg_mask[-self.num_neg:] indices = torch.cat([fg_indices, bg_indices], dim=0) fg_bg_mask = torch.cat([fg_mask, bg_mask], dim=0) matched_idx = matched_idx[indices] # (num_samples) # label 暂时不考虑background label = labels[i][matched_idx] # 把标签-1变成0,F.one_hot不支持负数 _label = label.clone() _label[_label == -1] = 0 # (num_samples, num_classes) label_mask = F.one_hot(_label, self.num_classes) # (num_samples*num_classes,) label_mask = label_mask.reshape(-1).to(torch.bool) _reg_pred = reg_pred[i][indices].reshape(-1, 4) # (num_samples, 4) _reg_pred = _reg_pred[label_mask] _rois = rois[i][indices] total_cls_pred.append(cls_pred[i][indices]) total_reg_pred.append(_reg_pred) total_fg_bg_mask.append(fg_bg_mask) total_labels.append(label) total_reg_target.append( self.box_coder.encode(_rois, gt_bboxes[i][matched_idx])) # from lib import debug # debug.rcnn_pos_bboxes.append(self.box_coder.decode(_rois, _reg_pred)[fg_bg_mask == 1]) # (BS, num_samples, num_classes+1) cls_pred = torch.stack(total_cls_pred) # (BS, num_samples, 4) reg_pred = torch.stack(total_reg_pred) # (BS, num_samples) fg_bg_mask = torch.stack(total_fg_bg_mask) # (BS, num_samples) labels = torch.stack(total_labels) # (BS, num_samples, 4) reg_target = torch.stack(total_reg_target) if torch.any(torch.isnan(reg_target[fg_bg_mask == 1])): raise Exception("some elements in reg_target is nan") if torch.any(torch.isinf(reg_target[fg_bg_mask == 1])): raise Exception("some elements in reg_target is inf") assert torch.any(fg_bg_mask == 1) # 把gt加入到rois中,不可能没有正样本 rcnn_reg_loss = F.smooth_l1_loss(reg_pred[fg_bg_mask == 1], reg_target[fg_bg_mask == 1]) # if torch.any(fg_bg_mask == 1): # rcnn_reg_loss = F.smooth_l1_loss(reg_pred[fg_bg_mask == 1], reg_target[fg_bg_mask == 1]) # else: # 没有正样本 # rcnn_reg_loss = torch.zeros([], dtype=reg_pred.dtype, device=reg_pred.device) cls_label = labels + 1 # 所有类别id+1,为background空出0 cls_label = cls_label.reshape((-1, )) cls_pred = cls_pred.reshape((-1, self.num_classes + 1)) fg_bg_mask = fg_bg_mask.reshape(-1, ) # 设置background的label为0 cls_label = torch.where(fg_bg_mask == -1, torch.zeros_like(cls_label), cls_label) rcnn_cls_loss = F.cross_entropy(cls_pred[fg_bg_mask != 0], cls_label[fg_bg_mask != 0]) cls_pred = torch.argmax(cls_pred, dim=1) acc = torch.mean( (cls_pred == cls_label)[fg_bg_mask != 0].to(torch.float)) num_pos = (fg_bg_mask == 1).sum() num_neg = (fg_bg_mask == -1).sum() if self.logger is not None: self.logger.add_scalar("rcnn/acc", acc.detach().cpu().item()) self.logger.add_scalar("rcnn/num_pos", num_pos.detach().cpu().item()) self.logger.add_scalar("rcnn/num_neg", num_neg.detach().cpu().item()) return rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss cls_scores = F.softmax(cls_pred, dim=2) # (BS, num_rois, num_classes) cls_scores = cls_scores[:, :, 1:] # from lib import debug # debug.rois_scores.append(cls_scores) # rois: (BS*num_rois, 5) # reg_pred: (BS, num_rois, num_classes, 4) # _reg_pred: (num_classes, BS*num_rois, 4) _reg_pred = reg_pred.permute( (2, 0, 1, 3)).reshape(self.num_classes, BS * num_rois, 4) # (num_classes, BS*num_rois, 4) reg_bboxes = self.box_coder.decode(rois[:, 1:], _reg_pred) reg_bboxes[..., 0].clamp_(0, self.image_size[0]) reg_bboxes[..., 1].clamp_(0, self.image_size[1]) reg_bboxes[..., 2].clamp_(0, self.image_size[0]) reg_bboxes[..., 3].clamp_(0, self.image_size[1]) # (BS, num_rois, num_classes, 4) reg_bboxes = reg_bboxes.permute((1, 0, 2)).reshape( (BS, num_rois, self.num_classes, 4)) # (num_rois, num_classes) classes_id = torch.cat( [ # (num_rois, 1) torch.full_like(cls_scores[0, :, :1], i) for i in range(self.num_classes) ], dim=1) # (num_rois*num_classes) classes_id = classes_id.reshape((-1, )) # (BS, num_rois*num_classes) cls_scores = cls_scores.reshape((BS, -1)) # (BS, num_rois*num_classes, 4) reg_bboxes = reg_bboxes.reshape((BS, -1, 4)) scores = [] bboxes = [] labels = [] for i in range(BS): _scores = torch.full((self.max_objs_per_image, ), -1, dtype=cls_scores.dtype, device=cls_scores.device) _labels = torch.full((self.max_objs_per_image, ), -1, dtype=classes_id.dtype, device=classes_id.device) _bboxes = torch.full((self.max_objs_per_image, 4), -1, dtype=reg_bboxes.dtype, device=reg_bboxes.device) keep_mask = cls_scores[i] >= self.obj_thresh _reg_bboxes = reg_bboxes[i][keep_mask] _cls_scores = cls_scores[i][keep_mask] _classes_id = classes_id[keep_mask] keep = ops.boxes.batched_nms(_reg_bboxes, _cls_scores, _classes_id, self.nms_thresh) n_keep = keep.shape[0] n_keep = min(n_keep, self.max_objs_per_image) keep = keep[:n_keep] _scores[:n_keep] = _cls_scores[keep] _labels[:n_keep] = _classes_id[keep] _bboxes[:n_keep] = _reg_bboxes[keep] scores.append(_scores) labels.append(_labels) bboxes.append(_bboxes) scores = torch.stack(scores) # (BS, max_objs) labels = torch.stack(labels) # (BS, max_objs) bboxes = torch.stack(bboxes) # (BS, max_objs, 4) return scores, labels, bboxes
def _get_graph_centers(boxes, cls_prob, im_labels): """Get graph centers.""" num_images, num_classes = im_labels.shape assert num_images == 1, 'batch size shoud be equal to 1' dev = cls_prob.device gt_boxes = torch.zeros((0, 4), dtype=boxes.dtype, device=dev) gt_classes = torch.zeros((0, 1), dtype=torch.long, device=dev) gt_scores = torch.zeros((0, 1), dtype=cls_prob.dtype, device=dev) for i in im_labels.nonzero()[:, 1]: cls_prob_tmp = cls_prob[:, i] idxs = (cls_prob_tmp >= 0).nonzero()[:, 0] idxs_tmp = _get_top_ranking_propoals(cls_prob_tmp[idxs].reshape(-1, 1)) idxs = idxs[idxs_tmp] boxes_tmp = boxes[idxs, :] cls_prob_tmp = cls_prob_tmp[idxs] graph = (ops.box_iou(boxes_tmp, boxes_tmp) > 0.4).float() keep_idxs = [] gt_scores_tmp = [] count = cls_prob_tmp.size(0) while True: order = graph.sum(dim=1).argsort(descending=True) tmp = order[0] keep_idxs.append(tmp) inds = (graph[tmp, :] > 0).nonzero()[:, 0] gt_scores_tmp.append(cls_prob_tmp[inds].max()) graph[:, inds] = 0 graph[inds, :] = 0 count = count - len(inds) if count <= 5: break gt_boxes_tmp = boxes_tmp[keep_idxs, :].view(-1, 4).to(dev) gt_scores_tmp = torch.tensor(gt_scores_tmp, device=dev) keep_idxs_new = torch.from_numpy( (gt_scores_tmp.argsort().to('cpu').numpy()[-1:( -1 - min(len(gt_scores_tmp), 5)):-1]).copy()).to(dev) gt_boxes = torch.cat((gt_boxes, gt_boxes_tmp[keep_idxs_new, :])) gt_scores = torch.cat( (gt_scores, gt_scores_tmp[keep_idxs_new].reshape(-1, 1))) gt_classes = torch.cat((gt_classes, (i + 1) * torch.ones( (len(keep_idxs_new), 1), dtype=torch.long, device=dev))) # If a proposal is chosen as a cluster center, # we simply delete a proposal from the candidata proposal pool, # because we found that the results of different strategies are similar and this strategy is more efficient another_tmp = idxs.to('cpu')[torch.tensor(keep_idxs)][keep_idxs_new.to( 'cpu')].numpy() cls_prob = torch.from_numpy( np.delete(cls_prob.to('cpu').numpy(), another_tmp, axis=0)).to(dev) boxes = torch.from_numpy( np.delete(boxes.to('cpu').numpy(), another_tmp, axis=0)).to(dev) proposals = { 'gt_boxes': gt_boxes.to(dev), 'gt_classes': gt_classes.to(dev), 'gt_scores': gt_scores.to(dev) } return proposals