コード例 #1
0
ファイル: predict.py プロジェクト: maxwehner/Calamari-OCR
def create_ctc_decoder_params(args):
    params = CTCDecoderParams()
    params.beam_width = args.beam_width
    params.word_separator = ' '

    if args.dictionary and len(args.dictionary) > 0:
        dictionary = set()
        print("Creating dictionary")
        for path in glob_all(args.dictionary):
            with open(path, 'r') as f:
                dictionary = dictionary.union(
                    {word
                     for word in f.read().split()})

        params.dictionary[:] = dictionary
        print("Dictionary with {} unique words successfully created.".format(
            len(dictionary)))
    else:
        args.dictionary = None

    if args.dictionary:
        print(
            "USING A LANGUAGE MODEL IS CURRENTLY EXPERIMENTAL ONLY. NOTE: THE PREDICTION IS VERY SLOW!"
        )
        params.type = CTCDecoderParams.CTC_WORD_BEAM_SEARCH

    return params
コード例 #2
0
ファイル: predict.py プロジェクト: jacektl/calamari
def prepare_ctc_decoder_params(ctc_decoder: CTCDecoderParams):
    if ctc_decoder.dictionary:
        dictionary = set()
        logger.info("Creating dictionary")
        for path in glob_all(ctc_decoder.dictionary):
            with open(path, "r") as f:
                dictionary = dictionary.union(
                    {word
                     for word in f.read().split()})

        ctc_decoder.dictionary = dictionary
        logger.info(
            "Dictionary with {} unique words successfully created.".format(
                len(dictionary)))

    if ctc_decoder.dictionary:
        logger.warning(
            "USING A LANGUAGE MODEL IS CURRENTLY EXPERIMENTAL ONLY. NOTE: THE PREDICTION IS VERY SLOW!"
        )
        ctc_decoder.type = CTCDecoderType.WordBeamSearch
