def __init__(self): self.dataset = DataSetType.FILE self.gt_extension = DataSetType.gt_extension(self.dataset) self.files = glob_all( [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]) self.seed = 24 self.backend = "tensorflow" self.network = "cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5" self.line_height = 48 self.pad = 16 self.num_threads = 1 self.display = 1 self.batch_size = 1 self.checkpoint_frequency = 1000 self.epochs = 1 self.samples_per_epoch = 8 self.stats_size = 100 self.no_skip_invalid_gt = False self.no_progress_bars = True self.output_dir = None self.output_model_prefix = "uw3_50lines" self.bidi_dir = None self.weights = None self.ema_weights = False self.whitelist_files = [] self.whitelist = [] self.gradient_clipping_norm = 5 self.validation = None self.validation_dataset = DataSetType.FILE self.validation_extension = None self.validation_split_ratio = None self.early_stopping_frequency = -1 self.early_stopping_nbest = 10 self.early_stopping_at_accuracy = 0.99 self.early_stopping_best_model_prefix = "uw3_50lines_best" self.early_stopping_best_model_output_dir = self.output_dir self.n_augmentations = 0 self.num_inter_threads = 0 self.num_intra_threads = 0 self.text_regularization = ["extended"] self.text_normalization = "NFC" self.text_generator_params = None self.line_generator_params = None self.pagexml_text_index = 0 self.text_files = None self.only_train_on_augmented = False self.data_preprocessing = [p.name for p in default_image_processors()] self.shuffle_buffer_size = 1000 self.keep_loaded_codec = False self.train_data_on_the_fly = False self.validation_data_on_the_fly = False self.no_auto_compute_codec = False self.dataset_pad = 0 self.debug = False self.train_verbose = True self.use_train_as_val = False self.ensemble = -1 self.masking_mode = 1
def __init__(self): self.dataset = DataSetType.FILE self.gt_extension = DataSetType.gt_extension(self.dataset) self.files = glob_all( [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]) self.seed = 24 self.backend = "tensorflow" self.network = "cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5" self.line_height = 48 self.pad = 16 self.num_threads = 1 self.display = 1 self.batch_size = 1 self.checkpoint_frequency = 1000 self.max_iters = 1000 self.stats_size = 100 self.no_skip_invalid_gt = False self.no_progress_bars = True self.output_dir = os.path.join(this_dir, "test_models") self.output_model_prefix = "uw3_50lines" self.bidi_dir = None self.weights = None self.whitelist_files = [] self.whitelist = [] self.gradient_clipping_mode = "AUTO" self.gradient_clipping_const = 0 self.validation = None self.validation_dataset = DataSetType.FILE self.validation_extension = None self.early_stopping_frequency = -1 self.early_stopping_nbest = 10 self.early_stopping_best_model_prefix = "uw3_50lines_best" self.early_stopping_best_model_output_dir = self.output_dir self.n_augmentations = 0 self.fuzzy_ctc_library_path = "" self.num_inter_threads = 0 self.num_intra_threads = 0 self.text_regularization = ["extended"] self.text_normalization = "NFC" self.text_generator_params = None self.line_generator_params = None self.pagexml_text_index = 0 self.text_files = None self.only_train_on_augmented = False self.data_preprocessing = [DataPreprocessorParams.DEFAULT_NORMALIZER] self.shuffle_buffer_size = 1000 self.keep_loaded_codec = False self.train_data_on_the_fly = False self.validation_data_on_the_fly = False self.no_auto_compute_codec = False
def main(): parser = ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True, help="The checkpoint used to resume") # validation files parser.add_argument("--validation", type=str, nargs="+", help="Validation line files used for early stopping") parser.add_argument( "--validation_text_files", nargs="+", default=None, help= "Optional list of validation GT files if they are in other directory") parser.add_argument( "--validation_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) # input files parser.add_argument( "--files", nargs="+", help= "List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location" ) parser.add_argument( "--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument( "--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument( "--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension( args.validation_dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: gt_txt_files = [ split_all_ext(f)[0] + args.gt_extension for f in input_image_files ] else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name( input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of file: {} and {}".format( img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception( "Some image are occurring more than once in the data set.") dataset = create_dataset(args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [ split_all_ext(f)[0] + args.validation_extension for f in validation_image_files ] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name( validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of validation file: {} and {}" .format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception( "Some validation images are occurring more than once in the data set." ) validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format( len(validation_dataset))) else: validation_dataset = None print("Resuming training") with open(args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) trainer = Trainer(checkpoint_params, dataset, validation_dataset=validation_dataset, weights=args.checkpoint) trainer.train(progress_bar=True)
def main(): parser = argparse.ArgumentParser() # GENERAL/SHARED PARAMETERS parser.add_argument('--version', action='version', version='%(prog)s v' + __version__) parser.add_argument("--files", nargs="+", required=True, default=[], help="List all image files that shall be processed") parser.add_argument( "--text_files", nargs="+", default=None, help= "Optional list of additional text files. E.g. when updating Abbyy prediction, this parameter must be used for the xml files." ) parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--gt_extension", type=str, default=None, help="Define the gt extension.") parser.add_argument("-j", "--processes", type=int, default=1, help="Number of processes to use") parser.add_argument( "--batch_size", type=int, default=1, help= "The batch size during the prediction (number of lines to process in parallel)" ) parser.add_argument("--verbose", action="store_true", help="Print additional information") parser.add_argument("--no_progress_bars", action="store_true", help="Do not show any progress bars") parser.add_argument("--dump", type=str, help="Dump the output as serialized pickle object") parser.add_argument( "--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") # dataset extra args parser.add_argument("--dataset_pad", default=None, nargs='+', type=int) parser.add_argument("--pagexml_text_index", default=1) # PREDICT PARAMETERS parser.add_argument("--checkpoint", type=str, nargs="+", required=True, help="Path to the checkpoint without file extension") # EVAL PARAMETERS parser.add_argument("--output_individual_voters", action='store_true', default=False) parser.add_argument( "--n_confusions", type=int, default=10, help= "Only print n most common confusions. Defaults to 10, use -1 for all.") args = parser.parse_args() # allow user to specify json file for model definition, but remove the file extension # for further processing args.checkpoint = [(cp if cp.endswith(".json") else cp + ".json") for cp in args.checkpoint] args.checkpoint = glob_all(args.checkpoint) args.checkpoint = [cp[:-5] for cp in args.checkpoint] # load files if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) pipeline_params = PipelineParams( type=args.dataset, skip_invalid=not args.no_skip_invalid_gt, remove_invalid=True, files=args.files, gt_extension=args.gt_extension, text_files=args.text_files, data_reader_args=FileDataReaderArgs( pad=args.dataset_pad, text_index=args.pagexml_text_index, ), batch_size=args.batch_size, num_processes=args.processes, ) from calamari_ocr.ocr.predict.predictor import MultiPredictor voter_params = VoterParams() predictor = MultiPredictor.from_paths(checkpoints=args.checkpoint, voter_params=voter_params, predictor_params=PredictorParams( silent=True, progress_bar=True)) do_prediction = predictor.predict(pipeline_params) all_voter_sentences = [] all_prediction_sentences = {} for s in do_prediction: inputs, (result, prediction), meta = s.inputs, s.outputs, s.meta sentence = prediction.sentence if prediction.voter_predictions is not None and args.output_individual_voters: for i, p in enumerate(prediction.voter_predictions): if i not in all_prediction_sentences: all_prediction_sentences[i] = [] all_prediction_sentences[i].append(p.sentence) all_voter_sentences.append(sentence) # evaluation from calamari_ocr.ocr.evaluator import Evaluator evaluator = Evaluator(predictor.data) evaluator.preload_gt(gt_dataset=pipeline_params, progress_bar=True) def single_evaluation(label, predicted_sentences): if len(predicted_sentences) != len(evaluator.preloaded_gt): raise Exception( "Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did " "not succeed".format(len(evaluator.preloaded_gt), len(predicted_sentences))) r = evaluator.evaluate(gt_data=evaluator.preloaded_gt, pred_data=predicted_sentences, progress_bar=True, processes=args.processes) print("=================") print(f"Evaluation result of {label}") print("=================") print("") print( "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)" .format(r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"])) print() print() # sort descending print_confusions(r, args.n_confusions) return r full_evaluation = {} for id, data in [ (str(i), sent) for i, sent in all_prediction_sentences.items() ] + [('voted', all_voter_sentences)]: full_evaluation[id] = { "eval": single_evaluation(id, data), "data": data } if args.verbose: print(full_evaluation) if args.dump: import pickle with open(args.dump, 'wb') as f: pickle.dump({ "full": full_evaluation, "gt": evaluator.preloaded_gt }, f)
def main(): parser = ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True, help="The checkpoint used to resume") # validation files parser.add_argument("--validation", type=str, nargs="+", help="Validation line files used for early stopping") parser.add_argument("--validation_text_files", nargs="+", default=None, help="Optional list of validation GT files if they are in other directory") parser.add_argument("--validation_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)") parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) # input files parser.add_argument("--files", nargs="+", help="List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location") parser.add_argument("--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument("--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)") parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension(args.validation_dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files] else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of file: {} and {}".format(img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception("Some image are occurring more than once in the data set.") dataset = create_dataset( args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt ) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception("Some validation images are occurring more than once in the data set.") validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format(len(validation_dataset))) else: validation_dataset = None print("Resuming training") with open(args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) trainer = Trainer(checkpoint_params, dataset, validation_dataset=validation_dataset, weights=args.checkpoint) trainer.train(progress_bar=True)