Esempio n. 1
0
def softmax_loss(pred, label, ignore_label=-1):
    max_pred = F.zero_grad(pred.max(axis=1, keepdims=True))
    pred -= max_pred
    log_prob = pred - F.log(F.exp(pred).sum(axis=1, keepdims=True))
    mask = 1 - F.equal(label, ignore_label)
    vlabel = label * mask
    loss = -(F.indexing_one_hot(log_prob, vlabel, 1) * mask)
    return loss
Esempio n. 2
0
def softmax_cross_entropy(pred, label, axis=1, ignore_index=255):
    offset = F.zero_grad(pred.max(axis=axis, keepdims=True))
    pred = pred - offset
    log_prob = pred - F.log(F.exp(pred).sum(axis=axis, keepdims=True))

    mask = 1 - F.equal(label, ignore_index)
    vlabel = label * mask
    loss = -(F.indexing_one_hot(log_prob, vlabel, axis) *
             mask).sum() / F.maximum(mask.sum(), 1)
    return loss
Esempio n. 3
0
def softmax_loss(score, label, ignore_label=-1):
    max_score = F.zero_grad(score.max(axis=1, keepdims=True))
    score -= max_score
    log_prob = score - F.log(F.exp(score).sum(axis=1, keepdims=True))
    mask = (label != ignore_label)
    vlabel = label * mask
    loss = -(F.indexing_one_hot(log_prob, vlabel.astype("int32"), 1) *
             mask).sum()
    loss = loss / F.maximum(mask.sum(), 1)
    return loss
Esempio n. 4
0
def softmax_loss(scores: Tensor,
                 labels: Tensor,
                 ignore_label: int = -1) -> Tensor:
    max_scores = F.zero_grad(scores.max(axis=1, keepdims=True))
    scores -= max_scores
    log_prob = scores - F.log(F.exp(scores).sum(axis=1, keepdims=True))
    mask = labels != ignore_label
    vlabels = labels * mask
    loss = -(F.indexing_one_hot(log_prob, vlabels.astype("int32"), 1) *
             mask).sum()
    loss = loss / F.maximum(mask.sum(), 1)
    return loss
Esempio n. 5
0
    def forward(self, fpn_fms, rcnn_rois, im_info=None, gt_boxes=None):
        rcnn_rois, labels, bbox_targets = self.get_ground_truth(
            rcnn_rois, im_info, gt_boxes)

        fpn_fms = [fpn_fms[x] for x in self.in_features]
        pool_features = layers.roi_pool(
            fpn_fms,
            rcnn_rois,
            self.stride,
            self.pooling_size,
            self.pooling_method,
        )
        flatten_feature = F.flatten(pool_features, start_axis=1)
        roi_feature = F.relu(self.fc1(flatten_feature))
        roi_feature = F.relu(self.fc2(roi_feature))
        pred_cls = self.pred_cls(roi_feature)
        pred_delta = self.pred_delta(roi_feature)

        if self.training:
            # loss for classification
            loss_rcnn_cls = layers.softmax_loss(pred_cls, labels)
            # loss for regression
            pred_delta = pred_delta.reshape(-1, self.cfg.num_classes + 1, 4)

            vlabels = labels.reshape(-1, 1).broadcast((labels.shapeof(0), 4))
            pred_delta = F.indexing_one_hot(pred_delta, vlabels, axis=1)

            loss_rcnn_loc = layers.get_smooth_l1_loss(
                pred_delta,
                bbox_targets,
                labels,
                self.cfg.rcnn_smooth_l1_beta,
                norm_type="all",
            )
            loss_dict = {
                'loss_rcnn_cls': loss_rcnn_cls,
                'loss_rcnn_loc': loss_rcnn_loc
            }
            return loss_dict
        else:
            # slice 1 for removing background
            pred_scores = F.softmax(pred_cls, axis=1)[:, 1:]
            pred_delta = pred_delta[:, 4:].reshape(-1, 4)
            target_shape = (rcnn_rois.shapeof(0), self.cfg.num_classes, 4)
            # rois (N, 4) -> (N, 1, 4) -> (N, 80, 4) -> (N * 80, 4)
            base_rois = F.add_axis(rcnn_rois[:, 1:5],
                                   1).broadcast(target_shape).reshape(-1, 4)
            pred_bbox = self.box_coder.decode(base_rois, pred_delta)
            return pred_bbox, pred_scores