コード例 #3
0
def run(args):

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    input_image_files = glob_all(args.files)
    if args.text_files:
        args.text_files = glob_all(args.text_files)

    # skip invalid files and remove them, there wont be predictions of invalid files
    dataset = create_dataset(
        args.dataset,
        DataSetMode.PREDICT,
        input_image_files,
        args.text_files,
        skip_invalid=True,
        remove_invalid=True,
        args={'text_index': args.pagexml_text_index},
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes)
    do_prediction = predictor.predict_dataset(
        dataset, progress_bar=not args.no_progress_bars)

    avg_sentence_confidence = 0
    n_predictions = 0

    # output the voted results to the appropriate files
    for result, sample in do_prediction:
        n_predictions += 1
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            print("{}: '{}{}{}'".format(sample['id'],
                                        lr[get_base_level(sentence)], sentence,
                                        "\u202C"))

        output_dir = args.output_dir

        dataset.store_text(sentence,
                           sample,
                           output_dir=output_dir,
                           extension=".pred.txt")

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample[
                'image_path'] if 'image_path' in sample else sample['id']
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(
                ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"),
                          'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"),
                          'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(
                        MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    print("Average sentence confidence: {:.2%}".format(
        avg_sentence_confidence / n_predictions))

    dataset.store()
    print("All files written")
コード例 #4
0
ファイル: predict.py プロジェクト: jacektl/calamari
def run(args: PredictArgs):
    # check if loading a json file
    # TODO: support running from JSON
    # if len(args.files) == 1 and args.files[0].endswith("json"):
    #     import json
    #     with open(args.files[0], 'r') as f:
    #        json_args = json.load(f)
    #        for key, value in json_args.items():
    #            setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create ctc decoder
    prepare_ctc_decoder_params(args.ctc_decoder)

    # predict for all models
    from calamari_ocr.ocr.predict.predictor import MultiPredictor

    predictor = MultiPredictor.from_paths(
        checkpoints=args.checkpoint,
        voter_params=args.voter,
        predictor_params=args.predictor,
    )
    do_prediction = predictor.predict(args.data)
    pipeline: CalamariPipeline = predictor.data.get_or_create_pipeline(
        predictor.params.pipeline, args.data)
    reader = pipeline.reader()
    if len(reader) == 0:
        raise Exception(
            "Empty dataset provided. Check your command line arguments or if the provided files are empty."
        )

    avg_sentence_confidence = 0
    n_predictions = 0

    reader.prepare_store()

    # output the voted results to the appropriate files
    for s in do_prediction:
        _, (result, prediction), meta = s.inputs, s.outputs, s.meta
        sample = reader.sample_by_id(meta["id"])
        n_predictions += 1
        sentence = prediction.sentence

        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            logger.info("{}: '{}{}{}'".format(meta["id"],
                                              lr[get_base_level(sentence)],
                                              sentence, "\u202C"))

        output_dir = args.output_dir if args.output_dir else os.path.dirname(
            prediction.line_path)

        reader.store_text_prediction(prediction,
                                     meta["id"],
                                     output_dir=output_dir)

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample[
                "image_path"] if "image_path" in sample else sample["id"]
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(
                ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                data = zlib.compress(
                    ps.to_json(indent=2, ensure_ascii=False).encode("utf-8"))
            elif args.extended_prediction_data_format == "json":
                # remove logits
                for p in ps.predictions:
                    p.logits = None

                data = ps.to_json(indent=2)
            else:
                raise Exception("Unknown prediction format.")

            reader.store_extended_prediction(
                data,
                sample,
                output_dir=output_dir,
                extension=args.extended_prediction_data_format,
            )

    logger.info("Average sentence confidence: {:.2%}".format(
        avg_sentence_confidence / n_predictions))

    reader.store()
    logger.info("All prediction files written")
コード例 #5
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--files",
                        nargs="+",
                        required=True,
                        default=[],
                        help="List all image files that shall be processed")
    parser.add_argument("--checkpoint",
                        type=str,
                        nargs="+",
                        default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j",
                        "--processes",
                        type=int,
                        default=1,
                        help="Number of processes to use")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help=
        "The batch size during the prediction (number of lines to process in parallel)"
    )
    parser.add_argument("--verbose",
                        action="store_true",
                        help="Print additional information")
    parser.add_argument(
        "--voter",
        type=str,
        default="confidence_voter_default_ctc",
        help=
        "The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
        "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument(
        "--output_dir",
        type=str,
        help=
        "By default the prediction files will be written to the same directory as the given files. "
        "You can use this argument to specify a specific output dir for the prediction files."
    )
    parser.add_argument(
        "--extended_prediction_data",
        action="store_true",
        help=
        "Write: Predicted string, labels; position, probabilities and alternatives of chars to a .pred (protobuf) file"
    )
    parser.add_argument(
        "--extended_prediction_data_format",
        type=str,
        default="json",
        help=
        "Extension format: Either pred or json. Note that json will not print logits."
    )
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")

    args = parser.parse_args()

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    input_image_files = sorted(glob_all(args.files))

    # skip invalid files, but keep then so that empty predictions are created
    dataset = FileDataSet(input_image_files,
                          skip_invalid=True,
                          remove_invalid=False)

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint)
    do_prediction = predictor.predict_dataset(
        dataset,
        batch_size=args.batch_size,
        processes=args.processes,
        progress_bar=not args.no_progress_bars)

    # output the voted results to the appropriate files
    for (result, sample), filepath in zip(do_prediction, input_image_files):
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        if args.verbose:
            print("{}: '{}'".format(sample['id'], sentence))

        output_dir = args.output_dir if args.output_dir else os.path.dirname(
            filepath)

        with codecs.open(os.path.join(output_dir, sample['id'] + ".pred.txt"),
                         'w', 'utf-8') as f:
            f.write(sentence)

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = filepath
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"),
                          'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"),
                          'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(
                        MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    print("All files written")
