def params_from_args(args): """ Turn args to calamari into params """ params = CheckpointParams() for attr in ["max_iters", "stats_size", "batch_size", "checkpoint_frequency", "output_dir", "output_model_prefix", "display", "early_stopping_nbest", "early_stopping_best_model_prefix"]: setattr(params, attr, getattr(args, attr)) params.processes = args.num_threads params.skip_invalid_gt = not args.no_skip_invalid_gt params.early_stopping_frequency = args.early_stopping_frequency\ if args.early_stopping_frequency >= 0 else args.checkpoint_frequency params.early_stopping_best_model_output_dir = args.early_stopping_best_model_output_dir\ if args.early_stopping_best_model_output_dir else args.output_dir params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER params.model.data_preprocessor.line_height = args.line_height params.model.data_preprocessor.pad = args.pad # Text pre processing (reading) params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_preprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER # Text post processing (prediction) params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_postprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER if args.seed > 0: params.model.network.backend.random_seed = args.seed if args.bidi_dir: # change bidirectional text direction if desired bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR, "auto": TextProcessorParams.BIDI_AUTO} bidi_processor_params = params.model.text_preprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir] bidi_processor_params = params.model.text_postprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO params.model.line_height = args.line_height network_params_from_definition_string(args.network, params.model.network) params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper()) params.model.network.clipping_constant = args.gradient_clipping_const params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path params.model.network.backend.num_inter_threads = args.num_inter_threads params.model.network.backend.num_intra_threads = args.num_intra_threads return params
def run(cfg: CfgNode): # check if loading a json file if len(cfg.DATASET.TRAIN.PATH) == 1 and cfg.DATASET.TRAIN.PATH[0].endswith( "json"): import json with open(cfg.DATASET.TRAIN.PATH[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): if key == 'dataset' or key == 'validation_dataset': setattr(cfg, key, DataSetType.from_string(value)) else: setattr(cfg, key, value) # parse whitelist whitelist = cfg.MODEL.CODEX.WHITELIST if len(whitelist) == 1: whitelist = list(whitelist[0]) whitelist_files = glob_all(cfg.MODEL.CODEX.WHITELIST_FILES) for f in whitelist_files: with open(f) as txt: whitelist += list(txt.read()) if cfg.DATASET.TRAIN.GT_EXTENSION is False: cfg.DATASET.TRAIN.GT_EXTENSION = DataSetType.gt_extension( cfg.DATASET.TRAIN.TYPE) if cfg.DATASET.VALID.GT_EXTENSION is False: cfg.DATASET.VALID.GT_EXTENSION = DataSetType.gt_extension( cfg.DATASET.VALID.TYPE) text_generator_params = TextGeneratorParameters() line_generator_params = LineGeneratorParameters() dataset_args = { 'line_generator_params': line_generator_params, 'text_generator_params': text_generator_params, 'pad': None, 'text_index': 0, } # Training dataset dataset = create_train_dataset(cfg, dataset_args) # Validation dataset validation_dataset_list = create_test_dataset(cfg, dataset_args) params = CheckpointParams() params.max_iters = cfg.SOLVER.MAX_ITER params.stats_size = cfg.STATS_SIZE params.batch_size = cfg.SOLVER.BATCH_SIZE params.checkpoint_frequency = cfg.SOLVER.CHECKPOINT_FREQ if cfg.SOLVER.CHECKPOINT_FREQ >= 0 else cfg.SOLVER.EARLY_STOPPING_FREQ params.output_dir = cfg.OUTPUT_DIR params.output_model_prefix = cfg.OUTPUT_MODEL_PREFIX params.display = cfg.DISPLAY params.skip_invalid_gt = not cfg.DATALOADER.NO_SKIP_INVALID_GT params.processes = cfg.NUM_THREADS params.data_aug_retrain_on_original = not cfg.DATALOADER.ONLY_TRAIN_ON_AUGMENTED params.early_stopping_at_acc = cfg.SOLVER.EARLY_STOPPING_AT_ACC params.early_stopping_frequency = cfg.SOLVER.EARLY_STOPPING_FREQ params.early_stopping_nbest = cfg.SOLVER.EARLY_STOPPING_NBEST params.early_stopping_best_model_prefix = cfg.EARLY_STOPPING_BEST_MODEL_PREFIX params.early_stopping_best_model_output_dir = \ cfg.EARLY_STOPPING_BEST_MODEL_OUTPUT_DIR if cfg.EARLY_STOPPING_BEST_MODEL_OUTPUT_DIR else cfg.OUTPUT_DIR if cfg.INPUT.DATA_PREPROCESSING is False or len( cfg.INPUT.DATA_PREPROCESSING) == 0: cfg.INPUT.DATA_PREPROCESSING = [ DataPreprocessorParams.DEFAULT_NORMALIZER ] params.model.data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER for preproc in cfg.INPUT.DATA_PREPROCESSING: pp = params.model.data_preprocessor.children.add() pp.type = DataPreprocessorParams.Type.Value(preproc) if isinstance( preproc, str) else preproc pp.line_height = cfg.INPUT.LINE_HEIGHT pp.pad = cfg.INPUT.PAD # Text pre processing (reading) params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params( params.model.text_preprocessor.children.add(), default=cfg.INPUT.TEXT_NORMALIZATION) default_text_regularizer_params( params.model.text_preprocessor.children.add(), groups=cfg.INPUT.TEXT_REGULARIZATION) strip_processor_params = params.model.text_preprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER # Text post processing (prediction) params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params( params.model.text_postprocessor.children.add(), default=cfg.INPUT.TEXT_NORMALIZATION) default_text_regularizer_params( params.model.text_postprocessor.children.add(), groups=cfg.INPUT.TEXT_REGULARIZATION) strip_processor_params = params.model.text_postprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER if cfg.SEED > 0: params.model.network.backend.random_seed = cfg.SEED if cfg.INPUT.BIDI_DIR: # change bidirectional text direction if desired bidi_dir_to_enum = { "rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR, "auto": TextProcessorParams.BIDI_AUTO } bidi_processor_params = params.model.text_preprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = bidi_dir_to_enum[ cfg.INPUT.BIDI_DIR] bidi_processor_params = params.model.text_postprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO params.model.line_height = cfg.INPUT.LINE_HEIGHT params.model.network.learning_rate = cfg.SOLVER.LR params.model.network.lr_decay = cfg.SOLVER.LR_DECAY params.model.network.lr_decay_freq = cfg.SOLVER.LR_DECAY_FREQ params.model.network.train_last_n_layer = cfg.SOLVER.TRAIN_LAST_N_LAYER network_params_from_definition_string(cfg.MODEL.NETWORK, params.model.network) params.model.network.clipping_norm = cfg.SOLVER.GRADIENT_CLIPPING_NORM params.model.network.backend.num_inter_threads = 0 params.model.network.backend.num_intra_threads = 0 params.model.network.backend.shuffle_buffer_size = cfg.DATALOADER.SHUFFLE_BUFFER_SIZE if cfg.MODEL.WEIGHTS == "": weights = None else: weights = cfg.MODEL.WEIGHTS # create the actual trainer trainer = Trainer( params, dataset, validation_dataset=validation_dataset_list, data_augmenter=SimpleDataAugmenter(), n_augmentations=cfg.INPUT.N_AUGMENT, weights=weights, codec_whitelist=whitelist, keep_loaded_codec=cfg.MODEL.CODEX.KEEP_LOADED_CODEC, preload_training=not cfg.DATALOADER.TRAIN_ON_THE_FLY, preload_validation=not cfg.DATALOADER.VALID_ON_THE_FLY, ) trainer.train(auto_compute_codec=not cfg.MODEL.CODEX.SEE_WHITELIST, progress_bar=not cfg.NO_PROGRESS_BAR)
def params_from_args(args): """ Turn args to calamari into params """ params = CheckpointParams() params.max_iters = args.max_iters params.stats_size = args.stats_size params.batch_size = args.batch_size params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency params.output_dir = args.output_dir params.output_model_prefix = args.output_model_prefix params.display = args.display params.skip_invalid_gt = not args.no_skip_invalid_gt params.processes = args.num_threads params.data_aug_retrain_on_original = not args.only_train_on_augmented params.early_stopping_at_acc = args.early_stopping_at_accuracy params.early_stopping_frequency = args.early_stopping_frequency params.early_stopping_nbest = args.early_stopping_nbest params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix params.early_stopping_best_model_output_dir = \ args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER params.model.data_preprocessor.line_height = args.line_height params.model.data_preprocessor.pad = args.pad # Text pre processing (reading) params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params( params.model.text_preprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params( params.model.text_preprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_preprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER # Text post processing (prediction) params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params( params.model.text_postprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params( params.model.text_postprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_postprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER if args.seed > 0: params.model.network.backend.random_seed = args.seed if args.bidi_dir: # change bidirectional text direction if desired bidi_dir_to_enum = { "rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR, "auto": TextProcessorParams.BIDI_AUTO } bidi_processor_params = params.model.text_preprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir] bidi_processor_params = params.model.text_postprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO params.model.line_height = args.line_height network_params_from_definition_string(args.network, params.model.network) params.model.network.clipping_norm = args.gradient_clipping_norm params.model.network.backend.num_inter_threads = args.num_inter_threads params.model.network.backend.num_intra_threads = args.num_intra_threads params.model.network.backend.shuffle_buffer_size = args.shuffle_buffer_size params.early_stopping_at_acc = args.early_stopping_at_accuracy return params
def run(args): # check if loading a json file if len(args.files) == 1 and args.files[0].endswith("json"): import json with open(args.files[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): if key == 'dataset' or key == 'validation_dataset': setattr(args, key, DataSetType.from_string(value)) else: setattr(args, key, value) # parse whitelist whitelist = args.whitelist if len(whitelist) == 1: whitelist = list(whitelist[0]) whitelist_files = glob_all(args.whitelist_files) for f in whitelist_files: with open(f) as txt: whitelist += list(txt.read()) if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension(args.validation_dataset) if args.text_generator_params is not None: with open(args.text_generator_params, 'r') as f: args.text_generator_params = json_format.Parse(f.read(), TextGeneratorParameters()) else: args.text_generator_params = TextGeneratorParameters() if args.line_generator_params is not None: with open(args.line_generator_params, 'r') as f: args.line_generator_params = json_format.Parse(f.read(), LineGeneratorParameters()) else: args.line_generator_params = LineGeneratorParameters() dataset_args = { 'line_generator_params': args.line_generator_params, 'text_generator_params': args.text_generator_params, 'pad': args.dataset_pad, 'text_index': args.pagexml_text_index, } # Training dataset dataset = create_train_dataset(args, dataset_args) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception("Some validation images are occurring more than once in the data set.") validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt, args=dataset_args, ) print("Found {} files in the validation dataset".format(len(validation_dataset))) else: validation_dataset = None params = CheckpointParams() params.max_iters = args.max_iters params.stats_size = args.stats_size params.batch_size = args.batch_size params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency params.output_dir = args.output_dir params.output_model_prefix = args.output_model_prefix params.display = args.display params.skip_invalid_gt = not args.no_skip_invalid_gt params.processes = args.num_threads params.data_aug_retrain_on_original = not args.only_train_on_augmented params.early_stopping_frequency = args.early_stopping_frequency params.early_stopping_nbest = args.early_stopping_nbest params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix params.early_stopping_best_model_output_dir = \ args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir if args.data_preprocessing is None or len(args.data_preprocessing) == 0: args.data_preprocessing = [DataPreprocessorParams.DEFAULT_NORMALIZER] params.model.data_preprocessor.type = DataPreprocessorParams.MULTI_NORMALIZER for preproc in args.data_preprocessing: pp = params.model.data_preprocessor.children.add() pp.type = DataPreprocessorParams.Type.Value(preproc) if isinstance(preproc, str) else preproc pp.line_height = args.line_height pp.pad = args.pad # Text pre processing (reading) params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_preprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER # Text post processing (prediction) params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_postprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER if args.seed > 0: params.model.network.backend.random_seed = args.seed if args.bidi_dir: # change bidirectional text direction if desired bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR, "auto": TextProcessorParams.BIDI_AUTO} bidi_processor_params = params.model.text_preprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir] bidi_processor_params = params.model.text_postprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO params.model.line_height = args.line_height network_params_from_definition_string(args.network, params.model.network) params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper()) params.model.network.clipping_constant = args.gradient_clipping_const params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path params.model.network.backend.num_inter_threads = args.num_inter_threads params.model.network.backend.num_intra_threads = args.num_intra_threads params.model.network.backend.shuffle_buffer_size = args.shuffle_buffer_size # create the actual trainer trainer = Trainer(params, dataset, validation_dataset=validation_dataset, data_augmenter=SimpleDataAugmenter(), n_augmentations=args.n_augmentations, weights=args.weights, codec_whitelist=whitelist, keep_loaded_codec=args.keep_loaded_codec, preload_training=not args.train_data_on_the_fly, preload_validation=not args.validation_data_on_the_fly, ) trainer.train( auto_compute_codec=not args.no_auto_compute_codec, progress_bar=not args.no_progress_bars )
def run(args): # check if loading a json file if len(args.files) == 1 and args.files[0].endswith("json"): import json with open(args.files[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): setattr(args, key, value) # parse whitelist whitelist = args.whitelist if len(whitelist) == 1: whitelist = list(whitelist[0]) whitelist_files = glob_all(args.whitelist_files) for f in whitelist_files: with open(f) as txt: whitelist += list(txt.read()) if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension(args.validation_dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files] else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of file: {} and {}".format(img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception("Some image are occurring more than once in the data set.") dataset = create_dataset( args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt ) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception("Some validation images are occurring more than once in the data set.") validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format(len(validation_dataset))) else: validation_dataset = None params = CheckpointParams() params.max_iters = args.max_iters params.stats_size = args.stats_size params.batch_size = args.batch_size params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency params.output_dir = args.output_dir params.output_model_prefix = args.output_model_prefix params.display = args.display params.skip_invalid_gt = not args.no_skip_invalid_gt params.processes = args.num_threads params.data_aug_retrain_on_original = not args.only_train_on_augmented params.early_stopping_frequency = args.early_stopping_frequency params.early_stopping_nbest = args.early_stopping_nbest params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix params.early_stopping_best_model_output_dir = \ args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER params.model.data_preprocessor.line_height = args.line_height params.model.data_preprocessor.pad = args.pad # Text pre processing (reading) params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_preprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER # Text post processing (prediction) params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_postprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER if args.seed > 0: params.model.network.backend.random_seed = args.seed if args.bidi_dir: # change bidirectional text direction if desired bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR, "auto": TextProcessorParams.BIDI_AUTO} bidi_processor_params = params.model.text_preprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir] bidi_processor_params = params.model.text_postprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO params.model.line_height = args.line_height network_params_from_definition_string(args.network, params.model.network) params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper()) params.model.network.clipping_constant = args.gradient_clipping_const params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path params.model.network.backend.num_inter_threads = args.num_inter_threads params.model.network.backend.num_intra_threads = args.num_intra_threads # create the actual trainer trainer = Trainer(params, dataset, validation_dataset=validation_dataset, data_augmenter=SimpleDataAugmenter(), n_augmentations=args.n_augmentations, weights=args.weights, codec_whitelist=whitelist, preload_training=not args.train_data_on_the_fly, preload_validation=not args.validation_data_on_the_fly, ) trainer.train( auto_compute_codec=not args.no_auto_compute_codec, progress_bar=not args.no_progress_bars )
def main(): parser = argparse.ArgumentParser() setup_train_args(parser) args = parser.parse_args() # check if loading a json file if len(args.files) == 1 and args.files[0].endswith("json"): import json with open(args.files[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): setattr(args, key, value) # parse whitelist whitelist = args.whitelist whitelist_files = glob_all(args.whitelist_files) for f in whitelist_files: with open(f) as txt: whitelist += list(txt.read()) # Training dataset print("Resolving input files") input_image_files = glob_all(args.files) gt_txt_files = [split_all_ext(f)[0] + ".gt.txt" for f in input_image_files] if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception("Some image are occurring more than once in the data set.") dataset = FileDataSet(input_image_files, gt_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) val_txt_files = [split_all_ext(f)[0] + ".gt.txt" for f in validation_image_files] if len(set(val_txt_files)) != len(val_txt_files): raise Exception("Some validation images are occurring more than once in the data set.") validation_dataset = FileDataSet(validation_image_files, val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format(len(validation_dataset))) else: validation_dataset = None params = CheckpointParams() params.max_iters = args.max_iters params.stats_size = args.stats_size params.batch_size = args.batch_size params.checkpoint_frequency = args.checkpoint_frequency params.output_dir = args.output_dir params.output_model_prefix = args.output_model_prefix params.display = args.display params.skip_invalid_gt = not args.no_skip_invalid_gt params.processes = args.num_threads params.early_stopping_frequency = args.early_stopping_frequency if args.early_stopping_frequency >= 0 else args.checkpoint_frequency params.early_stopping_nbest = args.early_stopping_nbest params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix params.early_stopping_best_model_output_dir = \ args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER params.model.data_preprocessor.line_height = args.line_height params.model.data_preprocessor.pad = args.pad # Text pre processing (reading) params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_preprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER # Text post processing (prediction) params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_postprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER if args.seed > 0: params.model.network.backend.random_seed = args.seed if args.bidi_dir: # change bidirectional text direction if desired bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR} bidi_processor_params = params.model.text_preprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir] bidi_processor_params = params.model.text_postprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir] params.model.line_height = args.line_height network_params_from_definition_string(args.network, params.model.network) params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper()) params.model.network.clipping_constant = args.gradient_clipping_const params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path params.model.network.backend.num_inter_threads = args.num_inter_threads params.model.network.backend.num_intra_threads = args.num_intra_threads # create the actual trainer trainer = Trainer(params, dataset, validation_dataset=validation_dataset, data_augmenter=SimpleDataAugmenter(), n_augmentations=args.n_augmentations, weights=args.weights, codec_whitelist=whitelist, ) trainer.train(progress_bar=not args.no_progress_bars)
def _train(self, target_book: Optional[DatabaseBook] = None, callback: Optional[TrainerCallback] = None): if callback: callback.resolving_files() calamari_callback = CalamariTrainerCallback(callback) else: calamari_callback = None train_dataset = self.train_dataset.to_text_line_calamari_dataset(train=True, callback=callback) val_dataset = self.validation_dataset.to_text_line_calamari_dataset(train=True, callback=callback) output = self.settings.model.path params = CheckpointParams() params.max_iters = self.params.n_iter params.stats_size = 1000 params.batch_size = 1 params.checkpoint_frequency = 0 params.output_dir = output params.output_model_prefix = 'text' params.display = self.params.display params.skip_invalid_gt = True params.processes = 2 params.data_aug_retrain_on_original = True params.early_stopping_at_acc = self.params.early_stopping_at_acc if self.params.early_stopping_at_acc else 0 params.early_stopping_frequency = self.params.early_stopping_test_interval params.early_stopping_nbest = self.params.early_stopping_max_keep params.early_stopping_best_model_prefix = 'text_best' params.early_stopping_best_model_output_dir = output params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER params.model.data_preprocessor.pad = 5 params.model.data_preprocessor.line_height = self.settings.dataset_params.height params.model.text_preprocessor.type = TextProcessorParams.NOOP_NORMALIZER params.model.text_postprocessor.type = TextProcessorParams.NOOP_NORMALIZER params.model.line_height = self.settings.dataset_params.height network_str = self.settings.calamari_params.network if self.params.l_rate > 0: network_str += ',learning_rate={}'.format(self.params.l_rate) if self.settings.calamari_params.n_folds > 1: train_args = { "max_iters": params.max_iters, "stats_size": params.stats_size, "checkpoint_frequency": params.checkpoint_frequency, "pad": 0, "network": network_str, "early-stopping_at_accuracy": params.early_stopping_at_acc, "early_stopping_frequency": params.early_stopping_frequency, "early_stopping_nbest": params.early_stopping_nbest, "line_height": params.model.line_height, "data_preprocessing": ["RANGE_NORMALIZER", "FINAL_PREPARATION"], } trainer = CrossFoldTrainer( self.settings.calamari_params.n_folds, train_dataset, output, 'omr_best_{id}', train_args, progress_bars=True ) temporary_dir = os.path.join(output, "temporary_dir") trainer.run( self.settings.calamari_params.single_folds, temporary_dir=temporary_dir, spawn_subprocesses=False, max_parallel_models=1, # Force to run in same scope as parent process ) else: network_params_from_definition_string(network_str, params.model.network) trainer = Trainer( codec_whitelist='abcdefghijklmnopqrstuvwxyz ', # Always keep space and all letters checkpoint_params=params, dataset=train_dataset, validation_dataset=val_dataset, n_augmentations=self.params.data_augmentation_factor if self.params.data_augmentation_factor else 0, data_augmenter=SimpleDataAugmenter(), weights=None if not self.params.model_to_load() else self.params.model_to_load().local_file('text_best.ckpt'), preload_training=True, preload_validation=True, ) trainer.train(training_callback=calamari_callback, auto_compute_codec=True, )
def _train(self, target_book: Optional[DatabaseBook] = None, callback: Optional[TrainerCallback] = None): if callback: callback.resolving_files() train_dataset = self.train_dataset.to_calamari_dataset( train=True, callback=callback) val_dataset = self.validation_dataset.to_calamari_dataset( train=True, callback=callback) params = CheckpointParams() params.max_iters = self.params.n_iter params.stats_size = 1000 params.batch_size = 5 params.checkpoint_frequency = 0 params.output_dir = self.settings.model.path params.output_model_prefix = 'omr' params.display = self.params.display params.skip_invalid_gt = True params.processes = self.params.processes params.data_aug_retrain_on_original = False params.early_stopping_frequency = self.params.early_stopping_test_interval params.early_stopping_nbest = self.params.early_stopping_max_keep params.early_stopping_best_model_prefix = 'omr_best' params.early_stopping_best_model_output_dir = self.settings.model.path params.model.data_preprocessor.type = DataPreprocessorParams.NOOP_NORMALIZER # for preproc in [DataPreprocessorParams.RANGE_NORMALIZER, DataPreprocessorParams.FINAL_PREPARATION]: # pp = params.model.data_preprocessor.children.add() # pp.type = preproc params.model.text_preprocessor.type = TextProcessorParams.NOOP_NORMALIZER params.model.text_postprocessor.type = TextProcessorParams.NOOP_NORMALIZER params.model.line_height = self.settings.dataset_params.height params.model.network.channels = self.settings.calamari_params.channels network_str = self.settings.calamari_params.network if self.params.l_rate > 0: network_str += ',learning_rate={}'.format(self.params.l_rate) if self.settings.calamari_params.n_folds > 0: train_args = { "max_iters": params.max_iters, "stats_size": params.stats_size, "checkpoint_frequency": params.checkpoint_frequency, "pad": 0, "network": network_str, "early_stopping_frequency": params.early_stopping_frequency, "early_stopping_nbest": params.early_stopping_nbest, "line_height": params.model.line_height, "data_preprocessing": ["RANGE_NORMALIZER", "FINAL_PREPARATION"], } trainer = CrossFoldTrainer(self.settings.calamari_params.n_folds, train_dataset, params.output_dir, 'omr_best_{id}', train_args, progress_bars=True) temporary_dir = os.path.join(params.output_dir, "temporary_dir") trainer.run( self.settings.calamari_params.single_folds, temporary_dir=temporary_dir, spawn_subprocesses=False, max_parallel_models= 1, # Force to run in same scope as parent process ) else: network_params_from_definition_string(network_str, params.model.network) trainer = Trainer( checkpoint_params=params, dataset=train_dataset, validation_dataset=val_dataset, n_augmentations=self.settings.page_segmentation_params. data_augmentation * 10, data_augmenter=SimpleDataAugmenter(), weights=None if not self.params.model_to_load() else self.params.model_to_load().local_file( params.early_stopping_best_model_prefix + '.ckpt'), preload_training=True, preload_validation=True, codec=Codec(self.settings.dataset_params.calamari_codec.codec. values()), ) trainer.train()