Beispiel #1
0
    def create_loss(self,
                    box_preds,
                    cls_preds,
                    cls_targets,
                    cls_weights,
                    reg_targets,
                    reg_weights,
                    num_class,
                    use_sigmoid_cls=True,
                    encode_rad_error_by_sin=True,
                    box_code_size=7):
        batch_size = int(box_preds.shape[0])
        box_preds = box_preds.view(batch_size, -1, box_code_size)

        if use_sigmoid_cls:
            cls_preds = cls_preds.view(batch_size, -1, num_class)
        else:
            cls_preds = cls_preds.view(batch_size, -1, num_class + 1)

        one_hot_targets = one_hot(
            cls_targets, depth=num_class + 1, dtype=box_preds.dtype)

        if use_sigmoid_cls:
            one_hot_targets = one_hot_targets[..., 1:]

        if encode_rad_error_by_sin:
            # sin(a - b) = sinacosb-cosasinb
            box_preds, reg_targets = self.add_sin_difference(box_preds, reg_targets)

        loc_losses = weighted_smoothl1(box_preds, reg_targets, beta=1 / 9., \
                                       weight=reg_weights[..., None], avg_factor=1.)
        cls_losses = weighted_sigmoid_focal_loss(cls_preds, one_hot_targets, \
                                                 weight=cls_weights[..., None], avg_factor=1.)

        return loc_losses, cls_losses
Beispiel #2
0
 def get_direction_target(self, anchors, reg_targets, use_one_hot=True):
     batch_size = reg_targets.shape[0]
     anchors = anchors.view(batch_size, -1, 7)
     rot_gt = reg_targets[..., -1] + anchors[..., -1]
     dir_cls_targets = (rot_gt > 0).long()
     if use_one_hot:
         dir_cls_targets = one_hot(dir_cls_targets, 2, dtype=anchors.dtype)
     return dir_cls_targets