def prepare_for_mode(self, mode: PipelineMode): logger.info("Resolving input files") input_image_files = sorted(glob_all(self.images)) if not self.texts: gt_txt_files = [split_all_ext(f)[0] + self.gt_extension for f in input_image_files] else: gt_txt_files = sorted(glob_all(self.texts)) if mode in INPUT_PROCESSOR: input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception(f"Expected identical basenames of file: {img} and {gt}") else: input_image_files = None if mode in {PipelineMode.TRAINING, PipelineMode.EVALUATION}: if len(set(gt_txt_files)) != len(gt_txt_files): logger.warning( "Some ground truth text files occur more than once in the data set " "(ignore this warning, if this was intended)." ) if len(set(input_image_files)) != len(input_image_files): logger.warning( "Some images occur more than once in the data set. " "This warning should usually not be ignored." ) self.images = input_image_files self.texts = gt_txt_files
def create_train_dataset(args, dataset_args=None): gt_extension = args.gt_extension if args.gt_extension is not None else DataSetType.gt_extension(args.dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: if gt_extension: gt_txt_files = [split_all_ext(f)[0] + gt_extension for f in input_image_files] else: gt_txt_files = [None] * len(input_image_files) else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of file: {} and {}".format(img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception("Some image are occurring more than once in the data set.") dataset = create_dataset( args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt, args=dataset_args if dataset_args else {}, ) print("Found {} files in the dataset".format(len(dataset))) return dataset
def create_dataset( type: DataSetType, mode: DataSetMode, images: List[str] = None, texts: List[str] = None, skip_invalid=False, remove_invalid=True, non_existing_as_empty=False, args: dict = None, ): if images is None: images = [] if texts is None: texts = [] if args is None: args = dict() if DataSetType.files(type): if images: images.sort() if texts: texts.sort() if images and texts and len(images) > 0 and len(texts) > 0: images, texts = keep_files_with_same_file_name(images, texts) if type == DataSetType.RAW: return RawDataSet(mode, images, texts) elif type == DataSetType.FILE: return FileDataSet(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty) elif type == DataSetType.ABBYY: return AbbyyDataSet(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty) elif type == DataSetType.PAGEXML: return PageXMLDataset(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty, args=args) elif type == DataSetType.EXTENDED_PREDICTION: from .extended_prediction_dataset import ExtendedPredictionDataSet return ExtendedPredictionDataSet(texts=texts) else: raise Exception("Unsupported dataset type {}".format(type))
def prepare_for_mode(self, mode: PipelineMode) -> 'PipelineParams': from calamari_ocr.ocr.dataset.datareader.factory import DataReaderFactory assert (self.type is not None) params_out = deepcopy(self) # Training dataset logger.info("Resolving input files") if isinstance(self.type, str): try: self.type = DataSetType.from_string(self.type) except ValueError: # Not a valid type, must be custom if self.type not in DataReaderFactory.CUSTOM_READERS: raise KeyError( f"DataSetType {self.type} is neither a standard DataSetType or preset as custom " f"reader ({list(DataReaderFactory.CUSTOM_READERS.keys())})" ) if not isinstance(self.type, str) and self.type not in { DataSetType.RAW, DataSetType.GENERATED_LINE }: input_image_files = sorted(glob_all( self.files)) if self.files else None if not self.text_files: if self.gt_extension: gt_txt_files = [ split_all_ext(f)[0] + self.gt_extension for f in input_image_files ] else: gt_txt_files = None else: gt_txt_files = sorted(glob_all(self.text_files)) if mode in INPUT_PROCESSOR: input_image_files, gt_txt_files = keep_files_with_same_file_name( input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext( os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of file: {} and {}" .format(img, gt)) else: input_image_files = None if mode in {PipelineMode.Training, PipelineMode.Evaluation}: if len(set(gt_txt_files)) != len(gt_txt_files): logger.warning( "Some ground truth text files occur more than once in the data set " "(ignore this warning, if this was intended).") if len(set(input_image_files)) != len(input_image_files): logger.warning( "Some images occur more than once in the data set. " "This warning should usually not be ignored.") params_out.files = input_image_files params_out.text_files = gt_txt_files return params_out
def data_reader_from_params(mode: PipelineMode, params: PipelineParams) -> DataReader: assert (params.type is not None) from calamari_ocr.ocr.dataset.dataset_factory import create_data_reader # Training dataset logger.info("Resolving input files") if params.type not in {DataSetType.RAW, DataSetType.GENERATED_LINE}: input_image_files = sorted(glob_all( params.files)) if params.files else None if not params.text_files: if params.gt_extension: gt_txt_files = [ split_all_ext(f)[0] + params.gt_extension for f in input_image_files ] else: gt_txt_files = None else: gt_txt_files = sorted(glob_all(params.text_files)) if mode in INPUT_PROCESSOR: input_image_files, gt_txt_files = keep_files_with_same_file_name( input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext( os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of file: {} and {}". format(img, gt)) else: input_image_files = None if mode in {PipelineMode.Training, PipelineMode.Evaluation}: if len(set(gt_txt_files)) != len(gt_txt_files): logger.warning( "Some ground truth text files occur more than once in the data set " "(ignore this warning, if this was intended).") if len(set(input_image_files)) != len(input_image_files): logger.warning( "Some images occur more than once in the data set. " "This warning should usually not be ignored.") else: input_image_files = params.files gt_txt_files = params.text_files dataset = create_data_reader( params.type, mode, images=input_image_files, texts=gt_txt_files, skip_invalid=params.skip_invalid, args=params.data_reader_args if params.data_reader_args else FileDataReaderArgs(), ) logger.info(f"Found {len(dataset)} files in the dataset") return dataset
def create_test_dataset( cfg: CfgNode, dataset_args=None ) -> Union[List[Union[RawDataSet, FileDataSet, AbbyyDataSet, PageXMLDataset, Hdf5DataSet, ExtendedPredictionDataSet, GeneratedLineDataset]], None]: if cfg.DATASET.VALID.TEXT_FILES: assert len(cfg.DATASET.VALID.PATH) == len(cfg.DATASET.VALID.TEXT_FILES) if cfg.DATASET.VALID.PATH: validation_dataset_list = [] print("Resolving validation files") for i, valid_path in enumerate(cfg.DATASET.VALID.PATH): validation_image_files = glob_all(valid_path) dataregistry.register( i, os.path.basename(os.path.dirname(valid_path)), len(validation_image_files)) if not cfg.DATASET.VALID.TEXT_FILES: val_txt_files = [ split_all_ext(f)[0] + cfg.DATASET.VALID.GT_EXTENSION for f in validation_image_files ] else: val_txt_files = sorted( glob_all(cfg.DATASET.VALID.TEXT_FILES[i])) validation_image_files, val_txt_files = keep_files_with_same_file_name( validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext( os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of validation file: {} and {}" .format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception( "Some validation images are occurring more than once in the data set." ) validation_dataset = create_dataset( cfg.DATASET.VALID.TYPE, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not cfg.DATALOADER.NO_SKIP_INVALID_GT, args=dataset_args, ) print("Found {} files in the validation dataset".format( len(validation_dataset))) validation_dataset_list.append(validation_dataset) else: validation_dataset_list = None return validation_dataset_list
def create_dataset( type: DataSetType, mode: DataSetMode, images=list(), texts=list(), skip_invalid=False, remove_invalid=True, non_existing_as_empty=False, args=dict(), ): if DataSetType.files(type): if images: images.sort() if texts: texts.sort() if images and texts and len(images) > 0 and len(texts) > 0: images, texts = keep_files_with_same_file_name(images, texts) if type == DataSetType.RAW: return RawDataSet(mode, images, texts) elif type == DataSetType.FILE: return FileDataSet(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty) elif type == DataSetType.ABBYY: return AbbyyDataSet(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty) elif type == DataSetType.PAGEXML: return PageXMLDataset(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty, args=args) else: raise Exception("Unsuppoted dataset type {}".format(type))
def create_dataset(type: DataSetType, mode: DataSetMode, images = list(), texts = list(), skip_invalid=False, remove_invalid=True, non_existing_as_empty=False, args = dict(), ): if DataSetType.files(type): if images: images.sort() if texts: texts.sort() if images and texts and len(images) > 0 and len(texts) > 0: images, texts = keep_files_with_same_file_name(images, texts) if type == DataSetType.RAW: return RawDataSet(mode, images, texts) elif type == DataSetType.FILE: return FileDataSet(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty) elif type == DataSetType.ABBYY: return AbbyyDataSet(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty) elif type == DataSetType.PAGEXML: return PageXMLDataset(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty, args=args) elif type == DataSetType.EXTENDED_PREDICTION: from .extended_prediction_dataset import ExtendedPredictionDataSet return ExtendedPredictionDataSet(texts=texts) else: raise Exception("Unsuppoted dataset type {}".format(type))
def create_train_dataset(cfg: CfgNode, dataset_args=None): gt_extension = cfg.DATASET.TRAIN.GT_EXTENSION if cfg.DATASET.TRAIN.GT_EXTENSION is not False else DataSetType.gt_extension( cfg.DATASET.TRAIN.TYPE) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(cfg.DATASET.TRAIN.PATH)) if not cfg.DATASET.TRAIN.TEXT_FILES: if gt_extension: gt_txt_files = [ split_all_ext(f)[0] + gt_extension for f in input_image_files ] else: gt_txt_files = [None] * len(input_image_files) else: gt_txt_files = sorted(glob_all(cfg.DATASET.TRAIN.TEXT_FILES)) input_image_files, gt_txt_files = keep_files_with_same_file_name( input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of file: {} and {}".format( img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception( "Some image are occurring more than once in the data set.") dataset = create_dataset( cfg.DATASET.TRAIN.TYPE, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not cfg.DATALOADER.NO_SKIP_INVALID_GT, args=dataset_args if dataset_args else {}, ) print("Found {} files in the dataset".format(len(dataset))) return dataset
def create_data_reader( cls, type: Union[DataSetType, str], mode: PipelineMode, images: List[str] = None, texts: List[str] = None, skip_invalid=False, remove_invalid=True, non_existing_as_empty=False, args: FileDataReaderArgs = None, ) -> DataReader: if images is None: images = [] if texts is None: texts = [] if args is None: args = dict() if type in cls.CUSTOM_READERS: return cls.CUSTOM_READERS[type]( mode=mode, images=images, texts=texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty, args=args, ) if DataSetType.files(type): if images: images.sort() if texts: texts.sort() if images and texts and len(images) > 0 and len(texts) > 0: images, texts = keep_files_with_same_file_name(images, texts) if type == DataSetType.RAW: from calamari_ocr.ocr.dataset.datareader.raw import RawDataReader return RawDataReader(mode, images, texts) elif type == DataSetType.FILE: from calamari_ocr.ocr.dataset.datareader.file import FileDataReader return FileDataReader(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty) elif type == DataSetType.ABBYY: from calamari_ocr.ocr.dataset.datareader.abbyy import AbbyyReader return AbbyyReader(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty) elif type == DataSetType.PAGEXML: from calamari_ocr.ocr.dataset.datareader.pagexml.reader import PageXMLReader return PageXMLReader(mode, images, texts, skip_invalid=skip_invalid, remove_invalid=remove_invalid, non_existing_as_empty=non_existing_as_empty, args=args) elif type == DataSetType.HDF5: from calamari_ocr.ocr.dataset.datareader.hdf5 import Hdf5Reader return Hdf5Reader(mode, images, texts) elif type == DataSetType.EXTENDED_PREDICTION: from calamari_ocr.ocr.dataset.extended_prediction_dataset import ExtendedPredictionDataSet return ExtendedPredictionDataSet(texts=texts) elif type == DataSetType.GENERATED_LINE: from calamari_ocr.ocr.dataset.datareader.generated_line_dataset.dataset import GeneratedLineDataset return GeneratedLineDataset(mode, args=args) else: raise Exception("Unsupported dataset type {}".format(type))
def main(): parser = ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True, help="The checkpoint used to resume") # validation files parser.add_argument("--validation", type=str, nargs="+", help="Validation line files used for early stopping") parser.add_argument( "--validation_text_files", nargs="+", default=None, help= "Optional list of validation GT files if they are in other directory") parser.add_argument( "--validation_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) # input files parser.add_argument( "--files", nargs="+", help= "List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location" ) parser.add_argument( "--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument( "--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument( "--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension( args.validation_dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: gt_txt_files = [ split_all_ext(f)[0] + args.gt_extension for f in input_image_files ] else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name( input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of file: {} and {}".format( img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception( "Some image are occurring more than once in the data set.") dataset = create_dataset(args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [ split_all_ext(f)[0] + args.validation_extension for f in validation_image_files ] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name( validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of validation file: {} and {}" .format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception( "Some validation images are occurring more than once in the data set." ) validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format( len(validation_dataset))) else: validation_dataset = None print("Resuming training") with open(args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) trainer = Trainer(checkpoint_params, dataset, validation_dataset=validation_dataset, weights=args.checkpoint) trainer.train(progress_bar=True)
def run(args): # check if loading a json file if len(args.files) == 1 and args.files[0].endswith("json"): import json with open(args.files[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): if key == 'dataset' or key == 'validation_dataset': setattr(args, key, DataSetType.from_string(value)) else: setattr(args, key, value) # parse whitelist whitelist = args.whitelist if len(whitelist) == 1: whitelist = list(whitelist[0]) whitelist_files = glob_all(args.whitelist_files) for f in whitelist_files: with open(f) as txt: whitelist += list(txt.read()) if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension(args.validation_dataset) if args.text_generator_params is not None: with open(args.text_generator_params, 'r') as f: args.text_generator_params = json_format.Parse(f.read(), TextGeneratorParameters()) else: args.text_generator_params = TextGeneratorParameters() if args.line_generator_params is not None: with open(args.line_generator_params, 'r') as f: args.line_generator_params = json_format.Parse(f.read(), LineGeneratorParameters()) else: args.line_generator_params = LineGeneratorParameters() dataset_args = { 'line_generator_params': args.line_generator_params, 'text_generator_params': args.text_generator_params, 'pad': args.dataset_pad, 'text_index': args.pagexml_text_index, } # Training dataset dataset = create_train_dataset(args, dataset_args) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception("Some validation images are occurring more than once in the data set.") validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt, args=dataset_args, ) print("Found {} files in the validation dataset".format(len(validation_dataset))) else: validation_dataset = None params = CheckpointParams() params.max_iters = args.max_iters params.stats_size = args.stats_size params.batch_size = args.batch_size params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency params.output_dir = args.output_dir params.output_model_prefix = args.output_model_prefix params.display = args.display params.skip_invalid_gt = not args.no_skip_invalid_gt params.processes = args.num_threads params.data_aug_retrain_on_original = not args.only_train_on_augmented params.early_stopping_frequency = args.early_stopping_frequency params.early_stopping_nbest = args.early_stopping_nbest params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix params.early_stopping_best_model_output_dir = \ args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir if args.data_preprocessing is None or len(args.data_preprocessing) == 0: args.data_preprocessing = [DataPreprocessorParams.DEFAULT_NORMALIZER] params.model.data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER for preproc in args.data_preprocessing: pp = params.model.data_preprocessor.children.add() pp.type = DataPreprocessorParams.Type.Value(preproc) if isinstance(preproc, str) else preproc pp.line_height = args.line_height pp.pad = args.pad # Text pre processing (reading) params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_preprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER # Text post processing (prediction) params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_postprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER if args.seed > 0: params.model.network.backend.random_seed = args.seed if args.bidi_dir: # change bidirectional text direction if desired bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR, "auto": TextProcessorParams.BIDI_AUTO} bidi_processor_params = params.model.text_preprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir] bidi_processor_params = params.model.text_postprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO params.model.line_height = args.line_height network_params_from_definition_string(args.network, params.model.network) params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper()) params.model.network.clipping_constant = args.gradient_clipping_const params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path params.model.network.backend.num_inter_threads = args.num_inter_threads params.model.network.backend.num_intra_threads = args.num_intra_threads params.model.network.backend.shuffle_buffer_size = args.shuffle_buffer_size # create the actual trainer trainer = Trainer(params, dataset, validation_dataset=validation_dataset, data_augmenter=SimpleDataAugmenter(), n_augmentations=args.n_augmentations, weights=args.weights, codec_whitelist=whitelist, keep_loaded_codec=args.keep_loaded_codec, preload_training=not args.train_data_on_the_fly, preload_validation=not args.validation_data_on_the_fly, ) trainer.train( auto_compute_codec=not args.no_auto_compute_codec, progress_bar=not args.no_progress_bars )
def main(): parser = argparse.ArgumentParser() parser.add_argument('--version', action='version', version='%(prog)s v' + __version__) parser.add_argument( "--files", nargs="+", help= "List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location", required=True) parser.add_argument( "--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument( "--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") parser.add_argument("--processes", type=int, default=1, help="The number of threads to use for all operations") parser.add_argument("--n_cols", type=int, default=1) parser.add_argument("--n_rows", type=int, default=5) parser.add_argument("--select", type=int, nargs="+", default=[]) # text normalization/regularization parser.add_argument( "--n_augmentations", type=float, default=0, help= "Amount of data augmentation per line (done before training). If this number is < 1 " "the amount is relative.") parser.add_argument("--text_regularization", type=str, nargs="+", default=["extended"], help="Text regularization to apply.") parser.add_argument( "--text_normalization", type=str, default="NFC", help="Unicode text normalization to apply. Defaults to NFC") parser.add_argument("--data_preprocessing", nargs="+", type=DataPreprocessorParams.Type.Value, choices=DataPreprocessorParams.Type.values(), default=[DataPreprocessorParams.DEFAULT_NORMALIZER]) args = parser.parse_args() # Text/Data processing if args.data_preprocessing is None or len(args.data_preprocessing) == 0: args.data_preprocessing = [DataPreprocessorParams.DEFAULT_NORMALIZER] data_preprocessor = DataPreprocessorParams() data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER for preproc in args.data_preprocessing: pp = data_preprocessor.children.add() pp.type = preproc pp.line_height = args.line_height pp.pad = args.pad # Text pre processing (reading) text_preprocessor = TextProcessorParams() text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(text_preprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(text_preprocessor.children.add(), groups=args.text_regularization) strip_processor_params = text_preprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER text_preprocessor = text_processor_from_proto(text_preprocessor) data_preprocessor = data_processor_from_proto(data_preprocessor) print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: if args.gt_extension: gt_txt_files = [ split_all_ext(f)[0] + args.gt_extension for f in input_image_files ] else: gt_txt_files = [None] * len(input_image_files) else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name( input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of file: {} and {}".format( img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception( "Some image are occurring more than once in the data set.") dataset = create_dataset( args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, non_existing_as_empty=True, ) if len(args.select) == 0: args.select = range(len(dataset.samples())) dataset._samples = dataset.samples() else: dataset._samples = [dataset.samples()[i] for i in args.select] samples = dataset.samples() print("Found {} files in the dataset".format(len(dataset))) with StreamingInputDataset( dataset, data_preprocessor, text_preprocessor, SimpleDataAugmenter(), args.n_augmentations, ) as input_dataset: f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all') row, col = 0, 0 for i, (id, sample) in enumerate( zip(args.select, input_dataset.generator(args.processes))): line, text, params = sample if args.n_cols == 1: ax[row].imshow(line.transpose()) ax[row].set_title("ID: {}\n{}".format(id, text)) else: ax[row, col].imshow(line.transpose()) ax[row, col].set_title("ID: {}\n{}".format(id, text)) row += 1 if row == args.n_rows: row = 0 col += 1 if col == args.n_cols or i == len(samples) - 1: plt.show() f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all') row, col = 0, 0
def run(args): # check if loading a json file if len(args.files) == 1 and args.files[0].endswith("json"): import json with open(args.files[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): setattr(args, key, value) # parse whitelist whitelist = args.whitelist if len(whitelist) == 1: whitelist = list(whitelist[0]) whitelist_files = glob_all(args.whitelist_files) for f in whitelist_files: with open(f) as txt: whitelist += list(txt.read()) if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension(args.validation_dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files] else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of file: {} and {}".format(img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception("Some image are occurring more than once in the data set.") dataset = create_dataset( args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt ) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception("Some validation images are occurring more than once in the data set.") validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format(len(validation_dataset))) else: validation_dataset = None params = CheckpointParams() params.max_iters = args.max_iters params.stats_size = args.stats_size params.batch_size = args.batch_size params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency params.output_dir = args.output_dir params.output_model_prefix = args.output_model_prefix params.display = args.display params.skip_invalid_gt = not args.no_skip_invalid_gt params.processes = args.num_threads params.data_aug_retrain_on_original = not args.only_train_on_augmented params.early_stopping_frequency = args.early_stopping_frequency params.early_stopping_nbest = args.early_stopping_nbest params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix params.early_stopping_best_model_output_dir = \ args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER params.model.data_preprocessor.line_height = args.line_height params.model.data_preprocessor.pad = args.pad # Text pre processing (reading) params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_preprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER # Text post processing (prediction) params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_postprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER if args.seed > 0: params.model.network.backend.random_seed = args.seed if args.bidi_dir: # change bidirectional text direction if desired bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR, "auto": TextProcessorParams.BIDI_AUTO} bidi_processor_params = params.model.text_preprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir] bidi_processor_params = params.model.text_postprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO params.model.line_height = args.line_height network_params_from_definition_string(args.network, params.model.network) params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper()) params.model.network.clipping_constant = args.gradient_clipping_const params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path params.model.network.backend.num_inter_threads = args.num_inter_threads params.model.network.backend.num_intra_threads = args.num_intra_threads # create the actual trainer trainer = Trainer(params, dataset, validation_dataset=validation_dataset, data_augmenter=SimpleDataAugmenter(), n_augmentations=args.n_augmentations, weights=args.weights, codec_whitelist=whitelist, preload_training=not args.train_data_on_the_fly, preload_validation=not args.validation_data_on_the_fly, ) trainer.train( auto_compute_codec=not args.no_auto_compute_codec, progress_bar=not args.no_progress_bars )
def main(): parser = ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True, help="The checkpoint used to resume") # validation files parser.add_argument("--validation", type=str, nargs="+", help="Validation line files used for early stopping") parser.add_argument("--validation_text_files", nargs="+", default=None, help="Optional list of validation GT files if they are in other directory") parser.add_argument("--validation_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)") parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) # input files parser.add_argument("--files", nargs="+", help="List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location") parser.add_argument("--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument("--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)") parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension(args.validation_dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files] else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of file: {} and {}".format(img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception("Some image are occurring more than once in the data set.") dataset = create_dataset( args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt ) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception("Some validation images are occurring more than once in the data set.") validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format(len(validation_dataset))) else: validation_dataset = None print("Resuming training") with open(args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) trainer = Trainer(checkpoint_params, dataset, validation_dataset=validation_dataset, weights=args.checkpoint) trainer.train(progress_bar=True)