Exemplo n.º 1
0
 def test_additional_and_batch(self):
     ious = intersection_over_union(self.t12_bboxes1,
                                    self.t12_bboxes2,
                                    box_format="corners")
     all_true = torch.all(
         torch.abs(self.t12_correct_ious - ious.squeeze(1)) < self.epsilon)
     self.assertTrue(all_true)
Exemplo n.º 2
0
def non_max_suppression(bboxes,
                        iou_threshold,
                        threshold,
                        box_format='corners'):
    """

    :param bboxes: list, [class, prob, x1, y1, x2, y2]
    :param iou_threshold:
    :param threshold:
    :param box_format:
    :return:
    """
    assert type(bboxes) == list
    bboxes = [box for box in bboxes if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)

        bboxes = [
            box for box in bboxes if box[0] != chosen_box[0]
            or intersection_over_union(torch.tensor(chosen_box[2:]),
                                       torch.tensor(box[2:]),
                                       box_format=box_format) < iou_threshold
        ]
        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms
Exemplo n.º 3
0
def nms(bboxes, iou_threshold, threshold, box_format="corners"):
    """
    Does Non Max Suppression given bboxes
    Parameters:
        bboxes (list): list of lists containing all bboxes with each bboxes
        specified as [class_pred, prob_score, x1, y1, x2, y2]
        iou_threshold (float): threshold where predicted bboxes is correct
        threshold (float): threshold to remove predicted bboxes (independent of IoU) 
        box_format (str): "midpoint" or "corners" used to specify bboxes
    Returns:
        list: bboxes after performing NMS given a specific IoU threshold
    """

    assert type(bboxes) == list

    bboxes = [box for box in bboxes if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)

        bboxes = [
            box for box in bboxes
            if box[0] != chosen_box[0] or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format,
            ) < iou_threshold
        ]

        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms
def non_max_suppression(bboxes,
                        iou_threshold,
                        threshold,
                        box_format="corners"):
    """
    Calucurate Non Max Supperssion
    
    :param bboxes:
    :param iou_threshold:
    :param threshold:
    :param box_format:
    :return:
    """
    assert type(bboxes) == list

    bboxes = [box for box in bboxes if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)

        bboxes = [
            box for box in bboxes
            if box[0] != chosen_box[0] or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format,
            ) < iou_threshold
        ]

        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms
Exemplo n.º 5
0
def non_max_supression(predicted_boxes, iou_threshold, prob_threshold):
    """Perform Non Max Supression over the given bounding boxes

    Args:
        predicted_boxes (list): list of all the predicted bounding boxes described as [class_pred, prob_score, x1, y1, x2, y2]
        iou_threshold (float): threshold for when a predicted bounding box is correct
        prob_threshold (float): threshold to remove predicted bounding boxes with low probabilities

    Returns:
        list: bounding boxes after performing non max supression
    """

    assert type(predicted_boxes) == list

    # Remove the prediction boxes that have a low probability of detecting
    # an object
    predicted_boxes = [
        box for box in predicted_boxes if box[1] > prob_threshold]

    predicted_boxes = sorted(
        predicted_boxes, key=lambda x: x[1], reverse=True)

    bounding_boxes_nms = []
    while predicted_boxes:
        highest_prob_pred_box = predicted_boxes.pop(0)

        # Remove the bounding boxes that have the same class as the highest probability bounding box
        # and that have a high IoU with that box
        predicted_boxes = [
            box for box in predicted_boxes if box[0] != highest_prob_pred_box[0]
            or intersection_over_union(
                tf.constant(highest_prob_pred_box[2:], dtype=tf.float32),
                tf.constant(box[2:], dtype=tf.float32)
            ) < iou_threshold
        ]

        bounding_boxes_nms.append(highest_prob_pred_box)

    # Return the list of the filtred bounding boxes
    return bounding_boxes_nms
