コード例 #1
0
    def __call__(self, ins_pred_list, ins_label_list, cate_preds, cate_labels,
                 num_ins):
        """
        Get loss of network of SOLOv2.
        Args:
            ins_pred_list (list): Variable list of instance branch output.
            ins_label_list (list): List of instance labels pre batch.
            cate_preds (list): Concat Variable list of categroy branch output.
            cate_labels (list): Concat list of categroy labels pre batch.
            num_ins (int): Number of positive samples in a mini-batch.
        Returns:
            loss_ins (Variable): The instance loss Variable of SOLOv2 network.
            loss_cate (Variable): The category loss Variable of SOLOv2 network.
        """

        #1. Ues dice_loss to calculate instance loss
        loss_ins = []
        total_weights = paddle.zeros(shape=[1], dtype='float32')
        for input, target in zip(ins_pred_list, ins_label_list):
            if input is None:
                continue
            target = paddle.cast(target, 'float32')
            target = paddle.reshape(
                target,
                shape=[-1,
                       paddle.shape(input)[-2],
                       paddle.shape(input)[-1]])
            weights = paddle.cast(
                paddle.sum(target, axis=[1, 2]) > 0, 'float32')
            input = F.sigmoid(input)
            dice_out = paddle.multiply(self._dice_loss(input, target), weights)
            total_weights += paddle.sum(weights)
            loss_ins.append(dice_out)
        loss_ins = paddle.sum(paddle.concat(loss_ins)) / total_weights
        loss_ins = loss_ins * self.ins_loss_weight

        #2. Ues sigmoid_focal_loss to calculate category loss
        # expand onehot labels
        num_classes = cate_preds.shape[-1]
        cate_labels_bin = F.one_hot(cate_labels, num_classes=num_classes + 1)
        cate_labels_bin = cate_labels_bin[:, 1:]

        loss_cate = F.sigmoid_focal_loss(cate_preds,
                                         label=cate_labels_bin,
                                         normalizer=num_ins + 1.,
                                         gamma=self.focal_loss_gamma,
                                         alpha=self.focal_loss_alpha)

        return loss_ins, loss_cate
コード例 #2
0
 def forward(self, pred, target, reduction='none'):
     """forward function.
     Args:
         pred (Tensor): logits of class prediction, of shape (N, num_classes)
         target (Tensor): target class label, of shape (N, )
         reduction (str): the way to reduce loss, one of (none, sum, mean)
     """
     num_classes = pred.shape[1]
     target = F.one_hot(target, num_classes + 1).cast(pred.dtype)
     target = target[:, :-1].detach()
     loss = F.sigmoid_focal_loss(pred,
                                 target,
                                 alpha=self.alpha,
                                 gamma=self.gamma,
                                 reduction=reduction)
     return loss * self.loss_weight
コード例 #3
0
ファイル: detr_loss.py プロジェクト: ghostxsl/PaddleDetection
    def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts):
        # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
        loss = dict()
        if sum(len(a) for a in gt_mask) == 0:
            loss['loss_mask'] = paddle.to_tensor([0.])
            loss['loss_dice'] = paddle.to_tensor([0.])
            return loss

        src_masks, target_masks = self._get_src_target_assign(
            masks, gt_mask, match_indices)
        src_masks = F.interpolate(src_masks.unsqueeze(0),
                                  size=target_masks.shape[-2:],
                                  mode="bilinear")[0]
        loss['loss_mask'] = self.loss_coeff['mask'] * F.sigmoid_focal_loss(
            src_masks, target_masks,
            paddle.to_tensor([num_gts], dtype='float32'))
        loss['loss_dice'] = self.loss_coeff['dice'] * self._dice_loss(
            src_masks, target_masks, num_gts)
        return loss
