Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs",
                        type=str,
                        nargs="+",
                        required=True,
                        help="The evaluation files")
    parser.add_argument("--checkpoint",
                        type=str,
                        nargs="+",
                        default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j",
                        "--processes",
                        type=int,
                        default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose",
                        action="store_true",
                        help="Print additional information")
    parser.add_argument(
        "--voter",
        type=str,
        nargs="+",
        default=[
            "sequence_voter", "confidence_voter_default_ctc",
            "confidence_voter_fuzzy_ctc"
        ],
        help=
        "The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
        "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size",
                        type=int,
                        default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump",
                        type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp)
                       for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [
        split_all_ext(path)[0] + ".gt.txt"
        for path in sorted(glob_all(args.eval_imgs))
    ]

    dataset = FileDataSet(images=gt_images,
                          texts=gt_txts,
                          skip_invalid=not args.no_skip_invalid_gt)

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(
                voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(
        predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception(
                "Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = RawDataSet(texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set,
                          progress_bar=True,
                          processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [
        (str(i), sent) for i, sent in enumerate(all_prediction_sentences)
    ] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump(
                {
                    "full": full_evaluation,
                    "gt_txts": gt_txts,
                    "gt": dataset.text_samples()
                }, f)
Exemplo n.º 2
0
    def predict_books(
        self,
        books,
        checkpoint,
        cachefile=None,
        pageupload=True,
        text_index=1,
        pred_all=False,
    ):
        keras.backend.clear_session()
        if type(books) == str:
            books = [books]
        if type(checkpoint) == str:
            checkpoint = [checkpoint]
        checkpoint = [
            (cp if cp.endswith(".json") else cp + ".json") for cp in checkpoint
        ]
        checkpoint = glob_all(checkpoint)
        checkpoint = [cp[:-5] for cp in checkpoint]
        if cachefile is None:
            cachefile = self.cachefile
        verbose = False
        lids = list(
            lids_from_books(
                books,
                cachefile,
                complete_only=False,
                skip_commented=False,
                new_only=not pred_all,
            )
        )
        data = Nsh5(cachefile=cachefile, lines=lids)

        predparams = PredictorParams()
        predparams.device.gpus = [n for n, _ in enumerate(list_physical_devices("GPU"))]

        predictor = MultiPredictor.from_paths(
            checkpoints=checkpoint,
            voter_params=VoterParams(),
            predictor_params=predparams,
        )

        newprcs = []
        for prc in predictor.data.params.pre_proc.processors:
            prc = deepcopy(prc)
            if isinstance(prc, FinalPreparationProcessorParams):
                prc.normalize, prc.invert, prc.transpose = False, False, True
                newprcs.append(prc)
            elif isinstance(prc, PrepareSampleProcessorParams):
                newprcs.append(prc)
        predictor.data.params.pre_proc.processors = newprcs

        do_prediction = predictor.predict(data)
        pipeline = predictor.data.get_or_create_pipeline(
            predictor.params.pipeline, data
        )
        reader = pipeline.reader()
        if len(reader) == 0:
            raise Exception(
                "Empty dataset provided. Check your lines (got {})!".format(lids)
            )

        avg_sentence_confidence = 0
        n_predictions = 0

        reader.prepare_store()

        samples = []
        sentences = []
        # output the voted results to the appropriate files
        for s in do_prediction:
            _, (_, prediction), meta = s.inputs, s.outputs, s.meta
            sample = reader.sample_by_id(meta["id"])
            n_predictions += 1
            sentence = prediction.sentence

            avg_sentence_confidence += prediction.avg_char_probability
            if verbose:
                lr = "\u202A\u202B"
                logger.info(
                    "{}: '{}{}{}'".format(
                        meta["id"], lr[get_base_level(sentence)], sentence, "\u202C"
                    )
                )

            samples.append(sample)
            sentences.append(sentence)
            reader.store_text(sentence, sample, output_dir=None, extension=None)

        logger.info(
            "Average sentence confidence: {:.2%}".format(
                avg_sentence_confidence / n_predictions
            )
        )

        if pageupload:
            ocrdata = {}
            for lname, text in reader.predictions.items():
                _, b, p, ln = lname.split("/")
                if b not in ocrdata:
                    ocrdata[b] = {}
                if p not in ocrdata[b]:
                    ocrdata[b][p] = {}
                ocrdata[b][p][ln] = text

            data = {"ocrdata": ocrdata, "index": text_index}
            self.get_session().post(
                self.baseurl + "/_ocrdata",
                data=gzip.compress(json.dumps(data).encode("utf-8")),
                headers={
                    "Content-Type": "application/json;charset=UTF-8",
                    "Content-Encoding": "gzip",
                },
            )
            logger.info("Results uploaded")
        else:
            reader.store()
            logger.info("All prediction files written")
Exemplo n.º 3
0
    def evaluate_books(
        self,
        books,
        checkpoint,
        cachefile=None,
        output_individual_voters=False,
        n_confusions=10,
        silent=True,
    ):
        keras.backend.clear_session()
        if type(books) == str:
            books = [books]
        if type(checkpoint) == str:
            checkpoint = [checkpoint]
        checkpoint = [
            (cp if cp.endswith(".json") else cp + ".json") for cp in checkpoint
        ]
        checkpoint = glob_all(checkpoint)
        checkpoint = [cp[:-5] for cp in checkpoint]
        if cachefile is None:
            cachefile = self.cachefile

        lids = list(
            lids_from_books(books, cachefile, complete_only=True, skip_commented=True)
        )
        data = Nsh5(cachefile=cachefile, lines=lids)

        predparams = PredictorParams()
        predparams.device.gpus = [n for n, _ in enumerate(list_physical_devices("GPU"))]
        predparams.silent = silent

        predictor = MultiPredictor.from_paths(
            checkpoints=checkpoint,
            voter_params=VoterParams(),
            predictor_params=predparams,
        )

        newprcs = []
        for prc in predictor.data.params.pre_proc.processors:
            prc = deepcopy(prc)
            if isinstance(prc, FinalPreparationProcessorParams):
                prc.normalize, prc.invert, prc.transpose = False, False, True
                newprcs.append(prc)
            elif isinstance(prc, PrepareSampleProcessorParams):
                newprcs.append(prc)
        predictor.data.params.pre_proc.processors = newprcs

        do_prediction = predictor.predict(data)

        all_voter_sentences = {}
        all_prediction_sentences = {}

        for s in do_prediction:
            _, (_, prediction), _ = s.inputs, s.outputs, s.meta
            sentence = prediction.sentence
            if prediction.voter_predictions is not None and output_individual_voters:
                for i, p in enumerate(prediction.voter_predictions):
                    if i not in all_prediction_sentences:
                        all_prediction_sentences[i] = {}
                    all_prediction_sentences[i][s.meta["id"]] = p.sentence
            all_voter_sentences[s.meta["id"]] = sentence

        # evaluation
        from calamari_ocr.ocr.evaluator import Evaluator, EvaluatorParams

        evaluator_params = EvaluatorParams(
            setup=predparams.pipeline,
            progress_bar=True,
            skip_empty_gt=True,
        )
        evaluator = Evaluator(evaluator_params, predictor.data)
        evaluator.preload_gt(gt_dataset=data, progress_bar=True)

        def single_evaluation(label, predicted_sentences):
            r = evaluator.evaluate(
                gt_data=evaluator.preloaded_gt, pred_data=predicted_sentences
            )

            print("=================")
            print(f"Evaluation result of {label}")
            print("=================")
            print("")
            print(
                "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
                    r["avg_ler"],
                    r["total_char_errs"],
                    r["total_chars"],
                    r["total_sync_errs"],
                )
            )
            print()
            print()

            # sort descending
            print_confusions(r, n_confusions)

            return r

        full_evaluation = {}
        for id, data in [
            (str(i), sent) for i, sent in all_prediction_sentences.items()
        ] + [("voted", all_voter_sentences)]:
            full_evaluation[id] = {"eval": single_evaluation(id, data), "data": data}

        if not predparams.silent:
            print(full_evaluation)

        return full_evaluation
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs", type=str, nargs="+", required=True,
                        help="The evaluation files")
    parser.add_argument("--eval_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--checkpoint", type=str, nargs="+", default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j", "--processes", type=int, default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose", action="store_true",
                        help="Print additional information")
    parser.add_argument("--voter", type=str, nargs="+", default=["sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc"],
                        help="The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
                             "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size", type=int, default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump", type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs))]

    dataset = create_dataset(
        args.eval_dataset,
        DataSetMode.TRAIN,
        images=gt_images,
        texts=gt_txts,
        skip_invalid=not args.no_skip_invalid_gt
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                            "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = create_dataset(
            DataSetType.RAW,
            DataSetMode.EVAL,
            texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [(str(i), sent) for i, sent in enumerate(all_prediction_sentences)] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump({"full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples()}, f)
Exemplo n.º 5
0
 def __init__(self, voter_params, *args, **kwargs):
     super(MultiPredictor, self).__init__(*args, **kwargs)
     self.voter_params = voter_params or VoterParams()
def main(args: PredictAndEvalArgs):
    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    from calamari_ocr.ocr.predict.predictor import MultiPredictor

    voter_params = VoterParams()
    predictor = MultiPredictor.from_paths(
        checkpoints=args.checkpoint,
        voter_params=voter_params,
        predictor_params=args.predictor,
    )
    do_prediction = predictor.predict(args.data)

    all_voter_sentences = {}
    all_prediction_sentences = {}

    for s in do_prediction:
        inputs, (result, prediction), meta = s.inputs, s.outputs, s.meta
        sentence = prediction.sentence
        if prediction.voter_predictions is not None and args.output_individual_voters:
            for i, p in enumerate(prediction.voter_predictions):
                if i not in all_prediction_sentences:
                    all_prediction_sentences[i] = {}
                all_prediction_sentences[i][s.meta["id"]] = p.sentence
        all_voter_sentences[s.meta["id"]] = sentence

    # evaluation
    from calamari_ocr.ocr.evaluator import Evaluator, EvaluatorParams

    evaluator_params = EvaluatorParams(
        setup=args.predictor.pipeline,
        progress_bar=args.predictor.progress_bar,
        skip_empty_gt=args.skip_empty_gt,
    )
    evaluator = Evaluator(evaluator_params, predictor.data)
    evaluator.preload_gt(gt_dataset=args.data, progress_bar=True)

    def single_evaluation(label, predicted_sentences):
        r = evaluator.evaluate(gt_data=evaluator.preloaded_gt,
                               pred_data=predicted_sentences)

        print("=================")
        print(f"Evaluation result of {label}")
        print("=================")
        print("")
        print(
            "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
            .format(
                r["avg_ler"],
                r["total_char_errs"],
                r["total_chars"],
                r["total_sync_errs"],
            ))
        print()
        print()

        # sort descending
        print_confusions(r, args.n_confusions)

        return r

    full_evaluation = {}
    for id, data in [
        (str(i), sent) for i, sent in all_prediction_sentences.items()
    ] + [("voted", all_voter_sentences)]:
        full_evaluation[id] = {
            "eval": single_evaluation(id, data),
            "data": data
        }

    if not args.predictor.silent:
        print(full_evaluation)

    if args.dump:
        import pickle

        with open(args.dump, "wb") as f:
            pickle.dump({
                "full": full_evaluation,
                "gt": evaluator.preloaded_gt
            }, f)

    return full_evaluation
Exemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser()

    # GENERAL/SHARED PARAMETERS
    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)

    parser.add_argument("--files",
                        nargs="+",
                        required=True,
                        default=[],
                        help="List all image files that shall be processed")
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help=
        "Optional list of additional text files. E.g. when updating Abbyy prediction, this parameter must be used for the xml files."
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--gt_extension",
                        type=str,
                        default=None,
                        help="Define the gt extension.")
    parser.add_argument("-j",
                        "--processes",
                        type=int,
                        default=1,
                        help="Number of processes to use")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help=
        "The batch size during the prediction (number of lines to process in parallel)"
    )
    parser.add_argument("--verbose",
                        action="store_true",
                        help="Print additional information")
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument("--dump",
                        type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")
    # dataset extra args
    parser.add_argument("--dataset_pad", default=None, nargs='+', type=int)
    parser.add_argument("--pagexml_text_index", default=1)

    # PREDICT PARAMETERS
    parser.add_argument("--checkpoint",
                        type=str,
                        nargs="+",
                        required=True,
                        help="Path to the checkpoint without file extension")

    # EVAL PARAMETERS
    parser.add_argument("--output_individual_voters",
                        action='store_true',
                        default=False)
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]
    # load files
    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    pipeline_params = PipelineParams(
        type=args.dataset,
        skip_invalid=not args.no_skip_invalid_gt,
        remove_invalid=True,
        files=args.files,
        gt_extension=args.gt_extension,
        text_files=args.text_files,
        data_reader_args=FileDataReaderArgs(
            pad=args.dataset_pad,
            text_index=args.pagexml_text_index,
        ),
        batch_size=args.batch_size,
        num_processes=args.processes,
    )

    from calamari_ocr.ocr.predict.predictor import MultiPredictor
    voter_params = VoterParams()
    predictor = MultiPredictor.from_paths(checkpoints=args.checkpoint,
                                          voter_params=voter_params,
                                          predictor_params=PredictorParams(
                                              silent=True, progress_bar=True))
    do_prediction = predictor.predict(pipeline_params)

    all_voter_sentences = []
    all_prediction_sentences = {}

    for s in do_prediction:
        inputs, (result, prediction), meta = s.inputs, s.outputs, s.meta
        sentence = prediction.sentence
        if prediction.voter_predictions is not None and args.output_individual_voters:
            for i, p in enumerate(prediction.voter_predictions):
                if i not in all_prediction_sentences:
                    all_prediction_sentences[i] = []
                all_prediction_sentences[i].append(p.sentence)
        all_voter_sentences.append(sentence)

    # evaluation
    from calamari_ocr.ocr.evaluator import Evaluator
    evaluator = Evaluator(predictor.data)
    evaluator.preload_gt(gt_dataset=pipeline_params, progress_bar=True)

    def single_evaluation(label, predicted_sentences):
        if len(predicted_sentences) != len(evaluator.preloaded_gt):
            raise Exception(
                "Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                "not succeed".format(len(evaluator.preloaded_gt),
                                     len(predicted_sentences)))

        r = evaluator.evaluate(gt_data=evaluator.preloaded_gt,
                               pred_data=predicted_sentences,
                               progress_bar=True,
                               processes=args.processes)

        print("=================")
        print(f"Evaluation result of {label}")
        print("=================")
        print("")
        print(
            "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
            .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                    r["total_sync_errs"]))
        print()
        print()

        # sort descending
        print_confusions(r, args.n_confusions)

        return r

    full_evaluation = {}
    for id, data in [
        (str(i), sent) for i, sent in all_prediction_sentences.items()
    ] + [('voted', all_voter_sentences)]:
        full_evaluation[id] = {
            "eval": single_evaluation(id, data),
            "data": data
        }

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump({
                "full": full_evaluation,
                "gt": evaluator.preloaded_gt
            }, f)