Esempio n. 6
0
 def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None):
     # stride: 64,32,16,8,4 -> 4, 8, 16, 32
     fpn_fms = fpn_fms[1:][::-1]
     stride = [4, 8, 16, 32]
     pool_features, rcnn_rois, labels, bbox_targets = roi_pool(
         fpn_fms, rcnn_rois, stride, (7, 7), 'roi_align', labels,
         bbox_targets)
     flatten_feature = F.flatten(pool_features, start_axis=1)
     roi_feature = F.relu(self.fc1(flatten_feature))
     roi_feature = F.relu(self.fc2(roi_feature))
     pred_cls = self.pred_cls(roi_feature)
     pred_delta = self.pred_delta(roi_feature)
     if self.training:
         # loss for regression
         labels = labels.astype(np.int32).reshape(-1)
         # mulitple class to one
         pos_masks = labels > 0
         pred_delta = pred_delta.reshape(-1, config.num_classes, 4)
         indexing_label = (labels * pos_masks).reshape(-1, 1)
         indexing_label = indexing_label.broadcast((labels.shapeof()[0], 4))
         pred_delta = F.indexing_one_hot(pred_delta, indexing_label, 1)
         localization_loss = smooth_l1_loss(pred_delta, bbox_targets,
                                            config.rcnn_smooth_l1_beta)
         localization_loss = localization_loss * pos_masks
         # loss for classification
         valid_masks = labels >= 0
         objectness_loss = softmax_loss(pred_cls, labels)
         objectness_loss = objectness_loss * valid_masks
         normalizer = 1.0 / (valid_masks.sum())
         loss_rcnn_cls = objectness_loss.sum() * normalizer
         loss_rcnn_loc = localization_loss.sum() * normalizer
         loss_dict = {}
         loss_dict['loss_rcnn_cls'] = loss_rcnn_cls
         loss_dict['loss_rcnn_loc'] = loss_rcnn_loc
         return loss_dict
     else:
         pred_scores = F.softmax(pred_cls)[:, 1:].reshape(-1, 1)
         pred_delta = pred_delta[:, 4:].reshape(-1, 4)
         target_shape = (rcnn_rois.shapeof()[0], config.num_classes - 1, 4)
         base_rois = F.add_axis(rcnn_rois[:, 1:5],
                                1).broadcast(target_shape).reshape(-1, 4)
         pred_bbox = restore_bbox(base_rois, pred_delta, True)
         pred_bbox = F.concat([pred_bbox, pred_scores], axis=1)
         return pred_bbox
