def create_voter(self, data_params: 'DataParams') -> MultiModelVoter: # Cut non text processors (first two) post_proc = [Data.data_processor_factory().create_sequence( data.params().post_processors_.sample_processors[2:], data.params(), PipelineMode.Prediction) for data in self.datas] pre_proc = Data.data_processor_factory().create_sequence( self.data.params().pre_processors_.sample_processors, self.data.params(), PipelineMode.Prediction) out_to_in_transformer = OutputToInputTransformer(pre_proc) return CalamariMultiModelVoter(self.voter_params, self.datas, post_proc, out_to_in_transformer)
def main(args: EvalArgs): # Local imports (imports that require tensorflow) from calamari_ocr.ocr.scenario import CalamariScenario from calamari_ocr.ocr.dataset.data import Data from calamari_ocr.ocr.evaluator import Evaluator if args.checkpoint: saved_model = SavedCalamariModel(args.checkpoint, auto_update=True) trainer_params = CalamariScenario.trainer_cls().params_cls().from_dict(saved_model.dict) data_params = trainer_params.scenario.data else: data_params = Data.default_params() data = Data(data_params) pred_data = args.pred if args.pred is not None else args.gt.to_prediction() evaluator = Evaluator(args.evaluator, data=data) evaluator.preload_gt(gt_dataset=args.gt) r = evaluator.run(gt_dataset=args.gt, pred_dataset=pred_data) # TODO: More output print("Evaluation result") print("=================") print("") print( "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format( r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"] ) ) # sort descending print_confusions(r, args.n_confusions) samples = data.create_pipeline(evaluator.params.setup, args.gt).reader().samples() print_worst_lines(r, samples, args.n_worst_lines) if args.xlsx_output: write_xlsx( args.xlsx_output, [ { "prefix": "evaluation", "results": r, "gt_files": [s["id"] for s in samples], } ], ) return r
def get_preproc_image(): data_params = Data.default_params() data_params.skip_invalid_gt = False data_params.pre_proc.run_parallel = False data_params.pre_proc.processors = data_params.pre_proc.processors[:-1] for p in data_params.pre_proc.processors_of_type(FinalPreparationProcessorParams): p.pad = 0 post_init(data_params) pl = Data(data_params).create_pipeline(DataPipelineParams, None) pl.mode = PipelineMode.PREDICTION preproc = data_params.pre_proc.create(pl) def pp(image): its = InputSample( image, None, SampleMeta("001", fold_id="01") ).to_input_target_sample() s = preproc.apply_on_sample(its) return s.inputs return pp
def get_preproc_text(rtl=False): data_params = Data.default_params() data_params.skip_invalid_gt = False data_params.pre_proc.run_parallel = False if rtl: for p in data_params.pre_proc.processors_of_type(BidiTextProcessorParams): p.bidi_direction = BidiDirection.RTL post_init(data_params) pl = Data(data_params).create_pipeline(DataPipelineParams, None) pl.mode = PipelineMode.TARGETS preproc = data_params.pre_proc.create(pl) def pp(text): its = InputSample( None, text, SampleMeta("001", fold_id="01") ).to_input_target_sample() s = preproc.apply_on_sample(its) return s.targets return pp
def run(args): # local imports (to prevent tensorflow from being imported to early) from calamari_ocr.ocr.scenario import Scenario from calamari_ocr.ocr.dataset.data import Data # check if loading a json file if len(args.files) == 1 and args.files[0].endswith("json"): with open(args.files[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): if key == 'dataset' or key == 'validation_dataset': setattr(args, key, DataSetType.from_string(value)) else: setattr(args, key, value) check_train_args(args) if args.output_dir is not None: args.output_dir = os.path.abspath(args.output_dir) setup_log(args.output_dir, append=False) # parse whitelist whitelist = args.whitelist if len(whitelist) == 1: whitelist = list(whitelist[0]) whitelist_files = glob_all(args.whitelist_files) for f in whitelist_files: with open(f) as txt: whitelist += list(txt.read()) if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension( args.validation_dataset) if args.dataset == DataSetType.GENERATED_LINE or args.validation_dataset == DataSetType.GENERATED_LINE: if args.text_generator_params is not None: with open(args.text_generator_params, 'r') as f: args.text_generator_params = TextGeneratorParams.from_json( f.read()) else: args.text_generator_params = TextGeneratorParams() if args.line_generator_params is not None: with open(args.line_generator_params, 'r') as f: args.line_generator_params = LineGeneratorParams.from_json( f.read()) else: args.line_generator_params = LineGeneratorParams() dataset_args = FileDataReaderArgs( line_generator_params=args.line_generator_params, text_generator_params=args.text_generator_params, pad=args.dataset_pad, text_index=args.pagexml_text_index, ) params: TrainerParams = Scenario.default_trainer_params() # ================================================================================================================= # Data Params data_params: DataParams = params.scenario_params.data_params data_params.train = PipelineParams( type=args.dataset, skip_invalid=not args.no_skip_invalid_gt, remove_invalid=True, files=args.files, text_files=args.text_files, gt_extension=args.gt_extension, data_reader_args=dataset_args, batch_size=args.batch_size, num_processes=args.num_threads, ) if args.validation_split_ratio: if args.validation is not None: raise ValueError("Set either validation_split_ratio or validation") if not 0 < args.validation_split_ratio < 1: raise ValueError("validation_split_ratio must be in (0, 1)") # resolve all files so we can split them data_params.train = data_params.train.prepare_for_mode( PipelineMode.Training) n = int(args.validation_split_ratio * len(data_params.train.files)) logger.info( f"Splitting training and validation files with ratio {args.validation_split_ratio}: " f"{n}/{len(data_params.train.files) - n} for validation/training.") indices = list(range(len(data_params.train.files))) shuffle(indices) all_files = data_params.train.files all_text_files = data_params.train.text_files # split train and val img/gt files. Use train settings data_params.train.files = [all_files[i] for i in indices[:n]] if all_text_files is not None: assert (len(all_text_files) == len(all_files)) data_params.train.text_files = [ all_text_files[i] for i in indices[:n] ] data_params.val = PipelineParams( type=args.dataset, skip_invalid=not args.no_skip_invalid_gt, remove_invalid=True, files=[all_files[i] for i in indices[n:]], text_files=[all_text_files[i] for i in indices[n:]] if data_params.train.text_files is not None else None, gt_extension=args.gt_extension, data_reader_args=dataset_args, batch_size=args.batch_size, num_processes=args.num_threads, ) elif args.validation: data_params.val = PipelineParams( type=args.validation_dataset, files=args.validation, text_files=args.validation_text_files, skip_invalid=not args.no_skip_invalid_gt, gt_extension=args.validation_extension, data_reader_args=dataset_args, batch_size=args.batch_size, num_processes=args.num_threads, ) else: data_params.val = None data_params.pre_processors_ = SamplePipelineParams(run_parallel=True) data_params.post_processors_.run_parallel = SamplePipelineParams( run_parallel=False, sample_processors=[ DataProcessorFactoryParams(ReshapeOutputsProcessor.__name__), DataProcessorFactoryParams(CTCDecoderProcessor.__name__), ]) for p in args.data_preprocessing: p_p = Data.data_processor_factory().processors[p].default_params() if 'pad' in p_p: p_p['pad'] = args.pad data_params.pre_processors_.sample_processors.append( DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p)) # Text pre processing (reading) data_params.pre_processors_.sample_processors.extend([ DataProcessorFactoryParams( TextNormalizer.__name__, TARGETS_PROCESSOR, {'unicode_normalization': args.text_normalization}), DataProcessorFactoryParams( TextRegularizer.__name__, TARGETS_PROCESSOR, { 'replacements': default_text_regularizer_replacements(args.text_regularization) }), DataProcessorFactoryParams(StripTextProcessor.__name__, TARGETS_PROCESSOR) ]) # Text post processing (prediction) data_params.post_processors_.sample_processors.extend([ DataProcessorFactoryParams( TextNormalizer.__name__, TARGETS_PROCESSOR, {'unicode_normalization': args.text_normalization}), DataProcessorFactoryParams( TextRegularizer.__name__, TARGETS_PROCESSOR, { 'replacements': default_text_regularizer_replacements(args.text_regularization) }), DataProcessorFactoryParams(StripTextProcessor.__name__, TARGETS_PROCESSOR) ]) if args.bidi_dir: data_params.pre_processors_.sample_processors.append( DataProcessorFactoryParams(BidiTextProcessor.__name__, TARGETS_PROCESSOR, {'bidi_direction': args.bidi_dir})) data_params.post_processors_.sample_processors.append( DataProcessorFactoryParams(BidiTextProcessor.__name__, TARGETS_PROCESSOR, {'bidi_direction': args.bidi_dir})) data_params.pre_processors_.sample_processors.extend([ DataProcessorFactoryParams(AugmentationProcessor.__name__, {PipelineMode.Training}, {'augmenter_type': 'simple'}), DataProcessorFactoryParams(PrepareSampleProcessor.__name__, INPUT_PROCESSOR), ]) data_params.data_aug_params = DataAugmentationAmount.from_factor( args.n_augmentations) data_params.line_height_ = args.line_height # ================================================================================================================= # Trainer Params params.calc_ema = args.ema_weights params.verbose = args.train_verbose params.force_eager = args.debug params.skip_model_load_test = not args.debug params.scenario_params.debug_graph_construction = args.debug params.epochs = args.epochs params.samples_per_epoch = int( args.samples_per_epoch) if args.samples_per_epoch >= 1 else -1 params.scale_epoch_size = abs( args.samples_per_epoch) if args.samples_per_epoch < 1 else 1 params.skip_load_model_test = True params.scenario_params.export_frozen = False params.checkpoint_save_freq_ = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency params.checkpoint_dir = args.output_dir params.test_every_n = args.display params.skip_invalid_gt = not args.no_skip_invalid_gt params.data_aug_retrain_on_original = not args.only_train_on_augmented params.use_training_as_validation = args.use_train_as_val if args.seed > 0: params.random_seed = args.seed params.optimizer_params.clip_grad = args.gradient_clipping_norm params.codec_whitelist = whitelist params.keep_loaded_codec = args.keep_loaded_codec params.preload_training = not args.train_data_on_the_fly params.preload_validation = not args.validation_data_on_the_fly params.warmstart_params.model = args.weights params.auto_compute_codec = not args.no_auto_compute_codec params.progress_bar = not args.no_progress_bars params.early_stopping_params.frequency = args.early_stopping_frequency params.early_stopping_params.upper_threshold = 0.9 params.early_stopping_params.lower_threshold = 1.0 - args.early_stopping_at_accuracy params.early_stopping_params.n_to_go = args.early_stopping_nbest params.early_stopping_params.best_model_name = '' params.early_stopping_params.best_model_output_dir = args.early_stopping_best_model_output_dir params.scenario_params.default_serve_dir_ = f'{args.early_stopping_best_model_prefix}.ckpt.h5' params.scenario_params.trainer_params_filename_ = f'{args.early_stopping_best_model_prefix}.ckpt.json' # ================================================================================================================= # Model params params_from_definition_string(args.network, params) params.scenario_params.model_params.ensemble = args.ensemble params.scenario_params.model_params.masking_mode = args.masking_mode scenario = Scenario(params.scenario_params) trainer = scenario.create_trainer(params) trainer.train()
def setup_train_args(parser, omit=None): # required params for args from calamari_ocr.ocr.dataset.data import Data if omit is None: omit = [] parser.add_argument('--version', action='version', version='%(prog)s v' + __version__) if "files" not in omit: parser.add_argument( "--files", nargs="+", default=[], help= "List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location" ) parser.add_argument( "--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument( "--gt_extension", default=None, help= "Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument( "--train_data_on_the_fly", action='store_true', default=False, help= 'Instead of preloading all data during the training, load the data on the fly. ' 'This is slower, but might be required for limited RAM or large dataset' ) parser.add_argument( "--seed", type=int, default="0", help= "Seed for random operations. If negative or zero a 'random' seed is used" ) parser.add_argument( "--network", type=str, default="cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5", help="The network structure") parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") parser.add_argument("--num_threads", type=int, default=1, help="The number of threads to use for all operations") parser.add_argument( "--display", type=int, default=1, help= "Frequency of how often an output shall occur during training [epochs]" ) parser.add_argument("--batch_size", type=int, default=1, help="The batch size to use for training") parser.add_argument("--ema_weights", action="store_true", default=False, help="Use exponentially averaged weights") parser.add_argument( "--checkpoint_frequency", type=int, default=-1, help= "The frequency how often to write checkpoints during training [epochs]" "If -1 (default), the early_stopping_frequency will be used. If 0 no checkpoints are written" ) parser.add_argument( "--epochs", type=int, default=100, help="The number of iterations for training. " "If using early stopping, this is the maximum number of iterations") parser.add_argument( "--samples_per_epoch", type=float, default=-1, help= "The number of samples to process per epoch. By default the size of the training dataset." "If in (0,1) it is relative to the dataset size") parser.add_argument( "--early_stopping_at_accuracy", type=float, default=1.0, help="Stop training if the early stopping accuracy reaches this value") parser.add_argument( "--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") parser.add_argument("--no_progress_bars", action="store_true", help="Do not show any progress bars") if "output_dir" not in omit: parser.add_argument( "--output_dir", type=str, default="", help="Default directory where to store checkpoints and models") if "output_model_prefix" not in omit: parser.add_argument("--output_model_prefix", type=str, default="model_", help="Prefix for storing checkpoints and models") parser.add_argument( "--bidi_dir", type=str, default=None, choices=["ltr", "rtl", "auto"], help= "The default text direction when preprocessing bidirectional text. Supported values " "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and " "right-to-left, respectively") if "weights" not in omit: parser.add_argument("--weights", type=str, default=None, help="Load network weights from the given file.") parser.add_argument( "--no_auto_compute_codec", action='store_true', default=False, help="Do not compute the codec automatically. See also whitelist") parser.add_argument( "--whitelist_files", type=str, nargs="+", default=[], help= "Whitelist of txt files that may not be removed on restoring a model") parser.add_argument( "--whitelist", type=str, nargs="+", default=[], help= "Whitelist of characters that may not be removed on restoring a model. " "For large dataset you can use this to skip the automatic codec computation " "(see --no_auto_compute_codec)") parser.add_argument( "--keep_loaded_codec", action='store_true', default=False, help="Fully include the codec of the loaded model to the new codec") parser.add_argument("--ensemble", type=int, default=-1, help="Number of voting models") parser.add_argument("--masking_mode", type=int, default=0, help="Do not use") # TODO: remove # clipping parser.add_argument("--gradient_clipping_norm", type=float, default=5, help="Clipping constant of the norm of the gradients.") # early stopping if "validation" not in omit: parser.add_argument( "--validation", type=str, nargs="+", help="Validation line files used for early stopping") parser.add_argument( "--validation_text_files", nargs="+", default=None, help= "Optional list of validation GT files if they are in other directory" ) parser.add_argument( "--validation_extension", default=None, help= "Default extension of the gt files (expected to exist in same dir). " "By default same as gt_extension") parser.add_argument( "--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=None, help= "Default validation data set type. By default same as --dataset") parser.add_argument("--use_train_as_val", action='store_true', default=False) parser.add_argument( "--validation_split_ratio", type=float, default=None, help= "Use n percent of the training dataset for validation. Can not be used with --validation" ) parser.add_argument( "--validation_data_on_the_fly", action='store_true', default=False, help= 'Instead of preloading all data during the training, load the data on the fly. ' 'This is slower, but might be required for limited RAM or large dataset' ) parser.add_argument("--early_stopping_frequency", type=int, default=1, help="The frequency of early stopping [epochs].") parser.add_argument( "--early_stopping_nbest", type=int, default=5, help= "The number of models that must be worse than the current best model to stop" ) if "early_stopping_best_model_prefix" not in omit: parser.add_argument( "--early_stopping_best_model_prefix", type=str, default="best", help="The prefix of the best model using early stopping") if "early_stopping_best_model_output_dir" not in omit: parser.add_argument( "--early_stopping_best_model_output_dir", type=str, default=None, help="Path where to store the best model. Default is output_dir") parser.add_argument( "--n_augmentations", type=float, default=0, help= "Amount of data augmentation per line (done before training). If this number is < 1 " "the amount is relative.") parser.add_argument( "--only_train_on_augmented", action="store_true", default=False, help= "When training with augmentations usually the model is retrained in a second run with " "only the non augmented data. This will take longer. Use this flag to disable this " "behavior.") # text normalization/regularization parser.add_argument("--text_regularization", type=str, nargs="+", default=["extended"], help="Text regularization to apply.") parser.add_argument( "--text_normalization", type=str, default="NFC", help="Unicode text normalization to apply. Defaults to NFC") parser.add_argument( "--data_preprocessing", nargs="+", type=str, choices=[ k for k, p in Data.data_processor_factory().processors.items() if issubclass(p, ImageProcessor) ], default=[p.name for p in default_image_processors()]) # text/line generation params (loaded from json files) parser.add_argument("--text_generator_params", type=str, default=None) parser.add_argument("--line_generator_params", type=str, default=None) # additional dataset args parser.add_argument("--dataset_pad", default=None, nargs='+', type=int) parser.add_argument("--pagexml_text_index", default=0) parser.add_argument("--debug", action='store_true') parser.add_argument("--train_verbose", default=1)
def main(args=None): parser = PAIArgumentParser() parser.add_argument('--version', action='version', version='%(prog)s v' + __version__) parser.add_argument("--n_cols", type=int, default=1) parser.add_argument("--n_rows", type=int, default=5) parser.add_argument("--select", type=int, nargs="+", default=[]) parser.add_argument("--preload", action='store_true', help='Simulate preloading') parser.add_argument("--as_validation", action='store_true', help="Access as validation instead of training data.") parser.add_argument("--n_augmentations", type=float, default=0) parser.add_argument("--no_plot", action='store_true', help='This parameter is for testing only') parser.add_root_argument("data", DataWrapper) args = parser.parse_args(args=args) data_wrapper: DataWrapper = args.data data_params = data_wrapper.data data_params.pre_proc.run_parallel = False data_params.pre_proc.erase_all(PrepareSampleProcessorParams) for p in data_params.pre_proc.processors_of_type(AugmentationProcessorParams): p.n_augmentations = args.n_augmentations data_params.__post_init__() data_wrapper.pipeline.mode = PipelineMode.EVALUATION if args.as_validation else PipelineMode.TRAINING data_wrapper.gen.prepare_for_mode(data_wrapper.pipeline.mode) data = Data(data_params) if len(args.select) == 0: args.select = list(range(len(data_wrapper.gen))) else: try: data_wrapper.gen.select(args.select) except NotImplementedError: logger.warning(f"Selecting is not supported for a data generator of type {type(data_wrapper.gen)}. " f"Resuming without selection.") data_pipeline = data.create_pipeline(data_wrapper.pipeline, data_wrapper.gen) if args.preload: data_pipeline = data_pipeline.as_preloaded() if args.no_plot: with data_pipeline as dataset: list(zip(args.select, dataset.generate_input_samples(auto_repeat=False))) return import matplotlib.pyplot as plt f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all') row, col = 0, 0 with data_pipeline as dataset: for i, (id, sample) in enumerate(zip(args.select, dataset.generate_input_samples(auto_repeat=False))): line, text, params = sample.inputs, sample.targets, sample.meta if args.n_cols == 1: ax[row].imshow(line.transpose()) ax[row].set_title("ID: {}\n{}".format(id, text)) else: ax[row, col].imshow(line.transpose()) ax[row, col].set_title("ID: {}\n{}".format(id, text)) row += 1 if row == args.n_rows: row = 0 col += 1 if col == args.n_cols or i == len(dataset) - 1: plt.show() f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all') row, col = 0, 0 plt.show()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--version', action='version', version='%(prog)s v' + __version__) parser.add_argument( "--files", nargs="+", help= "List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location", required=True) parser.add_argument( "--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument( "--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") parser.add_argument("--processes", type=int, default=1, help="The number of threads to use for all operations") parser.add_argument("--n_cols", type=int, default=1) parser.add_argument("--n_rows", type=int, default=5) parser.add_argument("--select", type=int, nargs="+", default=[]) # text normalization/regularization parser.add_argument( "--n_augmentations", type=float, default=0, help= "Amount of data augmentation per line (done before training). If this number is < 1 " "the amount is relative.") parser.add_argument("--text_regularization", type=str, nargs="+", default=["extended"], help="Text regularization to apply.") parser.add_argument( "--text_normalization", type=str, default="NFC", help="Unicode text normalization to apply. Defaults to NFC") parser.add_argument( "--data_preprocessing", nargs="+", type=str, choices=[ k for k, p in Data.data_processor_factory().processors.items() if issubclass(p, ImageProcessor) ], default=[p.name for p in default_image_processors()]) parser.add_argument( "--bidi_dir", type=str, default=None, choices=["ltr", "rtl", "auto"], help= "The default text direction when preprocessing bidirectional text. Supported values " "are 'auto' to automatically detect the direction, 'ltr' and 'rtl' for left-to-right and " "right-to-left, respectively") parser.add_argument("--preload", action='store_true', help='Simulate preloading') parser.add_argument("--as_validation", action='store_true', help="Access as validation instead of training data.") args = parser.parse_args() if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) dataset_args = FileDataReaderArgs(pad=args.pad, ) data_params: DataParams = Data.get_default_params() data_params.train = PipelineParams( type=args.dataset, remove_invalid=True, files=args.files, text_files=args.text_files, gt_extension=args.gt_extension, data_reader_args=dataset_args, batch_size=1, num_processes=args.processes, ) data_params.val = data_params.train data_params.pre_processors_ = SamplePipelineParams(run_parallel=True) data_params.post_processors_ = SamplePipelineParams(run_parallel=True) for p in args.data_preprocessing: p_p = Data.data_processor_factory().processors[p].default_params() if 'pad' in p_p: p_p['pad'] = args.pad data_params.pre_processors_.sample_processors.append( DataProcessorFactoryParams(p, INPUT_PROCESSOR, p_p)) # Text pre processing (reading) data_params.pre_processors_.sample_processors.extend([ DataProcessorFactoryParams( TextNormalizer.__name__, TARGETS_PROCESSOR, {'unicode_normalization': args.text_normalization}), DataProcessorFactoryParams( TextRegularizer.__name__, TARGETS_PROCESSOR, { 'replacements': default_text_regularizer_replacements(args.text_regularization) }), DataProcessorFactoryParams(StripTextProcessor.__name__, TARGETS_PROCESSOR) ]) # Text post processing (prediction) data_params.post_processors_.sample_processors.extend([ DataProcessorFactoryParams( TextNormalizer.__name__, TARGETS_PROCESSOR, {'unicode_normalization': args.text_normalization}), DataProcessorFactoryParams( TextRegularizer.__name__, TARGETS_PROCESSOR, { 'replacements': default_text_regularizer_replacements(args.text_regularization) }), DataProcessorFactoryParams(StripTextProcessor.__name__, TARGETS_PROCESSOR) ]) if args.bidi_dir: data_params.pre_processors_.sample_processors.append( DataProcessorFactoryParams(BidiTextProcessor.__name__, TARGETS_PROCESSOR, {'bidi_direction': args.bidi_dir})) data_params.post_processors_.sample_processors.append( DataProcessorFactoryParams(BidiTextProcessor.__name__, TARGETS_PROCESSOR, {'bidi_direction': args.bidi_dir})) data_params.pre_processors_.sample_processors.extend([ DataProcessorFactoryParams(AugmentationProcessor.__name__, {PipelineMode.Training}, {'augmenter_type': 'simple'}), # DataProcessorFactoryParams(PrepareSampleProcessor.__name__), # NOT THIS, since, we want to access raw input ]) data_params.data_aug_params = DataAugmentationAmount.from_factor( args.n_augmentations) data_params.line_height_ = args.line_height data = Data(data_params) data_pipeline = data.get_val_data( ) if args.as_validation else data.get_train_data() if not args.preload: reader: DataReader = data_pipeline.reader() if len(args.select) == 0: args.select = range(len(reader)) else: reader._samples = [reader.samples()[i] for i in args.select] else: data.preload() data_pipeline = data.get_val_data( ) if args.as_validation else data.get_train_data() samples = data_pipeline.samples if len(args.select) == 0: args.select = range(len(samples)) else: data_pipeline.samples = [samples[i] for i in args.select] f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all') row, col = 0, 0 with data_pipeline as dataset: for i, (id, sample) in enumerate( zip(args.select, dataset.generate_input_samples(auto_repeat=False))): line, text, params = sample if args.n_cols == 1: ax[row].imshow(line.transpose()) ax[row].set_title("ID: {}\n{}".format(id, text)) else: ax[row, col].imshow(line.transpose()) ax[row, col].set_title("ID: {}\n{}".format(id, text)) row += 1 if row == args.n_rows: row = 0 col += 1 if col == args.n_cols or i == len(dataset) - 1: plt.show() f, ax = plt.subplots(args.n_rows, args.n_cols, sharey='all') row, col = 0, 0
def main(): # Local imports (imports that require tensorflow) from calamari_ocr.ocr.scenario import Scenario from calamari_ocr.ocr.dataset.data import Data from calamari_ocr.ocr.evaluator import Evaluator parser = ArgumentParser() parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--gt", nargs="+", required=True, help="Ground truth files (.gt.txt extension). " "Optionally, you can pass a single json file defining all parameters.") parser.add_argument("--pred", nargs="+", default=None, help="Prediction files if provided. Else files with .pred.txt are expected at the same " "location as the gt.") parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--pred_ext", type=str, default=".pred.txt", help="Extension of the predicted text files") parser.add_argument("--n_confusions", type=int, default=10, help="Only print n most common confusions. Defaults to 10, use -1 for all.") parser.add_argument("--n_worst_lines", type=int, default=0, help="Print the n worst recognized text lines with its error") parser.add_argument("--xlsx_output", type=str, help="Optionally write a xlsx file with the evaluation results") parser.add_argument("--num_threads", type=int, default=1, help="Number of threads to use for evaluation") parser.add_argument("--non_existing_file_handling_mode", type=str, default="error", help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. " "'Skip' will simply skip the evaluation of that file (not counting it to errors). " "'Empty' will handle this file as would it be empty (fully checking for errors)." "'Error' will throw an exception if a file is not existing. This is the default behaviour.") parser.add_argument("--skip_empty_gt", action="store_true", default=False, help="Ignore lines of the gt that are empty.") parser.add_argument("--no_progress_bars", action="store_true", help="Do not show any progress bars") parser.add_argument("--checkpoint", type=str, default=None, help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)") # page xml specific args parser.add_argument("--pagexml_gt_text_index", default=0) parser.add_argument("--pagexml_pred_text_index", default=1) args = parser.parse_args() # check if loading a json file if len(args.gt) == 1 and args.gt[0].endswith("json"): with open(args.gt[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): setattr(args, key, value) logger.info("Resolving files") gt_files = sorted(glob_all(args.gt)) if args.pred: pred_files = sorted(glob_all(args.pred)) else: pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files] args.pred_dataset = args.dataset if args.non_existing_file_handling_mode.lower() == "skip": non_existing_pred = [p for p in pred_files if not os.path.exists(p)] for f in non_existing_pred: idx = pred_files.index(f) del pred_files[idx] del gt_files[idx] data_params = Data.get_default_params() if args.checkpoint: saved_model = SavedCalamariModel(args.checkpoint, auto_update=True) trainer_params = Scenario.trainer_params_from_dict(saved_model.dict) data_params = trainer_params.scenario_params.data_params data = Data(data_params) gt_reader_args = FileDataReaderArgs( text_index=args.pagexml_gt_text_index ) pred_reader_args = FileDataReaderArgs( text_index=args.pagexml_pred_text_index ) gt_data_set = PipelineParams( type=args.dataset, text_files=gt_files, data_reader_args=gt_reader_args, skip_invalid=args.skip_empty_gt, ) pred_data_set = PipelineParams( type=args.pred_dataset, text_files=pred_files, data_reader_args=pred_reader_args, ) evaluator = Evaluator(data=data) evaluator.preload_gt(gt_dataset=gt_data_set) r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads, progress_bar=not args.no_progress_bars) # TODO: More output print("Evaluation result") print("=================") print("") print("Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format( r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"])) # sort descending print_confusions(r, args.n_confusions) print_worst_lines(r, data.create_pipeline(PipelineMode.Targets, gt_data_set).reader().samples(), args.n_worst_lines) if args.xlsx_output: write_xlsx(args.xlsx_output, [{ "prefix": "evaluation", "results": r, "gt_files": gt_files, }])