def show_anchor(input_shape_hw, stride, anchor_generator_cfg, random_n, select_n): img = np.zeros(input_shape_hw, np.uint8) feature_map = [] for s in stride: feature_map.append([input_shape_hw[0] // s, input_shape_hw[1] // s]) anchor_generator = build_anchor_generator(anchor_generator_cfg) anchors = anchor_generator.grid_anchors(feature_map) # 输出原图尺度上anchor坐标 xyxy格式 左上角格式 for _ in range(random_n): disp_img = [] for anchor in anchors: anchor = anchor.cpu().numpy() index = (anchor[:, 0] > 0) & (anchor[:, 1] > 0) & (anchor[:, 2] < input_shape_hw[1]) & \ (anchor[:, 3] < input_shape_hw[0]) anchor = anchor[index] anchor = np.random.permutation(anchor) img_ = cv_core.show_bbox(img, anchor[:select_n], thickness=1, is_show=False) disp_img.append(img_) cv_core.show_img(disp_img, stride)
def _show_save_data(featurevis, img, img_orig, feature_indexs, filepath, is_show, output_dir): show_datas = [] for feature_index in feature_indexs: feature_map = featurevis.run(img.copy(), feature_index=feature_index)[0] data = show_tensor(feature_map[0], resize_hw=img.shape[:2], show_split=False, is_show=False)[0] am_data = cv2.addWeighted(data, 0.5, img_orig, 0.5, 0) show_datas.append(am_data) if is_show: show_img(show_datas) if output_dir is not None: filename = os.path.join(output_dir, Path(filepath).name ) if len(show_datas) == 1: imwrite(show_datas[0], filename) else: for i in range(len(show_datas)): fname, suffix = os.path.splitext(filename) imwrite(show_datas[i], fname + '_{}'.format(str(i)) + suffix)
def _get_target_single(self, flat_anchors, valid_flags, num_level_anchors, gt_bboxes, gt_bboxes_ignore, gt_labels, img_meta, label_channels=1, unmap_outputs=True): """Compute regression, classification targets for anchors in a single image. Args: flat_anchors (Tensor): Multi-level anchors of the image, which are concatenated into a single tensor of shape (num_anchors ,4) valid_flags (Tensor): Multi level valid flags of the image, which are concatenated into a single tensor of shape (num_anchors,). num_level_anchors Tensor): Number of anchors of each scale level. gt_bboxes (Tensor): Ground truth bboxes of the image, shape (num_gts, 4). gt_bboxes_ignore (Tensor): Ground truth bboxes to be ignored, shape (num_ignored_gts, 4). gt_labels (Tensor): Ground truth labels of each box, shape (num_gts,). img_meta (dict): Meta info of the image. label_channels (int): Channel of label. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. Returns: tuple: N is the number of total anchors in the image. labels (Tensor): Labels of all anchors in the image with shape (N,). label_weights (Tensor): Label weights of all anchor in the image with shape (N,). bbox_targets (Tensor): BBox targets of all anchors in the image with shape (N, 4). bbox_weights (Tensor): BBox weights of all anchors in the image with shape (N, 4) pos_inds (Tensor): Indices of postive anchor with shape (num_pos,). neg_inds (Tensor): Indices of negative anchor with shape (num_neg,). """ # pad_shape属性是图片没有经过collate函数右下pad后的图片size,也就是datalayer吐出的图片shape # inside_flags会改变总anchor数目 inside_flags = anchor_inside_flags(flat_anchors, valid_flags, img_meta['img_shape'][:2], self.train_cfg.allowed_border) if not inside_flags.any(): return (None, ) * 7 # assign gt and sample anchors # 要小心这个步骤,可能anchors变少了 anchors = flat_anchors[inside_flags, :] # 调试anchor # imgs = [] # count = 0 # 由于anchors可能变少了,故num_level_anchors参数遍历是不对的 # for num_level in num_level_anchors: # an = flat_anchors[count:count + num_level] # count += num_level # img = show_pos_anchor(img_meta, an, gt_bboxes, is_show=False) # imgs.append(img) # cv_core.show_img(imgs) num_level_anchors_inside = self.get_num_level_anchors_inside( num_level_anchors, inside_flags) assign_result = self.assigner.assign(anchors, num_level_anchors_inside, gt_bboxes, gt_bboxes_ignore, gt_labels) sampling_result = self.sampler.sample(assign_result, anchors, gt_bboxes) # 正样本可视化 if self.debug: gt_inds = assign_result.gt_inds # 0 1 -1 正负忽略样本标志 index = gt_inds > 0 gt_inds = gt_inds[index] print('单张图片中正样本anchor个数', len(gt_inds)) imgs = [] count = 0 for num_level in num_level_anchors_inside: # 注意要用num_level_anchors_inside,而不是num_level_anchors gt_inds = assign_result.gt_inds[count:count + num_level] anchor = anchors[count:count + num_level] count += num_level index = gt_inds > 0 gt_anchor = anchor[index] img = show_pos_anchor(img_meta, gt_anchor, gt_bboxes, is_show=False) imgs.append(img) cv_core.show_img(imgs) num_valid_anchors = anchors.shape[0] bbox_targets = torch.zeros_like(anchors) bbox_weights = torch.zeros_like(anchors) labels = anchors.new_full((num_valid_anchors, ), self.num_classes, dtype=torch.long) label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds if len(pos_inds) > 0: pos_bbox_targets = self.bbox_coder.encode( sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) bbox_targets[pos_inds, :] = pos_bbox_targets bbox_weights[pos_inds, :] = 1.0 if gt_labels is None: # Only rpn gives gt_labels as None # Foreground is the first class since v2.5.0 labels[pos_inds] = 0 else: labels[pos_inds] = gt_labels[ sampling_result.pos_assigned_gt_inds] if self.train_cfg.pos_weight <= 0: label_weights[pos_inds] = 1.0 else: label_weights[pos_inds] = self.train_cfg.pos_weight if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 # map up to original set of anchors if unmap_outputs: num_total_anchors = flat_anchors.size(0) anchors = unmap(anchors, num_total_anchors, inside_flags) labels = unmap(labels, num_total_anchors, inside_flags, fill=self.num_classes) label_weights = unmap(label_weights, num_total_anchors, inside_flags) bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) return (anchors, labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds)
def _get_targets_single(self, anchors, responsible_flags, gt_bboxes, gt_labels, img_meta, num_level_anchors=None): """Generate matching bounding box prior and converted GT. Args: anchors (list[Tensor]): Multi-level anchors of the image. responsible_flags (list[Tensor]): Multi-level responsible flags of anchors gt_bboxes (Tensor): Ground truth bboxes of single image. gt_labels (Tensor): Ground truth labels of single image. Returns: tuple: target_map (Tensor): Predication target map of each scale level, shape (num_total_anchors, 5+num_classes) neg_map (Tensor): Negative map of each scale level, shape (num_total_anchors,) """ anchor_strides = [] for i in range(len(anchors)): anchor_strides.append( torch.tensor(self.featmap_strides[i], device=gt_bboxes.device).repeat(len(anchors[i]))) concat_anchors = torch.cat(anchors) # 三个输出层的anchor合并 concat_responsible_flags = torch.cat(responsible_flags) anchor_strides = torch.cat(anchor_strides) assert len(anchor_strides) == len(concat_anchors) == \ len(concat_responsible_flags) assign_result = self.assigner.assign(concat_anchors, concat_responsible_flags, gt_bboxes) if self.debug: # 统计下正样本个数 print('----anchor分配正负样本后,正样本anchor可视化,白色bbox是gt----') gt_inds = assign_result.gt_inds # 0 1 -1 index = gt_inds > 0 gt_inds = gt_inds[index] gt_anchors = concat_anchors[index] print('单张图片中正样本anchor个数', len(gt_inds)) if num_level_anchors is None: # 不分层显示 show_pos_anchor(img_meta, gt_anchors, gt_bboxes) else: imgs = [] count = 0 for num_level in num_level_anchors: gt_inds = assign_result.gt_inds[count:count + num_level] anchor = concat_anchors[count:count + num_level] count += num_level index = gt_inds > 0 gt_anchor = anchor[index] img = show_pos_anchor(img_meta, gt_anchor, gt_bboxes, is_show=False) imgs.append(img) print('从小特征图到大特征图顺序显示') cv_core.show_img(imgs) # 相当于没有,只是为了不报错 sampling_result = self.sampler.sample(assign_result, concat_anchors, gt_bboxes) # 转化为最终计算Loss所需要的格式 target_map = concat_anchors.new_zeros(concat_anchors.size(0), self.num_attrib) # 5+class_count # 正样本位置anchor bbox;对应的gt bbox;strides # target_map前4个是xywh在特征图尺度上面的转化后的label target_map[sampling_result.pos_inds, :4] = self.bbox_coder.encode( sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes, anchor_strides[sampling_result.pos_inds]) target_map[sampling_result.pos_inds, 4] = 1 # confidence label gt_labels_one_hot = F.one_hot(gt_labels, num_classes=self.num_classes).float() if self.one_hot_smoother != 0: # label smooth gt_labels_one_hot = gt_labels_one_hot * ( 1 - self.one_hot_smoother ) + self.one_hot_smoother / self.num_classes target_map[sampling_result.pos_inds, 5:] = gt_labels_one_hot[ sampling_result.pos_assigned_gt_inds] # class one hot label neg_map = concat_anchors.new_zeros(concat_anchors.size(0), dtype=torch.uint8) neg_map[sampling_result.neg_inds] = 1 return target_map, neg_map
def _get_targets_single(self, flat_anchors, valid_flags, gt_bboxes, gt_bboxes_ignore, gt_labels, img_meta, label_channels=1, unmap_outputs=True, num_level_anchors=None): """Compute regression and classification targets for anchors in a single image. Args: flat_anchors (Tensor): Multi-level anchors of the image, which are concatenated into a single tensor of shape (num_anchors ,4) valid_flags (Tensor): Multi level valid flags of the image, which are concatenated into a single tensor of shape (num_anchors,). gt_bboxes (Tensor): Ground truth bboxes of the image, shape (num_gts, 4). img_meta (dict): Meta info of the image. gt_bboxes_ignore (Tensor): Ground truth bboxes to be ignored, shape (num_ignored_gts, 4). img_meta (dict): Meta info of the image. gt_labels (Tensor): Ground truth labels of each box, shape (num_gts,). label_channels (int): Channel of label. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. Returns: tuple: labels_list (list[Tensor]): Labels of each level label_weights_list (list[Tensor]): Label weights of each level bbox_targets_list (list[Tensor]): BBox targets of each level bbox_weights_list (list[Tensor]): BBox weights of each level num_total_pos (int): Number of positive samples in all images num_total_neg (int): Number of negative samples in all images """ # # 默认self.train_cfg.allowed_border=-1 也就是仅仅看valid_flags就可以,功能重复 inside_flags = anchor_inside_flags(flat_anchors, valid_flags, img_meta['img_shape'][:2], self.train_cfg.allowed_border) if not inside_flags.any(): return (None,) * 7 # assign gt and sample anchors anchors = flat_anchors[inside_flags, :] # 核心类,基于anchor和gt bbox,进行正负样本anchor分配 assign_result = self.assigner.assign( anchors, gt_bboxes, gt_bboxes_ignore, None if self.sampling else gt_labels) if self.debug: # 统计下正样本个数 print('----anchor分配正负样本后,正样本anchor可视化,白色bbox是gt----') gt_inds = assign_result.gt_inds # 0 1 -1 正负忽略样本标志 index = gt_inds > 0 gt_inds = gt_inds[index] gt_anchors = anchors[index] print('单张图片中正样本anchor个数', len(gt_inds)) if num_level_anchors is None: # 不分层显示 show_pos_anchor(img_meta, gt_anchors, gt_bboxes) else: imgs = [] count = 0 for num_level in num_level_anchors: gt_inds = assign_result.gt_inds[count:count + num_level] anchor = anchors[count:count + num_level] count += num_level index = gt_inds > 0 gt_anchor = anchor[index] img = show_pos_anchor(img_meta, gt_anchor, gt_bboxes, is_show=False) imgs.append(img) print('从大特征图到小特征图顺序显示') cv_core.show_img(imgs) # 正负样本采样,默认是伪随机,也就是相当于没有 sampling_result = self.sampler.sample(assign_result, anchors, gt_bboxes) num_valid_anchors = anchors.shape[0] # 这里计算的仅仅是有效区域anchor bbox_targets = torch.zeros_like(anchors) bbox_weights = torch.zeros_like(anchors) labels = anchors.new_full((num_valid_anchors,), self.background_label, dtype=torch.long) # 正负样本区域权重为1,其余位置全部设置为0 label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds if len(pos_inds) > 0: # 是否需要对gt进行编码,主要是区分l1 loss和iou类 loss # l1 loss需要编解码,iou loss不需要 if not self.reg_decoded_bbox: pos_bbox_targets = self.bbox_coder.encode( sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) else: pos_bbox_targets = sampling_result.pos_gt_bboxes bbox_targets[pos_inds, :] = pos_bbox_targets bbox_weights[pos_inds, :] = 1.0 if gt_labels is None: # only rpn gives gt_labels as None, this time FG is 1 labels[pos_inds] = 1 else: labels[pos_inds] = gt_labels[ sampling_result.pos_assigned_gt_inds] # labels里面存储的,0-num_class表示正样本对应的class,num_class表示背景 if self.train_cfg.pos_weight <= 0: label_weights[pos_inds] = 1.0 else: label_weights[pos_inds] = self.train_cfg.pos_weight if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 # 为了方便计算,还原到最原始anchor个数,这个操作也非常关键,特别是label_weights和bbox_weights # map up to original set of anchors if unmap_outputs: num_total_anchors = flat_anchors.size(0) labels = unmap( labels, num_total_anchors, inside_flags, fill=self.background_label) # fill bg label label_weights = unmap(label_weights, num_total_anchors, inside_flags) bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds, sampling_result)
def loss(self, cls_scores, bbox_preds, iou_preds, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): """Compute losses of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W) iou_preds (list[Tensor]): iou_preds for each scale level with shape (N, num_anchors * 1, H, W) gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (list[Tensor] | None): Specify which bounding boxes can be ignored when are computing the loss. Returns: dict[str, Tensor]: A dictionary of loss gmm_assignment. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.anchor_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors(featmap_sizes, img_metas, device=device) label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 # 第一轮正负样本定义,由于阈值很低,非常多正样本 cls_reg_targets = self.get_targets( anchor_list, valid_flag_list, gt_bboxes, img_metas, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=label_channels, ) (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds, pos_gt_index) = cls_reg_targets cls_scores = levels_to_images(cls_scores) cls_scores = [ item.reshape(-1, self.cls_out_channels) for item in cls_scores ] bbox_preds = levels_to_images(bbox_preds) bbox_preds = [item.reshape(-1, 4) for item in bbox_preds] iou_preds = levels_to_images(iou_preds) iou_preds = [item.reshape(-1, 1) for item in iou_preds] # 第二轮正负样本定义 pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list, cls_scores, bbox_preds, labels, labels_weight, bboxes_target, bboxes_weight, pos_inds) with torch.no_grad(): labels, label_weights, bbox_weights, num_pos = multi_apply( self.paa_reassign, pos_losses_list, labels, labels_weight, bboxes_weight, pos_inds, pos_gt_index, anchor_list, ) num_pos = sum(num_pos) # 正样本可视化 if self.debug: for i in range(len(anchor_list)): # 遍历图片数 anchors = anchor_list[i] img_meta = img_metas[i] label = labels[i] gt_bbox = gt_bboxes[i] imgs = [] count = 0 for j in range(len(anchors)): # 遍历每个输出层 anchor = anchors[j] label_ = label[count:count + len(anchor)] count += len(anchor) index = (label_ >= 0) & (label_ < self.num_classes) gt_anchor = anchor[index] img = atss_head.show_pos_anchor(img_meta, gt_anchor, gt_bbox, is_show=False) imgs.append(img) cv_core.show_img(imgs) # convert all tensor list to a flatten tensor cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1)) iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1)) labels = torch.cat(labels, 0).view(-1) flatten_anchors = torch.cat( [torch.cat(item, 0) for item in anchor_list]) labels_weight = torch.cat(labels_weight, 0).view(-1) bboxes_target = torch.cat(bboxes_target, 0).view(-1, bboxes_target[0].size(-1)) pos_inds_flatten = ((labels >= 0) & (labels < self.num_classes)).nonzero().reshape(-1) losses_cls = self.loss_cls(cls_scores, labels, labels_weight, avg_factor=max(num_pos, len(img_metas))) if num_pos: pos_bbox_pred = self.bbox_coder.decode( flatten_anchors[pos_inds_flatten], bbox_preds[pos_inds_flatten]) pos_bbox_target = bboxes_target[pos_inds_flatten] iou_target = bbox_overlaps(pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True) losses_iou = self.loss_centerness( # iou预测分支 iou_preds[pos_inds_flatten], iou_target.unsqueeze(-1), avg_factor=num_pos) losses_bbox = self.loss_bbox(pos_bbox_pred, pos_bbox_target, iou_target.clamp(min=eps), avg_factor=iou_target.sum()) else: losses_iou = iou_preds.sum() * 0 losses_bbox = bbox_preds.sum() * 0 return dict(loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou)
def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): '''''' """Compute loss of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_points * num_classes. [batchsize,80,H_i,W_i] bbox_preds (list[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_points * 4. [batchsize,4,H_i,W_i] centernesses (list[Tensor]): Centerss for each scale level, each is a 4D-tensor, the channel number is num_points * 1. [batchsize,1,H_i,W_i] gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. [batchsize][num_obj,4] gt_labels (list[Tensor]): class indices corresponding to each box [batchsize][num_obj] img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. [batchsize][(dict)dict_keys(['filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'img_norm_cfg'])] gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ assert len(cls_scores) == len(bbox_preds) == len(centernesses) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] # P3-P7特征图的大小 ''' [ torch.Size([100, 152]), torch.Size([50, 76]), torch.Size([25, 38]), torch.Size([13, 19]), torch.Size([7, 10]) ] 特征图的大小就相当于把原图分为多大的grid,特征图每个像素映射到原图就是该grid的中心点,不同大小的特征图就有不同的grid # bbox_preds[0].dtype:torch.float32 # all_level_points:(list) [5][n_points][2] ''' all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes, gt_labels) if self.debug: is_upsample = False # 是否采用上采样可视化模式 circle_ratio = [2, 4, 6, 8, 8] # 可视化正样本区域 # 遍历每一张图片 batch_size = cls_scores[0].shape[0] for i in range(batch_size): gt_bbox = gt_bboxes[i] img_meta = img_metas[i] img = img_meta['img'].data.numpy() mean = img_meta['img_norm_cfg']['mean'] std = img_meta['img_norm_cfg']['std'] # # 输入是bgr数据 img = np.transpose(img.copy(), (1, 2, 0)) img = img * std.reshape([1, 1, 3]) + mean.reshape([1, 1, 3]) img = img.astype(np.uint8) img = cv_core.show_bbox(img, gt_bbox.cpu().numpy(), is_show=False, thickness=2, color=(255, 255, 255)) # 遍历每一个输出层 disp_img = [] for j in range(len(labels)): bbox_target = bbox_targets[j] bbox_target = bbox_target[i * len(bbox_target) // batch_size:(i + 1) * len(bbox_target) // batch_size] * self.strides[j] label = labels[j] level_label = label[i * len(label) // batch_size:(i + 1) * len(label) // batch_size] print(bbox_target[(level_label >= 0) & (level_label < self.num_classes)]) level_label = level_label.view(featmap_sizes[j][0], featmap_sizes[j][1]) # 得到正样本位置 pos_mask = (level_label >= 0) & (level_label < self.num_classes) if pos_mask.is_cuda: pos_mask = pos_mask.cpu() pos_mask = pos_mask.data.numpy() # 特征图还原到原图尺寸 if is_upsample: # 先利用stride还原 pos_mask = (pos_mask * 255).astype(np.uint8) pos_mask = cv_core.imrescale(pos_mask, self.strides[j], interpolation="nearest") # 然后裁剪 resize_pos_mask = pos_mask[0:img.shape[0], 0:img.shape[1]] # 转化bgr,方便add resize_pos_mask = cv_core.gray2bgr(resize_pos_mask) img_ = cv2.addWeighted(img, 0.5, resize_pos_mask, 0.5, 0) else: pos_mask = pos_mask.astype(np.uint8) index = pos_mask.nonzero() index_yx = np.stack(index, axis=1) # 还原到原图尺度 pos_img_yx = index_yx * self.strides[j] + self.strides[ j] // 2 img_ = img.copy() # 圆形模式 for z in range(pos_img_yx.shape[0]): point = (int(pos_img_yx[z, 1]), int(pos_img_yx[z, 0])) cv2.circle(img_, point, 1, (0, 0, 255), circle_ratio[j]) # 点模式 # img_[pos_img_yx[:, 0], pos_img_yx[:, 1], :] = (0, 255, 0) disp_img.append(img_) cv_core.show_img(disp_img) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_centerness = torch.cat(flatten_centerness) flatten_labels = torch.cat(labels) flatten_bbox_targets = torch.cat(bbox_targets) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) num_pos = len(pos_inds) loss_cls = self.loss_cls(flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_centerness = flatten_centerness[pos_inds] if num_pos > 0: pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_centerness_targets = self.centerness_target(pos_bbox_targets) pos_points = flatten_points[pos_inds] pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) # centerness weighted iou loss loss_bbox = self.loss_bbox(pos_decoded_bbox_preds, pos_decoded_target_preds, weight=pos_centerness_targets, avg_factor=pos_centerness_targets.sum()) loss_centerness = self.loss_centerness(pos_centerness, pos_centerness_targets) else: loss_bbox = pos_bbox_preds.sum() loss_centerness = pos_centerness.sum() return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_centerness=loss_centerness)