Esempio n. 7
0
 def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None):
     # stride: 64,32,16,8,4 -> 4, 8, 16, 32
     fpn_fms = fpn_fms[1:][::-1]
     stride = [4, 8, 16, 32]
     pool_features, rcnn_rois, labels, bbox_targets = roi_pool(
             fpn_fms, rcnn_rois, stride, (7, 7), 'roi_align',
             labels, bbox_targets)
     flatten_feature = F.flatten(pool_features, start_axis=1)
     roi_feature = F.relu(self.fc1(flatten_feature))
     roi_feature = F.relu(self.fc2(roi_feature))
     pred_emd_pred_cls_0 = self.emd_pred_cls_0(roi_feature)
     pred_emd_pred_delta_0 = self.emd_pred_delta_0(roi_feature)
     pred_emd_pred_cls_1 = self.emd_pred_cls_1(roi_feature)
     pred_emd_pred_delta_1 = self.emd_pred_delta_1(roi_feature)
     if self.training:
         loss0 = emd_loss(
                     pred_emd_pred_delta_0, pred_emd_pred_cls_0,
                     pred_emd_pred_delta_1, pred_emd_pred_cls_1,
                     bbox_targets, labels)
         loss1 = emd_loss(
                     pred_emd_pred_delta_1, pred_emd_pred_cls_1,
                     pred_emd_pred_delta_0, pred_emd_pred_cls_0,
                     bbox_targets, labels)
         loss = F.concat([loss0, loss1], axis=1)
         indices = F.argmin(loss, axis=1)
         loss_emd = F.indexing_one_hot(loss, indices, 1)
         loss_emd = loss_emd.sum()/loss_emd.shapeof()[0]
         loss_dict = {}
         loss_dict['loss_rcnn_emd'] = loss_emd
         return loss_dict
     else:
         pred_scores_0 = F.softmax(pred_emd_pred_cls_0)[:, 1:].reshape(-1, 1)
         pred_scores_1 = F.softmax(pred_emd_pred_cls_1)[:, 1:].reshape(-1, 1)
         pred_delta_0 = pred_emd_pred_delta_0[:, 4:].reshape(-1, 4)
         pred_delta_1 = pred_emd_pred_delta_1[:, 4:].reshape(-1, 4)
         target_shape = (rcnn_rois.shapeof()[0], config.num_classes - 1, 4)
         base_rois = F.add_axis(rcnn_rois[:, 1:5], 1).broadcast(target_shape).reshape(-1, 4)
         pred_bbox_0 = restore_bbox(base_rois, pred_delta_0, True)
         pred_bbox_1 = restore_bbox(base_rois, pred_delta_1, True)
         pred_bbox_0 = F.concat([pred_bbox_0, pred_scores_0], axis=1)
         pred_bbox_1 = F.concat([pred_bbox_1, pred_scores_1], axis=1)
         #[{head0, pre1, tag1}, {head1, pre1, tag1}, {head0, pre1, tag2}, ...]
         pred_bbox = F.concat((pred_bbox_0, pred_bbox_1), axis=1).reshape(-1,5)
         return pred_bbox
Esempio n. 8
0
def emd_loss(p_b0, p_c0, p_b1, p_c1, targets, labels):
    pred_box = F.concat([p_b0, p_b1], axis=1).reshape(-1, p_b0.shapeof()[-1])
    pred_box = pred_box.reshape(-1, config.num_classes, 4)
    pred_score = F.concat([p_c0, p_c1], axis=1).reshape(-1, p_c0.shapeof()[-1])
    targets = targets.reshape(-1, 4)
    labels = labels.reshape(-1).astype(np.int32)
    fg_masks = F.greater(labels, 0)
    non_ignore_masks = F.greater_equal(labels, 0)
    # mulitple class to one
    indexing_label = (labels * fg_masks).reshape(-1,1)
    indexing_label = indexing_label.broadcast((labels.shapeof()[0], 4))
    pred_box = F.indexing_one_hot(pred_box, indexing_label, 1)
    # loss for regression
    loss_box_reg = smooth_l1_loss(
        pred_box,
        targets,
        config.rcnn_smooth_l1_beta)
    # loss for classification
    loss_cls = softmax_loss(pred_score, labels)
    loss = loss_cls*non_ignore_masks + loss_box_reg * fg_masks
    loss = loss.reshape(-1, 2).sum(axis=1)
    return loss.reshape(-1, 1)
