Пример #1
0
def aux_loss(points, point_cls, point_reg, gt_bboxes):
    N = len(gt_bboxes)
    pts_labels, center_targets = build_aux_target(points[:, :4], gt_bboxes)
    rpn_cls_target = pts_labels.float()
    pos = (pts_labels > 0).float()
    neg = (pts_labels == 0).float()

    pos_normalizer = pos.sum()
    pos_normalizer = torch.clamp(pos_normalizer, min=1.0)

    cls_weights = pos + neg
    cls_weights = cls_weights / pos_normalizer

    reg_weights = pos
    reg_weights = reg_weights / pos_normalizer

    aux_loss_cls = weighted_sigmoid_focal_loss(point_cls.view(-1),
                                               rpn_cls_target,
                                               weight=cls_weights,
                                               avg_factor=1.)
    aux_loss_cls /= N

    aux_loss_reg = weighted_smoothl1(point_reg,
                                     center_targets,
                                     beta=1 / 9.,
                                     weight=reg_weights[..., None],
                                     avg_factor=1.)
    aux_loss_reg /= N

    return dict(
        aux_loss_cls=[aux_loss_cls],
        aux_loss_reg=[aux_loss_reg],
    )
Пример #2
0
def aux_loss(pts_labels, center_targets, point_cls, point_reg):  # 附加任务的损失函数
	r'''
		pts_labels = sample_seg_label, B * 512
		center_targets = sample_seg_offset, B * 512 * 3
		point_cls = estimation_seg, B * 512
		point_reg = estimation_offset, B * 512 * 3
	'''
	N = len(pts_labels)
	rpn_cls_target = pts_labels.float()
	pos = (pts_labels > 0).float()
	neg = (pts_labels == 0).float()

	pos_normalizer = pos.sum()
	pos_normalizer = torch.clamp(pos_normalizer, min=1.0)

	cls_weights = pos + neg
	cls_weights = cls_weights / pos_normalizer

	reg_weights = pos
	reg_weights = reg_weights / pos_normalizer

	aux_loss_cls = weighted_sigmoid_focal_loss(
		point_cls, rpn_cls_target, weight=cls_weights, avg_factor=1.)
	aux_loss_cls /= N

	aux_loss_reg = weighted_smoothl1(
		point_reg, center_targets, beta=1 / 9., weight=reg_weights[..., None], avg_factor=1.)
	aux_loss_reg /= N

	return aux_loss_cls, aux_loss_reg
Пример #3
0
    def create_loss(self,
                    box_preds,
                    cls_preds,
                    cls_targets,
                    cls_weights,
                    reg_targets,
                    reg_weights,
                    num_class,
                    use_sigmoid_cls=True,
                    encode_rad_error_by_sin=True,
                    box_code_size=7):
        batch_size = int(box_preds.shape[0])
        box_preds = box_preds.view(batch_size, -1, box_code_size)

        if use_sigmoid_cls:
            cls_preds = cls_preds.view(batch_size, -1, num_class)
        else:
            cls_preds = cls_preds.view(batch_size, -1, num_class + 1)

        one_hot_targets = one_hot(
            cls_targets, depth=num_class + 1, dtype=box_preds.dtype)

        if use_sigmoid_cls:
            one_hot_targets = one_hot_targets[..., 1:]

        if encode_rad_error_by_sin:
            # sin(a - b) = sinacosb-cosasinb
            box_preds, reg_targets = self.add_sin_difference(box_preds, reg_targets)

        loc_losses = weighted_smoothl1(box_preds, reg_targets, beta=1 / 9., \
                                       weight=reg_weights[..., None], avg_factor=1.)
        cls_losses = weighted_sigmoid_focal_loss(cls_preds, one_hot_targets, \
                                                 weight=cls_weights[..., None], avg_factor=1.)

        return loc_losses, cls_losses