def target_assign_single_img(self, cls_preds, center_priors, decoded_bboxes, gt_bboxes, gt_labels): """Compute classification, regression, and objectness targets for priors in a single image. Args: cls_preds (Tensor): Classification predictions of one image, a 2D-Tensor with shape [num_priors, num_classes] center_priors (Tensor): All priors of one image, a 2D-Tensor with shape [num_priors, 4] in [cx, xy, stride_w, stride_y] format. decoded_bboxes (Tensor): Decoded bboxes predictions of one image, a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format. gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format. gt_labels (Tensor): Ground truth labels of one image, a Tensor with shape [num_gts]. """ num_priors = center_priors.size(0) device = center_priors.device gt_bboxes = torch.from_numpy(gt_bboxes).to(device) gt_labels = torch.from_numpy(gt_labels).to(device) num_gts = gt_labels.size(0) gt_bboxes = gt_bboxes.to(decoded_bboxes.dtype) bbox_targets = torch.zeros_like(center_priors) dist_targets = torch.zeros_like(center_priors) labels = center_priors.new_full((num_priors, ), self.num_classes, dtype=torch.long) label_scores = center_priors.new_zeros(labels.shape, dtype=torch.float) # No target if num_gts == 0: return labels, label_scores, bbox_targets, dist_targets, 0 assign_result = self.assigner.assign(cls_preds.sigmoid(), center_priors, decoded_bboxes, gt_bboxes, gt_labels) pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.sample( assign_result, gt_bboxes) num_pos_per_img = pos_inds.size(0) pos_ious = assign_result.max_overlaps[pos_inds] if len(pos_inds) > 0: bbox_targets[pos_inds, :] = pos_gt_bboxes dist_targets[pos_inds, :] = ( bbox2distance(center_priors[pos_inds, :2], pos_gt_bboxes) / center_priors[pos_inds, None, 2]) dist_targets = dist_targets.clamp(min=0, max=self.reg_max - 0.1) labels[pos_inds] = gt_labels[pos_assigned_gt_inds] label_scores[pos_inds] = pos_ious return ( labels, label_scores, bbox_targets, dist_targets, num_pos_per_img, )
def loss_single(self, grid_cells, cls_score, bbox_pred, labels, label_weights, bbox_targets, stride, num_total_samples): grid_cells = grid_cells.reshape(-1, 4) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4 * (self.reg_max + 1)) bbox_targets = bbox_targets.reshape(-1, 4) labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = torch.nonzero( (labels >= 0) & (labels < bg_class_ind), as_tuple=False).squeeze(1) # 找到标签大于0且标签小于类数目的索引 score = label_weights.new_zeros(labels.shape) if len(pos_inds) > 0: pos_bbox_targets = bbox_targets[pos_inds] pos_bbox_pred = bbox_pred[pos_inds] # (n, 4 * (reg_max + 1)) pos_grid_cells = grid_cells[pos_inds] pos_grid_cell_centers = self.grid_cells_to_center( pos_grid_cells) / stride weight_targets = cls_score.detach().sigmoid() weight_targets = weight_targets.max(dim=1)[0][pos_inds] pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred) pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers, pos_bbox_pred_corners) pos_decode_bbox_targets = pos_bbox_targets / stride score[pos_inds] = bbox_overlaps(pos_decode_bbox_pred.detach(), pos_decode_bbox_targets, is_aligned=True) pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1) target_corners = bbox2distance(pos_grid_cell_centers, pos_decode_bbox_targets, self.reg_max).reshape(-1) # regression loss loss_bbox = self.loss_bbox(pos_decode_bbox_pred, pos_decode_bbox_targets, weight=weight_targets, avg_factor=1.0) # dfl loss loss_dfl = self.loss_dfl(pred_corners, target_corners, weight=weight_targets[:, None].expand( -1, 4).reshape(-1), avg_factor=4.0) else: loss_bbox = bbox_pred.sum() * 0 loss_dfl = bbox_pred.sum() * 0 weight_targets = torch.tensor(0).to(cls_score.device) # qfl loss loss_qfl = self.loss_qfl(cls_score, (labels, score), weight=label_weights, avg_factor=num_total_samples) return loss_qfl, loss_bbox, loss_dfl, weight_targets.sum()