Exemple #1
0
def eval(model_path, min_Iou=0.5, yolo_weights=None):
    """
    Introduction
    ------------
        计算模型在coco验证集上的MAP, 用于评价模型
    """
    ground_truth = {}
    class_pred = defaultdict(list)
    gt_counter_per_class = defaultdict(int)
    input_image_shape = tf.placeholder(dtype=tf.int32, shape=(2, ))
    input_image = tf.placeholder(shape=[None, 416, 416, 3], dtype=tf.float32)
    predictor = yolo_predictor(config.obj_threshold, config.nms_threshold,
                               config.classes_path, config.anchors_path)
    boxes, scores, classes = predictor.predict(input_image, input_image_shape)
    val_Reader = Reader("val",
                        config.data_dir,
                        config.anchors_path,
                        config.num_classes,
                        input_shape=config.input_shape,
                        max_boxes=config.max_boxes)
    image_files, bboxes_data = val_Reader.read_annotations()
    allBBox = 0
    with tf.Session() as sess:
        if yolo_weights is not None:
            with tf.variable_scope('predict'):
                boxes, scores, classes = predictor.predict(
                    input_image, input_image_shape)
            load_op = load_weights(tf.global_variables(scope='predict'),
                                   weights_file=yolo_weights)
            sess.run(load_op)
        else:
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(model_path)
            #saver.restore(sess, model_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        for index in range(len(image_files)):
            val_bboxes = []
            image_file = image_files[index]
            file_id = os.path.split(image_file)[-1].split('.')[0]
            for bbox in bboxes_data[index]:
                left, top, right, bottom, class_id = bbox[0], bbox[1], bbox[
                    2], bbox[3], bbox[4]
                class_name = val_Reader.class_names[int(class_id)]
                bbox = [float(left), float(top), float(right), float(bottom)]
                val_bboxes.append({
                    "class_name": class_name,
                    "bbox": bbox,
                    "used": False
                })
                gt_counter_per_class[class_name] += 1
            ground_truth[file_id] = val_bboxes
            image = Image.open(image_file)
            resize_image = letterbox_image(image, (416, 416))
            image_data = np.array(resize_image, dtype=np.float32)
            image_data /= 255.
            image_data = np.expand_dims(image_data, axis=0)

            out_boxes, out_scores, out_classes = sess.run(
                [boxes, scores, classes],
                feed_dict={
                    input_image: image_data,
                    input_image_shape: [image.size[1], image.size[0]]
                })
            allBBox += len(out_boxes)
            print("detect {}/{} found boxes: {},allBBox:{}".format(
                index, len(image_files), len(out_boxes), allBBox))
            for o, c in enumerate(out_classes):
                predicted_class = val_Reader.class_names[c]
                box = out_boxes[o]
                score = out_scores[o]

                top, left, bottom, right = box
                top = max(0, np.floor(top + 0.5).astype('int32'))
                left = max(0, np.floor(left + 0.5).astype('int32'))
                bottom = min(image.size[1],
                             np.floor(bottom + 0.5).astype('int32'))
                right = min(image.size[0],
                            np.floor(right + 0.5).astype('int32'))

                bbox = [left, top, right, bottom]
                class_pred[predicted_class].append({
                    "confidence": str(score),
                    "file_id": file_id,
                    "bbox": bbox
                })

    # 计算每个类别的AP
    sum_AP = 0.0
    sum_rec = 0.0
    sum_prec = 0.0
    count_true_positives = {}
    for class_index, class_name in enumerate(
            sorted(gt_counter_per_class.keys())):
        count_true_positives[class_name] = 0
        predictions_data = class_pred[class_name]
        # 该类别总共有多少个box
        nd = len(predictions_data)
        tp = [0] * nd  # true positive
        fp = [0] * nd  # false positive
        for idx, prediction in enumerate(predictions_data):
            file_id = prediction['file_id']
            ground_truth_data = ground_truth[file_id]
            bbox_pred = prediction['bbox']
            Iou_max = -1
            gt_match = None
            for obj in ground_truth_data:
                if obj['class_name'] == class_name:
                    bbox_gt = obj['bbox']
                    bbox_intersect = [
                        max(bbox_pred[0], bbox_gt[0]),
                        max(bbox_gt[1], bbox_pred[1]),
                        min(bbox_gt[2], bbox_pred[2]),
                        min(bbox_gt[3], bbox_pred[3])
                    ]
                    intersect_weight = bbox_intersect[2] - bbox_intersect[0] + 1
                    intersect_high = bbox_intersect[3] - bbox_intersect[1] + 1
                    if intersect_high > 0 and intersect_weight > 0:
                        union_area = (bbox_pred[2] - bbox_pred[0] + 1) * (
                            bbox_pred[3] - bbox_pred[1] +
                            1) + (bbox_gt[2] - bbox_gt[0] +
                                  1) * (bbox_gt[3] - bbox_gt[1] +
                                        1) - intersect_weight * intersect_high
                        Iou = intersect_high * intersect_weight / union_area
                        if Iou > Iou_max:
                            Iou_max = Iou
                            gt_match = obj
            if Iou_max > min_Iou:
                if not gt_match['used'] and gt_match is not None:
                    tp[idx] = 1
                    gt_match['used'] = True
                else:
                    fp[idx] = 1
            else:
                fp[idx] = 1
        # 计算精度和召回率
        sum_class = 0
        for idx, val in enumerate(fp):
            fp[idx] += sum_class
            sum_class += val
        sum_class = 0
        for idx, val in enumerate(tp):
            tp[idx] += sum_class
            sum_class += val
        rec = tp[:]
        for idx, val in enumerate(tp):
            rec[idx] = tp[idx] / gt_counter_per_class[class_name]
        prec = tp[:]
        for idx, val in enumerate(tp):
            prec[idx] = tp[idx] / (fp[idx] + tp[idx])

        ap, mrec, mprec = voc_ap(rec, prec)
        sum_AP += ap
        sum_rec += (mrec[-2])
        sum_prec += sum(mprec) / (allBBox + 2)
        f1 = 2 * sum_rec * sum_prec / (sum_rec + sum_prec)

    MAP = sum_AP / len(gt_counter_per_class) * 100
    #rec = sum_rec / len(gt_counter_per_class) * 100
    #prec = sum_prec / len(gt_counter_per_class) * 100
    print("The Model Eval MAP: {},prec:{},rec:{},f1:{}".format(
        MAP, sum_prec, sum_rec, f1))