コード例 #4
0
    def get_odm_loss(self, odm_target, s2anet_head_out, reg_loss_type='gwd'):
        (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes,
         pos_inds, neg_inds) = odm_target
        fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out

        odm_cls_losses = []
        odm_bbox_losses = []
        st_idx = 0
        num_total_samples = len(pos_inds) + len(
            neg_inds) if self.sampling else len(pos_inds)
        num_total_samples = max(1, num_total_samples)

        for idx, feat_size in enumerate(self.featmap_sizes_list):
            feat_anchor_num = feat_size[0] * feat_size[1]

            # step1:  get data
            feat_labels = labels[st_idx:st_idx + feat_anchor_num]
            feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]

            feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
            feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]

            # step2: calc cls loss
            feat_labels = feat_labels.reshape(-1)
            feat_label_weights = feat_label_weights.reshape(-1)

            odm_cls_score = odm_cls_branch_list[idx]
            odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
            odm_cls_score1 = odm_cls_score

            feat_labels = paddle.to_tensor(feat_labels)
            feat_labels_one_hot = paddle.nn.functional.one_hot(
                feat_labels, self.cls_out_channels + 1)
            feat_labels_one_hot = feat_labels_one_hot[:, 1:]
            feat_labels_one_hot.stop_gradient = True

            num_total_samples = paddle.to_tensor(
                num_total_samples, dtype='float32', stop_gradient=True)
            odm_cls = F.sigmoid_focal_loss(
                odm_cls_score1,
                feat_labels_one_hot,
                normalizer=num_total_samples,
                reduction='none')

            feat_label_weights = feat_label_weights.reshape(
                feat_label_weights.shape[0], 1)
            feat_label_weights = np.repeat(
                feat_label_weights, self.cls_out_channels, axis=1)
            feat_label_weights = paddle.to_tensor(feat_label_weights)
            feat_label_weights.stop_gradient = True

            odm_cls = odm_cls * feat_label_weights
            odm_cls_total = paddle.sum(odm_cls)
            odm_cls_losses.append(odm_cls_total)

            # # step3: regression loss
            feat_bbox_targets = paddle.to_tensor(
                feat_bbox_targets, dtype='float32')
            feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
            feat_bbox_targets.stop_gradient = True

            odm_bbox_pred = odm_reg_branch_list[idx]
            odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
            odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
            odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)

            loss_weight = paddle.to_tensor(
                self.reg_loss_weight, dtype='float32', stop_gradient=True)
            odm_bbox = paddle.multiply(odm_bbox, loss_weight)
            feat_bbox_weights = paddle.to_tensor(
                feat_bbox_weights, stop_gradient=True)

            if reg_loss_type == 'l1':
                odm_bbox = odm_bbox * feat_bbox_weights
                odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
            elif reg_loss_type == 'iou' or reg_loss_type == 'gwd':
                odm_bbox = paddle.sum(odm_bbox, axis=-1)
                feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1)
                try:
                    from rbox_iou_ops import rbox_iou
                except Exception as e:
                    print("import custom_ops error, try install rbox_iou_ops " \
                          "following ppdet/ext_op/README.md", e)
                    sys.stdout.flush()
                    sys.exit(-1)
                # calc iou
                odm_bbox_decode = self.delta2rbox(self.refine_anchor_list[idx],
                                                  odm_bbox_pred)
                bbox_gt_bboxes = paddle.to_tensor(
                    bbox_gt_bboxes,
                    dtype=odm_bbox_decode.dtype,
                    place=odm_bbox_decode.place)
                bbox_gt_bboxes.stop_gradient = True
                iou = rbox_iou(odm_bbox_decode, bbox_gt_bboxes)
                iou = paddle.diag(iou)

                if reg_loss_type == 'gwd':
                    bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx +
                                                          feat_anchor_num, :]
                    odm_bbox_total = self.gwd_loss(odm_bbox_decode,
                                                   bbox_gt_bboxes_level)
                    odm_bbox_total = odm_bbox_total * feat_bbox_weights
                    odm_bbox_total = paddle.sum(odm_bbox_total) / num_total_samples

            odm_bbox_losses.append(odm_bbox_total)
            st_idx += feat_anchor_num

        odm_cls_loss = paddle.add_n(odm_cls_losses)
        odm_cls_loss_weight = paddle.to_tensor(
            self.cls_loss_weight[1], dtype='float32', stop_gradient=True)
        odm_cls_loss = odm_cls_loss * odm_cls_loss_weight
        odm_reg_loss = paddle.add_n(odm_bbox_losses)
        return odm_cls_loss, odm_reg_loss
