Пример #1
0
def main(args: EvalArgs):
    # Local imports (imports that require tensorflow)
    from calamari_ocr.ocr.scenario import CalamariScenario
    from calamari_ocr.ocr.dataset.data import Data
    from calamari_ocr.ocr.evaluator import Evaluator

    if args.checkpoint:
        saved_model = SavedCalamariModel(args.checkpoint, auto_update=True)
        trainer_params = CalamariScenario.trainer_cls().params_cls().from_dict(saved_model.dict)
        data_params = trainer_params.scenario.data
    else:
        data_params = Data.default_params()

    data = Data(data_params)

    pred_data = args.pred if args.pred is not None else args.gt.to_prediction()
    evaluator = Evaluator(args.evaluator, data=data)
    evaluator.preload_gt(gt_dataset=args.gt)
    r = evaluator.run(gt_dataset=args.gt, pred_dataset=pred_data)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
            r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]
        )
    )

    # sort descending
    print_confusions(r, args.n_confusions)

    samples = data.create_pipeline(evaluator.params.setup, args.gt).reader().samples()
    print_worst_lines(r, samples, args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(
            args.xlsx_output,
            [
                {
                    "prefix": "evaluation",
                    "results": r,
                    "gt_files": [s["id"] for s in samples],
                }
            ],
        )

    return r
Пример #2
0
    def evaluate_books(
        self,
        books,
        checkpoint,
        cachefile=None,
        output_individual_voters=False,
        n_confusions=10,
        silent=True,
    ):
        keras.backend.clear_session()
        if type(books) == str:
            books = [books]
        if type(checkpoint) == str:
            checkpoint = [checkpoint]
        checkpoint = [
            (cp if cp.endswith(".json") else cp + ".json") for cp in checkpoint
        ]
        checkpoint = glob_all(checkpoint)
        checkpoint = [cp[:-5] for cp in checkpoint]
        if cachefile is None:
            cachefile = self.cachefile

        lids = list(
            lids_from_books(books, cachefile, complete_only=True, skip_commented=True)
        )
        data = Nsh5(cachefile=cachefile, lines=lids)

        predparams = PredictorParams()
        predparams.device.gpus = [n for n, _ in enumerate(list_physical_devices("GPU"))]
        predparams.silent = silent

        predictor = MultiPredictor.from_paths(
            checkpoints=checkpoint,
            voter_params=VoterParams(),
            predictor_params=predparams,
        )

        newprcs = []
        for prc in predictor.data.params.pre_proc.processors:
            prc = deepcopy(prc)
            if isinstance(prc, FinalPreparationProcessorParams):
                prc.normalize, prc.invert, prc.transpose = False, False, True
                newprcs.append(prc)
            elif isinstance(prc, PrepareSampleProcessorParams):
                newprcs.append(prc)
        predictor.data.params.pre_proc.processors = newprcs

        do_prediction = predictor.predict(data)

        all_voter_sentences = {}
        all_prediction_sentences = {}

        for s in do_prediction:
            _, (_, prediction), _ = s.inputs, s.outputs, s.meta
            sentence = prediction.sentence
            if prediction.voter_predictions is not None and output_individual_voters:
                for i, p in enumerate(prediction.voter_predictions):
                    if i not in all_prediction_sentences:
                        all_prediction_sentences[i] = {}
                    all_prediction_sentences[i][s.meta["id"]] = p.sentence
            all_voter_sentences[s.meta["id"]] = sentence

        # evaluation
        from calamari_ocr.ocr.evaluator import Evaluator, EvaluatorParams

        evaluator_params = EvaluatorParams(
            setup=predparams.pipeline,
            progress_bar=True,
            skip_empty_gt=True,
        )
        evaluator = Evaluator(evaluator_params, predictor.data)
        evaluator.preload_gt(gt_dataset=data, progress_bar=True)

        def single_evaluation(label, predicted_sentences):
            r = evaluator.evaluate(
                gt_data=evaluator.preloaded_gt, pred_data=predicted_sentences
            )

            print("=================")
            print(f"Evaluation result of {label}")
            print("=================")
            print("")
            print(
                "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
                    r["avg_ler"],
                    r["total_char_errs"],
                    r["total_chars"],
                    r["total_sync_errs"],
                )
            )
            print()
            print()

            # sort descending
            print_confusions(r, n_confusions)

            return r

        full_evaluation = {}
        for id, data in [
            (str(i), sent) for i, sent in all_prediction_sentences.items()
        ] + [("voted", all_voter_sentences)]:
            full_evaluation[id] = {"eval": single_evaluation(id, data), "data": data}

        if not predparams.silent:
            print(full_evaluation)

        return full_evaluation
