Example #1
0
def main(args: EvalArgs):
    # Local imports (imports that require tensorflow)
    from calamari_ocr.ocr.scenario import CalamariScenario
    from calamari_ocr.ocr.dataset.data import Data
    from calamari_ocr.ocr.evaluator import Evaluator

    if args.checkpoint:
        saved_model = SavedCalamariModel(args.checkpoint, auto_update=True)
        trainer_params = CalamariScenario.trainer_cls().params_cls().from_dict(saved_model.dict)
        data_params = trainer_params.scenario.data
    else:
        data_params = Data.default_params()

    data = Data(data_params)

    pred_data = args.pred if args.pred is not None else args.gt.to_prediction()
    evaluator = Evaluator(args.evaluator, data=data)
    evaluator.preload_gt(gt_dataset=args.gt)
    r = evaluator.run(gt_dataset=args.gt, pred_dataset=pred_data)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
            r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]
        )
    )

    # sort descending
    print_confusions(r, args.n_confusions)

    samples = data.create_pipeline(evaluator.params.setup, args.gt).reader().samples()
    print_worst_lines(r, samples, args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(
            args.xlsx_output,
            [
                {
                    "prefix": "evaluation",
                    "results": r,
                    "gt_files": [s["id"] for s in samples],
                }
            ],
        )

    return r
Example #2
0
def main(args=None):
    parser = PAIArgumentParser()
    parser.add_argument('--version', action='version', version='%(prog)s v' + __version__)

    parser.add_argument("--n_cols", type=int, default=1)
    parser.add_argument("--n_rows", type=int, default=5)
    parser.add_argument("--select", type=int, nargs="+", default=[])

    parser.add_argument("--preload", action='store_true', help='Simulate preloading')
    parser.add_argument("--as_validation", action='store_true', help="Access as validation instead of training data.")
    parser.add_argument("--n_augmentations", type=float, default=0)
    parser.add_argument("--no_plot", action='store_true', help='This parameter is for testing only')

    parser.add_root_argument("data", DataWrapper)
    args = parser.parse_args(args=args)

    data_wrapper: DataWrapper = args.data
    data_params = data_wrapper.data
    data_params.pre_proc.run_parallel = False
    data_params.pre_proc.erase_all(PrepareSampleProcessorParams)
    for p in data_params.pre_proc.processors_of_type(AugmentationProcessorParams):
        p.n_augmentations = args.n_augmentations
    data_params.__post_init__()
    data_wrapper.pipeline.mode = PipelineMode.EVALUATION if args.as_validation else PipelineMode.TRAINING
    data_wrapper.gen.prepare_for_mode(data_wrapper.pipeline.mode)

    data = Data(data_params)
    if len(args.select) == 0:
        args.select = list(range(len(data_wrapper.gen)))
    else:
        try:
            data_wrapper.gen.select(args.select)
        except NotImplementedError:
            logger.warning(f"Selecting is not supported for a data generator of type {type(data_wrapper.gen)}. "
                           f"Resuming without selection.")
    data_pipeline = data.create_pipeline(data_wrapper.pipeline, data_wrapper.gen)
    if args.preload:
        data_pipeline = data_pipeline.as_preloaded()

    if args.no_plot:
        with data_pipeline as dataset:
            list(zip(args.select, dataset.generate_input_samples(auto_repeat=False)))
        return

    import matplotlib.pyplot as plt
    f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
    row, col = 0, 0
    with data_pipeline as dataset:
        for i, (id, sample) in enumerate(zip(args.select, dataset.generate_input_samples(auto_repeat=False))):
            line, text, params = sample.inputs, sample.targets, sample.meta
            if args.n_cols == 1:
                ax[row].imshow(line.transpose())
                ax[row].set_title("ID: {}\n{}".format(id, text))
            else:
                ax[row, col].imshow(line.transpose())
                ax[row, col].set_title("ID: {}\n{}".format(id, text))

            row += 1
            if row == args.n_rows:
                row = 0
                col += 1

            if col == args.n_cols or i == len(dataset) - 1:
                plt.show()
                f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all')
                row, col = 0, 0

    plt.show()
Example #3
0
def main():
    # Local imports (imports that require tensorflow)
    from calamari_ocr.ocr.scenario import Scenario
    from calamari_ocr.ocr.dataset.data import Data
    from calamari_ocr.ocr.evaluator import Evaluator

    parser = ArgumentParser()
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--gt", nargs="+", required=True,
                        help="Ground truth files (.gt.txt extension). "
                             "Optionally, you can pass a single json file defining all parameters.")
    parser.add_argument("--pred", nargs="+", default=None,
                        help="Prediction files if provided. Else files with .pred.txt are expected at the same "
                             "location as the gt.")
    parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--pred_ext", type=str, default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument("--n_confusions", type=int, default=10,
                        help="Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument("--n_worst_lines", type=int, default=0,
                        help="Print the n worst recognized text lines with its error")
    parser.add_argument("--xlsx_output", type=str,
                        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads", type=int, default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument("--non_existing_file_handling_mode", type=str, default="error",
                        help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
                             "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
                             "'Empty' will handle this file as would it be empty (fully checking for errors)."
                             "'Error' will throw an exception if a file is not existing. This is the default behaviour.")
    parser.add_argument("--skip_empty_gt", action="store_true", default=False,
                        help="Ignore lines of the gt that are empty.")
    parser.add_argument("--no_progress_bars", action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)")

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)

    args = parser.parse_args()

    # check if loading a json file
    if len(args.gt) == 1 and args.gt[0].endswith("json"):
        with open(args.gt[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    logger.info("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    data_params = Data.get_default_params()
    if args.checkpoint:
        saved_model = SavedCalamariModel(args.checkpoint, auto_update=True)
        trainer_params = Scenario.trainer_params_from_dict(saved_model.dict)
        data_params = trainer_params.scenario_params.data_params

    data = Data(data_params)

    gt_reader_args = FileDataReaderArgs(
        text_index=args.pagexml_gt_text_index
    )
    pred_reader_args = FileDataReaderArgs(
        text_index=args.pagexml_pred_text_index
    )
    gt_data_set = PipelineParams(
        type=args.dataset,
        text_files=gt_files,
        data_reader_args=gt_reader_args,
        skip_invalid=args.skip_empty_gt,
    )
    pred_data_set = PipelineParams(
        type=args.pred_dataset,
        text_files=pred_files,
        data_reader_args=pred_reader_args,
    )

    evaluator = Evaluator(data=data)
    evaluator.preload_gt(gt_dataset=gt_data_set)
    r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print("Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
        r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, data.create_pipeline(PipelineMode.Targets, gt_data_set).reader().samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output,
                   [{
                       "prefix": "evaluation",
                       "results": r,
                       "gt_files": gt_files,
                   }])