Exemplo n.º 6
0
def mean_average_precision(pred_boxes,
                           true_boxes,
                           iou_threshold=0.5,
                           box_format="midpoint",
                           num_classes=20):
    """
    Calculates mean average precision
    Parameters:
        pred_boxes (list): list of lists containing all bboxes with each bboxes
        specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
        true_boxes (list): Similar as pred_boxes except all the correct ones
        iou_threshold (float): threshold where predicted bboxes is correct
        box_format (str): "midpoint" or "corners" used to specify bboxes
        num_classes (int): number of classes
    Returns:
        float: mAP value across all classes given a specific IoU threshold
    """

    # list storing all AP for respective classes
    average_precisions = []

    # used for numerical stability later on
    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []

        # Go through all predictions and targets,
        # and only add the ones that belong to the
        # current class c
        for detection in pred_boxes:
            if detection[1] == c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        # find the amount of bboxes for each training example
        # Counter here finds how many ground truth bboxes we get
        # for each training example, so let's say img 0 has 3,
        # img 1 has 5 then we will obtain a dictionary with:
        # amount_bboxes = {0:3, 1:5}
        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        # We then go through each key, val in this dictionary
        # and convert to the following (w.r.t same example):
        # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        #sort by box probalilities which is index 2
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        #下面这个循环,用来计算每个预测框是TP,还是FP
        for detection_idx, detection in enumerate(detections):
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[
                    0]  # 这里bbox[0] 和  detection[0]都是train_idx,也就是imgae
            ]

            num_gts = len(ground_truth_img
                          )  # number of ground_truth boundingbox in this image
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(torch.tensor(detection[3:]),
                                              torch.tensor(gt[3:]),
                                              box_format=box_format)

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                if amount_bboxes[detection[0]][
                        best_gt_idx] == 0:  # detection[0] 是trainning index
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][
                        best_gt_idx] = 1  #表示这个ground truth bounding box已经考虑过了,后面不需要再考虑了
                else:
                    FP[detection_idx] = 1
            else:
                FP[detection_idx] = 1

        # [1, 1, 0, 1, 0] -> [1, 2, 2, 3, 3]
        TP_cumsum = torch.cumsum(TP, dim=0)

        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + epsilon))

        #step4, plot the precision-recall graph
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))

        #step5: calculate area under pr curve
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)
Exemplo n.º 7
0
 def test_both_inside_cell_shares_entire_area(self):
     iou = intersection_over_union(self.t5_box1, self.t5_box2, box_format="midpoint")
     self.assertTrue((torch.abs(iou - self.t5_correct_iou) < self.epsilon))
Exemplo n.º 8
0
 def test_partially_outside_cell_shares_area(self):
     iou = intersection_over_union(self.t2_box1, self.t2_box2, box_format="midpoint")
     self.assertTrue((torch.abs(iou - self.t2_correct_iou) < self.epsilon))
Exemplo n.º 9
0
 def test_box_format_x1_y1_x2_y2(self):
     iou = intersection_over_union(self.t6_box1, self.t6_box2, box_format="corners")
     self.assertTrue((torch.abs(iou - self.t6_correct_iou) < self.epsilon))