def main(args: PredictAndEvalArgs):
    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    from calamari_ocr.ocr.predict.predictor import MultiPredictor

    voter_params = VoterParams()
    predictor = MultiPredictor.from_paths(
        checkpoints=args.checkpoint,
        voter_params=voter_params,
        predictor_params=args.predictor,
    )
    do_prediction = predictor.predict(args.data)

    all_voter_sentences = {}
    all_prediction_sentences = {}

    for s in do_prediction:
        inputs, (result, prediction), meta = s.inputs, s.outputs, s.meta
        sentence = prediction.sentence
        if prediction.voter_predictions is not None and args.output_individual_voters:
            for i, p in enumerate(prediction.voter_predictions):
                if i not in all_prediction_sentences:
                    all_prediction_sentences[i] = {}
                all_prediction_sentences[i][s.meta["id"]] = p.sentence
        all_voter_sentences[s.meta["id"]] = sentence

    # evaluation
    from calamari_ocr.ocr.evaluator import Evaluator, EvaluatorParams

    evaluator_params = EvaluatorParams(
        setup=args.predictor.pipeline,
        progress_bar=args.predictor.progress_bar,
        skip_empty_gt=args.skip_empty_gt,
    )
    evaluator = Evaluator(evaluator_params, predictor.data)
    evaluator.preload_gt(gt_dataset=args.data, progress_bar=True)

    def single_evaluation(label, predicted_sentences):
        r = evaluator.evaluate(gt_data=evaluator.preloaded_gt,
                               pred_data=predicted_sentences)

        print("=================")
        print(f"Evaluation result of {label}")
        print("=================")
        print("")
        print(
            "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
            .format(
                r["avg_ler"],
                r["total_char_errs"],
                r["total_chars"],
                r["total_sync_errs"],
            ))
        print()
        print()

        # sort descending
        print_confusions(r, args.n_confusions)

        return r

    full_evaluation = {}
    for id, data in [
        (str(i), sent) for i, sent in all_prediction_sentences.items()
    ] + [("voted", all_voter_sentences)]:
        full_evaluation[id] = {
            "eval": single_evaluation(id, data),
            "data": data
        }

    if not args.predictor.silent:
        print(full_evaluation)

    if args.dump:
        import pickle

        with open(args.dump, "wb") as f:
            pickle.dump({
                "full": full_evaluation,
                "gt": evaluator.preloaded_gt
            }, f)

    return full_evaluation
Пример #4
0
def main():
    parser = argparse.ArgumentParser()

    # GENERAL/SHARED PARAMETERS
    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)

    parser.add_argument("--files",
                        nargs="+",
                        required=True,
                        default=[],
                        help="List all image files that shall be processed")
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help=
        "Optional list of additional text files. E.g. when updating Abbyy prediction, this parameter must be used for the xml files."
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--gt_extension",
                        type=str,
                        default=None,
                        help="Define the gt extension.")
    parser.add_argument("-j",
                        "--processes",
                        type=int,
                        default=1,
                        help="Number of processes to use")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help=
        "The batch size during the prediction (number of lines to process in parallel)"
    )
    parser.add_argument("--verbose",
                        action="store_true",
                        help="Print additional information")
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument("--dump",
                        type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")
    # dataset extra args
    parser.add_argument("--dataset_pad", default=None, nargs='+', type=int)
    parser.add_argument("--pagexml_text_index", default=1)

    # PREDICT PARAMETERS
    parser.add_argument("--checkpoint",
                        type=str,
                        nargs="+",
                        required=True,
                        help="Path to the checkpoint without file extension")

    # EVAL PARAMETERS
    parser.add_argument("--output_individual_voters",
                        action='store_true',
                        default=False)
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]
    # load files
    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    pipeline_params = PipelineParams(
        type=args.dataset,
        skip_invalid=not args.no_skip_invalid_gt,
        remove_invalid=True,
        files=args.files,
        gt_extension=args.gt_extension,
        text_files=args.text_files,
        data_reader_args=FileDataReaderArgs(
            pad=args.dataset_pad,
            text_index=args.pagexml_text_index,
        ),
        batch_size=args.batch_size,
        num_processes=args.processes,
    )

    from calamari_ocr.ocr.predict.predictor import MultiPredictor
    voter_params = VoterParams()
    predictor = MultiPredictor.from_paths(checkpoints=args.checkpoint,
                                          voter_params=voter_params,
                                          predictor_params=PredictorParams(
                                              silent=True, progress_bar=True))
    do_prediction = predictor.predict(pipeline_params)

    all_voter_sentences = []
    all_prediction_sentences = {}

    for s in do_prediction:
        inputs, (result, prediction), meta = s.inputs, s.outputs, s.meta
        sentence = prediction.sentence
        if prediction.voter_predictions is not None and args.output_individual_voters:
            for i, p in enumerate(prediction.voter_predictions):
                if i not in all_prediction_sentences:
                    all_prediction_sentences[i] = []
                all_prediction_sentences[i].append(p.sentence)
        all_voter_sentences.append(sentence)

    # evaluation
    from calamari_ocr.ocr.evaluator import Evaluator
    evaluator = Evaluator(predictor.data)
    evaluator.preload_gt(gt_dataset=pipeline_params, progress_bar=True)

    def single_evaluation(label, predicted_sentences):
        if len(predicted_sentences) != len(evaluator.preloaded_gt):
            raise Exception(
                "Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                "not succeed".format(len(evaluator.preloaded_gt),
                                     len(predicted_sentences)))

        r = evaluator.evaluate(gt_data=evaluator.preloaded_gt,
                               pred_data=predicted_sentences,
                               progress_bar=True,
                               processes=args.processes)

        print("=================")
        print(f"Evaluation result of {label}")
        print("=================")
        print("")
        print(
            "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
            .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                    r["total_sync_errs"]))
        print()
        print()

        # sort descending
        print_confusions(r, args.n_confusions)

        return r

    full_evaluation = {}
    for id, data in [
        (str(i), sent) for i, sent in all_prediction_sentences.items()
    ] + [('voted', all_voter_sentences)]:
        full_evaluation[id] = {
            "eval": single_evaluation(id, data),
            "data": data
        }

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump({
                "full": full_evaluation,
                "gt": evaluator.preloaded_gt
            }, f)
