Esempio n. 1
0
def run_evaluation(labelmap, groundtruth, detections, exclusions):
    """Runs evaluations given input files.

  Args:
    labelmap: file object containing map of labels to consider, in pbtxt format
    groundtruth: file object
    detections: file object
    exclusions: file object or None.
  """
    categories, class_whitelist = read_labelmap(labelmap)
    logging.info("CATEGORIES (%d):\n%s", len(categories),
                 pprint.pformat(categories, indent=2))
    excluded_keys = read_exclusions(exclusions)

    pascal_evaluator = object_detection_evaluation.PascalDetectionEvaluator(
        categories)

    # Reads the ground truth data.
    boxes, labels, _ = read_csv(groundtruth, class_whitelist, 0)
    start = time.time()
    for image_key in boxes:
        if image_key in excluded_keys:
            logging.info(("Found excluded timestamp in ground truth: %s. "
                          "It will be ignored."), image_key)
            continue
        pascal_evaluator.add_single_ground_truth_image_info(
            image_key, {
                standard_fields.InputDataFields.groundtruth_boxes:
                np.array(boxes[image_key], dtype=float),
                standard_fields.InputDataFields.groundtruth_classes:
                np.array(labels[image_key], dtype=int),
                standard_fields.InputDataFields.groundtruth_difficult:
                np.zeros(len(boxes[image_key]), dtype=bool)
            })
    print_time("convert groundtruth", start)

    # Reads detections data.
    boxes, labels, scores = read_csv(detections, class_whitelist, 50)
    start = time.time()
    for image_key in boxes:
        if image_key in excluded_keys:
            logging.info(("Found excluded timestamp in detections: %s. "
                          "It will be ignored."), image_key)
            continue
        pascal_evaluator.add_single_detected_image_info(
            image_key, {
                standard_fields.DetectionResultFields.detection_boxes:
                np.array(boxes[image_key], dtype=float),
                standard_fields.DetectionResultFields.detection_classes:
                np.array(labels[image_key], dtype=int),
                standard_fields.DetectionResultFields.detection_scores:
                np.array(scores[image_key], dtype=float)
            })
    print_time("convert detections", start)

    start = time.time()
    metrics = pascal_evaluator.evaluate()
    print_time("run_evaluator", start)
    pprint.pprint(metrics, indent=2)
Esempio n. 2
0
 def __init__(self):
     this_dir = os.path.dirname(__file__)
     lib_path = os.path.join(this_dir, '../external/ActivityNet/Evaluation')
     sys.path.insert(0, lib_path)
     from ava import object_detection_evaluation
     from ava import standard_fields
     sys.path.pop(0)
     self.sf = standard_fields
     categories = [{"id": i, "name": i} for i in range(1, 81)]
     self.evaluator = object_detection_evaluation.PascalDetectionEvaluator(
         categories)
Esempio n. 3
0
 def __init__(self):
     this_dir = os.path.dirname(__file__)
     lib_path = os.path.join(this_dir, '../external/ActivityNet/Evaluation')
     sys.path.insert(0, lib_path)
     from ava import object_detection_evaluation
     from ava import standard_fields
     sys.path.pop(0)
     self.sf = standard_fields
     categories = [{"id": i, "name": i} for i in range(1, 81)]
     self.evaluator = object_detection_evaluation.PascalDetectionEvaluator(
         categories)
     top60path = '../external/ActivityNet/Evaluation/ava/ava_action_list_v2.1_for_activitynet_2018.pbtxt.txt'
     top60path = os.path.join(this_dir, top60path)
     with open(top60path) as f:
         self.top60 = [int(x) for x in re.findall('[0-9]+', f.read())]