コード例 #5
0
    def forward(self, cls_logits, bboxes_reg, centerness, tag_labels,
                tag_bboxes, tag_center):
        """
        Calculate the loss for classification, location and centerness
        Args:
            cls_logits (list): list of Tensor, which is predicted
                score for all anchor points with shape [N, M, C]
            bboxes_reg (list): list of Tensor, which is predicted
                offsets for all anchor points with shape [N, M, 4]
            centerness (list): list of Tensor, which is predicted
                centerness for all anchor points with shape [N, M, 1]
            tag_labels (list): list of Tensor, which is category
                targets for each anchor point
            tag_bboxes (list): list of Tensor, which is bounding
                boxes targets for positive samples
            tag_center (list): list of Tensor, which is centerness
                targets for positive samples
        Return:
            loss (dict): loss composed by classification loss, bounding box
        """
        cls_logits_flatten_list = []
        bboxes_reg_flatten_list = []
        centerness_flatten_list = []
        tag_labels_flatten_list = []
        tag_bboxes_flatten_list = []
        tag_center_flatten_list = []
        num_lvl = len(cls_logits)
        for lvl in range(num_lvl):
            cls_logits_flatten_list.append(
                flatten_tensor(cls_logits[lvl], True))
            bboxes_reg_flatten_list.append(
                flatten_tensor(bboxes_reg[lvl], True))
            centerness_flatten_list.append(
                flatten_tensor(centerness[lvl], True))

            tag_labels_flatten_list.append(
                flatten_tensor(tag_labels[lvl], False))
            tag_bboxes_flatten_list.append(
                flatten_tensor(tag_bboxes[lvl], False))
            tag_center_flatten_list.append(
                flatten_tensor(tag_center[lvl], False))

        cls_logits_flatten = paddle.concat(cls_logits_flatten_list, axis=0)
        bboxes_reg_flatten = paddle.concat(bboxes_reg_flatten_list, axis=0)
        centerness_flatten = paddle.concat(centerness_flatten_list, axis=0)

        tag_labels_flatten = paddle.concat(tag_labels_flatten_list, axis=0)
        tag_bboxes_flatten = paddle.concat(tag_bboxes_flatten_list, axis=0)
        tag_center_flatten = paddle.concat(tag_center_flatten_list, axis=0)
        tag_labels_flatten.stop_gradient = True
        tag_bboxes_flatten.stop_gradient = True
        tag_center_flatten.stop_gradient = True

        mask_positive_bool = tag_labels_flatten > 0
        mask_positive_bool.stop_gradient = True
        mask_positive_float = paddle.cast(mask_positive_bool, dtype="float32")
        mask_positive_float.stop_gradient = True

        num_positive_fp32 = paddle.sum(mask_positive_float)
        num_positive_fp32.stop_gradient = True
        num_positive_int32 = paddle.cast(num_positive_fp32, dtype="int32")
        num_positive_int32 = num_positive_int32 * 0 + 1
        num_positive_int32.stop_gradient = True

        normalize_sum = paddle.sum(tag_center_flatten * mask_positive_float)
        normalize_sum.stop_gradient = True

        # 1. cls_logits: sigmoid_focal_loss
        # expand onehot labels
        num_classes = cls_logits_flatten.shape[-1]
        tag_labels_flatten = paddle.squeeze(tag_labels_flatten, axis=-1)
        tag_labels_flatten_bin = F.one_hot(
            tag_labels_flatten, num_classes=1 + num_classes)
        tag_labels_flatten_bin = tag_labels_flatten_bin[:, 1:]
        # sigmoid_focal_loss
        cls_loss = F.sigmoid_focal_loss(
            cls_logits_flatten, tag_labels_flatten_bin) / num_positive_fp32

        # 2. bboxes_reg: giou_loss
        mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
        tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
        reg_loss = self.__iou_loss(
            bboxes_reg_flatten,
            tag_bboxes_flatten,
            mask_positive_float,
            weights=tag_center_flatten)
        reg_loss = reg_loss * mask_positive_float / normalize_sum

        # 3. centerness: sigmoid_cross_entropy_with_logits_loss
        centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
        ctn_loss = ops.sigmoid_cross_entropy_with_logits(centerness_flatten,
                                                         tag_center_flatten)
        ctn_loss = ctn_loss * mask_positive_float / num_positive_fp32

        loss_all = {
            "loss_centerness": paddle.sum(ctn_loss),
            "loss_cls": paddle.sum(cls_loss),
            "loss_box": paddle.sum(reg_loss)
        }
        return loss_all