def mean_average_precision(pred_boxes,
                           true_boxes,
                           iou_threshold=0.5,
                           box_format="corners",
                           num_classes=20):
    """

    :param pred_boxes: [[train_idx, class_pred, prob_score, x1, y1, x2, y2], ...]
    :param true_boxes:
    :param iou_threshold:
    :param box_format:
    :param num_classes:
    :return:
    """
    average_precisions = []
    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []

        for detection in pred_boxes:
            if detection[1] == c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)
        # amount_boxes = {0: torch.tensor([0, 0, 0]), 1: torch.tensor([0, 0, 0, 0, 0])}

        detection.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_ture_bboxes = len(ground_truths)

        for detection_idx, detection in enumerate(detections):
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                # The first three in for
                # train_idx, class_pred, and prob_score.
                iou = intersection_over_union(torch.tensor(detection[3:]),
                                              torch.tensor(gt[3:]),
                                              box_format=box_format)

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_ture_bboxes + epsilon)
        precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + epsilon))
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)
Exemplo n.º 11
0
def evaluate(pred_folder,
             data_folder,
             dataset,
             output_folder,
             small_objects_size=20):

    #thresholds = [0.9, 0.95, 0.99]
    thresholds = [0.9]

    hf = h5py.File(os.path.join(data_folder, dataset + '.hdf'), 'r')
    ds_samples = list(hf.keys())
    pred_samples = glob.glob(os.path.join(pred_folder, "*.zarr"))
    print(pred_samples)
    print('output folder: ', output_folder)

    cf = open(os.path.join(output_folder, 'iou.csv'), 'w')
    writer = csv.writer(cf,
                        delimiter=' ',
                        quotechar='|',
                        quoting=csv.QUOTE_MINIMAL)
    writer.writerow([
        'sample',
    ] + thresholds)
    iou_results = {}

    for sample in pred_samples:

        name = os.path.basename(sample)[:-5]

        if name not in ds_samples:
            print(name, " not in gt dataset! Skipping...")
            continue

        gt = np.asarray(hf[name + '/fg'], dtype=np.uint8)
        raw = np.asarray(hf[name + '/raw'])
        print('raw shape: ', raw.shape, raw.dtype)
        raw = (np.clip(raw, 0, 1000) / 1000. * 255).astype(np.uint8)
        raw_mip = np.moveaxis(np.max(raw, axis=1), 0, 2)

        zf = zarr.open(sample)
        pred = np.asarray(zf['volumes/pred_mask'])

        mip = np.max(gt, axis=0)
        mip[mip > 0] = 255
        io.imsave(os.path.join(output_folder, name + '_gt.png'), mip)
        io.imsave(os.path.join(output_folder, name + '_raw.png'), raw_mip)

        results = []

        for thresh in thresholds:

            print('sample: ', name)

            mask = (pred >= thresh).astype(np.uint8)
            mask = morphology.remove_small_objects(measure.label(
                mask, background=0, connectivity=1),
                                                   min_size=small_objects_size,
                                                   connectivity=1)
            mask = (mask > 0).astype(np.uint8)

            #result = voi.voi(mask, gt)
            #print('threshold: ', thresh, ', result: ', result)

            #result = rand.adapted_rand(mask, gt)
            #print('adapted rand for threshold: ', thresh, ', result: ', result)

            result = iou.intersection_over_union(np.expand_dims(mask, -1), gt)
            print('iou for threshold: ', thresh, ', result: ', result)
            results.append(result[1])

            mip = np.max(mask, axis=0)
            idx = mip > 0
            mip[idx] = 255
            io.imsave(
                os.path.join(output_folder,
                             name + '_mask_' + str(thresh) + '.png'), mip)

            #mip = np.zeros_like(raw_mip, dtype=np.uint8)

            raw_mip[idx] = (0.7 * np.array([139, 0, 128]) +
                            0.3 * raw_mip[idx]).astype(np.uint8)
            #raw_mip = (0.5 * raw_mip + 0.5 * mip).astype(np.uint8)
            io.imsave(
                os.path.join(output_folder,
                             name + '_overlay_' + str(thresh) + '.png'),
                raw_mip)

        iou_results[name] = results
        print(len(iou_results))

    avg = None
    for k, v in iou_results.items():
        writer.writerow([
            k,
        ] + v)
        if avg is None:
            avg = np.asarray(v)
        else:
            avg += np.asarray(v)

    avg /= len(iou_results)
    writer.writerow([
        'Average',
    ] + list(avg))
