Esempio n. 1
0
def tao_evaluation(tao_ann_file, anns, results_coco_format):
    from tao.toolkit.tao import TaoEval

    ############################## debugging code to make sure we are using TaoEval correctly
    ############################## we pass the gt ann as predictions
    # annos = anns['annotations']
    # for ann in annos:
    #     ann['score'] = 1
    # import logging
    # logger = logging.getLogger()
    # logger.setLevel(logging.INFO)
    # tao_eval = TaoEval(tao_ann_file, annos)
    # # tao_eval = TaoEval(tao_ann_file, annos[:len(annos)//2])
    # import pdb;pdb.set_trace()
    # tao_eval.run()
    # tao_eval.print_results()
    ############################## end debugging code

    # convert results from coco format to tao format
    global_instance_id = 0
    results_tao_format = []
    for img, results_in_img in zip(anns['images'], results_coco_format):
        img_id = img['id']

        if img['frame_id'] == 0:
            global_instance_id += 10000  # shift it 10000 to restart counting in next video

        for instance_id, result in results_in_img.items():
            instance_id = int(instance_id) + global_instance_id
            result_tao_format = {
                "image_id": img_id,
                "category_id": result['label'] + 1,  # coco labels are 1-based
                "bbox": xyxy2xywh(result['bbox'][:-1]),
                "score": result['bbox'][-1],
                "track_id": instance_id,
                "video_id": img['video_id'],
            }
            results_tao_format.append(result_tao_format)

    import logging
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    tao_eval = TaoEval(tao_ann_file, results_tao_format)
    tao_eval.run()
    tao_eval.print_results()
    results = tao_eval.get_results()
    return results
Esempio n. 2
0
    def evaluate(self,
                 results,
                 metric=['track'],
                 logger=None,
                 resfile_path=None):
        if isinstance(metric, list):
            metrics = metric
        elif isinstance(metric, str):
            metrics = [metric]
        else:
            raise TypeError('metric must be a list or a str.')
        allowed_metrics = ['bbox', 'track']
        for metric in metrics:
            if metric not in allowed_metrics:
                raise KeyError(f'metric {metric} is not supported.')

        result_files, tmp_dir = self.format_results(results, resfile_path)

        eval_results = dict()

        if 'track' in metrics:
            from tao.toolkit.tao import TaoEval
            print_log('Evaluating TAO results...', logger)
            tao_eval = TaoEval(self.ann_file, result_files['track'])
            tao_eval.params.img_ids = self.img_ids
            tao_eval.params.cat_ids = self.cat_ids
            tao_eval.params.iou_thrs = np.array([0.5, 0.75])
            tao_eval.run()

            tao_eval.print_results()
            tao_results = tao_eval.get_results()
            for k, v in tao_results.items():
                if isinstance(k, str) and k.startswith('AP'):
                    key = 'track_{}'.format(k)
                    val = float('{:.3f}'.format(float(v)))
                    eval_results[key] = val

        if 'bbox' in metrics:
            print_log('Evaluating detection results...', logger)
            lvis_gt = LVIS(self.ann_file)
            lvis_dt = LVISResults(lvis_gt, result_files['bbox'])
            lvis_eval = LVISEval(lvis_gt, lvis_dt, 'bbox')
            lvis_eval.params.imgIds = self.img_ids
            lvis_eval.params.catIds = self.cat_ids
            lvis_eval.evaluate()
            lvis_eval.accumulate()
            lvis_eval.summarize()
            lvis_eval.print_results()
            lvis_results = lvis_eval.get_results()
            for k, v in lvis_results.items():
                if k.startswith('AP'):
                    key = '{}_{}'.format('bbox', k)
                    val = float('{:.3f}'.format(float(v)))
                    eval_results[key] = val
            ap_summary = ' '.join([
                '{}:{:.3f}'.format(k, float(v))
                for k, v in lvis_results.items() if k.startswith('AP')
            ])
            eval_results['bbox_mAP_copypaste'] = ap_summary

        if tmp_dir is not None:
            tmp_dir.cleanup()

        return eval_results