def run_evaluation(labelmap, groundtruth, exclusions, iou):

    root_dir = '../../../data/AVA/files/'
    test_dir = "../test_outputs/"
    # Make sure not to mess this up
    experiments_filters = {}
    experiments_detections = {}

    experiments_filters['pose'] = ['Pose']
    experiments_detections['pose'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb')
    ]

    experiments_filters['rgb-streams-aug'] = ['RGB', 'Crop', 'Gauss', 'Fovea']
    experiments_detections['rgb-streams-aug'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'),
        open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'),
        open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'),
        open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb')
    ]

    experiments_filters['flow vs flowcrop'] = ['Flow', 'Flowcrop']
    experiments_detections['flow vs flowcrop'] = [
        open(test_dir + "/flow/output_test_flowcrop.csv", 'rb'),
    ]

    #all_detections.append(open(test_dir + "/flow/output_test_flow.csv", 'rb'))

    experiments_filters['two-streams'] = [
        'Two-Stream-RGB', 'Two-Stream-Crop', 'Two-Stream-Gauss',
        'Two-Stream-Fovea'
    ]
    experiments_detections['two-streams'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'),
        open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'),
        open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'),
        open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb')
    ]

    experiments_filters['two-streams-aug'] = ['RGB', 'Crop', 'Gauss', 'Fovea']
    experiments_detections['two-streams-aug'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'),
        open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'),
        open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'),
        open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb')
    ]

    experiments_filters['mlp vs lstm'] = ['RGB', 'Crop', 'Gauss', 'Fovea']
    experiments_detections['mlp vs lstm'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'),
        open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'),
        open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'),
        open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb')
    ]

    experiments_filters['lstmA vs lstmB'] = ['RGB', 'Crop', 'Gauss', 'Fovea']
    experiments_detections['lstmA vs lstmB'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'),
        open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'),
        open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'),
        open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb')
    ]

    experiments_filters['context-fusion mlp vs lstm'] = [
        'RGB', 'Crop', 'Gauss', 'Fovea'
    ]
    experiments_detections['context-fusion mlp vs lstm'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'),
        open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'),
        open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'),
        open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb')
    ]

    experiments_filters['balancing sampling'] = [
        'RGB', 'Crop', 'Gauss', 'Fovea'
    ]
    experiments_detections['balancing sampling'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'),
        open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'),
        open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'),
        open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb')
    ]

    experiments_filters['balancing weights'] = [
        'RGB', 'Crop', 'Gauss', 'Fovea'
    ]
    experiments_detections['balancing weights'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'),
        open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'),
        open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'),
        open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb')
    ]

    experiments_filters['balancing prior'] = ['RGB', 'Crop', 'Gauss', 'Fovea']
    experiments_detections['balancing prior'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'),
        open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'),
        open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'),
        open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb')
    ]

    # experiment =
    filters = []

    # filters.append("pose")

    # filters.append("rgb-base")
    # filters.append("rgb-prior")
    # filters.append("rgb-sampling")
    # filters.append("rgb-weights")

    # filters.append("rgb-kinetics")
    # filters.append("flow-kinetics")

    # filters.append("rgb")
    # filters.append("crop")
    # filters.append("gauss")
    # filters.append("fovea")

    # filters.append("flowcrop")
    # filters.append("flow")

    # filters.append("MLP")
    #filters.append("best case scenario thresh 0.1")
    #filters.append("two pass scenario thresh 0.1")
    filters.append("fovea")
    filters.append("dense-gt")
    #filters.append("sampling no aug")
    filters.append("dense-2pass")
    #filters.append("weights no aug")

    # filters.append("LSTM5-A-512")
    # filters.append("random")
    # filters.append("LSTM5-B-512")
    # filters.append("LSTM10")

    # filters.append("2st(rgb)")
    # filters.append("2st(crop)")
    # filters.append("2st(gauss)")
    # filters.append("2st(fovea)")

    #filters.append("2st(crop) + flowcrop")
    #filters.append("2st(gauss) + flowcrop")
    #filters.append("2st(fovea) + flowcrop")

    #filters.append("2st(fovea) + mlp")
    #filters.append("2st(crop) + mlp")
    #filters.append("2st(gauss) + mlp")

    # filters.append("2stream")
    #filters.append("2stream + lstm (extra pass)")
    # filters.append("gauss")
    #filters.append("gauss aug")
    #filters.append("LSTMA 512 5 2")
    #filters.append("LSTMA 512 5 3")
    #filters.append("LSTMA 512 5 3")
    #filters.append("LSTMA 1024 5 3")
    #filters.append("LSTMA 2048 5 3")

    #filters.append("LSTMB 512 3 3")
    #filters.append("LSTMB 1024 3 3")
    #filters.append("LSTMB 2048 3 3")

    # filters.append("2st(gauss) + lstm")
    all_detections = []
    #all_detections.append(open(test_dir + "context/lstmA/output_test_ctx_lstm_512_5_3.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmA/output_test_ctx_lstm_1024_5_3.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmA/output_test_ctx_lstm_2048_5_3.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_512_3_3.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_1024_3_3.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_2048_3_3.csv", 'rb'))

    # Pose
    # all_detections.append(open(test_dir + "output_test_flowcrop.csv", 'rb'))

    # Balancing
    #all_detections.append(open(test_dir + "output_test_flowcrop.csv", 'rb'))
    #all_detections.append(open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'))
    #all_detections.append(open(test_dir + "/augmentation/predictions_rgb_gauss_1807241628_1000.csv", 'rb'))
    #all_detections.append(open(test_dir + "/augmentation/output_test_sampling_gauss_1809221859.csv", 'rb'))
    #all_detections.append(open(test_dir + "/augmentation/output_test_weights_gauss_1809221904.csv", 'rb'))
    # RGB Streams
    #all_detections.append(open(test_dir + "/kinetics_init/output_test_rgb_kineticsinit_gauss_1809220212.csv", 'rb'))
    #all_detections.append(open(test_dir + "/kinetics_init/output_test_flow_kineticsinit_1809220244.csv", 'rb'))
    # Flow Streams

    # Context (LSTMs)
    #filters.append("LSTMB 512 3 3")
    #filters.append("LSTMB 512 3 2")
    #filters.append("LSTMB 512 3 1")

    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_512_3_1.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_512_3_2.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_512_3_3.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmA/output_test_ctx_lstm_512_5_1.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmA/output_test_ctx_lstm_512_5_2.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmA/output_test_ctx_lstm_512_5_3.csv", 'rb'))

    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_32_3_3.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_32_5_3.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_32_10_3.csv", 'rb'))

    #all_detections.append(open(test_dir + "context/mlp/output_test_ctx.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/mlp/output_test_ctx_mlp_1809212356.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmA/output_test_ctx_lstm_512_5_3_1809220010.csv", 'rb'))
    #all_detections.append(open(test_dir + "random/output_test_random_1809221552.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_512_5_3_1809211924.csv", 'rb'))
    #all_detections.append(open(test_dir + "context/lstmB/output_test_ctx_lstm_128_10_3_1809211930.csv", 'rb'))

    # 6 2-streams + baseline
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_rgb_1809220100.csv", 'rb'))
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_crop.csv", 'rb'))
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_gauss.csv", 'rb'))
    all_detections.append(
        open(test_dir + "/two-streams/output_test_2stream_fovea.csv", 'rb'))

    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_flowcrop_crop_1809220117.csv", 'rb'))
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_flowcrop_gauss_1809220152.csv", 'rb'))
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_flowcrop_fovea_1809220136.csv", 'rb'))

    # Context Fusions
    # all_detections.append(open(test_dir + "/context_fusion/output_test_3stream_fovea.csv", 'rb'))
    # all_detections.append(open(test_dir + "/context_fusion/output_test_3stream_crop.csv", 'rb'))
    # all_detections.append(open(test_dir + "/context_fusion/output_test_3stream_gauss.csv", 'rb'))
    all_detections.append(
        open(
            test_dir +
            "/context_fusion/output_test_LSTM_FCfusion_contextGT_gauss_1810011737.csv",
            'rb'))
    all_detections.append(
        open(
            test_dir +
            "/context_fusion/output_test_LSTM_FCfusion_context_secondpass_gauss_1810011754.csv",
            'rb'))

    # LSTMs
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_rgb_1809220100.csv", 'rb'))
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_crop.csv", 'rb'))
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_gauss.csv", 'rb'))
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_fovea.csv", 'rb'))
    #all_detections.append(open(test_dir + "/context_fusion/output_test_ctx_lstm_fusion_thresh_512_5_3_1809242315.csv", 'rb'))
    #all_detections.append(open(test_dir + "/context_fusion/output_test_ctx_lstm_512_5_3_1809242252.csv", 'rb'))
    #all_detections.append(open(test_dir + "/context_fusion/output_test_ctx_lstmavggoodpedro_512_5_3_1809242338.csv", 'rb'))
    #all_detections.append(open(test_dir + "/context_fusion/output_test_ctx_lstmavg_twophase_thresh02_512_5_3_1809281219.csv", 'rb'))
    #all_detections.append(open(test_dir + "/context_fusion/output_test_ctx_lstmavg_threephase_512_5_3_1809281317.csv", 'rb'))
    #all_detections.append(open(test_dir + "/context_fusion/output_test_ctx_lstm_fusion_thresh01_512_5_3_1809281400.csv", 'rb'))
    #all_detections.append(open(test_dir + "/context_fusion/output_test_ctx_lstmavg_twophase_thresh01_512_5_3_1809281423.csv", 'rb'))
    #all_detections.append(open(test_dir + "rgb_gauss/output_test_gauss.csv", 'rb'))
    #all_detections.append(open(test_dir + "augmentation/output_test_sampling_gauss_1809221859.csv", 'rb'))
    #all_detections.append(open(test_dir + "augmentation/output_test_samplingnoaug_gauss_1809281439.csv", 'rb'))
    #all_detections.append(open(test_dir + "augmentation/output_test_weightsnew_gauss_1809291104.csv", 'rb'))
    #all_detections.append(open(test_dir + "augmentation/output_test_weightsaug_gauss_1809261228.csv", 'rb'))
    # output_test_ctx_lstm_512_5_3_1809242252.csv

    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_flowcrop_crop_1809220117.csv", 'rb'))
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_flowcrop_gauss_1809220152.csv", 'rb'))
    #all_detections.append(open(test_dir + "/two-streams/output_test_2stream_flowcrop_fovea_1809220136.csv", 'rb'))

    # ---------------------------------
    # New run to compare new flow
    #all_detections.append(open(test_dir + "/flow/output_test_flowcrop.csv", 'rb'))
    #all_detections.append(open(test_dir + "/flow/output_test_flow.csv", 'rb'))

    # New 2 and 3 streams
    # all_detections.append(open(test_dir + "output_test_gauss.csv", 'rb'))
    # all_detections.append(open(test_dir + "output_test_gauss_extra.csv", 'rb'))
    # all_detections.append(open(test_dir + "output_test_3stream_gauss.csv", 'rb'))
    # all_detections.append(open(test_dir + "output_test_3stream_crop.csv", 'rb'))

    # Flow, context, 2-stream, 3-stream run
    #all_detections.append(open(test_dir + "output_test_ctx.csv", 'rb'))
    #all_detections.append(open(test_dir + "output_test_flow.csv", 'rb'))

    #all_detections.append(open(test_dir + "output_test_2stream.csv", 'rb'))
    #all_detections.append(open(test_dir + "output_test_3stream.csv", 'rb'))

    # RGB run
    # all_detections.append(open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'))
    # all_detections.append(open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'))
    # all_detections.append(open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'))
    # all_detections.append(open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb'))
    balancing = False

    all_gndtruths = []
    for i in range(len(all_detections)):
        if balancing is False:
            all_gndtruths.append(
                open(root_dir + "AVA_Test_Custom_Corrected.csv", 'rb'))
        else:
            all_gndtruths.append(
                open(root_dir + "AVA_Test_Custom_Corrected_Balanced.csv",
                     'rb'))
    """Runs evaluations given input files.

    Args:
      labelmap: file object containing map of labels to consider, in pbtxt format
      groundtruth: file object
      detections: file object
      exclusions: file object or None.
    """
    categories, class_whitelist = read_labelmap(labelmap)
    logging.info("CATEGORIES (%d):\n%s", len(categories),
                 pprint.pformat(categories, indent=2))
    excluded_keys = read_exclusions(exclusions)

    # Reads detections data.
    x_axis = []
    xpose_ax = []
    xobj_ax = []
    xhuman_ax = []
    ypose_ax = []
    yobj_ax = []
    yhuman_ax = []
    colors_pose = []
    colors_obj = []
    colors_human = []
    finalmAPs = []
    colors = []

    maxY = -1.0

    for detections, gndtruth, filter_type in zip(all_detections, all_gndtruths,
                                                 filters):
        pascal_evaluator = None
        metrics = None
        actions = None
        start = 0

        pascal_evaluator = object_detection_evaluation.PascalDetectionEvaluator(
            categories, matching_iou_threshold=iou)

        # Reads the ground truth data.
        boxes, labels, _ = read_csv(gndtruth, class_whitelist)
        start = time.time()
        for image_key in boxes:
            if image_key in excluded_keys:
                logging.info(("Found excluded timestamp in ground truth: %s. "
                              "It will be ignored."), image_key)
                continue
            pascal_evaluator.add_single_ground_truth_image_info(
                image_key, {
                    standard_fields.InputDataFields.groundtruth_boxes:
                    np.array(boxes[image_key], dtype=float),
                    standard_fields.InputDataFields.groundtruth_classes:
                    np.array(labels[image_key], dtype=int),
                    standard_fields.InputDataFields.groundtruth_difficult:
                    np.zeros(len(boxes[image_key]), dtype=bool)
                })
        print_time("convert groundtruth", start)

        # Run evaluation
        boxes, labels, scores = read_csv(detections, class_whitelist)
        start = time.time()
        for image_key in boxes:
            if image_key in excluded_keys:
                logging.info(("Found excluded timestamp in detections: %s. "
                              "It will be ignored."), image_key)
                continue
            pascal_evaluator.add_single_detected_image_info(
                image_key, {
                    standard_fields.DetectionResultFields.detection_boxes:
                    np.array(boxes[image_key], dtype=float),
                    standard_fields.DetectionResultFields.detection_classes:
                    np.array(labels[image_key], dtype=int),
                    standard_fields.DetectionResultFields.detection_scores:
                    np.array(scores[image_key], dtype=float)
                })
        print_time("convert detections", start)

        start = time.time()
        metrics = pascal_evaluator.evaluate()
        print_time("run_evaluator", start)

        # TODO Show a pretty histogram here besides pprint
        actions = list(metrics.keys())

        final_value = 0.0
        for m in actions:
            ms = m.split("/")[-1]

            if ms == 'mAP@' + str(iou) + 'IOU':
                final_value = metrics[m]
                finalmAPs.append(final_value)
            else:
                # x_axis.append(ms)
                # y_axis.append(metrics[m])
                for cat in categories:
                    if cat['name'].split("/")[-1] == ms:
                        if maxY < metrics[m]:
                            maxY = metrics[m]
                        if cat['id'] <= 10:
                            xpose_ax.append("[" + filter_type + "] " + ms)
                            ypose_ax.append(metrics[m])
                            colors_pose.append('pose')
                        elif cat['id'] <= 22:
                            xobj_ax.append("[" + filter_type + "] " + ms)
                            yobj_ax.append(metrics[m])
                            colors_obj.append('human-object')
                        else:
                            xhuman_ax.append("[" + filter_type + "] " + ms)
                            yhuman_ax.append(metrics[m])
                            colors_human.append('human-human')

                # Make a confusion matrix for this run

        pascal_evaluator = None

    x_axis = split_interleave(xpose_ax) + split_interleave(
        xobj_ax) + split_interleave(xhuman_ax)
    y_axis = split_interleave(ypose_ax) + split_interleave(
        yobj_ax) + split_interleave(yhuman_ax)
    colors = split_interleave(colors_pose) + split_interleave(
        colors_obj) + split_interleave(colors_human)

    plt.ylabel('frame-mAP')
    top = maxY + 0.1  # offset a bit so it looks good
    sns.set_style("whitegrid")

    g = sns.barplot(y_axis,
                    x_axis,
                    hue=colors,
                    palette=['red', 'blue', 'green'])

    ax = g
    #ax.legend(loc='lower right')
    # annotate axis = seaborn axis
    # for p in ax.patches:
    #    ax.annotate("%.3f" % p.get_height(), (p.get_x() + p.get_width() / 2., p.get_height()),
    #                ha='center', va='center', fontsize=10, color='gray', rotation=90, xytext=(0, 20),
    #                textcoords='offset points')
    # ax.set_ylim(-1, len(y_axis))
    sns.set()
    ax.tick_params(labelsize=6)
    for p in ax.patches:
        p.set_height(p.get_height() * 3)
        ax.annotate("%.3f" % p.get_width(),
                    (p.get_x() + p.get_width(), p.get_y()),
                    xytext=(5, -5),
                    fontsize=8,
                    color='gray',
                    textcoords='offset points')

    _ = g.set_xlim(0, top)  # To make space for the annotations
    pprint.pprint(metrics, indent=2)

    ax.set(ylabel="", xlabel="AP")
    plt.xticks(rotation=0)

    title = ""
    file = open("results.txt", "w")
    for filter_type, mAP in zip(filters, finalmAPs):
        ft = filter_type + ': mAP@' + str(iou) + 'IOU = ' + str(mAP) + '\n'
        title += ft
        file.write(ft)
    file.close()

    # ax.figure.tight_layout()
    ax.figure.subplots_adjust(left=0.2)  # change 0.3 to suit your needs.
    plt.title(title)
    plt.gca().xaxis.grid(True)

    plt.show()

    if len(all_detections) == 1:
        sz = 2
        grid_sz = [1, 1]
    elif len(all_detections) == 2:
        sz = 3
        grid_sz = [1, 2]
    elif len(all_detections) == 3:
        sz = 4
        grid_sz = [2, 2]
    else:
        sz = 5
        grid_sz = [2, 2]

    for i in range(1, sz):
        print(i)
        plt.subplot(grid_sz[0], grid_sz[1], i)
        if i <= len(all_detections):

            # Confusion matrix
            classes = []
            for k in categories:
                classes.append(k['name'])
            cm = confusion_matrix(all_gndtruths[i - 1], all_detections[i - 1],
                                  x_axis)
            g = sns.heatmap(cm,
                            annot=True,
                            fmt="d",
                            xticklabels=classes[:10],
                            yticklabels=classes[:10],
                            linewidths=0.5,
                            linecolor='black',
                            cbar=True)

            #t = 0
            # for ytick_label, xtick_label in zip(g.axes.get_yticklabels(), g.axes.get_xticklabels()):
            #    if t <= 9:
            #        ytick_label.set_color("r")
            #        xtick_label.set_color("r")

            #    elif t <= 22:
            #        ytick_label.set_color("b")
            #        xtick_label.set_color("b")
            #    else:
            #        ytick_label.set_color("g")
            #        xtick_label.set_color("g")
            #    t += 1
            plt.xticks(rotation=-90)
            plt.title("Pose Confusion Matrix (" + filters[i - 1] + ")")
    plt.show()
