def evaluate(self, results, metric='acc', logger=None, **kwargs): """Evaluate the dataset. Args: results (list): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. logger (logging.Logger | str | None): Logger used for printing related information during evaluation. Default: None. Returns: dict[str: float] """ gt_texts = [] pred_texts = [] for i in range(len(self)): item_info = self.data_infos[i] text = item_info['text'] gt_texts.append(text) pred_texts.append(results[i]['text']) eval_results = eval_ocr_metric(pred_texts, gt_texts) return eval_results
def main(): parser = ArgumentParser() parser.add_argument('img_root_path', type=str, help='Image root path') parser.add_argument('img_list', type=str, help='Image path list file') parser.add_argument('config', type=str, help='Config file') parser.add_argument('checkpoint', type=str, help='Checkpoint file') parser.add_argument( '--out-dir', type=str, default='./results', help='Dir to save results') parser.add_argument( '--show', action='store_true', help='show image or save') parser.add_argument( '--device', default='cuda:0', help='Device used for inference.') args = parser.parse_args() # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(args.out_dir, f'{timestamp}.log') logger = get_root_logger(log_file=log_file, log_level='INFO') # build the model from a config file and a checkpoint file model = init_detector(args.config, args.checkpoint, device=args.device) if hasattr(model, 'module'): model = model.module # Start Inference out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') mmcv.mkdir_or_exist(out_vis_dir) correct_vis_dir = osp.join(args.out_dir, 'correct') mmcv.mkdir_or_exist(correct_vis_dir) wrong_vis_dir = osp.join(args.out_dir, 'wrong') mmcv.mkdir_or_exist(wrong_vis_dir) img_paths, pred_labels, gt_labels = [], [], [] lines = list_from_file(args.img_list) progressbar = ProgressBar(task_num=len(lines)) num_gt_label = 0 for line in lines: progressbar.update() item_list = line.strip().split() img_file = item_list[0] gt_label = '' if len(item_list) >= 2: gt_label = item_list[1] num_gt_label += 1 img_path = osp.join(args.img_root_path, img_file) if not osp.exists(img_path): raise FileNotFoundError(img_path) # Test a single image result = model_inference(model, img_path) pred_label = result['text'] out_img_name = '_'.join(img_file.split('/')) out_file = osp.join(out_vis_dir, out_img_name) kwargs_dict = { 'gt_label': gt_label, 'show': args.show, 'out_file': '' if args.show else out_file } model.show_result(img_path, result, **kwargs_dict) if gt_label != '': if gt_label == pred_label: dst_file = osp.join(correct_vis_dir, out_img_name) else: dst_file = osp.join(wrong_vis_dir, out_img_name) shutil.copy(out_file, dst_file) img_paths.append(img_path) gt_labels.append(gt_label) pred_labels.append(pred_label) # Save results save_results(img_paths, pred_labels, gt_labels, args.out_dir) if num_gt_label == len(pred_labels): # eval eval_results = eval_ocr_metric(pred_labels, gt_labels) logger.info('\n' + '-' * 100) info = ('eval on testset with img_root_path ' f'{args.img_root_path} and img_list {args.img_list}\n') logger.info(info) logger.info(eval_results) print(f'\nInference done, and results saved in {args.out_dir}\n')