Esempio n. 9
0
    def get_ground_truth(self, anchors_list, batched_gt_boxes,
                         batched_num_gts):
        labels_list = []
        offsets_list = []
        ctrness_list = []

        all_level_anchors = F.concat(anchors_list, axis=0)
        for bid in range(batched_gt_boxes.shape[0]):
            gt_boxes = batched_gt_boxes[bid, :batched_num_gts[bid]]

            offsets = self.point_coder.encode(
                all_level_anchors, F.expand_dims(gt_boxes[:, :4], axis=1))

            object_sizes_of_interest = F.concat([
                F.broadcast_to(
                    F.expand_dims(mge.tensor(size, dtype=np.float32), axis=0),
                    (anchors_i.shape[0], 2)) for anchors_i, size in zip(
                        anchors_list, self.cfg.object_sizes_of_interest)
            ],
                                                axis=0)
            max_offsets = F.max(offsets, axis=2)
            is_cared_in_the_level = (
                (max_offsets >= F.expand_dims(object_sizes_of_interest[:, 0],
                                              axis=0))
                & (max_offsets <= F.expand_dims(object_sizes_of_interest[:, 1],
                                                axis=0)))

            if self.cfg.center_sampling_radius > 0:
                gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:4]) / 2
                is_in_boxes = []
                for stride, anchors_i in zip(self.cfg.stride, anchors_list):
                    radius = stride * self.cfg.center_sampling_radius
                    center_boxes = F.concat([
                        F.maximum(gt_centers - radius, gt_boxes[:, :2]),
                        F.minimum(gt_centers + radius, gt_boxes[:, 2:4]),
                    ],
                                            axis=1)
                    center_offsets = self.point_coder.encode(
                        anchors_i, F.expand_dims(center_boxes, axis=1))
                    is_in_boxes.append(F.min(center_offsets, axis=2) > 0)
                is_in_boxes = F.concat(is_in_boxes, axis=1)
            else:
                is_in_boxes = F.min(offsets, axis=2) > 0

            gt_area = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] -
                                                           gt_boxes[:, 1])
            # FIXME: use repeat instead of broadcast_to
            areas = F.broadcast_to(F.expand_dims(gt_area, axis=1),
                                   offsets.shape[:2])
            areas[~is_cared_in_the_level] = float("inf")
            areas[~is_in_boxes] = float("inf")

            match_indices = F.argmin(areas, axis=0)
            gt_boxes_matched = gt_boxes[match_indices]
            anchor_min_area = F.indexing_one_hot(areas, match_indices, axis=0)

            labels = gt_boxes_matched[:, 4].astype(np.int32)
            labels[anchor_min_area == float("inf")] = 0
            offsets = self.point_coder.encode(all_level_anchors,
                                              gt_boxes_matched[:, :4])

            left_right = offsets[:, [0, 2]]
            top_bottom = offsets[:, [1, 3]]
            ctrness = F.sqrt(
                F.maximum(
                    F.min(left_right, axis=1) / F.max(left_right, axis=1), 0) *
                F.maximum(
                    F.min(top_bottom, axis=1) / F.max(top_bottom, axis=1), 0))

            labels_list.append(labels)
            offsets_list.append(offsets)
            ctrness_list.append(ctrness)

        return (
            F.stack(labels_list, axis=0).detach(),
            F.stack(offsets_list, axis=0).detach(),
            F.stack(ctrness_list, axis=0).detach(),
        )
