def create_loss(self, box_preds, cls_preds, cls_targets, cls_weights, reg_targets, reg_weights, num_class, use_sigmoid_cls=True, encode_rad_error_by_sin=True, box_code_size=7): batch_size = int(box_preds.shape[0]) box_preds = box_preds.view(batch_size, -1, box_code_size) if use_sigmoid_cls: cls_preds = cls_preds.view(batch_size, -1, num_class) else: cls_preds = cls_preds.view(batch_size, -1, num_class + 1) one_hot_targets = one_hot( cls_targets, depth=num_class + 1, dtype=box_preds.dtype) if use_sigmoid_cls: one_hot_targets = one_hot_targets[..., 1:] if encode_rad_error_by_sin: # sin(a - b) = sinacosb-cosasinb box_preds, reg_targets = self.add_sin_difference(box_preds, reg_targets) loc_losses = weighted_smoothl1(box_preds, reg_targets, beta=1 / 9., \ weight=reg_weights[..., None], avg_factor=1.) cls_losses = weighted_sigmoid_focal_loss(cls_preds, one_hot_targets, \ weight=cls_weights[..., None], avg_factor=1.) return loc_losses, cls_losses
def get_direction_target(self, anchors, reg_targets, use_one_hot=True): batch_size = reg_targets.shape[0] anchors = anchors.view(batch_size, -1, 7) rot_gt = reg_targets[..., -1] + anchors[..., -1] dir_cls_targets = (rot_gt > 0).long() if use_one_hot: dir_cls_targets = one_hot(dir_cls_targets, 2, dtype=anchors.dtype) return dir_cls_targets