Exemplo n.º 12
0
def mean_average_precision2(pred_boxes,
                            true_boxes,
                            iou_threshold=0.5,
                            box_format='corner',
                            num_classes=20):
    # train_idx = 각 이미지 index(해당 bbox가 어느 image에 속해있는지)
    # pred boxes (list) : [[train_idx, class, confidence, boxes(4)], ...]
    # true boxes (list) : [[train_idx, class, boxes(4)], ...]

    # Average Precision for each class
    average_precisions = []
    epsilon = 1e-6

    # 각 class별로 AP 계산
    for c in range(num_classes):
        detections = []
        ground_truths = []

        # 예측 bbox 중 class가 c인 bbox
        for detection in pred_boxes:
            if detection[1] == c:
                detections.append(detection)

        # 정답 bbox 중 class가 c인 bbox
        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        # 각 이미지별(train_idx) 정답 bbox의 개수
        # 예를 들어 img0에 3개의 정답 bbox, img1에 5개의 정답 bbox
        # ammount_bboxes = {0: 3, 1: 5}
        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        for key, val in amount_bboxes.items():
            # {0:torch.Tensor([0,0,0]), 1:torch.Tensor([0,0,0,0,0])}
            amount_bboxes[key] = torch.zeros(val)

        # 예측 bbox들을 confidence를 기준으로 내림차순 정렬
        detections.sort(key=lambda x: x[2], reversed=True)
        TP = torch.zeros(len(detections))  # 각각의 예측 bbox들이 TP인지 아닌지(0 or 1)
        FP = torch.zeros(len(detections))  # 각각의 예측 bbox들이 FP인지 아닌지(0 or 1)
        total_true_bboxes = len(ground_truths)  # 정답 bbox들의 개수

        # IoU를 통해 각각의 예측 bbox에 대해 대응되는 정답 bbox 찾기
        for detection_idx, detection in enumerate(detections):

            # 1개의 예측 bbox에 대해 같은 image에 있는 정답 bbox들 모두 찾기
            # 같은 image에 있는 예측 bbox와 정답 bbox의 쌍은 알지만
            # 각각의 bbox가 서로 대응되는지 모르기 때문에 이런 방법으로 접근
            # 해당 예측 bbox가 속해있는 image에 있는 정답 bbox들을 모두 찾겠지?
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                # 1개의 예측 bbox와 같은 image에 있는 모든 정답 bbox들의 IoU 계산
                iou = intersection_over_union(
                    torch.tensor(detection[3:]),
                    torch.tensor(gt[2:]),
                    box_format=box_format,
                )

                # 예측 bbox와 IoU가 가장 큰 정답 bbox
                if iou > best_iou:
                    best_iou = iou  # 가장 큰 IoU값
                    best_gt_idx = idx  # 예측 bbox와 매칭되는 정답 bbox의 idx

            # image에 있는 정답 bbox 중 예측 bbox와의 IoU값이 Threshold를 넘겼을 때
            if best_iou > iou_threshold:
                # 만약 본 적 없는 정답 bbox일 경우
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    TP[detection_idx] = 1  # Threshold를 넘었기에 해당 예측 bbox는 Ture Positive
                    amount_bboxes[
                        detection[0]][best_gt_idx] = 1  # 해당 정답 bbox는 본 것으로 처리

                # 만약 본 적 있는 정답 bbox일 경우
                else:
                    # 이미 해당 예측 bbox는 정답 bbox와 짝을 이루었기에
                    # 이미 짝을 이룬 정답 bbox와의 IoU가 최대인 예측 bbox는 False Positive
                    FP[detection_idx] = 1

            # 만약 예측 bbox가 특정 IoU가 넘는 정답 bbox가 없을 때
            else:
                FP[detection_idx] = 1

        # TP = [1, 1, 0, 1, 0] -> [1, 2, 2, 3, 3]
        # 이것은 마치 confidence순으로 정렬된 예측 bbox에 대해
        # 점점 예측 bbox들을 추가하면서 precision과 Recall을 구해주는 작업과 동일
        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)

        # Recall 계산
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        recalls = torch.cat(
            (torch.tensor([0]), recalls))  # Recall 시작은 x=0, 따라서 처음 0값 추가

        # Precision 계산
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
        precisions = torch.cat(
            (torch.tensor([1]), precisions))  # Precision 시작은 y=1, 따라서 처음 1값 추가

        # y value와 x value를 이용해 그래프 넓이 계산
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)
Exemplo n.º 13
0
def mean_average_precision(pred_bboxes,
                           true_bboxes,
                           iou_threshold=0.5,
                           box_format='corner',
                           num_classes=91):
    AP_per_class = torch.zeros((num_classes))
    epsilon = 1e-6

    for cls in range(num_classes):
        cls_mask_pred = pred_bboxes[:, 1] == cls
        cls_mask_true = true_bboxes[:, 1] == cls

        detections = pred_bboxes[cls_mask_pred, :]
        ground_truths = true_bboxes[cls_mask_true, :]

        # bboxes_per_image = {0:tensor(0,0,0), 1:tensor(0,0,0,0), ...}
        objects_per_image = ground_truths[:, 0].long().bincount()
        bboxes_per_image = {
            k: torch.zeros(v)
            for k, v in enumerate(objects_per_image) if v > 0
        }

        detections = sorted(detections, key=lambda x: x[2], reverse=True)
        TP = torch.zeros(len(detections))
        FP = torch.zeros(len(detections))
        total_gt_bboxes = len(ground_truths)

        for detection_idx, detection in enumerate(detections):
            gt_mask = ground_truths[:, 0] == detection[0]
            gt_bboxes = ground_truths[gt_mask, :]

            best_iou = 0
            for idx, gt_bbox in enumerate(gt_bboxes):
                iou = intersection_over_union(detection[3:],
                                              gt_bbox[2:],
                                              box_format=box_format)

                if iou > best_iou:
                    best_iou = iou
                    best_gt_bbox_idx = idx

            if best_iou >= iou_threshold:
                if bboxes_per_image[
                        detection[0].item()][best_gt_bbox_idx] == 0:
                    TP[detection_idx] = 1
                    bboxes_per_image[detection[0].item()][best_gt_bbox_idx] = 1
                else:
                    FP[detection_idx] = 1
            else:
                FP[detection_idx] = 1

        TP_cumsum = TP.cumsum(dim=0)
        FP_cumsum = FP.cumsum(dim=0)

        # Recall
        recalls = TP_cumsum / (total_gt_bboxes + epsilon)
        recalls = torch.cat([torch.tensor([0]), recalls])

        # Precision
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
        precisions = torch.cat([torch.tensor([1]), precisions])

        # AP
        ap = torch.trapz(precisions, recalls)
        AP_per_class[cls] = ap

    return sum(AP_per_class) / len(AP_per_class), AP_per_class