コード例 #6
0
    def forward(self, cls_logits, bboxes_reg, centerness, tag_labels, tag_bboxes, tag_center):
        """
        Calculate the loss for classification, location and centerness
        Args:
            cls_logits (list): list of Variables, which is predicted
                score for all anchor points with shape [N, M, C]
            bboxes_reg (list): list of Variables, which is predicted
                offsets for all anchor points with shape [N, M, 4]
            centerness (list): list of Variables, which is predicted
                centerness for all anchor points with shape [N, M, 1]
            tag_labels (list): list of Variables, which is category
                targets for each anchor point
            tag_bboxes (list): list of Variables, which is bounding
                boxes targets for positive samples
            tag_center (list): list of Variables, which is centerness
                targets for positive samples
        Return:
            loss (dict): loss composed by classification loss, bounding box
        """
        cls_logits_flatten_list = []
        bboxes_reg_flatten_list = []
        centerness_flatten_list = []
        tag_labels_flatten_list = []
        tag_bboxes_flatten_list = []
        tag_center_flatten_list = []
        num_lvl = len(cls_logits)
        for lvl in range(num_lvl):
            cls_logits_flatten_list.append(
                flatten_tensor(cls_logits[lvl], True))
            bboxes_reg_flatten_list.append(
                flatten_tensor(bboxes_reg[lvl], True))
            centerness_flatten_list.append(
                flatten_tensor(centerness[lvl], True))

            tag_labels_flatten_list.append(
                flatten_tensor(tag_labels[lvl], False))
            tag_bboxes_flatten_list.append(
                flatten_tensor(tag_bboxes[lvl], False))
            tag_center_flatten_list.append(
                flatten_tensor(tag_center[lvl], False))

        cls_logits_flatten = paddle.concat(
            cls_logits_flatten_list, axis=0)
        bboxes_reg_flatten = paddle.concat(
            bboxes_reg_flatten_list, axis=0)
        centerness_flatten = paddle.concat(
            centerness_flatten_list, axis=0)

        tag_labels_flatten = paddle.concat(
            tag_labels_flatten_list, axis=0)
        tag_bboxes_flatten = paddle.concat(
            tag_bboxes_flatten_list, axis=0)
        tag_center_flatten = paddle.concat(
            tag_center_flatten_list, axis=0)
        tag_labels_flatten.stop_gradient = True
        tag_bboxes_flatten.stop_gradient = True
        tag_center_flatten.stop_gradient = True

        mask_positive_bool = tag_labels_flatten > 0
        mask_positive_bool.stop_gradient = True
        mask_positive_float = paddle.cast(mask_positive_bool, dtype="float32")
        mask_positive_float.stop_gradient = True

        num_positive_fp32 = paddle.sum(mask_positive_float)
        num_positive_fp32.stop_gradient = True
        num_positive_int32 = paddle.cast(num_positive_fp32, dtype="int32")
        num_positive_int32 = num_positive_int32 * 0 + 1
        num_positive_int32.stop_gradient = True

        normalize_sum = paddle.sum(tag_center_flatten * mask_positive_float)
        normalize_sum.stop_gradient = True

        # expand_onehot_labels
        categories = cls_logits_flatten.shape[-1]
        tag_labels_bin = np.zeros(shape=(tag_labels_flatten.shape[0], categories))
        tag_labels_flatten_b = paddle.squeeze(tag_labels_flatten, axis=-1).numpy()
        inds = np.nonzero((tag_labels_flatten_b > 0) & (tag_labels_flatten_b <= categories))[0]#.squeeze()
        if len(inds) > 0:
            tag_labels_bin[inds, tag_labels_flatten_b[inds]-1] = 1
        tag_labels_flatten_bin = paddle.to_tensor(tag_labels_bin.astype('float32'))
        cls_loss = F.sigmoid_focal_loss(cls_logits_flatten, tag_labels_flatten_bin)/num_positive_fp32

        mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
        tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
        reg_loss = self.__iou_loss(bboxes_reg_flatten, tag_bboxes_flatten,
            mask_positive_float, weights=tag_center_flatten)
        reg_loss = reg_loss * mask_positive_float / normalize_sum

        centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
        ctn_loss = sigmoid_cross_entropy_with_logits(centerness_flatten, tag_center_flatten) / num_positive_fp32
        ctn_loss = ctn_loss * mask_positive_float / normalize_sum

        loss_all = {
            "loss_centerness": paddle.sum(ctn_loss),
            "loss_cls": paddle.sum(cls_loss),
            "loss_box": paddle.sum(reg_loss)
        }
        return loss_all