Esempio n. 10
0
 def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None):
     # stride: 64,32,16,8,4 -> 4, 8, 16, 32
     fpn_fms = fpn_fms[1:][::-1]
     stride = [4, 8, 16, 32]
     pool_features, rcnn_rois, labels, bbox_targets = roi_pool(
             fpn_fms, rcnn_rois, stride, (7, 7), 'roi_align',
             labels, bbox_targets)
     flatten_feature = F.flatten(pool_features, start_axis=1)
     roi_feature = F.relu(self.fc1(flatten_feature))
     roi_feature = F.relu(self.fc2(roi_feature))
     pred_emd_pred_cls_0 = self.emd_pred_cls_0(roi_feature)
     pred_emd_pred_delta_0 = self.emd_pred_delta_0(roi_feature)
     pred_emd_pred_cls_1 = self.emd_pred_cls_1(roi_feature)
     pred_emd_pred_delta_1 = self.emd_pred_delta_1(roi_feature)
     pred_emd_scores_0 = F.softmax(pred_emd_pred_cls_0)
     pred_emd_scores_1 = F.softmax(pred_emd_pred_cls_1)
     # make refine feature
     box_0 = F.concat((pred_emd_pred_delta_0,
         pred_emd_scores_0[:, 1][:, None]), axis=1)[:, None, :]
     box_1 = F.concat((pred_emd_pred_delta_1,
         pred_emd_scores_1[:, 1][:, None]), axis=1)[:, None, :]
     boxes_feature_0 = box_0.broadcast(
             box_0.shapeof()[0], 4, box_0.shapeof()[-1]).reshape(box_0.shapeof()[0], -1)
     boxes_feature_1 = box_1.broadcast(
             box_1.shapeof()[0], 4, box_1.shapeof()[-1]).reshape(box_1.shapeof()[0], -1)
     boxes_feature_0 = F.concat((roi_feature, boxes_feature_0), axis=1)
     boxes_feature_1 = F.concat((roi_feature, boxes_feature_1), axis=1)
     refine_feature_0 = F.relu(self.fc3(boxes_feature_0))
     refine_feature_1 = F.relu(self.fc3(boxes_feature_1))
     # refine
     pred_ref_pred_cls_0 = self.ref_pred_cls_0(refine_feature_0)
     pred_ref_pred_delta_0 = self.ref_pred_delta_0(refine_feature_0)
     pred_ref_pred_cls_1 = self.ref_pred_cls_1(refine_feature_1)
     pred_ref_pred_delta_1 = self.ref_pred_delta_1(refine_feature_1)
     if self.training:
         loss0 = emd_loss(
                     pred_emd_pred_delta_0, pred_emd_pred_cls_0,
                     pred_emd_pred_delta_1, pred_emd_pred_cls_1,
                     bbox_targets, labels)
         loss1 = emd_loss(
                     pred_emd_pred_delta_1, pred_emd_pred_cls_1,
                     pred_emd_pred_delta_0, pred_emd_pred_cls_0,
                     bbox_targets, labels)
         loss2 = emd_loss(
                     pred_ref_pred_delta_0, pred_ref_pred_cls_0,
                     pred_ref_pred_delta_1, pred_ref_pred_cls_1,
                     bbox_targets, labels)
         loss3 = emd_loss(
                     pred_ref_pred_delta_1, pred_ref_pred_cls_1,
                     pred_ref_pred_delta_0, pred_ref_pred_cls_0,
                     bbox_targets, labels)
         loss_rcnn = F.concat([loss0, loss1], axis=1)
         loss_ref = F.concat([loss2, loss3], axis=1)
         indices_rcnn = F.argmin(loss_rcnn, axis=1)
         indices_ref = F.argmin(loss_ref, axis=1)
         loss_rcnn = F.indexing_one_hot(loss_rcnn, indices_rcnn, 1)
         loss_ref = F.indexing_one_hot(loss_ref, indices_ref, 1)
         loss_rcnn = loss_rcnn.sum()/loss_rcnn.shapeof()[0]
         loss_ref = loss_ref.sum()/loss_ref.shapeof()[0]
         loss_dict = {}
         loss_dict['loss_rcnn_emd'] = loss_rcnn
         loss_dict['loss_ref_emd'] = loss_ref
         return loss_dict
     else:
         pred_ref_scores_0 = F.softmax(pred_ref_pred_cls_0)
         pred_ref_scores_1 = F.softmax(pred_ref_pred_cls_1)
         pred_bbox_0 = restore_bbox(rcnn_rois[:, 1:5], pred_ref_pred_delta_0, True)
         pred_bbox_1 = restore_bbox(rcnn_rois[:, 1:5], pred_ref_pred_delta_1, True)
         pred_bbox_0 = F.concat([pred_bbox_0, pred_ref_scores_0[:, 1].reshape(-1,1)], axis=1)
         pred_bbox_1 = F.concat([pred_bbox_1, pred_ref_scores_1[:, 1].reshape(-1,1)], axis=1)
         pred_bbox = F.concat((pred_bbox_0, pred_bbox_1), axis=1).reshape(-1,5)
         return pred_bbox