def run_evaluation_threshold(labelmap, groundtruth, exclusions, iou):

    # sns.palplot(sns.diverging_palette(128, 240, n=10))
    # seq_col_brew = sns.color_palette("Blues_r", 4) # For sequential, blue gradient in reverse
    # Qualitative data palette
    # current_palette = sns.color_palette("Paired")
    # sns.set_palette(current_palette)

    # Make sure not to mess this up
    filters = []
    filters.append("0.1")
    filters.append("0.2")
    filters.append("0.3")
    filters.append("0.4")
    filters.append("0.5")
    filters.append("0.6")
    filters.append("0.7")
    filters.append("0.8")
    filters.append("0.9")

    root_dir = '../../../data/AVA/files/'
    ftype = "fusion"

    all_detections = []
    ts = "1809281055"
    for f in filters:
        all_detections.append(
            open(
                "../thresholds/context_" + ftype +
                "/predictions_fusion_avg_fovea_" + ts + "_" + f + ".csv",
                'rb'))

    all_gndtruths = []
    for i in range(len(filters)):
        all_gndtruths.append(
            open(root_dir + "AVA_Val_Custom_Corrected.csv", 'rb'))

    #all_gndtruths.append(open("AVA_Test_Custom_Corrected.csv", 'rb'))
    #all_gndtruths.append(open("AVA_Test_Custom_Corrected.csv", 'rb'))
    """Runs evaluations given input files.

    Args:
      labelmap: file object containing map of labels to consider, in pbtxt format
      groundtruth: file object
      detections: file object
      exclusions: file object or None.
    """
    categories, class_whitelist = read_labelmap(labelmap)
    logging.info("CATEGORIES (%d):\n%s", len(categories),
                 pprint.pformat(categories, indent=2))
    excluded_keys = read_exclusions(exclusions)

    # Reads detections data.
    x_axis = []
    xpose_ax = []
    xobj_ax = []
    xhuman_ax = []
    ypose_ax = []
    yobj_ax = []
    yhuman_ax = []
    colors_pose = []
    colors_obj = []
    colors_human = []
    finalmAPs = []
    colors = []

    maxY = -1.0

    for detections, gndtruth, filter_type in zip(all_detections, all_gndtruths,
                                                 filters):
        pascal_evaluator = None
        metrics = None
        actions = None
        start = 0

        pascal_evaluator = object_detection_evaluation.PascalDetectionEvaluator(
            categories, matching_iou_threshold=iou)

        # Reads the ground truth data.
        boxes, labels, _ = read_csv(gndtruth, class_whitelist)
        start = time.time()
        for image_key in boxes:
            if image_key in excluded_keys:
                logging.info(("Found excluded timestamp in ground truth: %s. "
                              "It will be ignored."), image_key)
                continue
            pascal_evaluator.add_single_ground_truth_image_info(
                image_key, {
                    standard_fields.InputDataFields.groundtruth_boxes:
                    np.array(boxes[image_key], dtype=float),
                    standard_fields.InputDataFields.groundtruth_classes:
                    np.array(labels[image_key], dtype=int),
                    standard_fields.InputDataFields.groundtruth_difficult:
                    np.zeros(len(boxes[image_key]), dtype=bool)
                })
        print_time("convert groundtruth", start)

        # Run evaluation
        boxes, labels, scores = read_csv(detections, class_whitelist)
        start = time.time()
        for image_key in boxes:
            if image_key in excluded_keys:
                logging.info(("Found excluded timestamp in detections: %s. "
                              "It will be ignored."), image_key)
                continue
            pascal_evaluator.add_single_detected_image_info(
                image_key, {
                    standard_fields.DetectionResultFields.detection_boxes:
                    np.array(boxes[image_key], dtype=float),
                    standard_fields.DetectionResultFields.detection_classes:
                    np.array(labels[image_key], dtype=int),
                    standard_fields.DetectionResultFields.detection_scores:
                    np.array(scores[image_key], dtype=float)
                })
        print_time("convert detections", start)

        start = time.time()
        metrics = pascal_evaluator.evaluate()
        print_time("run_evaluator", start)

        # TODO Show a pretty histogram here besides pprint
        actions = list(metrics.keys())

        final_value = 0.0
        for m in actions:
            ms = m.split("/")[-1]

            if ms == 'mAP@' + str(iou) + 'IOU':
                final_value = metrics[m]
                finalmAPs.append(final_value)
            else:
                # x_axis.append(ms)
                # y_axis.append(metrics[m])
                for cat in categories:
                    if cat['name'].split("/")[-1] == ms:
                        if maxY < metrics[m]:
                            maxY = metrics[m]
                        if cat['id'] <= 10:
                            xpose_ax.append("(" + filter_type + ") " + ms)
                            ypose_ax.append(metrics[m])
                            colors_pose.append('red')
                        elif cat['id'] <= 22:
                            xobj_ax.append("(" + filter_type + ") " + ms)
                            yobj_ax.append(metrics[m])
                            colors_obj.append('blue')
                        else:
                            xhuman_ax.append("(" + filter_type + ") " + ms)
                            yhuman_ax.append(metrics[m])
                            colors_human.append('green')

                # Make a confusion matrix for this run

        pascal_evaluator = None

    x_axis = split_interleave(xpose_ax) + split_interleave(
        xobj_ax) + split_interleave(xhuman_ax)
    y_axis = split_interleave(ypose_ax) + split_interleave(
        yobj_ax) + split_interleave(yhuman_ax)
    colors = split_interleave(colors_pose) + split_interleave(
        colors_obj) + split_interleave(colors_human)
    print(filters)
    print(finalmAPs)

    plt.ylabel('frame-mAP')
    top = 0.1  # offset a bit so it looks good
    sns.set_style("whitegrid")
    clrs = ['blue' if (x < max(finalmAPs)) else 'red' for x in finalmAPs]
    g = sns.barplot(filters, finalmAPs, palette=clrs)
    ax = g
    # annotate axis = seaborn axis
    for p in ax.patches:
        ax.annotate("%.4f" % p.get_height(),
                    (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center',
                    va='center',
                    fontsize=10,
                    color='gray',
                    rotation=90,
                    xytext=(0, 20),
                    textcoords='offset points')
    _ = g.set_ylim(0, top)  # To make space for the annotations

    plt.show()
def run_evaluation(labelmap, groundtruth, exclusions, iou):

    root_dir = '../../../data/AVA/files/'
    test_dir = "../test_outputs/"
    # Make sure not to mess this up
    experiments_filters = {}
    experiments_detections = {}
    experiment = 'balancing'

    # Baseline
    experiments_filters['baseline'] = ['RGB', 'Flow', 'RGB+Flow']
    experiments_detections['baseline'] = [
        open(test_dir + "/rgb_rgb/output_test_rgb.csv", 'rb'),
        open(test_dir + "/flow/output_test_flow.csv", 'rb'),
        open(test_dir + "/two-streams/output_test_2stream_rgb_1809220100.csv",
             'rb')
    ]

    # RGBS
    experiments_filters['rgb-streams-aug'] = ['Crop', 'Gauss', 'Fovea']
    experiments_detections['rgb-streams-aug'] = [
        open(test_dir + "/rgb_crop/output_test_crop.csv", 'rb'),
        open(test_dir + "/rgb_gauss/output_test_gauss.csv", 'rb'),
        open(test_dir + "/rgb_fovea/output_test_fovea.csv", 'rb')
    ]

    # Flows
    experiments_filters['flow vs flowcrop'] = ['Flow', 'Flowcrop']
    experiments_detections['flow vs flowcrop'] = [
        open(test_dir + "/flow/output_test_flow.csv", 'rb'),
        open(test_dir + "/flow/output_test_flowcrop.csv", 'rb'),
    ]

    # Two-streams
    experiments_filters['two-streams'] = [
        'Two-Stream-Crop', 'Two-Stream-GBB', 'Two-Stream-Fovea'
    ]
    experiments_detections['two-streams'] = [
        open(test_dir + "/two-streams/output_test_2stream_crop_1807252254.csv",
             'rb'),
        open(
            test_dir + "/two-streams/output_test_2stream_gauss_1807252309.csv",
            'rb'),
        open(test_dir + "/two-streams/output_test_2stream_fovea.csv", 'rb')
    ]

    #experiments_filters['two-streams'] = ['Two-Stream-Crop', 'Two-Stream-Fovea']
    #experiments_detections['two-streams'] = [open(test_dir + "/two-streams/output_test_2stream_crop_1807252254.csv", 'rb'), open(test_dir + "/two-streams/output_test_2stream_fovea.csv", 'rb')]

    experiments_filters['two-streams-flowcrop'] = [
        'Two-Stream-Crop (Flowcrop)', 'Two-Stream-Gauss (Flowcrop)',
        'Two-Stream-Fovea (Flowcrop)'
    ]
    experiments_detections['two-streams-flowcrop'] = [
        open(
            test_dir +
            "/two-streams/output_test_2stream_flowcrop_crop_1809220117.csv",
            'rb'),
        open(
            test_dir +
            "/two-streams/output_test_2stream_flowcrop_gauss_1809220152.csv",
            'rb'),
        open(
            test_dir +
            "/two-streams/output_test_2stream_flowcrop_fovea_1809220136.csv",
            'rb')
    ]

    # MLP VS LSTM
    experiments_filters['mlp vs lstm'] = ['MLP', 'LSTMA', 'LSTMB']
    experiments_detections['mlp vs lstm'] = [
        open(test_dir + "context/mlp/output_test_ctx_mlp_1809212356.csv",
             'rb'),
        open(test_dir + "context/lstmA/output_test_ctx_lstm_128_3_3.csv",
             'rb'),
        open(test_dir + "context/lstmB/output_test_ctx_lstm_128_3_3.csv", 'rb')
    ]

    # LSTMS
    nhu = 512
    neighbs = 3
    tws = 3
    experiments_filters['lstmA vs lstmB'] = ['LSTM A', 'LSTM B']
    experiments_detections['lstmA vs lstmB'] = [
        open(
            test_dir + "context/lstmA/output_test_ctx_lstm_" + str(nhu) + "_" +
            str(tws) + "_" + str(neighbs) + ".csv", 'rb'),
        open(
            test_dir + "context/lstmB/output_test_ctx_lstm_" + str(nhu) + "_" +
            str(tws) + "_" + str(neighbs) + ".csv", 'rb')
    ]

    # Fusions
    experiments_filters['class-score-fusion-gt'] = [
        '2-stream Fovea', 'Class Score Fusion (GT)'
    ]
    experiments_detections['class-score-fusion-gt'] = [
        open(test_dir + "/two-streams/output_test_2stream_fovea.csv", 'rb'),
        open(
            test_dir +
            "/context_fusion/output_test_ctx_lstm_fusion_512_5_3_1809242338.csv",
            'rb')
    ]

    experiments_filters['class-score-fusion-two-pass'] = [
        '2-stream Fovea', 'Class Score Fusion (Two Pass)'
    ]
    experiments_detections['class-score-fusion-two-pass'] = [
        open(test_dir + "/two-streams/output_test_2stream_fovea.csv", 'rb'),
        open(
            test_dir +
            "/context_fusion/output_test_ctx_lstmavg_twophase_512_5_3_1809281149.csv",
            'rb')
    ]

    experiments_filters['dense-fusion-gt'] = ['2-stream Fovea', 'Dense Fusion']
    experiments_detections['dense-fusion-gt'] = [
        open(test_dir + "/two-streams/output_test_2stream_fovea.csv", 'rb'),
        open(
            test_dir +
            "/context_fusion/output_test_LSTM_FCfusion_contextGT_fovea_1810011737.csv",
            'rb')
    ]

    experiments_filters['dense-fusion-two-pass'] = [
        '2-stream Fovea', 'Dense Fusion (Two Pass)'
    ]
    experiments_detections['dense-fusion-two-pass'] = [
        open(test_dir + "/two-streams/output_test_2stream_fovea.csv", 'rb'),
        open(
            test_dir +
            "/context_fusion/output_test_LSTM_FCfusion_context_secondpass_fovea_1810011754.csv",
            'rb')
    ]

    # Voting
    experiments_filters['class-score-fusion-gt-voting'] = [
        'Class Score Fusion (0.2)', 'Class Score Fusion (0.1)'
    ]
    experiments_detections['class-score-fusion-gt-voting'] = [
        open(
            test_dir +
            "/context_fusion/output_test_ctx_lstm_fusion_thresh02_512_5_3_1809242315.csv",
            'rb'),
        open(
            test_dir +
            "/context_fusion/output_test_ctx_lstm_fusion_thresh01_512_5_3_1809281400.csv",
            'rb')
    ]

    experiments_filters['class-score-fusion-two-pass-voting'] = [
        'Class Score Fusion (Two-pass) (0.2)',
        'Class Score Fusion (Two-pass) (0.1)'
    ]
    experiments_detections['class-score-fusion-two-pass-voting'] = [
        open(
            test_dir +
            "/context_fusion/output_test_ctx_lstmavg_twophase_thresh02_512_5_3_1809281219.csv",
            'rb'),
        open(
            test_dir +
            "/context_fusion/output_test_ctx_lstmavg_twophase_thresh01_512_5_3_1809281423.csv",
            'rb')
    ]

    # Balancing
    experiments_filters['balancing'] = ['Imbalanced', 'Oversampling']
    experiments_detections['balancing'] = [
        open(test_dir + "rgb_gauss/output_test_gauss.csv", 'rb'),
        open(
            test_dir +
            "augmentation/output_test_samplingnoaug_gauss_1809281439.csv",
            'rb')
    ]

    # Extra experiments
    experiments_filters['context-fusion mlp'] = [
        '2-stream Fovea', 'Dense Fusion MLP'
    ]
    experiments_detections['context-fusion mlp'] = [
        open(test_dir + "/two-streams/output_test_2stream_fovea.csv", 'rb'),
        open(test_dir + "/context_fusion/output_test_3stream_fovea.csv", 'rb')
    ]

    experiments_filters['context-fusion extra pass'] = [
        'Class Score (Two Pass)', 'Class Score (Extra Pass)'
    ]
    experiments_detections['context-fusion extra pass'] = [
        open(
            test_dir +
            "/context_fusion/output_test_ctx_lstmavg_twophase_512_5_3_1809281149.csv",
            'rb'),
        open(
            test_dir +
            "/context_fusion/output_test_ctx_lstmavg_threephase_512_5_3_1809281317.csv",
            'rb')
    ]

    experiments_filters['random'] = ['Random guessing']
    experiments_detections['random'] = [
        open(test_dir + "random/output_test_random_1809221552.csv", 'rb')
    ]

    # Best
    experiments_filters['best'] = [
        '2-stream Fovea', 'Class Score Fusion (GT)',
        'Class Score Fusion (GT, v=0.1)'
    ]
    experiments_detections['best'] = [
        open(test_dir + "/two-streams/output_test_2stream_fovea.csv", 'rb'),
        open(
            test_dir +
            "/context_fusion/output_test_ctx_lstm_fusion_512_5_3_1809242338.csv",
            'rb'),
        open(
            test_dir +
            "/context_fusion/output_test_ctx_lstm_fusion_thresh01_512_5_3_1809281400.csv",
            'rb')
    ]

    filters = experiments_filters[experiment]
    all_detections = experiments_detections[experiment]

    balancing = False

    all_gndtruths = []
    for i in range(len(all_detections)):
        if balancing is False:
            all_gndtruths.append(
                open(root_dir + "AVA_Test_Custom_Corrected.csv", 'rb'))
        else:
            all_gndtruths.append(
                open(root_dir + "AVA_Test_Custom_Corrected_Balanced.csv",
                     'rb'))
    """Runs evaluations given input files.

    Args:
      labelmap: file object containing map of labels to consider, in pbtxt format
      groundtruth: file object
      detections: file object
      exclusions: file object or None.
    """
    categories, class_whitelist = read_labelmap(labelmap)
    logging.info("CATEGORIES (%d):\n%s", len(categories),
                 pprint.pformat(categories, indent=2))
    excluded_keys = read_exclusions(exclusions)

    # Reads detections data.
    x_axis = []
    xpose_ax = []
    xobj_ax = []
    xhuman_ax = []
    ypose_ax = []
    yobj_ax = []
    yhuman_ax = []
    colors_pose = []
    colors_obj = []
    colors_human = []
    finalmAPs = []
    colors = []

    maxY = -1.0

    for detections, gndtruth, filter_type in zip(all_detections, all_gndtruths,
                                                 filters):
        pascal_evaluator = None
        metrics = None
        actions = None
        start = 0

        pascal_evaluator = object_detection_evaluation.PascalDetectionEvaluator(
            categories, matching_iou_threshold=iou)

        # Reads the ground truth data.
        boxes, labels, _ = read_csv(gndtruth, class_whitelist)
        start = time.time()
        for image_key in boxes:
            if image_key in excluded_keys:
                logging.info(("Found excluded timestamp in ground truth: %s. "
                              "It will be ignored."), image_key)
                continue
            pascal_evaluator.add_single_ground_truth_image_info(
                image_key, {
                    standard_fields.InputDataFields.groundtruth_boxes:
                    np.array(boxes[image_key], dtype=float),
                    standard_fields.InputDataFields.groundtruth_classes:
                    np.array(labels[image_key], dtype=int),
                    standard_fields.InputDataFields.groundtruth_difficult:
                    np.zeros(len(boxes[image_key]), dtype=bool)
                })
        print_time("convert groundtruth", start)

        # Run evaluation
        boxes, labels, scores = read_csv(detections, class_whitelist)
        start = time.time()
        for image_key in boxes:
            if image_key in excluded_keys:
                logging.info(("Found excluded timestamp in detections: %s. "
                              "It will be ignored."), image_key)
                continue
            pascal_evaluator.add_single_detected_image_info(
                image_key, {
                    standard_fields.DetectionResultFields.detection_boxes:
                    np.array(boxes[image_key], dtype=float),
                    standard_fields.DetectionResultFields.detection_classes:
                    np.array(labels[image_key], dtype=int),
                    standard_fields.DetectionResultFields.detection_scores:
                    np.array(scores[image_key], dtype=float)
                })
        print_time("convert detections", start)

        start = time.time()
        metrics = pascal_evaluator.evaluate()
        print_time("run_evaluator", start)

        # TODO Show a pretty histogram here besides pprint
        actions = list(metrics.keys())

        final_value = 0.0
        for m in actions:
            ms = m.split("/")[-1]

            if ms == 'mAP@' + str(iou) + 'IOU':
                final_value = metrics[m]
                finalmAPs.append(final_value)
            else:
                # x_axis.append(ms)
                # y_axis.append(metrics[m])
                for cat in categories:
                    if cat['name'].split("/")[-1] == ms:
                        if maxY < metrics[m]:
                            maxY = metrics[m]
                        if cat['id'] <= 10:
                            xpose_ax.append("[" + filter_type + "] " + ms)
                            ypose_ax.append(metrics[m])
                            colors_pose.append('pose')
                        elif cat['id'] <= 22:
                            xobj_ax.append("[" + filter_type + "] " + ms)
                            yobj_ax.append(metrics[m])
                            colors_obj.append('human-object')
                        else:
                            xhuman_ax.append("[" + filter_type + "] " + ms)
                            yhuman_ax.append(metrics[m])
                            colors_human.append('human-human')

                # Make a confusion matrix for this run

        pascal_evaluator = None
    parts = len(filters)
    x_axis = split_interleave(xpose_ax, parts) + split_interleave(
        xobj_ax, parts) + split_interleave(xhuman_ax, parts)
    y_axis = split_interleave(ypose_ax, parts) + split_interleave(
        yobj_ax, parts) + split_interleave(yhuman_ax, parts)
    colors = split_interleave(colors_pose, parts) + split_interleave(
        colors_obj, parts) + split_interleave(colors_human, parts)

    plt.ylabel('frame-mAP')
    top = maxY + 0.1  # offset a bit so it looks good
    sns.set_style("whitegrid")

    g = sns.barplot(y_axis,
                    x_axis,
                    hue=colors,
                    palette=['red', 'blue', 'green'])

    ax = g
    # ax.legend(loc='lower right')
    # annotate axis = seaborn axis
    # for p in ax.patches:
    #    ax.annotate("%.3f" % p.get_height(), (p.get_x() + p.get_width() / 2., p.get_height()),
    #                ha='center', va='center', fontsize=10, color='gray', rotation=90, xytext=(0, 20),
    #                textcoords='offset points')
    # ax.set_ylim(-1, len(y_axis))
    sns.set()
    ax.tick_params(labelsize=6)
    for p in ax.patches:
        p.set_height(p.get_height() * 3)
        ax.annotate("%.3f" % p.get_width(),
                    (p.get_x() + p.get_width(), p.get_y()),
                    xytext=(5, -5),
                    fontsize=8,
                    color='gray',
                    textcoords='offset points')

    _ = g.set_xlim(0, top)  # To make space for the annotations
    pprint.pprint(metrics, indent=2)

    ax.set(ylabel="", xlabel="AP")
    plt.xticks(rotation=0)

    title = ""
    file = open("results.txt", "w")
    for filter_type, mAP in zip(filters, finalmAPs):
        ft = filter_type + ': mAP@' + str(iou) + 'IOU = ' + str(mAP) + '\n'
        title += ft
        file.write(ft)
    file.close()

    # ax.figure.tight_layout()
    ax.figure.subplots_adjust(left=0.2)  # change 0.3 to suit your needs.
    plt.title(title)
    plt.gca().xaxis.grid(True)

    plt.show()

    if len(all_detections) == 1:
        sz = 2
        grid_sz = [1, 1]
    elif len(all_detections) == 2:
        sz = 3
        grid_sz = [1, 2]
    elif len(all_detections) == 3:
        sz = 4
        grid_sz = [2, 2]
    else:
        sz = 5
        grid_sz = [2, 2]

    for i in range(1, sz):
        print(i)
        plt.subplot(grid_sz[0], grid_sz[1], i)
        if i <= len(all_detections):

            # Confusion matrix
            classes = []
            for k in categories:
                classes.append(k['name'])
            cm = confusion_matrix(all_gndtruths[i - 1], all_detections[i - 1],
                                  x_axis)
            g = sns.heatmap(cm,
                            annot=True,
                            fmt="d",
                            xticklabels=classes[:10],
                            yticklabels=classes[:10],
                            linewidths=0.5,
                            linecolor='black',
                            cbar=True,
                            vmin=0,
                            vmax=2000)

            # t = 0
            # for ytick_label, xtick_label in zip(g.axes.get_yticklabels(), g.axes.get_xticklabels()):
            #    if t <= 9:
            #        ytick_label.set_color("r")
            #        xtick_label.set_color("r")

            #    elif t <= 22:
            #        ytick_label.set_color("b")
            #        xtick_label.set_color("b")
            #    else:
            #        ytick_label.set_color("g")
            #        xtick_label.set_color("g")
            #    t += 1
            plt.xticks(rotation=-90)
            plt.title("Pose Confusion Matrix (" + filters[i - 1] + ")")
    plt.show()