Пример #5
0
def main():
    # Local imports (imports that require tensorflow)
    from calamari_ocr.ocr.scenario import Scenario
    from calamari_ocr.ocr.dataset.data import Data
    from calamari_ocr.ocr.evaluator import Evaluator

    parser = ArgumentParser()
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--gt", nargs="+", required=True,
                        help="Ground truth files (.gt.txt extension). "
                             "Optionally, you can pass a single json file defining all parameters.")
    parser.add_argument("--pred", nargs="+", default=None,
                        help="Prediction files if provided. Else files with .pred.txt are expected at the same "
                             "location as the gt.")
    parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--pred_ext", type=str, default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument("--n_confusions", type=int, default=10,
                        help="Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument("--n_worst_lines", type=int, default=0,
                        help="Print the n worst recognized text lines with its error")
    parser.add_argument("--xlsx_output", type=str,
                        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads", type=int, default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument("--non_existing_file_handling_mode", type=str, default="error",
                        help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
                             "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
                             "'Empty' will handle this file as would it be empty (fully checking for errors)."
                             "'Error' will throw an exception if a file is not existing. This is the default behaviour.")
    parser.add_argument("--skip_empty_gt", action="store_true", default=False,
                        help="Ignore lines of the gt that are empty.")
    parser.add_argument("--no_progress_bars", action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)")

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)

    args = parser.parse_args()

    # check if loading a json file
    if len(args.gt) == 1 and args.gt[0].endswith("json"):
        with open(args.gt[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    logger.info("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    data_params = Data.get_default_params()
    if args.checkpoint:
        saved_model = SavedCalamariModel(args.checkpoint, auto_update=True)
        trainer_params = Scenario.trainer_params_from_dict(saved_model.dict)
        data_params = trainer_params.scenario_params.data_params

    data = Data(data_params)

    gt_reader_args = FileDataReaderArgs(
        text_index=args.pagexml_gt_text_index
    )
    pred_reader_args = FileDataReaderArgs(
        text_index=args.pagexml_pred_text_index
    )
    gt_data_set = PipelineParams(
        type=args.dataset,
        text_files=gt_files,
        data_reader_args=gt_reader_args,
        skip_invalid=args.skip_empty_gt,
    )
    pred_data_set = PipelineParams(
        type=args.pred_dataset,
        text_files=pred_files,
        data_reader_args=pred_reader_args,
    )

    evaluator = Evaluator(data=data)
    evaluator.preload_gt(gt_dataset=gt_data_set)
    r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print("Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
        r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, data.create_pipeline(PipelineMode.Targets, gt_data_set).reader().samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output,
                   [{
                       "prefix": "evaluation",
                       "results": r,
                       "gt_files": gt_files,
                   }])