def update(self, mini_batch):
        """Update records per mini batch

        Args:
            mini_batch (list(list)): a list which contains batch_size of
            gt bboxes and pred bboxes pair in each image.
            For example, if batch size = 2, mini_batch looks like:
            [[gt_bboxes1, pred_bboxes1], [gt_bboxes2, pred_bboxes2]]
            where gt_bboxes1, pred_bboxes1 contain gt bboxes and pred bboxes
            in one image
        """
        for bboxes in mini_batch:
            gt_bboxes, pred_bboxes = bboxes

            pred_bboxes = sorted(
                pred_bboxes, key=lambda bbox: bbox.score, reverse=True
            )
            if len(pred_bboxes) > self._max_detections:
                pred_bboxes = pred_bboxes[: self._max_detections]

            bboxes_per_label = group_bbox2d_per_label(pred_bboxes)
            for label, boxes in bboxes_per_label.items():
                self._label_records[label].add_records(gt_bboxes, boxes)

            for gt_bbox in gt_bboxes:
                self._gt_bboxes_count[gt_bbox.label] += 1
def test_group_bbox2d_per_label():
    count1, count2 = 10, 11
    bbox1 = BBox2D(label="car", x=1, y=1, w=2, h=3)
    bbox2 = BBox2D(label="pedestrian", x=7, y=6, w=3, h=4)
    bboxes = []
    bboxes.extend([bbox1] * count1)
    bboxes.extend([bbox2] * count2)
    bboxes_per_label = group_bbox2d_per_label(bboxes)
    assert len(bboxes_per_label["car"]) == count1
    assert len(bboxes_per_label["pedestrian"]) == count2