def mean_avg_precision(predicted_boxes, true_boxes, iou_threshold=0.5, num_classes=20):
    """Calculate mean average precision

    Args:
        predicted_boxes (list): list of all the bounding boxes described as [train_idx, class_pred, prob_score, x1, y1, x2, y2]
        true_boxes (list): list of all the correct bounding boxes, similar to the prediction boxes
        iou_threshold (float, optional): threshold where the predicted bounding boxes is correct. Defaults to 0.5.
        num_classes (int, optional): number of classes. Defaults to 20.

    Returns:
        float: mean average precision across all classes for a specific IoU threshold
    """
    epsilon = 1e-6
    average_precisions = []
    ground_truths = []

    for c in range(num_classes):
        detections = []

        for prediction in predicted_boxes:
            if prediction[1] == c:
                detections.append(prediction)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        num_bounding_boxes = Counter([ground_truth[0]
                                      for ground_truth in ground_truths])

        for key, value in num_bounding_boxes.items():
            num_bounding_boxes[key] = np.zeros(value)

        detections.sort(key=lambda x: x[2], reverse=True)
        TP = np.zeros((len(detections)))
        FP = np.zeros((len(detections)))
        total_true_bounding_boxes = len(ground_truths)

        for detection_idx, detection in enumerate(detections):
            ground_truth_img = [
                bounding_box for bounding_box in ground_truths if bounding_box[0] == detection[0]]

            num_ground_truths = len(ground_truth_img)
            best_iou = 0

            for idx, ground_truth in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    tf.constant(detection[3:], dtype=tf.float32), tf.constant(ground_truth[3:], dtype=tf.float32))

                if iou > best_iou:
                    best_iou = iou
                    best_ground_truth_idx = idx

            if best_iou > iou_threshold:
                if num_bounding_boxes[detection[0]][best_ground_truth_idx] == 0:
                    TP[detection_idx] = 1
                    num_bounding_boxes[detection[0]][best_ground_truth_idx] = 1
                else:
                    FP[detection_idx] = 1
            else:
                FP[detection_idx] = 1

        TP_cumulative_sum = tf.cumsum(TP, axis=0)
        FP_cumulative_sum = tf.cumsum(FP, axis=0)
        recalls = TP_cumulative_sum / (total_true_bounding_boxes + epsilon)
        precisions = tf.divide(
            TP_cumulative_sum, (TP_cumulative_sum + FP_cumulative_sum + epsilon))
        precisions = tf.concat(
            (tf.constant([1], dtype=tf.float64), precisions), axis=0)
        recalls = tf.concat(
            (tf.constant([0], dtype=tf.float64), recalls), axis=0)
        average_precisions.append(np.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)
def mean_average_precision(pred_boxes,
                           true_boxes,
                           iou_threshold=0.5,
                           box_format="midpoint",
                           num_classes=20):
    """
    Calculates mean average precision 
    Parameters:
        pred_boxes (list): list of lists containing all bboxes with each bboxes
        specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
        true_boxes (list): Similar as pred_boxes except all the correct ones 
        iou_threshold (float): threshold where predicted bboxes is correct
        box_format (str): "midpoint" or "corners" used to specify bboxes
        num_classes (int): number of classes
    Returns:
        float: mAP value across all classes given a specific IoU threshold 
    """

    # list storing all AP for respective classes
    average_precisions = []

    # used for numerical stability later on
    epsilon = 1e-6

    for c in range(1, num_classes + 1):
        detections = []
        ground_truths = []

        # Go through all predictions and targets,
        # and only add the ones that belong to the
        # current class c
        for detection in pred_boxes:
            if detection[1] == c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        # find the amount of bboxes for each training example
        # Counter here finds how many ground truth bboxes we get
        # for each training example, so let's say img 0 has 3,
        # img 1 has 5 then we will obtain a dictionary with:
        # amount_bboxes = {0:3, 1:5}
        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        # We then go through each key, val in this dictionary
        # and convert to the following (w.r.t same example):
        # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        # sort by box probabilities which is index 2
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        # If none exists for this class then we can safely skip
        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            # Only take out the ground_truths that have the same
            # training idx as detection
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    torch.tensor(detection[3:]).float(),
                    torch.tensor(gt[3:]).float(),
                    box_format=box_format,
                )

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                # only detect ground truth detection once
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    # true positive and add this bounding box to seen
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1

            # if IOU is lower then the detection is a false positive
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))
        # torch.trapz for numerical integration
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)