コード例 #6
0
def main():
    parser = argparse.ArgumentParser()
    setup_train_args(parser)
    args = parser.parse_args()

    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # parse whitelist
    whitelist = args.whitelist
    whitelist_files = glob_all(args.whitelist_files)
    for f in whitelist_files:
        with open(f) as txt:
            whitelist += list(txt.read())

    # Training dataset
    print("Resolving input files")
    input_image_files = glob_all(args.files)
    gt_txt_files = [split_all_ext(f)[0] + ".gt.txt" for f in input_image_files]

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = FileDataSet(input_image_files, gt_txt_files, skip_invalid=not args.no_skip_invalid_gt)
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        val_txt_files = [split_all_ext(f)[0] + ".gt.txt" for f in validation_image_files]

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = FileDataSet(validation_image_files, val_txt_files,
                                         skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    params = CheckpointParams()

    params.max_iters = args.max_iters
    params.stats_size = args.stats_size
    params.batch_size = args.batch_size
    params.checkpoint_frequency = args.checkpoint_frequency
    params.output_dir = args.output_dir
    params.output_model_prefix = args.output_model_prefix
    params.display = args.display
    params.skip_invalid_gt = not args.no_skip_invalid_gt
    params.processes = args.num_threads

    params.early_stopping_frequency = args.early_stopping_frequency if args.early_stopping_frequency >= 0 else args.checkpoint_frequency
    params.early_stopping_nbest = args.early_stopping_nbest
    params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix
    params.early_stopping_best_model_output_dir = \
        args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir

    params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER
    params.model.data_preprocessor.line_height = args.line_height
    params.model.data_preprocessor.pad = args.pad

    # Text pre processing (reading)
    params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_preprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    # Text post processing (prediction)
    params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER
    default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization)
    default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization)
    strip_processor_params = params.model.text_postprocessor.children.add()
    strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER

    if args.seed > 0:
        params.model.network.backend.random_seed = args.seed

    if args.bidi_dir:
        # change bidirectional text direction if desired
        bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR}

        bidi_processor_params = params.model.text_preprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

        bidi_processor_params = params.model.text_postprocessor.children.add()
        bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER
        bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir]

    params.model.line_height = args.line_height

    network_params_from_definition_string(args.network, params.model.network)
    params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper())
    params.model.network.clipping_constant = args.gradient_clipping_const
    params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path
    params.model.network.backend.num_inter_threads = args.num_inter_threads
    params.model.network.backend.num_intra_threads = args.num_intra_threads

    # create the actual trainer
    trainer = Trainer(params,
                      dataset,
                      validation_dataset=validation_dataset,
                      data_augmenter=SimpleDataAugmenter(),
                      n_augmentations=args.n_augmentations,
                      weights=args.weights,
                      codec_whitelist=whitelist,
                      )
    trainer.train(progress_bar=not args.no_progress_bars)
コード例 #7
0
ファイル: predict_abbyy.py プロジェクト: Joanzm/calamari
def run(args):
    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    files = glob.glob(args.files)
    dataset = AbbyyDataSet(files,
                           skip_invalid=True,
                           remove_invalid=False,
                           binary=args.binary)

    dataset.load_samples(processes=args.processes,
                         progress_bar=not args.no_progress_bars)

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes)
    do_prediction = predictor.predict_dataset(
        dataset, progress_bar=not args.no_progress_bars)

    # output the voted results to the appropriate files
    input_image_files = []

    # creat input_image_files list for next loop
    for page in dataset.book.pages:
        for fo in page.getFormats():
            input_image_files.append(page.imgFile)

    for (result, sample), filepath in zip(do_prediction, input_image_files):
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        if args.verbose:
            lr = "\u202A\u202B"
            print("{}: '{}{}{}'".format(sample['id'],
                                        lr[get_base_level(sentence)], sentence,
                                        "\u202C"))

        output_dir = args.output_dir if args.output_dir else os.path.dirname(
            filepath)

        sample["format"].text = sentence

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = filepath
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"),
                          'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"),
                          'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(
                        MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    w = XMLWriter(output_dir, os.path.dirname(filepath), dataset.book)
    w.write()

    print("All files written")