Exemple #2
0
            cumsum += val
        cumsum = 0
        for idx, val in enumerate(tp):
            tp[idx] += cumsum
            cumsum += val
        #print(tp)
        rec = tp[:]
        for idx, val in enumerate(tp):
            rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name]
        #print(rec)
        prec = tp[:]
        for idx, val in enumerate(tp):
            prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx])
        #print(prec)

        ap, mrec, mprec = voc_ap(rec[:], prec[:])
        sum_AP += ap
        # text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
        text = class_name + ": (AP)= {0:.2f}%".format(ap * 100)
        """
         Write to results.txt
        """
        rounded_prec = ['%.2f' % elem for elem in prec]
        rounded_rec = ['%.2f' % elem for elem in rec]
        results_file.write(text + "\n Precision: " + str(rounded_prec) +
                           "\n Recall :" + str(rounded_rec) + "\n\n")
        if not args.quiet:
            print(text)
        ap_dictionary[class_name] = ap

        n_images = counter_images_per_class[class_name]
Exemple #3
0
    def eval_map(self, gt_folder_path, pred_folder_path, temp_json_folder_path,
                 output_files_path):
        """Process Gt"""
        ground_truth_files_list = glob(gt_folder_path + '/*.txt')
        assert len(ground_truth_files_list) > 0, 'no ground truth file'
        ground_truth_files_list.sort()
        # dictionary with counter per class
        gt_counter_per_class = {}
        counter_images_per_class = {}

        gt_files = []
        for txt_file in ground_truth_files_list:
            file_id = txt_file.split(".txt", 1)[0]
            file_id = os.path.basename(os.path.normpath(file_id))
            # check if there is a correspondent detection-results file
            temp_path = os.path.join(pred_folder_path, (file_id + ".txt"))
            assert os.path.exists(
                temp_path), "Error. File not found: {}\n".format(temp_path)
            lines_list = read_txt_to_list(txt_file)
            # create ground-truth dictionary
            bounding_boxes = []
            is_difficult = False
            already_seen_classes = []
            for line in lines_list:
                class_name, left, top, right, bottom = line.split()
                # check if class is in the ignore list, if yes skip
                bbox = left + " " + top + " " + right + " " + bottom
                bounding_boxes.append({
                    "class_name": class_name,
                    "bbox": bbox,
                    "used": False
                })
                # count that object
                if class_name in gt_counter_per_class:
                    gt_counter_per_class[class_name] += 1
                else:
                    # if class didn't exist yet
                    gt_counter_per_class[class_name] = 1

                if class_name not in already_seen_classes:
                    if class_name in counter_images_per_class:
                        counter_images_per_class[class_name] += 1
                    else:
                        # if class didn't exist yet
                        counter_images_per_class[class_name] = 1
                    already_seen_classes.append(class_name)

            # dump bounding_boxes into a ".json" file
            new_temp_file = os.path.join(
                temp_json_folder_path, file_id + "_ground_truth.json"
            )  #TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
            gt_files.append(new_temp_file)
            with open(new_temp_file, 'w') as outfile:
                json.dump(bounding_boxes, outfile)

        gt_classes = list(gt_counter_per_class.keys())
        # let's sort the classes alphabetically
        gt_classes = sorted(gt_classes)
        n_classes = len(gt_classes)
        print(gt_classes, gt_counter_per_class)
        """Process prediction"""

        dr_files_list = sorted(glob(os.path.join(pred_folder_path, '*.txt')))

        for class_index, class_name in enumerate(gt_classes):
            bounding_boxes = []
            for txt_file in dr_files_list:
                # the first time it checks if all the corresponding ground-truth files exist
                file_id = txt_file.split(".txt", 1)[0]
                file_id = os.path.basename(os.path.normpath(file_id))
                temp_path = os.path.join(gt_folder_path, (file_id + ".txt"))
                if class_index == 0:
                    if not os.path.exists(temp_path):
                        error_msg = f"Error. File not found: {temp_path}\n"
                        print(error_msg)
                lines = read_txt_to_list(txt_file)
                for line in lines:
                    try:
                        tmp_class_name, confidence, left, top, right, bottom = line.split(
                        )
                    except ValueError:
                        error_msg = f"""Error: File {txt_file} in the wrong format.\n 
                                        Expected: <class_name> <confidence> <left> <top> <right> <bottom>\n 
                                        Received: {line} \n"""
                        print(error_msg)
                    if tmp_class_name == class_name:
                        # print("match")
                        bbox = left + " " + top + " " + right + " " + bottom
                        bounding_boxes.append({
                            "confidence": confidence,
                            "file_id": file_id,
                            "bbox": bbox
                        })
            # sort detection-results by decreasing confidence
            bounding_boxes.sort(key=lambda x: float(x['confidence']),
                                reverse=True)
            with open(temp_json_folder_path + "/" + class_name + "_dr.json",
                      'w') as outfile:
                json.dump(bounding_boxes, outfile)
        """
         Calculate the AP for each class
        """
        sum_AP = 0.0
        ap_dictionary = {}
        # open file to store the output
        with open(output_files_path + "/output.txt", 'w') as output_file:
            output_file.write("# AP and precision/recall per class\n")
            count_true_positives = {}
            for class_index, class_name in enumerate(gt_classes):
                count_true_positives[class_name] = 0
                """
                 Load detection-results of that class
                """
                dr_file = temp_json_folder_path + "/" + class_name + "_dr.json"
                dr_data = json.load(open(dr_file))
                """
                 Assign detection-results to ground-truth objects
                """
                nd = len(dr_data)
                tp = [0] * nd  # creates an array of zeros of size nd
                fp = [0] * nd
                for idx, detection in enumerate(dr_data):
                    file_id = detection["file_id"]
                    gt_file = temp_json_folder_path + "/" + file_id + "_ground_truth.json"
                    ground_truth_data = json.load(open(gt_file))
                    ovmax = -1
                    gt_match = -1
                    # load detected object bounding-box
                    bb = [float(x) for x in detection["bbox"].split()]
                    for obj in ground_truth_data:
                        # look for a class_name match
                        if obj["class_name"] == class_name:
                            bbgt = [float(x) for x in obj["bbox"].split()]
                            bi = [
                                max(bb[0], bbgt[0]),
                                max(bb[1], bbgt[1]),
                                min(bb[2], bbgt[2]),
                                min(bb[3], bbgt[3])
                            ]
                            iw = bi[2] - bi[0] + 1
                            ih = bi[3] - bi[1] + 1
                            if iw > 0 and ih > 0:
                                # compute overlap (IoU) = area of intersection / area of union
                                ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + \
                                     (bbgt[2] - bbgt[0]+ 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
                                ov = iw * ih / ua
                                if ov > ovmax:
                                    ovmax = ov
                                    gt_match = obj

                    min_overlap = 0.5
                    if ovmax >= min_overlap:
                        # if "difficult" not in gt_match:
                        if not bool(gt_match["used"]):
                            # true positive
                            tp[idx] = 1
                            gt_match["used"] = True
                            count_true_positives[class_name] += 1
                            # update the ".json" file
                            with open(gt_file, 'w') as f:
                                f.write(json.dumps(ground_truth_data))
                        else:
                            # false positive (multiple detection)
                            fp[idx] = 1
                    else:
                        fp[idx] = 1

                # compute precision/recall
                cumsum = 0
                for idx, val in enumerate(fp):
                    fp[idx] += cumsum
                    cumsum += val
                print('fp ', cumsum)
                cumsum = 0
                for idx, val in enumerate(tp):
                    tp[idx] += cumsum
                    cumsum += val
                print('tp ', cumsum)
                rec = tp[:]
                for idx, val in enumerate(tp):
                    rec[idx] = float(
                        tp[idx]) / gt_counter_per_class[class_name]
                print('recall ', cumsum)
                prec = tp[:]
                for idx, val in enumerate(tp):
                    prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx])
                print('prec ', cumsum)

                ap, mrec, mprec = voc_ap(rec[:], prec[:])
                sum_AP += ap
                text = "{0:.2f}%".format(
                    ap * 100
                ) + " = " + class_name + " AP "  # class_name + " AP = {0:.2f}%".format(ap*100)

                print(text)
                ap_dictionary[class_name] = ap

                n_images = counter_images_per_class[class_name]
                # lamr, mr, fppi = log_average_miss_rate(np.array(prec), np.array(rec), n_images)
                # lamr_dictionary[class_name] = lamr
                """
                 Draw plot
                """
                if True:
                    plt.plot(rec, prec, '-o')
                    # add a new penultimate point to the list (mrec[-2], 0.0)
                    # since the last line segment (and respective area) do not affect the AP value
                    area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
                    area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
                    plt.fill_between(area_under_curve_x,
                                     0,
                                     area_under_curve_y,
                                     alpha=0.2,
                                     edgecolor='r')
                    # set window title
                    fig = plt.gcf()  # gcf - get current figure
                    fig.canvas.set_window_title('AP ' + class_name)
                    # set plot title
                    plt.title('class: ' + text)
                    # plt.suptitle('This is a somewhat long figure title', fontsize=16)
                    # set axis titles
                    plt.xlabel('Recall')
                    plt.ylabel('Precision')
                    # optional - set axes
                    axes = plt.gca()  # gca - get current axes
                    axes.set_xlim([0.0, 1.0])
                    axes.set_ylim([0.0, 1.05])  # .05 to give some extra space
                    # Alternative option -> wait for button to be pressed
                    # while not plt.waitforbuttonpress(): pass # wait for key display
                    # Alternative option -> normal display
                    plt.show()
                    # save the plot
                    # fig.savefig(output_files_path + "/classes/" + class_name + ".png")
                    # plt.cla()  # clear axes for next plot

            # if show_animation:
            #     cv2.destroyAllWindows()

            output_file.write("\n# mAP of all classes\n")
            mAP = sum_AP / n_classes
            text = "mAP = {0:.2f}%".format(mAP * 100)
            output_file.write(text + "\n")
            print(text)
        """
         Count total of detection-results
        """
        # iterate through all the files
        det_counter_per_class = {}
        for txt_file in dr_files_list:
            # get lines to list
            lines_list = read_txt_to_list(txt_file)
            for line in lines_list:
                class_name = line.split()[0]
                # check if class is in the ignore list, if yes skip
                # if class_name in args.ignore:
                #     continue
                # count that object
                if class_name in det_counter_per_class:
                    det_counter_per_class[class_name] += 1
                else:
                    # if class didn't exist yet
                    det_counter_per_class[class_name] = 1
        # print(det_counter_per_class)
        dr_classes = list(det_counter_per_class.keys())
        """
         Plot the total number of occurences of each class in the ground-truth
        """
        if True:
            window_title = "ground-truth-info"
            plot_title = "ground-truth\n"
            plot_title += "(" + str(
                len(ground_truth_files_list)) + " files and " + str(
                    n_classes) + " classes)"
            x_label = "Number of objects per class"
            output_path = output_files_path + "/ground-truth-info.png"
            to_show = False
            plot_color = 'forestgreen'
            draw_plot_func(
                gt_counter_per_class,
                n_classes,
                window_title,
                plot_title,
                x_label,
                output_path,
                to_show,
                plot_color,
                '',
            )
        """
         Finish counting true positives
        """
        for class_name in dr_classes:
            # if class exists in detection-result but not in ground-truth then there are no true positives in that class
            if class_name not in gt_classes:
                count_true_positives[class_name] = 0
        # print(count_true_positives)
        """
         Plot the total number of occurences of each class in the "detection-results" folder
        """
        if True:
            window_title = "detection-results-info"
            # Plot title
            plot_title = "detection-results\n"
            plot_title += "(" + str(len(dr_files_list)) + " files and "
            count_non_zero_values_in_dictionary = sum(
                int(x) > 0 for x in list(det_counter_per_class.values()))
            plot_title += str(
                count_non_zero_values_in_dictionary) + " detected classes)"
            # end Plot title
            x_label = "Number of objects per class"
            output_path = output_files_path + "/detection-results-info.png"
            to_show = False
            plot_color = 'forestgreen'
            true_p_bar = count_true_positives
            draw_plot_func(det_counter_per_class, len(det_counter_per_class),
                           window_title, plot_title, x_label, output_path,
                           to_show, plot_color, true_p_bar)
        """
         Draw mAP plot (Show AP's of all classes in decreasing order)
        """
        if True:
            window_title = "mAP"
            plot_title = "mAP = {0:.2f}%".format(mAP * 100)
            x_label = "Average Precision"
            output_path = output_files_path + "/mAP.png"
            to_show = True
            plot_color = 'royalblue'
            draw_plot_func(ap_dictionary, n_classes, window_title, plot_title,
                           x_label, output_path, to_show, plot_color, "")