def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", nargs="+", required=True, help="List of all image files with corresponding gt.txt files") parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") args = parser.parse_args() print("Resolving files") image_files = glob_all(args.files) gt_files = [split_all_ext(p)[0] + ".gt.txt" for p in image_files] ds = create_dataset( args.dataset, DataSetMode.TRAIN, images=image_files, texts=gt_files, non_existing_as_empty=True) print("Loading {} files".format(len(image_files))) ds.load_samples(processes=1, progress_bar=True) images, texts = ds.train_samples(skip_empty=True) statistics = { "n_lines": len(images), "chars": [len(c) for c in texts], "widths": [img.shape[1] / img.shape[0] * args.line_height + 2 * args.pad for img in images if img is not None and img.shape[0] > 0 and img.shape[1] > 0], "total_line_width": 0, "char_counts": {}, } for image, text in zip(images, texts): for c in text: if c in statistics["char_counts"]: statistics["char_counts"][c] += 1 else: statistics["char_counts"][c] = 1 statistics["av_line_width"] = np.average(statistics["widths"]) statistics["max_line_width"] = np.max(statistics["widths"]) statistics["min_line_width"] = np.min(statistics["widths"]) statistics["total_line_width"] = np.sum(statistics["widths"]) statistics["av_chars"] = np.average(statistics["chars"]) statistics["max_chars"] = np.max(statistics["chars"]) statistics["min_chars"] = np.min(statistics["chars"]) statistics["total_chars"] = np.sum(statistics["chars"]) statistics["av_px_per_char"] = statistics["av_line_width"] / statistics["av_chars"] statistics["codec_size"] = len(statistics["char_counts"]) del statistics["chars"] del statistics["widths"] print(statistics)
def test_raw_dataset_prediction(self): args = PredictionAttrs() predictor = Predictor(checkpoint=args.checkpoint[0]) data = create_dataset( DataSetType.FILE, DataSetMode.PREDICT, images=args.files, ) for prediction, sample in predictor.predict_dataset(data): pass
def single_evaluation(predicted_sentences): if len(predicted_sentences) != len(dataset): raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did " "not succeed".format(len(dataset), len(predicted_sentences))) pred_data_set = create_dataset( DataSetType.RAW, DataSetMode.EVAL, texts=predicted_sentences) r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes) return r
def main(): parser = argparse.ArgumentParser() parser.add_argument("--eval_imgs", type=str, nargs="+", required=True, help="The evaluation files") parser.add_argument("--eval_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--checkpoint", type=str, nargs="+", default=[], help="Path to the checkpoint without file extension") parser.add_argument("-j", "--processes", type=int, default=1, help="Number of processes to use") parser.add_argument("--verbose", action="store_true", help="Print additional information") parser.add_argument("--voter", type=str, nargs="+", default=["sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc"], help="The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), " "confidence_voter_fuzzy_ctc, sequence_voter") parser.add_argument("--batch_size", type=int, default=10, help="The batch size for prediction") parser.add_argument("--dump", type=str, help="Dump the output as serialized pickle object") parser.add_argument("--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() # allow user to specify json file for model definition, but remove the file extension # for further processing args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint] # load files gt_images = sorted(glob_all(args.eval_imgs)) gt_txts = [split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs))] dataset = create_dataset( args.eval_dataset, DataSetMode.TRAIN, images=gt_images, texts=gt_txts, skip_invalid=not args.no_skip_invalid_gt ) print("Found {} files in the dataset".format(len(dataset))) if len(dataset) == 0: raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files)) # predict for all models n_models = len(args.checkpoint) predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes) do_prediction = predictor.predict_dataset(dataset, progress_bar=True) voters = [] all_voter_sentences = [] all_prediction_sentences = [[] for _ in range(n_models)] for voter in args.voter: # create voter voter_params = VoterParams() voter_params.type = VoterParams.Type.Value(voter.upper()) voters.append(voter_from_proto(voter_params)) all_voter_sentences.append([]) for prediction, sample in do_prediction: for sent, p in zip(all_prediction_sentences, prediction): sent.append(p.sentence) # vote results for voter, voter_sentences in zip(voters, all_voter_sentences): voter_sentences.append(voter.vote_prediction_result(prediction).sentence) # evaluation text_preproc = text_processor_from_proto(predictor.predictors[0].model_params.text_preprocessor) evaluator = Evaluator(text_preprocessor=text_preproc) evaluator.preload_gt(gt_dataset=dataset, progress_bar=True) def single_evaluation(predicted_sentences): if len(predicted_sentences) != len(dataset): raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did " "not succeed".format(len(dataset), len(predicted_sentences))) pred_data_set = create_dataset( DataSetType.RAW, DataSetMode.EVAL, texts=predicted_sentences) r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes) return r full_evaluation = {} for id, data in [(str(i), sent) for i, sent in enumerate(all_prediction_sentences)] + list(zip(args.voter, all_voter_sentences)): full_evaluation[id] = {"eval": single_evaluation(data), "data": data} if args.verbose: print(full_evaluation) if args.dump: import pickle with open(args.dump, 'wb') as f: pickle.dump({"full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples()}, f)
def main(): parser = ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True, help="The checkpoint used to resume") # validation files parser.add_argument("--validation", type=str, nargs="+", help="Validation line files used for early stopping") parser.add_argument( "--validation_text_files", nargs="+", default=None, help= "Optional list of validation GT files if they are in other directory") parser.add_argument( "--validation_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) # input files parser.add_argument( "--files", nargs="+", help= "List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location" ) parser.add_argument( "--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument( "--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)" ) parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument( "--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension( args.validation_dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: gt_txt_files = [ split_all_ext(f)[0] + args.gt_extension for f in input_image_files ] else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name( input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of file: {} and {}".format( img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception( "Some image are occurring more than once in the data set.") dataset = create_dataset(args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [ split_all_ext(f)[0] + args.validation_extension for f in validation_image_files ] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name( validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext( os.path.basename(gt))[0]: raise Exception( "Expected identical basenames of validation file: {} and {}" .format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception( "Some validation images are occurring more than once in the data set." ) validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format( len(validation_dataset))) else: validation_dataset = None print("Resuming training") with open(args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) trainer = Trainer(checkpoint_params, dataset, validation_dataset=validation_dataset, weights=args.checkpoint) trainer.train(progress_bar=True)
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--files", nargs="+", required=True, help="List of all image files with corresponding gt.txt files") parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") args = parser.parse_args() print("Resolving files") image_files = glob_all(args.files) gt_files = [split_all_ext(p)[0] + ".gt.txt" for p in image_files] ds = create_dataset(args.dataset, DataSetMode.TRAIN, images=image_files, texts=gt_files, non_existing_as_empty=True) print("Loading {} files".format(len(image_files))) ds.load_samples(processes=1, progress_bar=True) images, texts = ds.train_samples(skip_empty=True) statistics = { "n_lines": len(images), "chars": [len(c) for c in texts], "widths": [ img.shape[1] / img.shape[0] * args.line_height + 2 * args.pad for img in images if img is not None and img.shape[0] > 0 and img.shape[1] > 0 ], "total_line_width": 0, "char_counts": {}, } for image, text in zip(images, texts): for c in text: if c in statistics["char_counts"]: statistics["char_counts"][c] += 1 else: statistics["char_counts"][c] = 1 statistics["av_line_width"] = np.average(statistics["widths"]) statistics["max_line_width"] = np.max(statistics["widths"]) statistics["min_line_width"] = np.min(statistics["widths"]) statistics["total_line_width"] = np.sum(statistics["widths"]) statistics["av_chars"] = np.average(statistics["chars"]) statistics["max_chars"] = np.max(statistics["chars"]) statistics["min_chars"] = np.min(statistics["chars"]) statistics["total_chars"] = np.sum(statistics["chars"]) statistics["av_px_per_char"] = statistics["av_line_width"] / statistics[ "av_chars"] statistics["codec_size"] = len(statistics["char_counts"]) del statistics["chars"] del statistics["widths"] print(statistics)
def main(): parser = ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True, help="The checkpoint used to resume") # validation files parser.add_argument("--validation", type=str, nargs="+", help="Validation line files used for early stopping") parser.add_argument("--validation_text_files", nargs="+", default=None, help="Optional list of validation GT files if they are in other directory") parser.add_argument("--validation_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)") parser.add_argument("--validation_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) # input files parser.add_argument("--files", nargs="+", help="List all image files that shall be processed. Ground truth fils with the same " "base name but with '.gt.txt' as extension are required at the same location") parser.add_argument("--text_files", nargs="+", default=None, help="Optional list of GT files if they are in other directory") parser.add_argument("--gt_extension", default=None, help="Default extension of the gt files (expected to exist in same dir)") parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() if args.gt_extension is None: args.gt_extension = DataSetType.gt_extension(args.dataset) if args.validation_extension is None: args.validation_extension = DataSetType.gt_extension(args.validation_dataset) # Training dataset print("Resolving input files") input_image_files = sorted(glob_all(args.files)) if not args.text_files: gt_txt_files = [split_all_ext(f)[0] + args.gt_extension for f in input_image_files] else: gt_txt_files = sorted(glob_all(args.text_files)) input_image_files, gt_txt_files = keep_files_with_same_file_name(input_image_files, gt_txt_files) for img, gt in zip(input_image_files, gt_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of file: {} and {}".format(img, gt)) if len(set(gt_txt_files)) != len(gt_txt_files): raise Exception("Some image are occurring more than once in the data set.") dataset = create_dataset( args.dataset, DataSetMode.TRAIN, images=input_image_files, texts=gt_txt_files, skip_invalid=not args.no_skip_invalid_gt ) print("Found {} files in the dataset".format(len(dataset))) # Validation dataset if args.validation: print("Resolving validation files") validation_image_files = glob_all(args.validation) if not args.validation_text_files: val_txt_files = [split_all_ext(f)[0] + args.validation_extension for f in validation_image_files] else: val_txt_files = sorted(glob_all(args.validation_text_files)) validation_image_files, val_txt_files = keep_files_with_same_file_name(validation_image_files, val_txt_files) for img, gt in zip(validation_image_files, val_txt_files): if split_all_ext(os.path.basename(img))[0] != split_all_ext(os.path.basename(gt))[0]: raise Exception("Expected identical basenames of validation file: {} and {}".format(img, gt)) if len(set(val_txt_files)) != len(val_txt_files): raise Exception("Some validation images are occurring more than once in the data set.") validation_dataset = create_dataset( args.validation_dataset, DataSetMode.TRAIN, images=validation_image_files, texts=val_txt_files, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the validation dataset".format(len(validation_dataset))) else: validation_dataset = None print("Resuming training") with open(args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) trainer = Trainer(checkpoint_params, dataset, validation_dataset=validation_dataset, weights=args.checkpoint) trainer.train(progress_bar=True)