Esempio n. 11
0
 def fwd(src, index):
     return F.indexing_one_hot(src, index)
Esempio n. 12
0
    def get_ground_truth(self, anchors_list, batched_gt_boxes,
                         batched_num_gts):
        labels_list = []
        offsets_list = []
        ctrness_list = []

        all_level_anchors = F.concat(anchors_list, axis=0)
        for bid in range(batched_gt_boxes.shape[0]):
            gt_boxes = batched_gt_boxes[bid, :batched_num_gts[bid]]

            ious = []
            candidate_idxs = []
            base = 0
            for stride, anchors_i in zip(self.cfg.stride, anchors_list):
                ious.append(
                    layers.get_iou(
                        gt_boxes[:, :4],
                        F.concat([
                            anchors_i - stride * self.cfg.anchor_scale / 2,
                            anchors_i + stride * self.cfg.anchor_scale / 2,
                        ],
                                 axis=1)))
                gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:4]) / 2
                distances = F.sqrt(
                    F.sum((F.expand_dims(gt_centers, axis=1) - anchors_i)**2,
                          axis=2))
                _, topk_idxs = F.topk(distances, self.cfg.anchor_topk)
                candidate_idxs.append(base + topk_idxs)
                base += anchors_i.shape[0]
            ious = F.concat(ious, axis=1)
            candidate_idxs = F.concat(candidate_idxs, axis=1)

            candidate_ious = F.gather(ious, 1, candidate_idxs)
            ious_thr = (F.mean(candidate_ious, axis=1, keepdims=True) +
                        F.std(candidate_ious, axis=1, keepdims=True))
            is_foreground = F.scatter(
                F.zeros(ious.shape), 1, candidate_idxs,
                F.ones(candidate_idxs.shape)).astype(bool) & (ious >= ious_thr)

            is_in_boxes = F.min(self.point_coder.encode(
                all_level_anchors, F.expand_dims(gt_boxes[:, :4], axis=1)),
                                axis=2) > 0

            ious[~is_foreground] = -1
            ious[~is_in_boxes] = -1

            match_indices = F.argmax(ious, axis=0)
            gt_boxes_matched = gt_boxes[match_indices]
            anchor_max_iou = F.indexing_one_hot(ious, match_indices, axis=0)

            labels = gt_boxes_matched[:, 4].astype(np.int32)
            labels[anchor_max_iou == -1] = 0
            offsets = self.point_coder.encode(all_level_anchors,
                                              gt_boxes_matched[:, :4])

            left_right = offsets[:, [0, 2]]
            top_bottom = offsets[:, [1, 3]]
            ctrness = F.sqrt(
                F.clip(F.min(left_right, axis=1) / F.max(left_right, axis=1),
                       lower=0) *
                F.clip(F.min(top_bottom, axis=1) / F.max(top_bottom, axis=1),
                       lower=0))

            labels_list.append(labels)
            offsets_list.append(offsets)
            ctrness_list.append(ctrness)

        return (
            F.stack(labels_list, axis=0).detach(),
            F.stack(offsets_list, axis=0).detach(),
            F.stack(ctrness_list, axis=0).detach(),
        )
