def aux_loss(pts_labels, center_targets, point_cls, point_reg): # 附加任务的损失函数 r''' pts_labels = sample_seg_label, B * 512 center_targets = sample_seg_offset, B * 512 * 3 point_cls = estimation_seg, B * 512 point_reg = estimation_offset, B * 512 * 3 ''' N = len(pts_labels) rpn_cls_target = pts_labels.float() pos = (pts_labels > 0).float() neg = (pts_labels == 0).float() pos_normalizer = pos.sum() pos_normalizer = torch.clamp(pos_normalizer, min=1.0) cls_weights = pos + neg cls_weights = cls_weights / pos_normalizer reg_weights = pos reg_weights = reg_weights / pos_normalizer aux_loss_cls = weighted_sigmoid_focal_loss( point_cls, rpn_cls_target, weight=cls_weights, avg_factor=1.) aux_loss_cls /= N aux_loss_reg = weighted_smoothl1( point_reg, center_targets, beta=1 / 9., weight=reg_weights[..., None], avg_factor=1.) aux_loss_reg /= N return aux_loss_cls, aux_loss_reg
def create_loss(self, box_preds, cls_preds, cls_targets, cls_weights, reg_targets, reg_weights, num_class, use_sigmoid_cls=True, encode_rad_error_by_sin=True, box_code_size=7): batch_size = int(box_preds.shape[0]) box_preds = box_preds.view(batch_size, -1, box_code_size) if use_sigmoid_cls: cls_preds = cls_preds.view(batch_size, -1, num_class) else: cls_preds = cls_preds.view(batch_size, -1, num_class + 1) one_hot_targets = one_hot( cls_targets, depth=num_class + 1, dtype=box_preds.dtype) if use_sigmoid_cls: one_hot_targets = one_hot_targets[..., 1:] if encode_rad_error_by_sin: # sin(a - b) = sinacosb-cosasinb box_preds, reg_targets = self.add_sin_difference(box_preds, reg_targets) loc_losses = weighted_smoothl1(box_preds, reg_targets, beta=1 / 9., \ weight=reg_weights[..., None], avg_factor=1.) cls_losses = weighted_sigmoid_focal_loss(cls_preds, one_hot_targets, \ weight=cls_weights[..., None], avg_factor=1.) return loc_losses, cls_losses
def aux_loss(points, point_cls, point_reg, gt_bboxes): N = len(gt_bboxes) pts_labels, center_targets = build_aux_target(points[:, :4], gt_bboxes) rpn_cls_target = pts_labels.float() pos = (pts_labels > 0).float() neg = (pts_labels == 0).float() pos_normalizer = pos.sum() pos_normalizer = torch.clamp(pos_normalizer, min=1.0) cls_weights = pos + neg cls_weights = cls_weights / pos_normalizer reg_weights = pos reg_weights = reg_weights / pos_normalizer aux_loss_cls = weighted_sigmoid_focal_loss(point_cls.view(-1), rpn_cls_target, weight=cls_weights, avg_factor=1.) aux_loss_cls /= N aux_loss_reg = weighted_smoothl1(point_reg, center_targets, beta=1 / 9., weight=reg_weights[..., None], avg_factor=1.) aux_loss_reg /= N return dict( aux_loss_cls=[aux_loss_cls], aux_loss_reg=[aux_loss_reg], )
def loss(self, cls_preds, gt_bboxes, gt_labels, anchors, cfg): batch_size = len(anchors) batch_none = (None, ) * batch_size # currently only support rescoring for class agnostic anchors labels, targets, ious = multi_apply( create_target_torch, anchors, batch_none, gt_bboxes, batch_none, batch_none, similarity_fn=getattr(iou3d_utils, cfg.assigner.similarity_fn)(), box_encoding_fn=second_box_encode, matched_threshold=cfg.assigner.pos_iou_thr, unmatched_threshold=cfg.assigner.neg_iou_thr) labels = torch.cat(labels, ).unsqueeze_(1) # soft_label = torch.clamp(2 * ious - 0.5, 0, 1) # labels = soft_label * labels.float() cared = labels >= 0 positives = labels > 0 negatives = labels == 0 negative_cls_weights = negatives.type(torch.float32) cls_weights = negative_cls_weights + positives.type(torch.float32) pos_normalizer = positives.sum().type(torch.float32) cls_weights /= torch.clamp(pos_normalizer, min=1.0) cls_targets = labels * cared.type_as(labels) cls_preds = cls_preds.view(-1, self._num_class) cls_losses = weighted_sigmoid_focal_loss(cls_preds, cls_targets.float(), \ weight=cls_weights, avg_factor=1.) cls_loss_reduced = cls_losses / batch_size return dict(loss_cls=cls_loss_reduced, )