def test_eval_files_with_different_sources(self): run_predict( predict_args(data=FileDataParams( pred_extension=".ext-pred.txt", images=sorted( glob_all([ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.png") ])), ))) r = run_eval( eval_args( gt_data=FileDataParams(texts=sorted( glob_all([ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.gt.txt") ]))), pred_data=FileDataParams(texts=sorted( glob_all([ os.path.join( this_dir, "data", "uw3_50lines", "test", "*.ext-pred.txt", ) ]))), )) self.assertLess(r["avg_ler"], 0.0009, msg="Current best model yields about 0.09% CER")
def prepare_for_mode(self, mode: PipelineMode): logger.info("Resolving input files") input_image_files = sorted(glob_all(self.images)) if not self.texts: gt_txt_files = [split_all_ext(f)[0] + self.gt_extension for f in input_image_files] else: gt_txt_files = sorted(glob_all(self.texts)) if mode in INPUT_PROCESSOR: input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception(f"Expected identical basenames of file: {img} and {gt}") else: input_image_files = None if mode in {PipelineMode.TRAINING, PipelineMode.EVALUATION}: if len(set(gt_txt_files)) != len(gt_txt_files): logger.warning( "Some ground truth text files occur more than once in the data set " "(ignore this warning, if this was intended)." ) if len(set(input_image_files)) != len(input_image_files): logger.warning( "Some images occur more than once in the data set. " "This warning should usually not be ignored." ) self.images = input_image_files self.texts = gt_txt_files
def create_train_dataset(args, dataset_args=None): gt_extension = args.gt_extension if args.gt_extension is not None else DataSetType.gt_extension(args.dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: if gt_extension: gt_txt_files = [split_all_ext(f)[0] + gt_extension for f in input_image_files] else: gt_txt_files = [None] * len(input_image_files) else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of file: {} and {}".format(img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception("Some image are occurring more than once in the data set.") dataset = create_dataset( args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt, args=dataset_args if dataset_args else {}, ) print("Found {} files in the dataset".format(len(dataset))) return dataset
def uw3_trainer_params(with_validation=False, with_split=False, preload=True, debug=False): p = CalamariTestScenario.default_trainer_params() p.scenario.debug_graph_construction = debug p.force_eager = debug train = FileDataParams( images=glob_all( [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]), preload=preload, ) if with_split: p.gen = CalamariSplitTrainerPipelineParams(validation_split_ratio=0.2, train=train) elif with_validation: p.gen.val.images = glob_all( [os.path.join(this_dir, "data", "uw3_50lines", "test", "*.png")]) p.gen.val.preload = preload p.gen.train = train p.gen.__post_init__() else: p.gen = CalamariTrainOnlyPipelineParams(train=train) p.gen.setup.val.batch_size = 1 p.gen.setup.val.num_processes = 1 p.gen.setup.train.batch_size = 1 p.gen.setup.train.num_processes = 1 post_init(p) return p
def prepare_for_mode(self, mode: PipelineMode) -> 'PipelineParams': from calamari_ocr.ocr.dataset.datareader.factory import DataReaderFactory assert (self.type is not None) params_out = deepcopy(self) # Training dataset logger.info("Resolving input files") if isinstance(self.type, str): try: self.type = DataSetType.from_string(self.type) except ValueError: # Not a valid type, must be custom if self.type not in DataReaderFactory.CUSTOM_READERS: raise KeyError( f"DataSetType {self.type} is neither a standard DataSetType or preset as custom " f"reader ({list(DataReaderFactory.CUSTOM_READERS.keys())})" ) if not isinstance(self.type, str) and self.type not in { DataSetType.RAW, DataSetType.GENERATED_LINE }: input_image_files = sorted(glob_all( self.files)) if self.files else None if not self.text_files: if self.gt_extension: gt_txt_files = [ split_all_ext(f)[0] + self.gt_extension for f in input_image_files ] else: gt_txt_files = None else: gt_txt_files = sorted(glob_all(self.text_files)) if mode in INPUT_PROCESSOR: input_image_files, gt_txt_files = keep_files_with_same_file_name( input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext( os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of file: {} and {}" .format(img, gt)) else: input_image_files = None if mode in {PipelineMode.Training, PipelineMode.Evaluation}: if len(set(gt_txt_files)) != len(gt_txt_files): logger.warning( "Some ground truth text files occur more than once in the data set " "(ignore this warning, if this was intended).") if len(set(input_image_files)) != len(input_image_files): logger.warning( "Some images occur more than once in the data set. " "This warning should usually not be ignored.") params_out.files = input_image_files params_out.text_files = gt_txt_files return params_out
def data_reader_from_params(mode: PipelineMode, params: PipelineParams) -> DataReader: assert (params.type is not None) from calamari_ocr.ocr.dataset.dataset_factory import create_data_reader # Training dataset logger.info("Resolving input files") if params.type not in {DataSetType.RAW, DataSetType.GENERATED_LINE}: input_image_files = sorted(glob_all( params.files)) if params.files else None if not params.text_files: if params.gt_extension: gt_txt_files = [ split_all_ext(f)[0] + params.gt_extension for f in input_image_files ] else: gt_txt_files = None else: gt_txt_files = sorted(glob_all(params.text_files)) if mode in INPUT_PROCESSOR: input_image_files, gt_txt_files = keep_files_with_same_file_name( input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext( os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of file: {} and {}". format(img, gt)) else: input_image_files = None if mode in {PipelineMode.Training, PipelineMode.Evaluation}: if len(set(gt_txt_files)) != len(gt_txt_files): logger.warning( "Some ground truth text files occur more than once in the data set " "(ignore this warning, if this was intended).") if len(set(input_image_files)) != len(input_image_files): logger.warning( "Some images occur more than once in the data set. " "This warning should usually not be ignored.") else: input_image_files = params.files gt_txt_files = params.text_files dataset = create_data_reader( params.type, mode, images=input_image_files, texts=gt_txt_files, skip_invalid=params.skip_invalid, args=params.data_reader_args if params.data_reader_args else FileDataReaderArgs(), ) logger.info(f"Found {len(dataset)} files in the dataset") return dataset
def create_test_dataset( cfg: CfgNode, dataset_args=None ) -> Union[List[Union[RawDataSet, FileDataSet, AbbyyDataSet, PageXMLDataset, Hdf5DataSet, ExtendedPredictionDataSet, GeneratedLineDataset]], None]: if cfg.DATASET.VALID.TEXT_FILES: assert len(cfg.DATASET.VALID.PATH) == len(cfg.DATASET.VALID.TEXT_FILES) if cfg.DATASET.VALID.PATH: validation_dataset_list = [] print("Resolving validation files") for i, valid_path in enumerate(cfg.DATASET.VALID.PATH): validation_image_files = glob_all(valid_path) dataregistry.register( i, os.path.basename(os.path.dirname(valid_path)), len(validation_image_files)) if not cfg.DATASET.VALID.TEXT_FILES: val_txt_files = [ split_all_ext(f)[0] + cfg.DATASET.VALID.GT_EXTENSION for f in validation_image_files ] else: val_txt_files = sorted( glob_all(cfg.DATASET.VALID.TEXT_FILES[i])) validation_image_files, val_txt_files = keep_files_with_same_file_name( validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext( os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of validation file: {} and {}" .format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception( "Some validation images are occurring more than once in the data set." ) validation_dataset = create_dataset( cfg.DATASET.VALID.TYPE, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not cfg.DATALOADER.NO_SKIP_INVALID_GT, args=dataset_args, ) print("Found {} files in the validation dataset".format( len(validation_dataset))) validation_dataset_list.append(validation_dataset) else: validation_dataset_list = None return validation_dataset_list
def prepare_for_mode(self, mode: PipelineMode): self.images = sorted(glob_all(self.images)) self.xml_files = sorted(self.xml_files) if not self.xml_files: self.xml_files = [ split_all_ext(f)[0] + self.gt_extension for f in self.images ] if not self.images: self.xml_files = sorted(glob_all(self.xml_files)) self.images = [None] * len(self.xml_files)
def main(): parser = ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True, help="The checkpoint used to resume") parser.add_argument("--validation", type=str, nargs="+", help="Validation line files used for early stopping") parser.add_argument("files", type=str, nargs="+", help="The files to use for training") args = parser.parse_args() # Train dataset input_image_files = glob_all(args.files) gt_txt_files = [split_all_ext(f)[0] + ".gt.txt" for f in input_image_files] if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception( "Some image are occurring more than once in the data set.") dataset = FileDataSet(input_image_files, gt_txt_files) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: validation_image_files = glob_all(args.validation) val_txt_files = [ split_all_ext(f)[0] + ".gt.txt" for f in validation_image_files ] if len(set(val_txt_files)) != len(val_txt_files): raise Exception( "Some validation images are occurring more than once in the data set." ) validation_dataset = FileDataSet(validation_image_files, val_txt_files) print("Found {} files in the validation dataset".format( len(validation_dataset))) else: validation_dataset = None with open(args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) trainer = Trainer(checkpoint_params, dataset, validation_dataset=validation_dataset, restore=args.checkpoint) trainer.train(progress_bar=True)
def test_eval_list_files(self): run_predict( predict_args(data=FileDataParams(images=sorted( glob_all([ os.path.join(this_dir, "data", "uw3_50lines", "test.files") ]))))) r = run_eval( eval_args(gt_data=FileDataParams(texts=sorted( glob_all([ os.path.join(this_dir, "data", "uw3_50lines", "test.gt.files") ]))))) self.assertLess(r["avg_ler"], 0.0009, msg="Current best model yields about 0.09% CER")
def to_prediction(self): self.files = sorted(glob_all(self.files)) pred = deepcopy(self) pred.files = [ split_all_ext(f)[0] + self.pred_extension for f in self.files ] return pred
def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", type=str, default=[], nargs="+", required=True, help="Protobuf files to convert") parser.add_argument("--logits", action="store_true", help="Do write logits") args = parser.parse_args() files = glob_all(args.files) for file in tqdm(files, desc="Converting"): predictions = Predictions() with open(file, 'rb') as f: predictions.ParseFromString(f.read()) if not args.logits: for prediction in predictions.predictions: prediction.logits.rows = 0 prediction.logits.cols = 0 prediction.logits.data[:] = [] out_json_path = split_all_ext(file)[0] + ".json" with open(out_json_path, 'w') as f: f.write( MessageToJson(predictions, including_default_value_fields=True))
def test_prediction_extended_and_positions(self): # With actual model to evaluate correct positions args = predict_args() args.checkpoint = [ os.path.join(this_dir, "models", f"version{SavedCalamariModel.VERSION}", "0.ckpt") ] args.extended_prediction_data = True run(args) jsons = [ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.json") ] run_compute_avg_pred(ExtendedPredictionDataParams(files=jsons)) def assert_pos_in_interval(p, start, end): self.assertGreaterEqual(p.global_start, start) self.assertGreaterEqual(p.global_end, start) self.assertLessEqual(p.global_start, end) self.assertLessEqual(p.global_end, end) with open(sorted(glob_all(jsons[0]))[0]) as f: first_pred: Predictions = Predictions.from_json(f.read()) for p in first_pred.predictions: # Check for correct prediction string (models is trained!) self.assertEqual( p.sentence, "The problem, simplified for our purposes, is set up as") # Check for correct character positions assert_pos_in_interval(p.positions[0], 0, 24) # T assert_pos_in_interval(p.positions[1], 24, 43) # h assert_pos_in_interval(p.positions[2], 45, 63) # e # ... assert_pos_in_interval(p.positions[-2], 1062, 1081) # a assert_pos_in_interval(p.positions[-1], 1084, 1099) # s
def __init__(self): self.files = glob_all([os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]) self.seed = 24 self.backend = "tensorflow" self.network = "cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5" self.line_height = 48 self.pad = 16 self.num_threads = 1 self.display = 1 self.batch_size = 1 self.checkpoint_frequency = 1000 self.max_iters = 1000 self.stats_size = 100 self.no_skip_invalid_gt = False self.no_progress_bars = True self.output_dir = os.path.join(this_dir, "test_models") self.output_model_prefix = "uw3_50lines" self.bidi_dir = None self.weights = None self.whitelist_files = [] self.whitelist = [] self.gradient_clipping_mode = "AUTO" self.gradient_clipping_const = 0 self.validation = None self.early_stopping_frequency = -1 self.early_stopping_nbest = 10 self.early_stopping_best_model_prefix = "uw3_50lines_best" self.early_stopping_best_model_output_dir = self.output_dir self.n_augmentations = 0 self.fuzzy_ctc_library_path = "" self.num_inter_threads = 0 self.num_intra_threads = 0 self.text_regularization = ["extended"] self.text_normalization = "NFC"
def __init__(self, settings: AlgorithmPredictorSettings): super().__init__(settings) # ctc_decoder_params = deepcopy(settings.params.ctcDecoder.params) # lnp = LyricsNormalizationProcessor(LyricsNormalizationParams(LyricsNormalization.ONE_STRING)) # if len(ctc_decoder_params.dictionary) > 0: # ctc_decoder_params.dictionary[:] = [lnp.apply(word) for word in ctc_decoder_params.dictionary] # else: # with open(os.path.join(BASE_DIR, 'internal_storage', 'resources', 'hyphen_dictionary.txt')) as f: # # TODO: dataset params in settings, that we can create the correct normalization params # ctc_decoder_params.dictionary[:] = [lnp.apply(line.split()[0]) for line in f.readlines()] # self.predictor = MultiPredictor(glob_all([s + '/text_best*.ckpt.json' for s in params.checkpoints])) voter_params = VoterParams() voter_params.type = VoterParams.type.ConfidenceVoterDefaultCTC self.predictor = MultiPredictor.from_paths( checkpoints=glob_all([settings.model.local_file('text.ckpt.json') ]), voter_params=voter_params, predictor_params=PredictorParams( silent=True, progress_bar=True, pipeline=DataPipelineParams(batch_size=1, mode=PipelineMode("prediction")))) # self.height = self.predictor.predictors[0].network_params.features self.voter = voter_from_params(voter_params) self.dict_corrector = None if settings.params.useDictionaryCorrection: self.dict_corrector = DictionaryCorrector()
def __init__(self, n_folds, source_files, output_dir): """ Prepare cross fold training This class creates folds out of the given source files. The individual splits are the optionally written to the `output_dir` in a json format. The file with index i will be assigned to fold i % n_folds (not randomly!) Parameters ---------- n_folds : int the number of folds to create source_files : str the source file names output_dir : str where to store the folds """ self.n_folds = n_folds self.inputs = glob_all(source_files) self.output_dir = os.path.abspath(output_dir) if len(self.inputs) == 0: raise Exception("No files found at '{}'".format(source_files)) if self.n_folds <= 1: raise Exception("At least two folds are required") # fill single fold files self.folds = [[] for _ in range(self.n_folds)] for i, input in enumerate(self.inputs): self.folds[i % n_folds].append(input)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", nargs="+", required=True, help="The image files to predict with its gt and pred") parser.add_argument("--html_output", type=str, required=True, help="Where to write the html file") parser.add_argument("--open", action="store_true", help="Automatically open the file") args = parser.parse_args() img_files = sorted(glob_all(args.files)) gt_files = [split_all_ext(f)[0] + ".gt.txt" for f in img_files] pred_files = [split_all_ext(f)[0] + ".pred.txt" for f in img_files] with open(args.html_output, 'w') as html: html.write(""" <!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"/> </head> <body> <ul>""") for img, gt, pred in zip(img_files, gt_files, pred_files): html.write("<li><p><img src=\"file://{}\"></p><p>{}</p><p>{}</p>\n".format( img.replace('\\', '/').replace('/', '\\\\'), open(gt).read(), open(pred).read() )) html.write("</ul></body></html>") if args.open: webbrowser.open(args.html_output)
def __init__(self, n_folds, source_files, output_dir): """ Prepare cross fold training This class creates folds out of the given source files. The individual splits are the optionally written to the `output_dir` in a json format. The file with index i will be assigned to fold i % n_folds (not randomly!) Parameters ---------- n_folds : int the number of folds to create source_files : str the source file names output_dir : str where to store the folds """ self.n_folds = n_folds self.inputs = sorted(glob_all(source_files)) self.output_dir = os.path.abspath(output_dir) if len(self.inputs) == 0: raise Exception("No files found at '{}'".format(source_files)) if self.n_folds <= 1: raise Exception("At least two folds are required") # fill single fold files self.folds = [[] for _ in range(self.n_folds)] for i, input in enumerate(self.inputs): self.folds[i % n_folds].append(input)
def test_prediction_files_with_different_extension(self): run_predict( predict_args(data=FileDataParams( pred_extension='.ext-pred.txt', images=sorted( glob_all([ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.png") ]))))) run_eval( eval_args(gt_data=FileDataParams( pred_extension='.ext-pred.txt', texts=sorted( glob_all([ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.gt.txt") ])))))
def prepare_for_mode(self, mode: PipelineMode): self.images = sorted(glob_all(self.images)) self.xml_files = sorted(glob_all(self.xml_files)) if not self.xml_files: self.xml_files = [split_all_ext(f)[0] + self.gt_extension for f in self.images] if not self.images: self.images = [None] * len(self.xml_files) if len(self.images) != len(self.xml_files): raise ValueError(f"Different number of image and xml files, {len(self.images)} != {len(self.xml_files)}") for img_path, xml_path in zip(self.images, self.xml_files): if img_path and xml_path: img_bn, xml_bn = split_all_ext(img_path)[0], split_all_ext(xml_path)[0] if img_bn != xml_bn: logger.warning( f"Filenames are not matching, got base names \n image: {img_bn}\n xml: {xml_bn}\n." )
def __init__(self, settings: AlgorithmPredictorSettings): super().__init__(settings) self.predictor = MultiPredictor(glob_all([s + '/omr_best*.ckpt.json' for s in [settings.model.path]]), ctc_decoder_params=settings.params.ctcDecoder.params) self.height = self.predictor.predictors[0].network_params.features voter_params = VoterParams() voter_params.type = VoterParams.CONFIDENCE_VOTER_DEFAULT_CTC self.voter = voter_from_proto(voter_params)
def __init__(self): self.dataset = DataSetType.FILE self.gt_extension = DataSetType.gt_extension(self.dataset) self.files = glob_all( [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]) self.seed = 24 self.backend = "tensorflow" self.network = "cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5" self.line_height = 48 self.pad = 16 self.num_threads = 1 self.display = 1 self.batch_size = 1 self.checkpoint_frequency = 1000 self.epochs = 1 self.samples_per_epoch = 8 self.stats_size = 100 self.no_skip_invalid_gt = False self.no_progress_bars = True self.output_dir = None self.output_model_prefix = "uw3_50lines" self.bidi_dir = None self.weights = None self.ema_weights = False self.whitelist_files = [] self.whitelist = [] self.gradient_clipping_norm = 5 self.validation = None self.validation_dataset = DataSetType.FILE self.validation_extension = None self.validation_split_ratio = None self.early_stopping_frequency = -1 self.early_stopping_nbest = 10 self.early_stopping_at_accuracy = 0.99 self.early_stopping_best_model_prefix = "uw3_50lines_best" self.early_stopping_best_model_output_dir = self.output_dir self.n_augmentations = 0 self.num_inter_threads = 0 self.num_intra_threads = 0 self.text_regularization = ["extended"] self.text_normalization = "NFC" self.text_generator_params = None self.line_generator_params = None self.pagexml_text_index = 0 self.text_files = None self.only_train_on_augmented = False self.data_preprocessing = [p.name for p in default_image_processors()] self.shuffle_buffer_size = 1000 self.keep_loaded_codec = False self.train_data_on_the_fly = False self.validation_data_on_the_fly = False self.no_auto_compute_codec = False self.dataset_pad = 0 self.debug = False self.train_verbose = True self.use_train_as_val = False self.ensemble = -1 self.masking_mode = 1
def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", nargs="+", required=True, help="List of all image files with corresponding gt.txt files") parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") args = parser.parse_args() print("Resolving files") image_files = glob_all(args.files) gt_files = [split_all_ext(p)[0] + ".gt.txt" for p in image_files] ds = create_dataset( args.dataset, DataSetMode.TRAIN, images=image_files, texts=gt_files, non_existing_as_empty=True) print("Loading {} files".format(len(image_files))) ds.load_samples(processes=1, progress_bar=True) images, texts = ds.train_samples(skip_empty=True) statistics = { "n_lines": len(images), "chars": [len(c) for c in texts], "widths": [img.shape[1] / img.shape[0] * args.line_height + 2 * args.pad for img in images if img is not None and img.shape[0] > 0 and img.shape[1] > 0], "total_line_width": 0, "char_counts": {}, } for image, text in zip(images, texts): for c in text: if c in statistics["char_counts"]: statistics["char_counts"][c] += 1 else: statistics["char_counts"][c] = 1 statistics["av_line_width"] = np.average(statistics["widths"]) statistics["max_line_width"] = np.max(statistics["widths"]) statistics["min_line_width"] = np.min(statistics["widths"]) statistics["total_line_width"] = np.sum(statistics["widths"]) statistics["av_chars"] = np.average(statistics["chars"]) statistics["max_chars"] = np.max(statistics["chars"]) statistics["min_chars"] = np.min(statistics["chars"]) statistics["total_chars"] = np.sum(statistics["chars"]) statistics["av_px_per_char"] = statistics["av_line_width"] / statistics["av_chars"] statistics["codec_size"] = len(statistics["char_counts"]) del statistics["chars"] del statistics["widths"] print(statistics)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", nargs="+", type=str, required=True, help="The image files to copy") parser.add_argument("--target_dir", type=str, required=True, help="") parser.add_argument("--index_files", action="store_true") parser.add_argument("--convert_images", type=str, help="Convert the image to a given type (by default use original format). E. g. jpg, png, tif, ...") parser.add_argument("--gt_ext", type=str, default=".gt.txt") parser.add_argument("--index_ext", type=str, default=".index") args = parser.parse_args() if args.convert_images and not args.convert_images.startswith("."): args.convert_images = "." + args.convert_images args.target_dir = os.path.expanduser(args.target_dir) print("Resolving files") image_files = glob_all(args.files) gt_files = [split_all_ext(p)[0] + ".gt.txt" for p in image_files] if len(image_files) == 0: raise Exception("No files found") if not os.path.isdir(args.target_dir): os.makedirs(args.target_dir) for i, (img, gt) in tqdm(enumerate(zip(image_files, gt_files)), total=len(gt_files), desc="Copying"): if not os.path.exists(img) or not os.path.exists(gt): # skip non existing examples continue # img with optional convert try: ext = split_all_ext(img)[1] target_ext = args.convert_images if args.convert_images else ext target_name = os.path.join(args.target_dir, "{:08}{}".format(i, target_ext)) if ext == target_ext: shutil.copyfile(img, target_name) else: data = skimage_io.imread(img) skimage_io.imsave(target_name, data) except: continue # gt txt target_name = os.path.join(args.target_dir, "{:08}{}".format(i, args.gt_ext)) shutil.copyfile(gt, target_name) if args.index_files: target_name = os.path.join(args.target_dir, "{:08}{}".format(i, args.index_ext)) with open(target_name, "w") as f: f.write(str(i))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--checkpoints', nargs='+', type=str, required=True) parser.add_argument('--dry_run', action='store_true') args = parser.parse_args() for ckpt in tqdm(glob_all(args.checkpoints)): ckpt = os.path.splitext(ckpt)[0] SavedCalamariModel(ckpt, dry_run=args.dry_run)
def setup_trainer_params(preload=True, debug=False): p = CalamariTestEnsembleScenario.default_trainer_params() p.force_eager = debug p.gen.train = FileDataParams( images=glob_all([os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]), preload=preload, ) post_init(p) return p
def __init__(self): self.files = glob_all([os.path.join(this_dir, "data", "uw3_50lines", "test", "*.png")]) self.checkpoint = [os.path.join(this_dir, "test_models", "uw3_50lines_best.ckpt")] self.processes = 1 self.batch_size = 1 self.verbose = True self.voter = "confidence_voter_default_ctc" self.output_dir = None self.extended_prediction_data = None self.extended_prediction_data_format = "json" self.no_progress_bars = True
def test_prediction_files(self): run_predict( predict_args(data=FileDataParams(images=sorted( glob_all([ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.png") ]))))) run_eval( eval_args(gt_data=FileDataParams(texts=sorted( glob_all([ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.gt.txt") ]))))) args = eval_args(gt_data=FileDataParams(texts=sorted( glob_all([ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.gt.txt") ])))) with tempfile.TemporaryDirectory() as d: args.xlsx_output = os.path.join(d, 'output.xlsx') run_eval(args)
def __post_init__(self): # parse whitelist if len(self.include) == 1: include = set(self.include[0]) else: include = set(self.include) for f in glob_all(self.include_files): with open(f) as txt: include = include.union(txt.read()) self.resolved_include_chars = include
def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", type=str, nargs="+", required=True, help="Text files to apply text processing") parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") parser.add_argument("--pad_value", type=int, default=1, help="Padding (left right) of the line") parser.add_argument("--processes", type=int, default=1) parser.add_argument("--verbose", action="store_true") parser.add_argument("--invert", action="store_true") parser.add_argument("--transpose", action="store_true") parser.add_argument("--dry_run", action="store_true", help="No not overwrite files, just run") args = parser.parse_args() params = DataPreprocessorParams() params.line_height = args.line_height params.pad = args.pad params.pad_value = args.pad_value params.no_invert = not args.invert params.no_transpos = not args.transpose data_proc = MultiDataProcessor([ DataRangeNormalizer(), CenterNormalizer(params), FinalPreparation(params, as_uint8=True), ]) print("Resolving files") img_files = sorted(glob_all(args.files)) handler = Handler(data_proc, args.dry_run) with multiprocessing.Pool(processes=args.processes, maxtasksperchild=100) as pool: list( tqdm(pool.imap(handler.handle_single, img_files), desc="Processing", total=len(img_files)))
def main(): parser = argparse.ArgumentParser(description=usage_str) parser.add_argument('--checkpoints', nargs='+', type=str, required=True) parser.add_argument('--replace_from') parser.add_argument('--replace_to') parser.add_argument('--add_prefix') parser.add_argument('--dry_run', action='store_true') args = parser.parse_args() for ckpt in tqdm(glob_all(args.checkpoints)): ckpt = os.path.splitext(ckpt)[0] rename(ckpt, args.replace_from, args.replace_to, args.add_prefix, args.dry_run)
def create_train_dataset(cfg: CfgNode, dataset_args=None): gt_extension = cfg.DATASET.TRAIN.GT_EXTENSION if cfg.DATASET.TRAIN.GT_EXTENSION is not False else DataSetType.gt_extension( cfg.DATASET.TRAIN.TYPE) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(cfg.DATASET.TRAIN.PATH)) if not cfg.DATASET.TRAIN.TEXT_FILES: if gt_extension: gt_txt_files = [ split_all_ext(f)[0] + gt_extension for f in input_image_files ] else: gt_txt_files = [None] * len(input_image_files) else: gt_txt_files = sorted(glob_all(cfg.DATASET.TRAIN.TEXT_FILES)) input_image_files, gt_txt_files = keep_files_with_same_file_name( input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of file: {} and {}".format( img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception( "Some image are occurring more than once in the data set.") dataset = create_dataset( cfg.DATASET.TRAIN.TYPE, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not cfg.DATALOADER.NO_SKIP_INVALID_GT, args=dataset_args if dataset_args else {}, ) print("Found {} files in the dataset".format(len(dataset))) return dataset
def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", type=str, nargs="+", required=True, help="Text files to apply text processing") parser.add_argument("--text_regularization", type=str, nargs="+", default=["extended"], help="Text regularization to apply.") parser.add_argument( "--text_normalization", type=str, default="NFC", help="Unicode text normalization to apply. Defaults to NFC") parser.add_argument("--verbose", action="store_true") parser.add_argument("--dry_run", action="store_true", help="No not overwrite files, just run") args = parser.parse_args() # Text pre processing (reading) preproc = TextProcessorParams() preproc.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(preproc.children.add(), default=args.text_normalization) default_text_regularizer_params(preproc.children.add(), groups=args.text_regularization) strip_processor_params = preproc.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER txt_proc = text_processor_from_proto(preproc, "pre") print("Resolving files") text_files = glob_all(args.files) for path in tqdm(text_files, desc="Processing", total=len(text_files)): with codecs.open(path, "r", "utf-8") as f: content = f.read() content = txt_proc.apply(content) if args.verbose: print(content) if not args.dry_run: with codecs.open(path, "w", "utf-8") as f: f.write(content)
def main(): parser = ArgumentParser() parser.add_argument("--pred", nargs="+", required=True, help="Extended prediction files (.json extension)") args = parser.parse_args() print("Resolving files") pred_files = sorted(glob_all(args.pred)) data_set = create_dataset( DataSetType.EXTENDED_PREDICTION, DataSetMode.EVAL, texts=pred_files, ) data_set.load_samples(progress_bar=True) print('Average confidence: {:.2%}'.format(np.mean([s['best_prediction'].avg_char_probability for s in data_set.samples()])))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", type=str, nargs="+", required=True, help="Text files to apply text processing") parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") parser.add_argument("--pad_value", type=int, default=1, help="Padding (left right) of the line") parser.add_argument("--processes", type=int, default=1) parser.add_argument("--verbose", action="store_true") parser.add_argument("--invert", action="store_true") parser.add_argument("--transpose", action="store_true") parser.add_argument("--dry_run", action="store_true", help="No not overwrite files, just run") args = parser.parse_args() params = DataPreprocessorParams() params.line_height = args.line_height params.pad = args.pad params.pad_value = args.pad_value params.no_invert = not args.invert params.no_transpos = not args.transpose data_proc = MultiDataProcessor([ DataRangeNormalizer(), CenterNormalizer(params), FinalPreparation(params, as_uint8=True), ]) print("Resolving files") img_files = sorted(glob_all(args.files)) handler = Handler(data_proc, args.dry_run) with multiprocessing.Pool(processes=args.processes, maxtasksperchild=100) as pool: list(tqdm(pool.imap(handler.handle_single, img_files), desc="Processing", total=len(img_files)))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", type=str, default=[], nargs="+", required=True, help="Protobuf files to convert") parser.add_argument("--logits", action="store_true", help="Do write logits") args = parser.parse_args() files = glob_all(args.files) for file in tqdm(files, desc="Converting"): predictions = Predictions() with open(file, 'rb') as f: predictions.ParseFromString(f.read()) if not args.logits: for prediction in predictions.predictions: prediction.logits.rows = 0 prediction.logits.cols = 0 prediction.logits.data[:] = [] out_json_path = split_all_ext(file)[0] + ".json" with open(out_json_path, 'w') as f: f.write(MessageToJson(predictions, including_default_value_fields=True))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", nargs="+", required=True, help="All img files, an appropriate .gt.txt must exist") parser.add_argument("--n_eval", type=float, required=True, help="The (relative or absolute) count of training files (or -1 to use the remaining)") parser.add_argument("--n_train", type=float, required=True, help="The (relative or absolute) count of training files (or -1 to use the remaining)") parser.add_argument("--output_dir", type=str, required=True, help="Where to write the splits") parser.add_argument("--eval_sub_dir", type=str, default="eval") parser.add_argument("--train_sub_dir", type=str, default="train") args = parser.parse_args() img_files = sorted(glob_all(args.files)) if len(img_files) == 0: raise Exception("No files were found") gt_txt_files = [split_all_ext(p)[0] + ".gt.txt" for p in img_files] if args.n_eval < 0: pass elif args.n_eval < 1: args.n_eval = int(args.n_eval) * len(img_files) else: args.n_eval = int(args.n_eval) if args.n_train < 0: pass elif args.n_train < 1: args.n_train = int(args.n_train) * len(img_files) else: args.n_train = int(args.n_train) if args.n_eval < 0 and args.n_train < 0: raise Exception("Either n_eval or n_train may be < 0") if args.n_eval < 0: args.n_eval = len(img_files) - args.n_train elif args.n_train < 0: args.n_train = len(img_files) - args.n_eval if args.n_eval + args.n_train > len(img_files): raise Exception("Got {} eval and {} train files = {} in total, but only {} files are in the dataset".format( args.n_eval, args.n_train, args.n_eval + args.n_train, len(img_files) )) def copy_files(imgs, txts, out_dir): assert(len(imgs) == len(txts)) if not os.path.exists(out_dir): os.makedirs(out_dir) for img, txt in tqdm(zip(imgs, txts), total=len(imgs), desc="Writing to {}".format(out_dir)): if not os.path.exists(img): print("Image file at {} not found".format(img)) continue if not os.path.exists(txt): print("Ground truth file at {} not found".format(txt)) continue shutil.copyfile(img, os.path.join(out_dir, os.path.basename(img))) shutil.copyfile(txt, os.path.join(out_dir, os.path.basename(txt))) copy_files(img_files[:args.n_eval], gt_txt_files[:args.n_eval], os.path.join(args.output_dir, args.eval_sub_dir)) copy_files(img_files[args.n_eval:], gt_txt_files[args.n_eval:], os.path.join(args.output_dir, args.train_sub_dir))
def main(): parser = ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True, help="The checkpoint used to resume") # validation files parser.add_argument("--validation", type=str, nargs="+", help="Validation line files used for early stopping") parser.add_argument("--validation_text_files", nargs="+", default=None, help="Optional list of validation GT files if they are in other directory") parser.add_argument("--validation_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)") parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) # input files parser.add_argument("--files", nargs="+", help="List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location") parser.add_argument("--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument("--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)") parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension(args.validation_dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files] else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of file: {} and {}".format(img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception("Some image are occurring more than once in the data set.") dataset = create_dataset( args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt ) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception("Some validation images are occurring more than once in the data set.") validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format(len(validation_dataset))) else: validation_dataset = None print("Resuming training") with open(args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) trainer = Trainer(checkpoint_params, dataset, validation_dataset=validation_dataset, weights=args.checkpoint) trainer.train(progress_bar=True)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--eval_imgs", type=str, nargs="+", required=True, help="The evaluation files") parser.add_argument("--eval_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--checkpoint", type=str, nargs="+", default=[], help="Path to the checkpoint without file extension") parser.add_argument("-j", "--processes", type=int, default=1, help="Number of processes to use") parser.add_argument("--verbose", action="store_true", help="Print additional information") parser.add_argument("--voter", type=str, nargs="+", default=["sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc"], help="The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), " "confidence_voter_fuzzy_ctc, sequence_voter") parser.add_argument("--batch_size", type=int, default=10, help="The batch size for prediction") parser.add_argument("--dump", type=str, help="Dump the output as serialized pickle object") parser.add_argument("--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() # allow user to specify json file for model definition, but remove the file extension # for further processing args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint] # load files gt_images = sorted(glob_all(args.eval_imgs)) gt_txts = [split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs))] dataset = create_dataset( args.eval_dataset, DataSetMode.TRAIN, images=gt_images, texts=gt_txts, skip_invalid=not args.no_skip_invalid_gt ) print("Found {} files in the dataset".format(len(dataset))) if len(dataset) == 0: raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files)) # predict for all models n_models = len(args.checkpoint) predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes) do_prediction = predictor.predict_dataset(dataset, progress_bar=True) voters = [] all_voter_sentences = [] all_prediction_sentences = [[] for _ in range(n_models)] for voter in args.voter: # create voter voter_params = VoterParams() voter_params.type = VoterParams.Type.Value(voter.upper()) voters.append(voter_from_proto(voter_params)) all_voter_sentences.append([]) for prediction, sample in do_prediction: for sent, p in zip(all_prediction_sentences, prediction): sent.append(p.sentence) # vote results for voter, voter_sentences in zip(voters, all_voter_sentences): voter_sentences.append(voter.vote_prediction_result(prediction).sentence) # evaluation text_preproc = text_processor_from_proto(predictor.predictors[0].model_params.text_preprocessor) evaluator = Evaluator(text_preprocessor=text_preproc) evaluator.preload_gt(gt_dataset=dataset, progress_bar=True) def single_evaluation(predicted_sentences): if len(predicted_sentences) != len(dataset): raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did " "not succeed".format(len(dataset), len(predicted_sentences))) pred_data_set = create_dataset( DataSetType.RAW, DataSetMode.EVAL, texts=predicted_sentences) r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes) return r full_evaluation = {} for id, data in [(str(i), sent) for i, sent in enumerate(all_prediction_sentences)] + list(zip(args.voter, all_voter_sentences)): full_evaluation[id] = {"eval": single_evaluation(data), "data": data} if args.verbose: print(full_evaluation) if args.dump: import pickle with open(args.dump, 'wb') as f: pickle.dump({"full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples()}, f)
def run(args): # check if loading a json file if len(args.files) == 1 and args.files[0].endswith("json"): import json with open(args.files[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): setattr(args, key, value) # parse whitelist whitelist = args.whitelist if len(whitelist) == 1: whitelist = list(whitelist[0]) whitelist_files = glob_all(args.whitelist_files) for f in whitelist_files: with open(f) as txt: whitelist += list(txt.read()) if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension(args.validation_dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files] else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of file: {} and {}".format(img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception("Some image are occurring more than once in the data set.") dataset = create_dataset( args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt ) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception("Some validation images are occurring more than once in the data set.") validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format(len(validation_dataset))) else: validation_dataset = None params = CheckpointParams() params.max_iters = args.max_iters params.stats_size = args.stats_size params.batch_size = args.batch_size params.checkpoint_frequency = args.checkpoint_frequency if args.checkpoint_frequency >= 0 else args.early_stopping_frequency params.output_dir = args.output_dir params.output_model_prefix = args.output_model_prefix params.display = args.display params.skip_invalid_gt = not args.no_skip_invalid_gt params.processes = args.num_threads params.data_aug_retrain_on_original = not args.only_train_on_augmented params.early_stopping_frequency = args.early_stopping_frequency params.early_stopping_nbest = args.early_stopping_nbest params.early_stopping_best_model_prefix = args.early_stopping_best_model_prefix params.early_stopping_best_model_output_dir = \ args.early_stopping_best_model_output_dir if args.early_stopping_best_model_output_dir else args.output_dir params.model.data_preprocessor.type = DataPreprocessorParams.DEFAULT_NORMALIZER params.model.data_preprocessor.line_height = args.line_height params.model.data_preprocessor.pad = args.pad # Text pre processing (reading) params.model.text_preprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_preprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_preprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_preprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER # Text post processing (prediction) params.model.text_postprocessor.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(params.model.text_postprocessor.children.add(), default=args.text_normalization) default_text_regularizer_params(params.model.text_postprocessor.children.add(), groups=args.text_regularization) strip_processor_params = params.model.text_postprocessor.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER if args.seed > 0: params.model.network.backend.random_seed = args.seed if args.bidi_dir: # change bidirectional text direction if desired bidi_dir_to_enum = {"rtl": TextProcessorParams.BIDI_RTL, "ltr": TextProcessorParams.BIDI_LTR, "auto": TextProcessorParams.BIDI_AUTO} bidi_processor_params = params.model.text_preprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = bidi_dir_to_enum[args.bidi_dir] bidi_processor_params = params.model.text_postprocessor.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO params.model.line_height = args.line_height network_params_from_definition_string(args.network, params.model.network) params.model.network.clipping_mode = NetworkParams.ClippingMode.Value("CLIP_" + args.gradient_clipping_mode.upper()) params.model.network.clipping_constant = args.gradient_clipping_const params.model.network.backend.fuzzy_ctc_library_path = args.fuzzy_ctc_library_path params.model.network.backend.num_inter_threads = args.num_inter_threads params.model.network.backend.num_intra_threads = args.num_intra_threads # create the actual trainer trainer = Trainer(params, dataset, validation_dataset=validation_dataset, data_augmenter=SimpleDataAugmenter(), n_augmentations=args.n_augmentations, weights=args.weights, codec_whitelist=whitelist, preload_training=not args.train_data_on_the_fly, preload_validation=not args.validation_data_on_the_fly, ) trainer.train( auto_compute_codec=not args.no_auto_compute_codec, progress_bar=not args.no_progress_bars )
def main(): parser = ArgumentParser() parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--gt", nargs="+", required=True, help="Ground truth files (.gt.txt extension)") parser.add_argument("--pred", nargs="+", default=None, help="Prediction files if provided. Else files with .pred.txt are expected at the same " "location as the gt.") parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--pred_ext", type=str, default=".pred.txt", help="Extension of the predicted text files") parser.add_argument("--n_confusions", type=int, default=10, help="Only print n most common confusions. Defaults to 10, use -1 for all.") parser.add_argument("--n_worst_lines", type=int, default=0, help="Print the n worst recognized text lines with its error") parser.add_argument("--xlsx_output", type=str, help="Optionally write a xlsx file with the evaluation results") parser.add_argument("--num_threads", type=int, default=1, help="Number of threads to use for evaluation") parser.add_argument("--non_existing_file_handling_mode", type=str, default="error", help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. " "'Skip' will simply skip the evaluation of that file (not counting it to errors). " "'Empty' will handle this file as would it be empty (fully checking for errors)." "'Error' will throw an exception if a file is not existing. This is the default behaviour.") parser.add_argument("--no_progress_bars", action="store_true", help="Do not show any progress bars") parser.add_argument("--checkpoint", type=str, default=None, help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)") # page xml specific args parser.add_argument("--pagexml_gt_text_index", default=0) parser.add_argument("--pagexml_pred_text_index", default=1) args = parser.parse_args() print("Resolving files") gt_files = sorted(glob_all(args.gt)) if args.pred: pred_files = sorted(glob_all(args.pred)) else: pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files] args.pred_dataset = args.dataset if args.non_existing_file_handling_mode.lower() == "skip": non_existing_pred = [p for p in pred_files if not os.path.exists(p)] for f in non_existing_pred: idx = pred_files.index(f) del pred_files[idx] del gt_files[idx] text_preproc = None if args.checkpoint: with open(args.checkpoint if args.checkpoint.endswith(".json") else args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) text_preproc = text_processor_from_proto(checkpoint_params.model.text_preprocessor) non_existing_as_empty = args.non_existing_file_handling_mode.lower() != "error " gt_data_set = create_dataset( args.dataset, DataSetMode.EVAL, texts=gt_files, non_existing_as_empty=non_existing_as_empty, args={'text_index': args.pagexml_gt_text_index}, ) pred_data_set = create_dataset( args.pred_dataset, DataSetMode.EVAL, texts=pred_files, non_existing_as_empty=non_existing_as_empty, args={'text_index': args.pagexml_pred_text_index}, ) evaluator = Evaluator(text_preprocessor=text_preproc) r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads, progress_bar=not args.no_progress_bars) # TODO: More output print("Evaluation result") print("=================") print("") print("Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format( r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"])) # sort descending print_confusions(r, args.n_confusions) print_worst_lines(r, gt_data_set.samples(), pred_data_set.text_samples(), args.n_worst_lines) if args.xlsx_output: write_xlsx(args.xlsx_output, [{ "prefix": "evaluation", "results": r, "gt_files": gt_files, "gts": gt_data_set.text_samples(), "preds": pred_data_set.text_samples() }])
def main(): parser = argparse.ArgumentParser() parser.add_argument("--base_dir", type=str, required=True, help="The base directory where to store all working files") parser.add_argument("--eval_files", type=str, nargs="+", required=True, help="All files that shall be used for evaluation") parser.add_argument("--train_files", type=str, nargs="+", required=True, help="All files that shall be used for (cross-fold) training") parser.add_argument("--n_lines", type=int, default=[-1], nargs="+", help="Optional argument to specify the number of lines (images) used for training. " "On default, all available lines will be used.") parser.add_argument("--run", type=str, default=None, help="An optional command that will receive the train calls. Useful e.g. when using a resource " "manager such as slurm.") parser.add_argument("--n_folds", type=int, default=5, help="The number of fold, that is the number of models to train") parser.add_argument("--max_parallel_models", type=int, default=-1, help="Number of models to train in parallel per fold. Defaults to all.") parser.add_argument("--weights", type=str, nargs="+", default=[], help="Load network weights from the given file. If more than one file is provided the number " "models must match the number of folds. Each fold is then initialized with the weights " "of each model, respectively") parser.add_argument("--single_fold", type=int, nargs="+", default=[], help="Only train a single (list of single) specific fold(s).") parser.add_argument("--skip_train", action="store_true", help="Skip the cross fold training") parser.add_argument("--skip_eval", action="store_true", help="Skip the cross fold evaluation") parser.add_argument("--verbose", action="store_true", help="Verbose output") parser.add_argument("--n_confusions", type=int, default=0, help="Only print n most common confusions. Defaults to 0, use -1 for all.") parser.add_argument("--xlsx_output", type=str, help="Optionally write a xlsx file with the evaluation results") setup_train_args(parser, omit=["files", "validation", "weights", "early_stopping_best_model_output_dir", "early_stopping_best_model_prefix", "output_dir"]) args = parser.parse_args() args.base_dir = os.path.abspath(os.path.expanduser(args.base_dir)) np.random.seed(args.seed) random.seed(args.seed) # argument checks args.weights = glob_all(args.weights) if len(args.weights) > 1 and len(args.weights) != args.n_folds: raise Exception("Either no, one or n_folds (={}) models are required for pretraining but got {}.".format( args.n_folds, len(args.weights) )) if len(args.single_fold) > 0: if len(set(args.single_fold)) != len(args.single_fold): raise Exception("Repeated fold id's found.") for fold_id in args.single_fold: if fold_id < 0 or fold_id >= args.n_folds: raise Exception("Invalid fold id found: 0 <= id <= {}, but id == {}".format(args.n_folds, fold_id)) actual_folds = args.single_fold else: actual_folds = list(range(args.n_folds)) # run for all lines single_args = [copy.copy(args) for _ in args.n_lines] for s_args, n_lines in zip(single_args, args.n_lines): s_args.n_lines = n_lines predictions = parallel_map(run_for_single_line, single_args, progress_bar=False, processes=len(single_args), use_thread_pool=True) # output predictions as csv: header = "lines," + ",".join([str(fold) for fold in range(args.n_folds)])\ + ",avg,std,seq. vot., def. conf. vot., fuz. conf. vot." print(header) for prediction_map, n_lines in zip(predictions, args.n_lines): prediction = prediction_map["full"] data = "{}".format(n_lines) folds_lers = [] for fold in range(len(actual_folds)): eval = prediction[str(fold)]["eval"] data += ",{}".format(eval['avg_ler']) folds_lers.append(eval['avg_ler']) data += ",{},{}".format(np.mean(folds_lers), np.std(folds_lers)) for voter in ['sequence_voter', 'confidence_voter_default_ctc', 'confidence_voter_fuzzy_ctc']: eval = prediction[voter]["eval"] data += ",{}".format(eval['avg_ler']) print(data) if args.n_confusions != 0: for prediction_map, n_lines in zip(predictions, args.n_lines): prediction = prediction_map["full"] print("") print("CONFUSIONS (lines = {})".format(n_lines)) print("==========") print() for fold in range(len(actual_folds)): print("FOLD {}".format(fold)) print_confusions(prediction[str(fold)]['eval'], args.n_confusions) for voter in ['sequence_voter', 'confidence_voter_default_ctc', 'confidence_voter_fuzzy_ctc']: print("VOTER {}".format(voter)) print_confusions(prediction[voter]['eval'], args.n_confusions) if args.xlsx_output: data_list = [] for prediction_map, n_lines in zip(predictions, args.n_lines): prediction = prediction_map["full"] for fold in actual_folds: pred = prediction[str(fold)] data_list.append({ "prefix": "L{} - Fold{}".format(n_lines, fold), "results": pred['eval'], "gt_files": prediction_map['gt_txts'], "gts": prediction_map['gt'], "preds": pred['data'] }) for voter in ['sequence_voter', 'confidence_voter_default_ctc']: pred = prediction[voter] data_list.append({ "prefix": "L{} - {}".format(n_lines, voter[:3]), "results": pred['eval'], "gt_files": prediction_map['gt_txts'], "gts": prediction_map['gt'], "preds": pred['data'] }) write_xlsx(args.xlsx_output, data_list)
def run_for_single_line(args): # lines/network/pretraining as base dir args.base_dir = os.path.join(args.base_dir, "all" if args.n_lines < 0 else str(args.n_lines)) pretrain_prefix = "scratch" if args.weights and len(args.weights) > 0: pretrain_prefix = ",".join([split_all_ext(os.path.basename(path))[0] for path in args.weights]) args.base_dir = os.path.join(args.base_dir, args.network, pretrain_prefix) if not os.path.exists(args.base_dir): os.makedirs(args.base_dir) tmp_dir = os.path.join(args.base_dir, "tmp") if not os.path.exists(tmp_dir): os.makedirs(tmp_dir) best_models_dir = os.path.join(args.base_dir, "models") if not os.path.exists(best_models_dir): os.makedirs(best_models_dir) prediction_dir = os.path.join(args.base_dir, "predictions") if not os.path.exists(prediction_dir): os.makedirs(prediction_dir) # select number of files files = args.train_files if args.n_lines > 0: all_files = glob_all(args.train_files) files = random.sample(all_files, args.n_lines) # run the cross-fold-training setattr(args, "max_parallel_models", args.max_parallel_models) setattr(args, "best_models_dir", best_models_dir) setattr(args, "temporary_dir", tmp_dir) setattr(args, "keep_temporary_files", False) setattr(args, "files", files) setattr(args, "best_model_label", "{id}") if not args.skip_train: cross_fold_train.main(args) dump_file = os.path.join(tmp_dir, "prediction.pkl") # run the prediction if not args.skip_eval: # locate the eval script (must be in the same dir as "this") predict_script_path = os.path.join(this_absdir, "experiment_eval.py") if len(args.single_fold) > 0: models = [os.path.join(best_models_dir, "{}.ckpt.json".format(sf)) for sf in args.single_fold] for m in models: if not os.path.exists(m): raise Exception("Expected model at '{}', but file does not exist".format(m)) else: models = [os.path.join(best_models_dir, d) for d in sorted(os.listdir(best_models_dir)) if d.endswith("json")] if len(models) != args.n_folds: raise Exception("Expected {} models, one for each fold respectively, but only {} models were found".format( args.n_folds, len(models) )) for line in run(prefix_run_command([ "python3", "-u", predict_script_path, "-j", str(args.num_threads), "--batch_size", str(args.batch_size), "--dump", dump_file, "--eval_imgs"] + args.eval_files + [ ] + (["--verbose"] if args.verbose else []) + [ "--checkpoint"] + models + [ ], args.run, {"threads": args.num_threads}), verbose=args.verbose): # Print the output of the thread if args.verbose: print(line) import pickle with open(dump_file, 'rb') as f: prediction = pickle.load(f) return prediction