def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes, gt_labels, img_metas, cfg, gt_bboxes_ignore=None): featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == len(self.anchor_generators) device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( featmap_sizes, img_metas, device=device) label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 cls_reg_targets = self.atss_target( anchor_list, valid_flag_list, gt_bboxes, img_metas, cfg, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=label_channels) if cls_reg_targets is None: return None (anchor_list, labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets num_total_samples = reduce_mean( torch.tensor(num_total_pos).cuda()).item() num_total_samples = max(num_total_samples, 1.0) losses_cls, losses_bbox, loss_centerness,\ bbox_avg_factor = multi_apply( self.loss_single, anchor_list, cls_scores, bbox_preds, centernesses, labels_list, label_weights_list, bbox_targets_list, num_total_samples=num_total_samples, cfg=cfg) bbox_avg_factor = sum(bbox_avg_factor) bbox_avg_factor = reduce_mean(bbox_avg_factor).item() losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) return dict( loss_cls=losses_cls, loss_bbox=losses_bbox, loss_centerness=loss_centerness)
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, cfg, gt_bboxes_ignore=None): featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == len(self.anchor_generators) device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors(featmap_sizes, img_metas, device=device) label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 cls_reg_targets = anchor_target(anchor_list, valid_flag_list, gt_bboxes, img_metas, self.target_means, self.target_stds, cfg, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=label_channels, sampling=self.sampling) if cls_reg_targets is None: return None (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets num_total_samples = (num_total_pos + num_total_neg if self.sampling else num_total_pos) losses_cls, losses_bbox = multi_apply( self.loss_single, cls_scores, bbox_preds, labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_samples=num_total_samples, cfg=cfg) return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
def fovea_target(self, gt_bbox_list, gt_label_list, featmap_sizes, points): label_list, bbox_target_list = multi_apply( self.fovea_target_single, gt_bbox_list, gt_label_list, featmap_size_list=featmap_sizes, point_list=points) flatten_labels = [ torch.cat([ labels_level_img.flatten() for labels_level_img in labels_level ]) for labels_level in zip(*label_list) ] flatten_bbox_targets = [ torch.cat([ bbox_targets_level_img.reshape(-1, 4) for bbox_targets_level_img in bbox_targets_level ]) for bbox_targets_level in zip(*bbox_target_list) ] flatten_labels = torch.cat(flatten_labels) flatten_bbox_targets = torch.cat(flatten_bbox_targets) return flatten_labels, flatten_bbox_targets
def fcos_target(self, points, gt_bboxes_list, gt_labels_list): assert len(points) == len(self.regress_ranges) num_levels = len(points) # expand regress ranges to align with points expanded_regress_ranges = [ points[i].new_tensor(self.regress_ranges[i])[None].expand_as( points[i]) for i in range(num_levels) ] # concat all levels points and regress ranges concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) concat_points = torch.cat(points, dim=0) # get labels and bbox_targets of each image labels_list, bbox_targets_list = multi_apply( self.fcos_target_single, gt_bboxes_list, gt_labels_list, points=concat_points, regress_ranges=concat_regress_ranges) # split to per img, per level num_points = [center.size(0) for center in points] labels_list = [labels.split(num_points, 0) for labels in labels_list] bbox_targets_list = [ bbox_targets.split(num_points, 0) for bbox_targets in bbox_targets_list ] # concat per level image concat_lvl_labels = [] concat_lvl_bbox_targets = [] for i in range(num_levels): concat_lvl_labels.append( torch.cat([labels[i] for labels in labels_list])) concat_lvl_bbox_targets.append( torch.cat( [bbox_targets[i] for bbox_targets in bbox_targets_list])) return concat_lvl_labels, concat_lvl_bbox_targets
def forward(self, feats): return multi_apply(self.forward_single, feats)
def loss(self, cls_scores, pts_preds_init, pts_preds_refine, gt_bboxes, gt_labels, img_metas, cfg, gt_bboxes_ignore=None): featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == len(self.point_generators) label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 # target for initial stage center_list, valid_flag_list = self.get_points(featmap_sizes, img_metas) pts_coordinate_preds_init = self.offset_to_pts(center_list, pts_preds_init) if cfg.init.assigner['type'] == 'PointAssigner': # Assign target for center list candidate_list = center_list else: # transform center list to bbox list and # assign target for bbox list bbox_list = self.centers_to_bboxes(center_list) candidate_list = bbox_list cls_reg_targets_init = point_target( candidate_list, valid_flag_list, gt_bboxes, img_metas, cfg.init, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=label_channels, sampling=self.sampling) (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init, num_total_pos_init, num_total_neg_init) = cls_reg_targets_init num_total_samples_init = (num_total_pos_init + num_total_neg_init if self.sampling else num_total_pos_init) # target for refinement stage center_list, valid_flag_list = self.get_points(featmap_sizes, img_metas) pts_coordinate_preds_refine = self.offset_to_pts( center_list, pts_preds_refine) bbox_list = [] for i_img, center in enumerate(center_list): bbox = [] for i_lvl in range(len(pts_preds_refine)): bbox_preds_init = self.points2bbox( pts_preds_init[i_lvl].detach()) bbox_shift = bbox_preds_init * self.point_strides[i_lvl] bbox_center = torch.cat( [center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1) bbox.append(bbox_center + bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4)) bbox_list.append(bbox) cls_reg_targets_refine = point_target( bbox_list, valid_flag_list, gt_bboxes, img_metas, cfg.refine, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=label_channels, sampling=self.sampling) (labels_list, label_weights_list, bbox_gt_list_refine, candidate_list_refine, bbox_weights_list_refine, num_total_pos_refine, num_total_neg_refine) = cls_reg_targets_refine num_total_samples_refine = (num_total_pos_refine + num_total_neg_refine if self.sampling else num_total_pos_refine) # compute loss losses_cls, losses_pts_init, losses_pts_refine = multi_apply( self.loss_single, cls_scores, pts_coordinate_preds_init, pts_coordinate_preds_refine, labels_list, label_weights_list, bbox_gt_list_init, bbox_weights_list_init, bbox_gt_list_refine, bbox_weights_list_refine, self.point_strides, num_total_samples_init=num_total_samples_init, num_total_samples_refine=num_total_samples_refine) loss_dict_all = { 'loss_cls': losses_cls, 'loss_pts_init': losses_pts_init, 'loss_pts_refine': losses_pts_refine } return loss_dict_all
def atss_target(self, anchor_list, valid_flag_list, gt_bboxes_list, img_metas, cfg, gt_bboxes_ignore_list=None, gt_labels_list=None, label_channels=1, unmap_outputs=True): """ almost the same with anchor_target, with a little modification, here we need return the anchor """ num_imgs = len(img_metas) assert len(anchor_list) == len(valid_flag_list) == num_imgs # anchor number of multi levels num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] num_level_anchors_list = [num_level_anchors] * num_imgs # concat all level anchors and flags to a single tensor for i in range(num_imgs): assert len(anchor_list[i]) == len(valid_flag_list[i]) anchor_list[i] = torch.cat(anchor_list[i]) valid_flag_list[i] = torch.cat(valid_flag_list[i]) # compute targets for each image if gt_bboxes_ignore_list is None: gt_bboxes_ignore_list = [None for _ in range(num_imgs)] if gt_labels_list is None: gt_labels_list = [None for _ in range(num_imgs)] (all_anchors, all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply( self.atss_target_single, anchor_list, valid_flag_list, num_level_anchors_list, gt_bboxes_list, gt_bboxes_ignore_list, gt_labels_list, img_metas, cfg=cfg, label_channels=label_channels, unmap_outputs=unmap_outputs) # no valid anchors if any([labels is None for labels in all_labels]): return None # sampled anchors of all images num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) # split targets to a list w.r.t. multiple levels anchors_list = images_to_levels(all_anchors, num_level_anchors) labels_list = images_to_levels(all_labels, num_level_anchors) label_weights_list = images_to_levels(all_label_weights, num_level_anchors) bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors) bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors) return (anchors_list, labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg)
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, cfg, gt_bboxes_ignore=None): featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == len(self.anchor_generators) device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors(featmap_sizes, img_metas, device=device) cls_reg_targets = anchor_target(anchor_list, valid_flag_list, gt_bboxes, img_metas, self.target_means, self.target_stds, cfg, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=1, sampling=False, unmap_outputs=False) if cls_reg_targets is None: return None (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets num_images = len(img_metas) all_cls_scores = torch.cat([ s.permute(0, 2, 3, 1).reshape( num_images, -1, self.cls_out_channels) for s in cls_scores ], 1) all_labels = torch.cat(labels_list, -1).view(num_images, -1) all_label_weights = torch.cat(label_weights_list, -1).view(num_images, -1) all_bbox_preds = torch.cat([ b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) for b in bbox_preds ], -2) all_bbox_targets = torch.cat(bbox_targets_list, -2).view(num_images, -1, 4) all_bbox_weights = torch.cat(bbox_weights_list, -2).view(num_images, -1, 4) # check NaN and Inf assert torch.isfinite(all_cls_scores).all().item(), \ 'classification scores become infinite or NaN!' assert torch.isfinite(all_bbox_preds).all().item(), \ 'bbox predications become infinite or NaN!' losses_cls, losses_bbox = multi_apply(self.loss_single, all_cls_scores, all_bbox_preds, all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, num_total_samples=num_total_pos, cfg=cfg) return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
def loss(self, cls_scores, bbox_preds, shape_preds, loc_preds, gt_bboxes, gt_labels, img_metas, cfg, gt_bboxes_ignore=None): featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == len(self.approx_generators) device = cls_scores[0].device # get loc targets loc_targets, loc_weights, loc_avg_factor = ga_loc_target( gt_bboxes, featmap_sizes, self.octave_base_scale, self.anchor_strides, center_ratio=cfg.center_ratio, ignore_ratio=cfg.ignore_ratio) # get sampled approxes approxs_list, inside_flag_list = self.get_sampled_approxs( featmap_sizes, img_metas, cfg, device=device) # get squares and guided anchors squares_list, guided_anchors_list, _ = self.get_anchors(featmap_sizes, shape_preds, loc_preds, img_metas, device=device) # get shape targets sampling = False if not hasattr(cfg, 'ga_sampler') else True shape_targets = ga_shape_target(approxs_list, inside_flag_list, squares_list, gt_bboxes, img_metas, self.approxs_per_octave, cfg, sampling=sampling) if shape_targets is None: return None (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num, anchor_bg_num) = shape_targets anchor_total_num = (anchor_fg_num if not sampling else anchor_fg_num + anchor_bg_num) # get anchor targets sampling = False if self.cls_focal_loss else True label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 cls_reg_targets = anchor_target(guided_anchors_list, inside_flag_list, gt_bboxes, img_metas, self.target_means, self.target_stds, cfg, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=label_channels, sampling=sampling) if cls_reg_targets is None: return None (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets num_total_samples = (num_total_pos if self.cls_focal_loss else num_total_pos + num_total_neg) # get classification and bbox regression losses losses_cls, losses_bbox = multi_apply( self.loss_single, cls_scores, bbox_preds, labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_samples=num_total_samples, cfg=cfg) # get anchor location loss losses_loc = [] for i in range(len(loc_preds)): loss_loc = self.loss_loc_single(loc_preds[i], loc_targets[i], loc_weights[i], loc_avg_factor=loc_avg_factor, cfg=cfg) losses_loc.append(loss_loc) # get anchor shape loss losses_shape = [] for i in range(len(shape_preds)): loss_shape = self.loss_shape_single( shape_preds[i], bbox_anchors_list[i], bbox_gts_list[i], anchor_weights_list[i], anchor_total_num=anchor_total_num) losses_shape.append(loss_shape) return dict(loss_cls=losses_cls, loss_bbox=losses_bbox, loss_shape=losses_shape, loss_loc=losses_loc)