def __init__(self): self.dataset = DataSetType.FILE self.gt_extension = DataSetType.gt_extension(self.dataset) self.files = glob_all( [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]) self.seed = 24 self.backend = "tensorflow" self.network = "cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5" self.line_height = 48 self.pad = 16 self.num_threads = 1 self.display = 1 self.batch_size = 1 self.checkpoint_frequency = 1000 self.epochs = 1 self.samples_per_epoch = 8 self.stats_size = 100 self.no_skip_invalid_gt = False self.no_progress_bars = True self.output_dir = None self.output_model_prefix = "uw3_50lines" self.bidi_dir = None self.weights = None self.ema_weights = False self.whitelist_files = [] self.whitelist = [] self.gradient_clipping_norm = 5 self.validation = None self.validation_dataset = DataSetType.FILE self.validation_extension = None self.validation_split_ratio = None self.early_stopping_frequency = -1 self.early_stopping_nbest = 10 self.early_stopping_at_accuracy = 0.99 self.early_stopping_best_model_prefix = "uw3_50lines_best" self.early_stopping_best_model_output_dir = self.output_dir self.n_augmentations = 0 self.num_inter_threads = 0 self.num_intra_threads = 0 self.text_regularization = ["extended"] self.text_normalization = "NFC" self.text_generator_params = None self.line_generator_params = None self.pagexml_text_index = 0 self.text_files = None self.only_train_on_augmented = False self.data_preprocessing = [p.name for p in default_image_processors()] self.shuffle_buffer_size = 1000 self.keep_loaded_codec = False self.train_data_on_the_fly = False self.validation_data_on_the_fly = False self.no_auto_compute_codec = False self.dataset_pad = 0 self.debug = False self.train_verbose = True self.use_train_as_val = False self.ensemble = -1 self.masking_mode = 1
def default_params(cls) -> DataParams: params: DataParams = super(Data, cls).default_params() params.pre_proc = SequentialProcessorPipelineParams( run_parallel=True, processors=default_image_processors() + default_text_pre_processors() + [ AugmentationProcessorParams(modes={PipelineMode.TRAINING}), PrepareSampleProcessorParams(modes=INPUT_PROCESSOR), ], ) params.post_proc = SequentialProcessorPipelineParams( run_parallel=True, processors=[ ReshapeOutputsProcessorParams(), CTCDecoderProcessorParams(), ] + default_text_pre_processors()) return params
def get_default_params(cls) -> DataParams: params: DataParams = super(Data, cls).get_default_params() params.pre_processors_ = SamplePipelineParams( run_parallel=True, sample_processors=default_image_processors() + default_text_pre_processors() + [ DataProcessorFactoryParams(AugmentationProcessor.__name__, {PipelineMode.Training}), DataProcessorFactoryParams(PrepareSampleProcessor.__name__, INPUT_PROCESSOR), ], ) params.post_processors_ = SamplePipelineParams( run_parallel=True, sample_processors=[ DataProcessorFactoryParams(ReshapeOutputsProcessor.__name__), DataProcessorFactoryParams(CTCDecoderProcessor.__name__), ] + default_text_pre_processors()) return params
this_dir = os.path.dirname(os.path.realpath(__file__)) base_path = os.path.abspath(os.path.join(this_dir, "..", "..", "test", "data", "uw3_50lines", "train")) fdr = FileDataParams( num_processes=8, images=[os.path.join(base_path, "*.png")], limit=1000, ) params = DataParams( codec=Codec("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,:;-?+=_()*{}[]`@#$%^&'\""), downscale_factor=4, line_height=48, pre_proc=SequentialProcessorPipelineParams( run_parallel=True, processors=default_image_processors() + default_text_pre_processors() + [ AugmentationProcessorParams( modes={PipelineMode.TRAINING}, data_aug_params=DataAugmentationAmount(amount=2), ), PrepareSampleProcessorParams(modes=INPUT_PROCESSOR), ], ), post_proc=SequentialProcessorPipelineParams(run_parallel=False), train=fdr, val=fdr, input_channels=1, ) params = DataParams.from_json(params.to_json())
def setup_train_args(parser, omit=None): # required params for args from calamari_ocr.ocr.dataset.data import Data if omit is None: omit = [] parser.add_argument('--version', action='version', version='%(prog)s v' + __version__) if "files" not in omit: parser.add_argument( "--files", nargs="+", default=[], help= "List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location" ) parser.add_argument( "--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument( "--gt_extension", default=None, help= "Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument( "--train_data_on_the_fly", action='store_true', default=False, help= 'Instead of preloading all data during the training, load the data on the fly. ' 'This is slower, but might be required for limited RAM or large dataset' ) parser.add_argument( "--seed", type=int, default="0", help= "Seed for random operations. If negative or zero a 'random' seed is used" ) parser.add_argument( "--network", type=str, default="cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5", help="The network structure") parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") parser.add_argument("--num_threads", type=int, default=1, help="The number of threads to use for all operations") parser.add_argument( "--display", type=int, default=1, help= "Frequency of how often an output shall occur during training [epochs]" ) parser.add_argument("--batch_size", type=int, default=1, help="The batch size to use for training") parser.add_argument("--ema_weights", action="store_true", default=False, help="Use exponentially averaged weights") parser.add_argument( "--checkpoint_frequency", type=int, default=-1, help= "The frequency how often to write checkpoints during training [epochs]" "If -1 (default), the early_stopping_frequency will be used. If 0 no checkpoints are written" ) parser.add_argument( "--epochs", type=int, default=100, help="The number of iterations for training. " "If using early stopping, this is the maximum number of iterations") parser.add_argument( "--samples_per_epoch", type=float, default=-1, help= "The number of samples to process per epoch. By default the size of the training dataset." "If in (0,1) it is relative to the dataset size") parser.add_argument( "--early_stopping_at_accuracy", type=float, default=1.0, help="Stop training if the early stopping accuracy reaches this value") parser.add_argument( "--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") parser.add_argument("--no_progress_bars", action="store_true", help="Do not show any progress bars") if "output_dir" not in omit: parser.add_argument( "--output_dir", type=str, default="", help="Default directory where to store checkpoints and models") if "output_model_prefix" not in omit: parser.add_argument("--output_model_prefix", type=str, default="model_", help="Prefix for storing checkpoints and models") parser.add_argument( "--bidi_dir", type=str, default=None, choices=["ltr", "rtl", "auto"], help= "The default text direction when preprocessing bidirectional text. Supported values " "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and " "right-to-left, respectively") if "weights" not in omit: parser.add_argument("--weights", type=str, default=None, help="Load network weights from the given file.") parser.add_argument( "--no_auto_compute_codec", action='store_true', default=False, help="Do not compute the codec automatically. See also whitelist") parser.add_argument( "--whitelist_files", type=str, nargs="+", default=[], help= "Whitelist of txt files that may not be removed on restoring a model") parser.add_argument( "--whitelist", type=str, nargs="+", default=[], help= "Whitelist of characters that may not be removed on restoring a model. " "For large dataset you can use this to skip the automatic codec computation " "(see --no_auto_compute_codec)") parser.add_argument( "--keep_loaded_codec", action='store_true', default=False, help="Fully include the codec of the loaded model to the new codec") parser.add_argument("--ensemble", type=int, default=-1, help="Number of voting models") parser.add_argument("--masking_mode", type=int, default=0, help="Do not use") # TODO: remove # clipping parser.add_argument("--gradient_clipping_norm", type=float, default=5, help="Clipping constant of the norm of the gradients.") # early stopping if "validation" not in omit: parser.add_argument( "--validation", type=str, nargs="+", help="Validation line files used for early stopping") parser.add_argument( "--validation_text_files", nargs="+", default=None, help= "Optional list of validation GT files if they are in other directory" ) parser.add_argument( "--validation_extension", default=None, help= "Default extension of the gt files (expected to exist in same dir). " "By default same as gt_extension") parser.add_argument( "--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=None, help= "Default validation data set type. By default same as --dataset") parser.add_argument("--use_train_as_val", action='store_true', default=False) parser.add_argument( "--validation_split_ratio", type=float, default=None, help= "Use n percent of the training dataset for validation. Can not be used with --validation" ) parser.add_argument( "--validation_data_on_the_fly", action='store_true', default=False, help= 'Instead of preloading all data during the training, load the data on the fly. ' 'This is slower, but might be required for limited RAM or large dataset' ) parser.add_argument("--early_stopping_frequency", type=int, default=1, help="The frequency of early stopping [epochs].") parser.add_argument( "--early_stopping_nbest", type=int, default=5, help= "The number of models that must be worse than the current best model to stop" ) if "early_stopping_best_model_prefix" not in omit: parser.add_argument( "--early_stopping_best_model_prefix", type=str, default="best", help="The prefix of the best model using early stopping") if "early_stopping_best_model_output_dir" not in omit: parser.add_argument( "--early_stopping_best_model_output_dir", type=str, default=None, help="Path where to store the best model. Default is output_dir") parser.add_argument( "--n_augmentations", type=float, default=0, help= "Amount of data augmentation per line (done before training). If this number is < 1 " "the amount is relative.") parser.add_argument( "--only_train_on_augmented", action="store_true", default=False, help= "When training with augmentations usually the model is retrained in a second run with " "only the non augmented data. This will take longer. Use this flag to disable this " "behavior.") # text normalization/regularization parser.add_argument("--text_regularization", type=str, nargs="+", default=["extended"], help="Text regularization to apply.") parser.add_argument( "--text_normalization", type=str, default="NFC", help="Unicode text normalization to apply. Defaults to NFC") parser.add_argument( "--data_preprocessing", nargs="+", type=str, choices=[ k for k, p in Data.data_processor_factory().processors.items() if issubclass(p, ImageProcessor) ], default=[p.name for p in default_image_processors()]) # text/line generation params (loaded from json files) parser.add_argument("--text_generator_params", type=str, default=None) parser.add_argument("--line_generator_params", type=str, default=None) # additional dataset args parser.add_argument("--dataset_pad", default=None, nargs='+', type=int) parser.add_argument("--pagexml_text_index", default=0) parser.add_argument("--debug", action='store_true') parser.add_argument("--train_verbose", default=1)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--version', action='version', version='%(prog)s v' + __version__) parser.add_argument( "--files", nargs="+", help= "List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location", required=True) parser.add_argument( "--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument( "--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") parser.add_argument("--processes", type=int, default=1, help="The number of threads to use for all operations") parser.add_argument("--n_cols", type=int, default=1) parser.add_argument("--n_rows", type=int, default=5) parser.add_argument("--select", type=int, nargs="+", default=[]) # text normalization/regularization parser.add_argument( "--n_augmentations", type=float, default=0, help= "Amount of data augmentation per line (done before training). If this number is < 1 " "the amount is relative.") parser.add_argument("--text_regularization", type=str, nargs="+", default=["extended"], help="Text regularization to apply.") parser.add_argument( "--text_normalization", type=str, default="NFC", help="Unicode text normalization to apply. Defaults to NFC") parser.add_argument( "--data_preprocessing", nargs="+", type=str, choices=[ k for k, p in Data.data_processor_factory().processors.items() if issubclass(p, ImageProcessor) ], default=[p.name for p in default_image_processors()]) parser.add_argument( "--bidi_dir", type=str, default=None, choices=["ltr", "rtl", "auto"], help= "The default text direction when preprocessing bidirectional text. Supported values " "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and " "right-to-left, respectively") parser.add_argument("--preload", action='store_true', help='Simulate preloading') parser.add_argument("--as_validation", action='store_true', help="Access as validation instead of training data.") args = parser.parse_args() if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) dataset_args = FileDataReaderArgs(pad=args.pad, ) data_params: DataParams = Data.get_default_params() data_params.train = PipelineParams( type=args.dataset, remove_invalid=True, files=args.files, text_files=args.text_files, gt_extension=args.gt_extension, data_reader_args=dataset_args, batch_size=1, num_processes=args.processes, ) data_params.val = data_params.train data_params.pre_processors_ = SamplePipelineParams(run_parallel=True) data_params.post_processors_ = SamplePipelineParams(run_parallel=True) for p in args.data_preprocessing: p_p = Data.data_processor_factory().processors[p].default_params() if 'pad' in p_p: p_p['pad'] = args.pad data_params.pre_processors_.sample_processors.append( DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p)) # Text pre processing (reading) data_params.pre_processors_.sample_processors.extend([ DataProcessorFactoryParams( TextNormalizer.__name__, TARGETS_PROCESSOR, {'unicode_normalization': args.text_normalization}), DataProcessorFactoryParams( TextRegularizer.__name__, TARGETS_PROCESSOR, { 'replacements': default_text_regularizer_replacements(args.text_regularization) }), DataProcessorFactoryParams(StripTextProcessor.__name__, TARGETS_PROCESSOR) ]) # Text post processing (prediction) data_params.post_processors_.sample_processors.extend([ DataProcessorFactoryParams( TextNormalizer.__name__, TARGETS_PROCESSOR, {'unicode_normalization': args.text_normalization}), DataProcessorFactoryParams( TextRegularizer.__name__, TARGETS_PROCESSOR, { 'replacements': default_text_regularizer_replacements(args.text_regularization) }), DataProcessorFactoryParams(StripTextProcessor.__name__, TARGETS_PROCESSOR) ]) if args.bidi_dir: data_params.pre_processors_.sample_processors.append( DataProcessorFactoryParams(BidiTextProcessor.__name__, TARGETS_PROCESSOR, {'bidi_direction': args.bidi_dir})) data_params.post_processors_.sample_processors.append( DataProcessorFactoryParams(BidiTextProcessor.__name__, TARGETS_PROCESSOR, {'bidi_direction': args.bidi_dir})) data_params.pre_processors_.sample_processors.extend([ DataProcessorFactoryParams(AugmentationProcessor.__name__, {PipelineMode.Training}, {'augmenter_type': 'simple'}), # DataProcessorFactoryParams(PrepareSampleProcessor.__name__), # NOT THIS, since, we want to access raw input ]) data_params.data_aug_params = DataAugmentationAmount.from_factor( args.n_augmentations) data_params.line_height_ = args.line_height data = Data(data_params) data_pipeline = data.get_val_data( ) if args.as_validation else data.get_train_data() if not args.preload: reader: DataReader = data_pipeline.reader() if len(args.select) == 0: args.select = range(len(reader)) else: reader._samples = [reader.samples()[i] for i in args.select] else: data.preload() data_pipeline = data.get_val_data( ) if args.as_validation else data.get_train_data() samples = data_pipeline.samples if len(args.select) == 0: args.select = range(len(samples)) else: data_pipeline.samples = [samples[i] for i in args.select] f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all') row, col = 0, 0 with data_pipeline as dataset: for i, (id, sample) in enumerate( zip(args.select, dataset.generate_input_samples(auto_repeat=False))): line, text, params = sample if args.n_cols == 1: ax[row].imshow(line.transpose()) ax[row].set_title("ID: {}\n{}".format(id, text)) else: ax[row, col].imshow(line.transpose()) ax[row, col].set_title("ID: {}\n{}".format(id, text)) row += 1 if row == args.n_rows: row = 0 col += 1 if col == args.n_cols or i == len(dataset) - 1: plt.show() f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all') row, col = 0, 0