Esempio n. 13
0
    def get_losses(self, anchors, pred_logits, pred_offsets, gt_boxes,
                   im_info):
        # pylint: disable=too-many-statements
        def positive_bag_loss(logits, axis=1):
            weight = 1.0 / (1.0 - logits)
            weight /= weight.sum(axis=axis, keepdims=True)
            bag_prob = (weight * logits).sum(axis=1)
            return -layers.safelog(bag_prob)

        def negative_bag_loss(logits, gamma):
            return (logits**gamma) * (-layers.safelog(1.0 - logits))

        pred_scores = F.sigmoid(pred_logits)
        box_prob_list = []
        positive_losses = []
        clamp_eps = 1e-7
        bucket_size = self.cfg.bucket_size

        for bid in range(im_info.shape[0]):
            boxes_info = gt_boxes[bid, :im_info[bid, 4].astype("int32")]
            # id 0 is used for background classes, so -1 first
            labels = boxes_info[:, 4].astype("int32") - 1

            pred_box = self.box_coder.decode(anchors,
                                             pred_offsets[bid]).detach()
            overlaps = layers.get_iou(boxes_info[:, :4], pred_box).detach()
            thresh1 = self.cfg.box_iou_threshold
            thresh2 = F.clip(overlaps.max(axis=1, keepdims=True),
                             lower=thresh1 + clamp_eps,
                             upper=1.0)
            gt_pred_prob = F.clip((overlaps - thresh1) / (thresh2 - thresh1),
                                  lower=0,
                                  upper=1.0)

            image_boxes_prob = F.zeros(pred_logits.shape[1:]).detach()
            # guarantee that nonzero_idx is not empty
            if gt_pred_prob.max() > clamp_eps:
                _, nonzero_idx = F.cond_take(gt_pred_prob != 0, gt_pred_prob)
                # since nonzeros is only 1 dim, use num_anchor to get real indices
                num_anchors = gt_pred_prob.shape[1]
                anchors_idx = nonzero_idx % num_anchors
                gt_idx = nonzero_idx // num_anchors
                image_boxes_prob[anchors_idx,
                                 labels[gt_idx]] = gt_pred_prob[gt_idx,
                                                                anchors_idx]

            box_prob_list.append(image_boxes_prob)

            # construct bags for objects
            match_quality_matrix = layers.get_iou(boxes_info[:, :4],
                                                  anchors).detach()
            num_gt = match_quality_matrix.shape[0]
            _, matched_idx = F.topk(
                match_quality_matrix,
                k=bucket_size,
                descending=True,
                no_sort=True,
            )

            matched_idx = matched_idx.detach()
            matched_idx_flatten = matched_idx.reshape(-1)
            gather_idx = labels.reshape(-1, 1)
            gather_idx = F.broadcast_to(gather_idx, (num_gt, bucket_size))

            gather_src = pred_scores[bid, matched_idx_flatten]
            gather_src = gather_src.reshape(num_gt, bucket_size, -1)
            matched_score = F.indexing_one_hot(gather_src, gather_idx, axis=2)

            topk_anchors = anchors[matched_idx_flatten]
            boxes_broad_cast = F.broadcast_to(
                F.expand_dims(boxes_info[:, :4], axis=1),
                (num_gt, bucket_size, 4)).reshape(-1, 4)

            matched_offsets = self.box_coder.encode(topk_anchors,
                                                    boxes_broad_cast)

            reg_loss = layers.smooth_l1_loss(
                pred_offsets[bid, matched_idx_flatten],
                matched_offsets,
                beta=self.cfg.smooth_l1_beta).sum(
                    axis=-1) * self.cfg.reg_loss_weight
            matched_reg_scores = F.exp(-reg_loss)

            positive_losses.append(
                positive_bag_loss(matched_score *
                                  matched_reg_scores.reshape(-1, bucket_size),
                                  axis=1))

        num_foreground = im_info[:, 4].sum()
        pos_loss = F.concat(positive_losses).sum() / F.maximum(
            1.0, num_foreground)
        box_probs = F.stack(box_prob_list, axis=0)

        neg_loss = negative_bag_loss(
            pred_scores *
            (1 - box_probs), self.cfg.focal_loss_gamma).sum() / F.maximum(
                1.0, num_foreground * bucket_size)

        alpha = self.cfg.focal_loss_alpha
        pos_loss = pos_loss * alpha
        neg_loss = neg_loss * (1 - alpha)
        loss_dict = {
            "total_loss": pos_loss + neg_loss,
            "pos_loss": pos_loss,
            "neg_loss": neg_loss,
        }
        return loss_dict