def default_image_processors() -> List[DataProcessorFactoryParams]: return [ DataProcessorFactoryParams(DataRangeNormalizer.__name__, INPUT_PROCESSOR), DataProcessorFactoryParams(CenterNormalizer.__name__, INPUT_PROCESSOR), DataProcessorFactoryParams(FinalPreparation.__name__, INPUT_PROCESSOR), ]
def default_text_pre_processors() -> List[DataProcessorFactoryParams]: return [ DataProcessorFactoryParams(BidiTextProcessor.__name__), DataProcessorFactoryParams(StripTextProcessor.__name__), DataProcessorFactoryParams(TextNormalizer.__name__), DataProcessorFactoryParams(TextRegularizer.__name__), ]
def get_default_params(cls) -> DataParams: params: DataParams = super(Data, cls).get_default_params() params.pre_processors_ = SamplePipelineParams( run_parallel=True, sample_processors=default_image_processors() + default_text_pre_processors() + [ DataProcessorFactoryParams(AugmentationProcessor.__name__, {PipelineMode.Training}), DataProcessorFactoryParams(PrepareSampleProcessor.__name__, INPUT_PROCESSOR), ], ) params.post_processors_ = SamplePipelineParams( run_parallel=True, sample_processors=[ DataProcessorFactoryParams(ReshapeOutputsProcessor.__name__), DataProcessorFactoryParams(CTCDecoderProcessor.__name__), ] + default_text_pre_processors()) return params
def run(args): # local imports (to prevent tensorflow from being imported to early) from calamari_ocr.ocr.scenario import Scenario from calamari_ocr.ocr.dataset.data import Data # check if loading a json file if len(args.files) == 1 and args.files[0].endswith("json"): with open(args.files[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): if key == 'dataset' or key == 'validation_dataset': setattr(args, key, DataSetType.from_string(value)) else: setattr(args, key, value) check_train_args(args) if args.output_dir is not None: args.output_dir = os.path.abspath(args.output_dir) setup_log(args.output_dir, append=False) # parse whitelist whitelist = args.whitelist if len(whitelist) == 1: whitelist = list(whitelist[0]) whitelist_files = glob_all(args.whitelist_files) for f in whitelist_files: with open(f) as txt: whitelist += list(txt.read()) if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension( args.validation_dataset) if args.dataset == DataSetType.GENERATED_LINE or args.validation_dataset == DataSetType.GENERATED_LINE: if args.text_generator_params is not None: with open(args.text_generator_params, 'r') as f: args.text_generator_params = TextGeneratorParams.from_json( f.read()) else: args.text_generator_params = TextGeneratorParams() if args.line_generator_params is not None: with open(args.line_generator_params, 'r') as f: args.line_generator_params = LineGeneratorParams.from_json( f.read()) else: args.line_generator_params = LineGeneratorParams() dataset_args = FileDataReaderArgs( line_generator_params=args.line_generator_params, text_generator_params=args.text_generator_params, pad=args.dataset_pad, text_index=args.pagexml_text_index, ) params: TrainerParams = Scenario.default_trainer_params() # ================================================================================================================= # Data Params data_params: DataParams = params.scenario_params.data_params data_params.train = PipelineParams( type=args.dataset, skip_invalid=not args.no_skip_invalid_gt, remove_invalid=True, files=args.files, text_files=args.text_files, gt_extension=args.gt_extension, data_reader_args=dataset_args, batch_size=args.batch_size, num_processes=args.num_threads, ) if args.validation_split_ratio: if args.validation is not None: raise ValueError("Set either validation_split_ratio or validation") if not 0 < args.validation_split_ratio < 1: raise ValueError("validation_split_ratio must be in (0, 1)") # resolve all files so we can split them data_params.train = data_params.train.prepare_for_mode( PipelineMode.Training) n = int(args.validation_split_ratio * len(data_params.train.files)) logger.info( f"Splitting training and validation files with ratio {args.validation_split_ratio}: " f"{n}/{len(data_params.train.files) - n} for validation/training.") indices = list(range(len(data_params.train.files))) shuffle(indices) all_files = data_params.train.files all_text_files = data_params.train.text_files # split train and val img/gt files. Use train settings data_params.train.files = [all_files[i] for i in indices[:n]] if all_text_files is not None: assert (len(all_text_files) == len(all_files)) data_params.train.text_files = [ all_text_files[i] for i in indices[:n] ] data_params.val = PipelineParams( type=args.dataset, skip_invalid=not args.no_skip_invalid_gt, remove_invalid=True, files=[all_files[i] for i in indices[n:]], text_files=[all_text_files[i] for i in indices[n:]] if data_params.train.text_files is not None else None, gt_extension=args.gt_extension, data_reader_args=dataset_args, batch_size=args.batch_size, num_processes=args.num_threads, ) elif args.validation: data_params.val = PipelineParams( type=args.validation_dataset, files=args.validation, text_files=args.validation_text_files, skip_invalid=not args.no_skip_invalid_gt, gt_extension=args.validation_extension, data_reader_args=dataset_args, batch_size=args.batch_size, num_processes=args.num_threads, ) else: data_params.val = None data_params.pre_processors_ = SamplePipelineParams(run_parallel=True) data_params.post_processors_.run_parallel = SamplePipelineParams( run_parallel=False, sample_processors=[ DataProcessorFactoryParams(ReshapeOutputsProcessor.__name__), DataProcessorFactoryParams(CTCDecoderProcessor.__name__), ]) for p in args.data_preprocessing: p_p = Data.data_processor_factory().processors[p].default_params() if 'pad' in p_p: p_p['pad'] = args.pad data_params.pre_processors_.sample_processors.append( DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p)) # Text pre processing (reading) data_params.pre_processors_.sample_processors.extend([ DataProcessorFactoryParams( TextNormalizer.__name__, TARGETS_PROCESSOR, {'unicode_normalization': args.text_normalization}), DataProcessorFactoryParams( TextRegularizer.__name__, TARGETS_PROCESSOR, { 'replacements': default_text_regularizer_replacements(args.text_regularization) }), DataProcessorFactoryParams(StripTextProcessor.__name__, TARGETS_PROCESSOR) ]) # Text post processing (prediction) data_params.post_processors_.sample_processors.extend([ DataProcessorFactoryParams( TextNormalizer.__name__, TARGETS_PROCESSOR, {'unicode_normalization': args.text_normalization}), DataProcessorFactoryParams( TextRegularizer.__name__, TARGETS_PROCESSOR, { 'replacements': default_text_regularizer_replacements(args.text_regularization) }), DataProcessorFactoryParams(StripTextProcessor.__name__, TARGETS_PROCESSOR) ]) if args.bidi_dir: data_params.pre_processors_.sample_processors.append( DataProcessorFactoryParams(BidiTextProcessor.__name__, TARGETS_PROCESSOR, {'bidi_direction': args.bidi_dir})) data_params.post_processors_.sample_processors.append( DataProcessorFactoryParams(BidiTextProcessor.__name__, TARGETS_PROCESSOR, {'bidi_direction': args.bidi_dir})) data_params.pre_processors_.sample_processors.extend([ DataProcessorFactoryParams(AugmentationProcessor.__name__, {PipelineMode.Training}, {'augmenter_type': 'simple'}), DataProcessorFactoryParams(PrepareSampleProcessor.__name__, INPUT_PROCESSOR), ]) data_params.data_aug_params = DataAugmentationAmount.from_factor( args.n_augmentations) data_params.line_height_ = args.line_height # ================================================================================================================= # Trainer Params params.calc_ema = args.ema_weights params.verbose = args.train_verbose params.force_eager = args.debug params.skip_model_load_test = not args.debug params.scenario_params.debug_graph_construction = args.debug params.epochs = args.epochs params.samples_per_epoch = int( args.samples_per_epoch) if args.samples_per_epoch >= 1 else -1 params.scale_epoch_size = abs( args.samples_per_epoch) if args.samples_per_epoch < 1 else 1 params.skip_load_model_test = True params.scenario_params.export_frozen = False params.checkpoint_save_freq_ = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency params.checkpoint_dir = args.output_dir params.test_every_n = args.display params.skip_invalid_gt = not args.no_skip_invalid_gt params.data_aug_retrain_on_original = not args.only_train_on_augmented params.use_training_as_validation = args.use_train_as_val if args.seed > 0: params.random_seed = args.seed params.optimizer_params.clip_grad = args.gradient_clipping_norm params.codec_whitelist = whitelist params.keep_loaded_codec = args.keep_loaded_codec params.preload_training = not args.train_data_on_the_fly params.preload_validation = not args.validation_data_on_the_fly params.warmstart_params.model = args.weights params.auto_compute_codec = not args.no_auto_compute_codec params.progress_bar = not args.no_progress_bars params.early_stopping_params.frequency = args.early_stopping_frequency params.early_stopping_params.upper_threshold = 0.9 params.early_stopping_params.lower_threshold = 1.0 - args.early_stopping_at_accuracy params.early_stopping_params.n_to_go = args.early_stopping_nbest params.early_stopping_params.best_model_name = '' params.early_stopping_params.best_model_output_dir = args.early_stopping_best_model_output_dir params.scenario_params.default_serve_dir_ = f'{args.early_stopping_best_model_prefix}.ckpt.h5' params.scenario_params.trainer_params_filename_ = f'{args.early_stopping_best_model_prefix}.ckpt.json' # ================================================================================================================= # Model params params_from_definition_string(args.network, params) params.scenario_params.model_params.ensemble = args.ensemble params.scenario_params.model_params.masking_mode = args.masking_mode scenario = Scenario(params.scenario_params) trainer = scenario.create_trainer(params) trainer.train()
files=[os.path.join(base_path, '*.png')], gt_extension=DataSetType.gt_extension(DataSetType.FILE), limit=1000, ) params = DataParams( codec=Codec( 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,:;-?+=_()*{}[]`@#$%^&\'"' ), downscale_factor_=4, line_height_=48, pre_processors_=SamplePipelineParams( run_parallel=True, sample_processors=default_image_processors() + default_text_pre_processors() + [ DataProcessorFactoryParams(AugmentationProcessor.__name__, {PipelineMode.Training}), DataProcessorFactoryParams(PrepareSampleProcessor.__name__), ], ), post_processors_=SamplePipelineParams(run_parallel=False), data_aug_params=DataAugmentationAmount(amount=2), train=fdr, val=fdr, input_channels=1, ) params = DataParams.from_json(params.to_json()) print(params.to_json(indent=2)) data = Data(params) pipeline: CalamariPipeline = data.get_train_data() if False:
def main(): parser = argparse.ArgumentParser() parser.add_argument('--version', action='version', version='%(prog)s v' + __version__) parser.add_argument( "--files", nargs="+", help= "List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location", required=True) parser.add_argument( "--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument( "--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") parser.add_argument("--processes", type=int, default=1, help="The number of threads to use for all operations") parser.add_argument("--n_cols", type=int, default=1) parser.add_argument("--n_rows", type=int, default=5) parser.add_argument("--select", type=int, nargs="+", default=[]) # text normalization/regularization parser.add_argument( "--n_augmentations", type=float, default=0, help= "Amount of data augmentation per line (done before training). If this number is < 1 " "the amount is relative.") parser.add_argument("--text_regularization", type=str, nargs="+", default=["extended"], help="Text regularization to apply.") parser.add_argument( "--text_normalization", type=str, default="NFC", help="Unicode text normalization to apply. Defaults to NFC") parser.add_argument( "--data_preprocessing", nargs="+", type=str, choices=[ k for k, p in Data.data_processor_factory().processors.items() if issubclass(p, ImageProcessor) ], default=[p.name for p in default_image_processors()]) parser.add_argument( "--bidi_dir", type=str, default=None, choices=["ltr", "rtl", "auto"], help= "The default text direction when preprocessing bidirectional text. Supported values " "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and " "right-to-left, respectively") parser.add_argument("--preload", action='store_true', help='Simulate preloading') parser.add_argument("--as_validation", action='store_true', help="Access as validation instead of training data.") args = parser.parse_args() if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) dataset_args = FileDataReaderArgs(pad=args.pad, ) data_params: DataParams = Data.get_default_params() data_params.train = PipelineParams( type=args.dataset, remove_invalid=True, files=args.files, text_files=args.text_files, gt_extension=args.gt_extension, data_reader_args=dataset_args, batch_size=1, num_processes=args.processes, ) data_params.val = data_params.train data_params.pre_processors_ = SamplePipelineParams(run_parallel=True) data_params.post_processors_ = SamplePipelineParams(run_parallel=True) for p in args.data_preprocessing: p_p = Data.data_processor_factory().processors[p].default_params() if 'pad' in p_p: p_p['pad'] = args.pad data_params.pre_processors_.sample_processors.append( DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p)) # Text pre processing (reading) data_params.pre_processors_.sample_processors.extend([ DataProcessorFactoryParams( TextNormalizer.__name__, TARGETS_PROCESSOR, {'unicode_normalization': args.text_normalization}), DataProcessorFactoryParams( TextRegularizer.__name__, TARGETS_PROCESSOR, { 'replacements': default_text_regularizer_replacements(args.text_regularization) }), DataProcessorFactoryParams(StripTextProcessor.__name__, TARGETS_PROCESSOR) ]) # Text post processing (prediction) data_params.post_processors_.sample_processors.extend([ DataProcessorFactoryParams( TextNormalizer.__name__, TARGETS_PROCESSOR, {'unicode_normalization': args.text_normalization}), DataProcessorFactoryParams( TextRegularizer.__name__, TARGETS_PROCESSOR, { 'replacements': default_text_regularizer_replacements(args.text_regularization) }), DataProcessorFactoryParams(StripTextProcessor.__name__, TARGETS_PROCESSOR) ]) if args.bidi_dir: data_params.pre_processors_.sample_processors.append( DataProcessorFactoryParams(BidiTextProcessor.__name__, TARGETS_PROCESSOR, {'bidi_direction': args.bidi_dir})) data_params.post_processors_.sample_processors.append( DataProcessorFactoryParams(BidiTextProcessor.__name__, TARGETS_PROCESSOR, {'bidi_direction': args.bidi_dir})) data_params.pre_processors_.sample_processors.extend([ DataProcessorFactoryParams(AugmentationProcessor.__name__, {PipelineMode.Training}, {'augmenter_type': 'simple'}), # DataProcessorFactoryParams(PrepareSampleProcessor.__name__), # NOT THIS, since, we want to access raw input ]) data_params.data_aug_params = DataAugmentationAmount.from_factor( args.n_augmentations) data_params.line_height_ = args.line_height data = Data(data_params) data_pipeline = data.get_val_data( ) if args.as_validation else data.get_train_data() if not args.preload: reader: DataReader = data_pipeline.reader() if len(args.select) == 0: args.select = range(len(reader)) else: reader._samples = [reader.samples()[i] for i in args.select] else: data.preload() data_pipeline = data.get_val_data( ) if args.as_validation else data.get_train_data() samples = data_pipeline.samples if len(args.select) == 0: args.select = range(len(samples)) else: data_pipeline.samples = [samples[i] for i in args.select] f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all') row, col = 0, 0 with data_pipeline as dataset: for i, (id, sample) in enumerate( zip(args.select, dataset.generate_input_samples(auto_repeat=False))): line, text, params = sample if args.n_cols == 1: ax[row].imshow(line.transpose()) ax[row].set_title("ID: {}\n{}".format(id, text)) else: ax[row, col].imshow(line.transpose()) ax[row, col].set_title("ID: {}\n{}".format(id, text)) row += 1 if row == args.n_rows: row = 0 col += 1 if col == args.n_cols or i == len(dataset) - 1: plt.show() f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all') row, col = 0, 0