コード例 #7
0
    def get_odm_loss(self, odm_target, s2anet_head_out):
        (feat_labels, feat_label_weights, feat_bbox_targets, feat_bbox_weights,
         pos_inds, neg_inds) = odm_target
        odm_cls_score, odm_bbox_pred = s2anet_head_out

        # step1:  sample count
        num_total_samples = len(pos_inds) + len(
            neg_inds) if self.sampling else len(pos_inds)
        num_total_samples = max(1, num_total_samples)

        # step2: calc cls loss
        feat_labels = feat_labels.reshape(-1)
        feat_label_weights = feat_label_weights.reshape(-1)
        odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
        odm_cls_score1 = odm_cls_score

        # gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1
        # for debug 0426
        feat_labels = feat_labels + 1
        feat_labels = paddle.to_tensor(feat_labels)
        feat_labels_one_hot = F.one_hot(feat_labels, self.cls_out_channels + 1)
        feat_labels_one_hot = feat_labels_one_hot[:, 1:]
        feat_labels_one_hot.stop_gradient = True

        num_total_samples = paddle.to_tensor(num_total_samples,
                                             dtype='float32',
                                             stop_gradient=True)

        odm_cls = F.sigmoid_focal_loss(odm_cls_score1,
                                       feat_labels_one_hot,
                                       normalizer=num_total_samples,
                                       reduction='none')

        feat_label_weights = feat_label_weights.reshape(
            feat_label_weights.shape[0], 1)
        feat_label_weights = np.repeat(feat_label_weights,
                                       self.cls_out_channels,
                                       axis=1)
        feat_label_weights = paddle.to_tensor(feat_label_weights,
                                              stop_gradient=True)

        odm_cls = odm_cls * feat_label_weights
        odm_cls_total = paddle.sum(odm_cls)

        # step3: regression loss
        feat_bbox_targets = paddle.to_tensor(feat_bbox_targets,
                                             dtype='float32',
                                             stop_gradient=True)
        feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
        odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
        odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
        odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
        loss_weight = paddle.to_tensor(self.reg_loss_weight,
                                       dtype='float32',
                                       stop_gradient=True)
        odm_bbox = paddle.multiply(odm_bbox, loss_weight)
        feat_bbox_weights = paddle.to_tensor(feat_bbox_weights,
                                             stop_gradient=True)
        odm_bbox = odm_bbox * feat_bbox_weights
        odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples

        odm_cls_loss_weight = paddle.to_tensor(self.cls_loss_weight[0],
                                               dtype='float32',
                                               stop_gradient=True)
        odm_cls_loss = odm_cls_total * odm_cls_loss_weight
        odm_reg_loss = paddle.add_n(odm_bbox_total)
        return odm_cls_loss, odm_reg_loss
