def __init__(self):
     self.dataset = DataSetType.FILE
     self.gt_extension = DataSetType.gt_extension(self.dataset)
     self.files = glob_all(
         [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")])
     self.seed = 24
     self.backend = "tensorflow"
     self.network = "cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5"
     self.line_height = 48
     self.pad = 16
     self.num_threads = 1
     self.display = 1
     self.batch_size = 1
     self.checkpoint_frequency = 1000
     self.epochs = 1
     self.samples_per_epoch = 8
     self.stats_size = 100
     self.no_skip_invalid_gt = False
     self.no_progress_bars = True
     self.output_dir = None
     self.output_model_prefix = "uw3_50lines"
     self.bidi_dir = None
     self.weights = None
     self.ema_weights = False
     self.whitelist_files = []
     self.whitelist = []
     self.gradient_clipping_norm = 5
     self.validation = None
     self.validation_dataset = DataSetType.FILE
     self.validation_extension = None
     self.validation_split_ratio = None
     self.early_stopping_frequency = -1
     self.early_stopping_nbest = 10
     self.early_stopping_at_accuracy = 0.99
     self.early_stopping_best_model_prefix = "uw3_50lines_best"
     self.early_stopping_best_model_output_dir = self.output_dir
     self.n_augmentations = 0
     self.num_inter_threads = 0
     self.num_intra_threads = 0
     self.text_regularization = ["extended"]
     self.text_normalization = "NFC"
     self.text_generator_params = None
     self.line_generator_params = None
     self.pagexml_text_index = 0
     self.text_files = None
     self.only_train_on_augmented = False
     self.data_preprocessing = [p.name for p in default_image_processors()]
     self.shuffle_buffer_size = 1000
     self.keep_loaded_codec = False
     self.train_data_on_the_fly = False
     self.validation_data_on_the_fly = False
     self.no_auto_compute_codec = False
     self.dataset_pad = 0
     self.debug = False
     self.train_verbose = True
     self.use_train_as_val = False
     self.ensemble = -1
     self.masking_mode = 1
 def __init__(self):
     self.dataset = DataSetType.FILE
     self.gt_extension = DataSetType.gt_extension(self.dataset)
     self.files = glob_all(
         [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")])
     self.seed = 24
     self.backend = "tensorflow"
     self.network = "cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5"
     self.line_height = 48
     self.pad = 16
     self.num_threads = 1
     self.display = 1
     self.batch_size = 1
     self.checkpoint_frequency = 1000
     self.max_iters = 1000
     self.stats_size = 100
     self.no_skip_invalid_gt = False
     self.no_progress_bars = True
     self.output_dir = os.path.join(this_dir, "test_models")
     self.output_model_prefix = "uw3_50lines"
     self.bidi_dir = None
     self.weights = None
     self.whitelist_files = []
     self.whitelist = []
     self.gradient_clipping_mode = "AUTO"
     self.gradient_clipping_const = 0
     self.validation = None
     self.validation_dataset = DataSetType.FILE
     self.validation_extension = None
     self.early_stopping_frequency = -1
     self.early_stopping_nbest = 10
     self.early_stopping_best_model_prefix = "uw3_50lines_best"
     self.early_stopping_best_model_output_dir = self.output_dir
     self.n_augmentations = 0
     self.fuzzy_ctc_library_path = ""
     self.num_inter_threads = 0
     self.num_intra_threads = 0
     self.text_regularization = ["extended"]
     self.text_normalization = "NFC"
     self.text_generator_params = None
     self.line_generator_params = None
     self.pagexml_text_index = 0
     self.text_files = None
     self.only_train_on_augmented = False
     self.data_preprocessing = [DataPreprocessorParams.DEFAULT_NORMALIZER]
     self.shuffle_buffer_size = 1000
     self.keep_loaded_codec = False
     self.train_data_on_the_fly = False
     self.validation_data_on_the_fly = False
     self.no_auto_compute_codec = False
示例#3
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="The checkpoint used to resume")

    # validation files
    parser.add_argument("--validation",
                        type=str,
                        nargs="+",
                        help="Validation line files used for early stopping")
    parser.add_argument(
        "--validation_text_files",
        nargs="+",
        default=None,
        help=
        "Optional list of validation GT files if they are in other directory")
    parser.add_argument(
        "--validation_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--validation_dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)

    # input files
    parser.add_argument(
        "--files",
        nargs="+",
        help=
        "List all image files that shall be processed. Ground truth fils with the same "
        "base name but with '.gt.txt' as extension are required at the same location"
    )
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help="Optional list of GT files if they are in other directory")
    parser.add_argument(
        "--gt_extension",
        default=None,
        help="Default extension of the gt files (expected to exist in same dir)"
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(
            args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [
            split_all_ext(f)[0] + args.gt_extension for f in input_image_files
        ]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(
            input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                    os.path.basename(gt))[0]:
                raise Exception(
                    "Expected identical basenames of file: {} and {}".format(
                        img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception(
            "Some image are occurring more than once in the data set.")

    dataset = create_dataset(args.dataset,
                             DataSetMode.TRAIN,
                             images=input_image_files,
                             texts=gt_txt_files,
                             skip_invalid=not args.no_skip_invalid_gt)
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [
                split_all_ext(f)[0] + args.validation_extension
                for f in validation_image_files
            ]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(
                validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(
                        os.path.basename(gt))[0]:
                    raise Exception(
                        "Expected identical basenames of validation file: {} and {}"
                        .format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception(
                "Some validation images are occurring more than once in the data set."
            )

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(
            len(validation_dataset)))
    else:
        validation_dataset = None

    print("Resuming training")
    with open(args.checkpoint + '.json', 'r') as f:
        checkpoint_params = json_format.Parse(f.read(), CheckpointParams())

        trainer = Trainer(checkpoint_params,
                          dataset,
                          validation_dataset=validation_dataset,
                          weights=args.checkpoint)
        trainer.train(progress_bar=True)
示例#4
0
def main():
    parser = argparse.ArgumentParser()

    # GENERAL/SHARED PARAMETERS
    parser.add_argument('--version',
                        action='version',
                        version='%(prog)s v' + __version__)

    parser.add_argument("--files",
                        nargs="+",
                        required=True,
                        default=[],
                        help="List all image files that shall be processed")
    parser.add_argument(
        "--text_files",
        nargs="+",
        default=None,
        help=
        "Optional list of additional text files. E.g. when updating Abbyy prediction, this parameter must be used for the xml files."
    )
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--gt_extension",
                        type=str,
                        default=None,
                        help="Define the gt extension.")
    parser.add_argument("-j",
                        "--processes",
                        type=int,
                        default=1,
                        help="Number of processes to use")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help=
        "The batch size during the prediction (number of lines to process in parallel)"
    )
    parser.add_argument("--verbose",
                        action="store_true",
                        help="Print additional information")
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument("--dump",
                        type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")
    # dataset extra args
    parser.add_argument("--dataset_pad", default=None, nargs='+', type=int)
    parser.add_argument("--pagexml_text_index", default=1)

    # PREDICT PARAMETERS
    parser.add_argument("--checkpoint",
                        type=str,
                        nargs="+",
                        required=True,
                        help="Path to the checkpoint without file extension")

    # EVAL PARAMETERS
    parser.add_argument("--output_individual_voters",
                        action='store_true',
                        default=False)
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json")
                       for cp in args.checkpoint]
    args.checkpoint = glob_all(args.checkpoint)
    args.checkpoint = [cp[:-5] for cp in args.checkpoint]
    # load files
    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    pipeline_params = PipelineParams(
        type=args.dataset,
        skip_invalid=not args.no_skip_invalid_gt,
        remove_invalid=True,
        files=args.files,
        gt_extension=args.gt_extension,
        text_files=args.text_files,
        data_reader_args=FileDataReaderArgs(
            pad=args.dataset_pad,
            text_index=args.pagexml_text_index,
        ),
        batch_size=args.batch_size,
        num_processes=args.processes,
    )

    from calamari_ocr.ocr.predict.predictor import MultiPredictor
    voter_params = VoterParams()
    predictor = MultiPredictor.from_paths(checkpoints=args.checkpoint,
                                          voter_params=voter_params,
                                          predictor_params=PredictorParams(
                                              silent=True, progress_bar=True))
    do_prediction = predictor.predict(pipeline_params)

    all_voter_sentences = []
    all_prediction_sentences = {}

    for s in do_prediction:
        inputs, (result, prediction), meta = s.inputs, s.outputs, s.meta
        sentence = prediction.sentence
        if prediction.voter_predictions is not None and args.output_individual_voters:
            for i, p in enumerate(prediction.voter_predictions):
                if i not in all_prediction_sentences:
                    all_prediction_sentences[i] = []
                all_prediction_sentences[i].append(p.sentence)
        all_voter_sentences.append(sentence)

    # evaluation
    from calamari_ocr.ocr.evaluator import Evaluator
    evaluator = Evaluator(predictor.data)
    evaluator.preload_gt(gt_dataset=pipeline_params, progress_bar=True)

    def single_evaluation(label, predicted_sentences):
        if len(predicted_sentences) != len(evaluator.preloaded_gt):
            raise Exception(
                "Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                "not succeed".format(len(evaluator.preloaded_gt),
                                     len(predicted_sentences)))

        r = evaluator.evaluate(gt_data=evaluator.preloaded_gt,
                               pred_data=predicted_sentences,
                               progress_bar=True,
                               processes=args.processes)

        print("=================")
        print(f"Evaluation result of {label}")
        print("=================")
        print("")
        print(
            "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
            .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                    r["total_sync_errs"]))
        print()
        print()

        # sort descending
        print_confusions(r, args.n_confusions)

        return r

    full_evaluation = {}
    for id, data in [
        (str(i), sent) for i, sent in all_prediction_sentences.items()
    ] + [('voted', all_voter_sentences)]:
        full_evaluation[id] = {
            "eval": single_evaluation(id, data),
            "data": data
        }

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump({
                "full": full_evaluation,
                "gt": evaluator.preloaded_gt
            }, f)
示例#5
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="The checkpoint used to resume")

    # validation files
    parser.add_argument("--validation", type=str, nargs="+",
                        help="Validation line files used for early stopping")
    parser.add_argument("--validation_text_files", nargs="+", default=None,
                        help="Optional list of validation GT files if they are in other directory")
    parser.add_argument("--validation_extension", default=None,
                        help="Default extension of the gt files (expected to exist in same dir)")
    parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)

    # input files
    parser.add_argument("--files", nargs="+",
                        help="List all image files that shall be processed. Ground truth fils with the same "
                             "base name but with '.gt.txt' as extension are required at the same location")
    parser.add_argument("--text_files", nargs="+", default=None,
                        help="Optional list of GT files if they are in other directory")
    parser.add_argument("--gt_extension", default=None,
                        help="Default extension of the gt files (expected to exist in same dir)")
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    if args.gt_extension is None:
        args.gt_extension = DataSetType.gt_extension(args.dataset)

    if args.validation_extension is None:
        args.validation_extension = DataSetType.gt_extension(args.validation_dataset)

    # Training dataset
    print("Resolving input files")
    input_image_files = sorted(glob_all(args.files))
    if not args.text_files:
        gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files]
    else:
        gt_txt_files = sorted(glob_all(args.text_files))
        input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files)
        for img, gt in zip(input_image_files, gt_txt_files):
            if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                raise Exception("Expected identical basenames of file: {} and {}".format(img, gt))

    if len(set(gt_txt_files)) != len(gt_txt_files):
        raise Exception("Some image are occurring more than once in the data set.")

    dataset = create_dataset(
        args.dataset,
        DataSetMode.TRAIN,
        images=input_image_files,
        texts=gt_txt_files,
        skip_invalid=not args.no_skip_invalid_gt
    )
    print("Found {} files in the dataset".format(len(dataset)))

    # Validation dataset
    if args.validation:
        print("Resolving validation files")
        validation_image_files = glob_all(args.validation)
        if not args.validation_text_files:
            val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files]
        else:
            val_txt_files = sorted(glob_all(args.validation_text_files))
            validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files)
            for img, gt in zip(validation_image_files, val_txt_files):
                if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]:
                    raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt))

        if len(set(val_txt_files)) != len(val_txt_files):
            raise Exception("Some validation images are occurring more than once in the data set.")

        validation_dataset = create_dataset(
            args.validation_dataset,
            DataSetMode.TRAIN,
            images=validation_image_files,
            texts=val_txt_files,
            skip_invalid=not args.no_skip_invalid_gt)
        print("Found {} files in the validation dataset".format(len(validation_dataset)))
    else:
        validation_dataset = None

    print("Resuming training")
    with open(args.checkpoint + '.json', 'r') as f:
        checkpoint_params = json_format.Parse(f.read(), CheckpointParams())

        trainer = Trainer(checkpoint_params, dataset,
                          validation_dataset=validation_dataset,
                          weights=args.checkpoint)
        trainer.train(progress_bar=True)