コード例 #8
0
ファイル: predict.py プロジェクト: AIRob/calamari
def run(args):
    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception("Only 'pred' and 'json' are allowed extended prediction data formats")

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json") for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterParams.Type.Value(args.voter.upper())
    voter = voter_from_proto(voter_params)

    # load files
    input_image_files = glob_all(args.files)
    if args.text_files:
        args.text_files = glob_all(args.text_files)

    # skip invalid files and remove them, there wont be predictions of invalid files
    dataset = create_dataset(
        args.dataset,
        DataSetMode.PREDICT,
        input_image_files,
        args.text_files,
        skip_invalid=True,
        remove_invalid=True,
        args={'text_index': args.pagexml_text_index},
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files))

    # predict for all models
    predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=not args.no_progress_bars)

    avg_sentence_confidence = 0
    n_predictions = 0

    # output the voted results to the appropriate files
    for result, sample in do_prediction:
        n_predictions += 1
        for i, p in enumerate(result):
            p.prediction.id = "fold_{}".format(i)

        # vote the results (if only one model is given, this will just return the sentences)
        prediction = voter.vote_prediction_result(result)
        prediction.id = "voted"
        sentence = prediction.sentence
        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            print("{}: '{}{}{}'".format(sample['id'], lr[get_base_level(sentence)], sentence, "\u202C" ))

        output_dir = args.output_dir

        dataset.store_text(sentence, sample, output_dir=output_dir, extension=".pred.txt")

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample['image_path'] if 'image_path' in sample else sample['id']
            ps.predictions.extend([prediction] + [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                with open(os.path.join(output_dir, sample['id'] + ".pred"), 'wb') as f:
                    f.write(ps.SerializeToString())
            elif args.extended_prediction_data_format == "json":
                with open(os.path.join(output_dir, sample['id'] + ".json"), 'w') as f:
                    # remove logits
                    for prediction in ps.predictions:
                        prediction.logits.rows = 0
                        prediction.logits.cols = 0
                        prediction.logits.data[:] = []

                    f.write(MessageToJson(ps, including_default_value_fields=True))
            else:
                raise Exception("Unknown prediction format.")

    print("Average sentence confidence: {:.2%}".format(avg_sentence_confidence / n_predictions))

    dataset.store()
    print("All files written")
コード例 #9
0
ファイル: predict.py プロジェクト: znsoftm/calamari
def run(args):
    # check if loading a json file
    if len(args.files) == 1 and args.files[0].endswith("json"):
        import json
        with open(args.files[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    # checks
    if args.extended_prediction_data_format not in ["pred", "json"]:
        raise Exception(
            "Only 'pred' and 'json' are allowed extended prediction data formats"
        )

    # add json as extension, resolve wildcard, expand user, ... and remove .json again
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]
    args.extension = args.extension if args.extension else DataSetType.pred_extension(
        args.dataset)

    # create ctc decoder
    ctc_decoder_params = create_ctc_decoder_params(args)

    # create voter
    voter_params = VoterParams()
    voter_params.type = VoterType(args.voter)

    # load files
    input_image_files = glob_all(args.files)
    if args.text_files:
        args.text_files = glob_all(args.text_files)

    # skip invalid files and remove them, there wont be predictions of invalid files
    predict_params = PipelineParams(
        type=args.dataset,
        skip_invalid=True,
        remove_invalid=True,
        files=input_image_files,
        text_files=args.text_files,
        data_reader_args=FileDataReaderArgs(
            pad=args.dataset_pad,
            text_index=args.pagexml_text_index,
        ),
        batch_size=args.batch_size,
        num_processes=args.processes,
    )

    # predict for all models
    # TODO: Use CTC Decoder params
    from calamari_ocr.ocr.predict.predictor import MultiPredictor
    predictor = MultiPredictor.from_paths(
        checkpoints=args.checkpoint,
        voter_params=voter_params,
        predictor_params=PredictorParams(
            silent=True, progress_bar=not args.no_progress_bars))
    do_prediction = predictor.predict(predict_params)
    pipeline: CalamariPipeline = predictor.data.get_predict_data(
        predict_params)
    reader = pipeline.reader()
    if len(reader) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    avg_sentence_confidence = 0
    n_predictions = 0

    reader.prepare_store()

    # output the voted results to the appropriate files
    for s in do_prediction:
        inputs, (result, prediction), meta = s.inputs, s.outputs, s.meta
        sample = reader.sample_by_id(meta['id'])
        n_predictions += 1
        sentence = prediction.sentence

        avg_sentence_confidence += prediction.avg_char_probability
        if args.verbose:
            lr = "\u202A\u202B"
            logger.info("{}: '{}{}{}'".format(meta['id'],
                                              lr[get_base_level(sentence)],
                                              sentence, "\u202C"))

        output_dir = args.output_dir

        reader.store_text(sentence,
                          sample,
                          output_dir=output_dir,
                          extension=args.extension)

        if args.extended_prediction_data:
            ps = Predictions()
            ps.line_path = sample[
                'image_path'] if 'image_path' in sample else sample['id']
            ps.predictions.extend([prediction] +
                                  [r.prediction for r in result])
            output_dir = output_dir if output_dir else os.path.dirname(
                ps.line_path)
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            if args.extended_prediction_data_format == "pred":
                data = zlib.compress(
                    ps.to_json(indent=2, ensure_ascii=False).encode('utf-8'))
            elif args.extended_prediction_data_format == "json":
                # remove logits
                for p in ps.predictions:
                    p.logits = None

                data = ps.to_json(indent=2)
            else:
                raise Exception("Unknown prediction format.")

            reader.store_extended_prediction(
                data,
                sample,
                output_dir=output_dir,
                extension=args.extended_prediction_data_format)

    logger.info("Average sentence confidence: {:.2%}".format(
        avg_sentence_confidence / n_predictions))

    reader.store(args.extension)
    logger.info("All prediction files written")