Beispiel #1
0
def main(args: EvalArgs):
    # Local imports (imports that require tensorflow)
    from calamari_ocr.ocr.scenario import CalamariScenario
    from calamari_ocr.ocr.dataset.data import Data
    from calamari_ocr.ocr.evaluator import Evaluator

    if args.checkpoint:
        saved_model = SavedCalamariModel(args.checkpoint, auto_update=True)
        trainer_params = CalamariScenario.trainer_cls().params_cls().from_dict(saved_model.dict)
        data_params = trainer_params.scenario.data
    else:
        data_params = Data.default_params()

    data = Data(data_params)

    pred_data = args.pred if args.pred is not None else args.gt.to_prediction()
    evaluator = Evaluator(args.evaluator, data=data)
    evaluator.preload_gt(gt_dataset=args.gt)
    r = evaluator.run(gt_dataset=args.gt, pred_dataset=pred_data)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
            r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]
        )
    )

    # sort descending
    print_confusions(r, args.n_confusions)

    samples = data.create_pipeline(evaluator.params.setup, args.gt).reader().samples()
    print_worst_lines(r, samples, args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(
            args.xlsx_output,
            [
                {
                    "prefix": "evaluation",
                    "results": r,
                    "gt_files": [s["id"] for s in samples],
                }
            ],
        )

    return r
Beispiel #2
0
def get_preproc_image():
    data_params = Data.default_params()
    data_params.skip_invalid_gt = False
    data_params.pre_proc.run_parallel = False
    data_params.pre_proc.processors = data_params.pre_proc.processors[:-1]
    for p in data_params.pre_proc.processors_of_type(FinalPreparationProcessorParams):
        p.pad = 0
    post_init(data_params)
    pl = Data(data_params).create_pipeline(DataPipelineParams, None)
    pl.mode = PipelineMode.PREDICTION
    preproc = data_params.pre_proc.create(pl)

    def pp(image):
        its = InputSample(
            image, None, SampleMeta("001", fold_id="01")
        ).to_input_target_sample()
        s = preproc.apply_on_sample(its)
        return s.inputs

    return pp
Beispiel #3
0
def get_preproc_text(rtl=False):
    data_params = Data.default_params()
    data_params.skip_invalid_gt = False
    data_params.pre_proc.run_parallel = False

    if rtl:
        for p in data_params.pre_proc.processors_of_type(BidiTextProcessorParams):
            p.bidi_direction = BidiDirection.RTL
    post_init(data_params)

    pl = Data(data_params).create_pipeline(DataPipelineParams, None)
    pl.mode = PipelineMode.TARGETS
    preproc = data_params.pre_proc.create(pl)

    def pp(text):
        its = InputSample(
            None, text, SampleMeta("001", fold_id="01")
        ).to_input_target_sample()
        s = preproc.apply_on_sample(its)
        return s.targets

    return pp