コード例 #8
0
    def get_odm_loss(self, odm_target, s2anet_head_out):
        (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
         neg_inds) = odm_target
        fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out

        odm_cls_losses = []
        odm_bbox_losses = []
        st_idx = 0
        featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
        num_total_samples = len(pos_inds) + len(
            neg_inds) if self.sampling else len(pos_inds)
        num_total_samples = max(1, num_total_samples)
        for idx, feat_size in enumerate(featmap_sizes):
            feat_anchor_num = feat_size[0] * feat_size[1]

            # step1:  get data
            feat_labels = labels[st_idx:st_idx + feat_anchor_num]
            feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]

            feat_bbox_targets = bbox_targets[st_idx:st_idx +
                                             feat_anchor_num, :]
            feat_bbox_weights = bbox_weights[st_idx:st_idx +
                                             feat_anchor_num, :]
            st_idx += feat_anchor_num

            # step2: calc cls loss
            feat_labels = feat_labels.reshape(-1)
            feat_label_weights = feat_label_weights.reshape(-1)

            odm_cls_score = odm_cls_branch_list[idx]
            odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
            odm_cls_score1 = odm_cls_score

            # gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1
            feat_labels = paddle.to_tensor(feat_labels)
            feat_labels_one_hot = paddle.nn.functional.one_hot(
                feat_labels, self.cls_out_channels + 1)
            feat_labels_one_hot = feat_labels_one_hot[:, 1:]
            feat_labels_one_hot.stop_gradient = True

            num_total_samples = paddle.to_tensor(num_total_samples,
                                                 dtype='float32',
                                                 stop_gradient=True)
            odm_cls = F.sigmoid_focal_loss(odm_cls_score1,
                                           feat_labels_one_hot,
                                           normalizer=num_total_samples,
                                           reduction='none')

            feat_label_weights = feat_label_weights.reshape(
                feat_label_weights.shape[0], 1)
            feat_label_weights = np.repeat(feat_label_weights,
                                           self.cls_out_channels,
                                           axis=1)
            feat_label_weights = paddle.to_tensor(feat_label_weights)
            feat_label_weights.stop_gradient = True

            odm_cls = odm_cls * feat_label_weights
            odm_cls_total = paddle.sum(odm_cls)
            odm_cls_losses.append(odm_cls_total)

            # # step3: regression loss
            feat_bbox_targets = paddle.to_tensor(feat_bbox_targets,
                                                 dtype='float32')
            feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
            feat_bbox_targets.stop_gradient = True

            odm_bbox_pred = odm_reg_branch_list[idx]
            odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
            odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
            odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
            loss_weight = paddle.to_tensor(self.reg_loss_weight,
                                           dtype='float32',
                                           stop_gradient=True)
            odm_bbox = paddle.multiply(odm_bbox, loss_weight)
            feat_bbox_weights = paddle.to_tensor(feat_bbox_weights,
                                                 stop_gradient=True)
            odm_bbox = odm_bbox * feat_bbox_weights
            odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
            odm_bbox_losses.append(odm_bbox_total)

        odm_cls_loss = paddle.add_n(odm_cls_losses)
        odm_cls_loss = odm_cls_loss * 2.0
        odm_reg_loss = paddle.add_n(odm_bbox_losses)
        return odm_cls_loss, odm_reg_loss