Exemplo n.º 8
0
def run(args):
    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]
    args.extension = args.extension if args.extension else DataSetType.pred_extension(
        args.dataset)

    # create ctc decoder
    ctc_decoder_params = create_ctc_decoder_params(args)

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterType(args.voter)

    # load files
    input_image_files = glob_all(args.files)
    if args.text_files:
        args.text_files = glob_all(args.text_files)

    # skip invalid files and remove them, there wont be predictions of invalid files
    predict_params = PipelineParams(
        type=args.dataset,
        skip_invalid=True,
        remove_invalid=True,
        files=input_image_files,
        text_files=args.text_files,
        data_reader_args=FileDataReaderArgs(
            pad=args.dataset_pad,
            text_index=args.pagexml_text_index,
        ),
        batch_size=args.batch_size,
        num_processes=args.processes,
    )

    # predict for all models
    # TODO: Use CTC Decoder params
    from calamari_ocr.ocr.predict.predictor import MultiPredictor
    predictor = MultiPredictor.from_paths(
        checkpoints=args.checkpoint,
        voter_params=voter_params,
        predictor_params=PredictorParams(
            silent=True, progress_bar=not args.no_progress_bars))
    do_prediction = predictor.predict(predict_params)
    pipeline: CalamariPipeline = predictor.data.get_predict_data(
        predict_params)
    reader = pipeline.reader()
    if len(reader) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    avg_sentence_confidence = 0
    n_predictions = 0

    reader.prepare_store()

    # output the voted results to the appropriate files
    for s in do_prediction:
        inputs, (result, prediction), meta = s.inputs, s.outputs, s.meta
        sample = reader.sample_by_id(meta['id'])
        n_predictions += 1
        sentence = prediction.sentence

        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            logger.info("{}: '{}{}{}'".format(meta['id'],
                                              lr[get_base_level(sentence)],
                                              sentence, "\u202C"))

        output_dir = args.output_dir

        reader.store_text(sentence,
                          sample,
                          output_dir=output_dir,
                          extension=args.extension)

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample[
                'image_path'] if 'image_path' in sample else sample['id']
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(
                ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                data = zlib.compress(
                    ps.to_json(indent=2, ensure_ascii=False).encode('utf-8'))
            elif args.extended_prediction_data_format == "json":
                # remove logits
                for p in ps.predictions:
                    p.logits = None

                data = ps.to_json(indent=2)
            else:
                raise Exception("Unknown prediction format.")

            reader.store_extended_prediction(
                data,
                sample,
                output_dir=output_dir,
                extension=args.extended_prediction_data_format)

    logger.info("Average sentence confidence: {:.2%}".format(
        avg_sentence_confidence / n_predictions))

    reader.store(args.extension)
    logger.info("All prediction files written")