def evaluate_books(self, books, models, mode="auto", sample=-1): if type(books) == str: books = [books] if type(models) == str: models = [models] results = {} if mode == "auto": with h5py.File(self.cachefile, 'r', libver='latest', swmr=True) as cache: for b in books: for p in cache[b]: for s in cache[b][p]: if "text" in cache[b][p][s].attrs: mode = "eval" break if mode != "auto": break if mode != "auto": break if mode == "auto": mode = "conf" if mode == "conf": dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books) else: dset = Nash5DataSet(DataSetMode.EVAL, self.cachefile, books) if 0 < sample < len(dset): delsamples = random.sample(dset._samples, len(dset) - sample) for s in delsamples: dset._samples.remove(s) if mode == "conf": for model in models: if isinstance(model, str): model = [model] predictor = MultiPredictor(checkpoints=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1) voter_params = VoterParams() voter_params.type = VoterParams.Type.Value("confidence_voter_default_ctc".upper()) voter = voter_from_proto(voter_params) do_prediction = predictor.predict_dataset(dset, progress_bar=True) avg_sentence_confidence = 0 n_predictions = 0 for result, sample in do_prediction: n_predictions += 1 prediction = voter.vote_prediction_result(result) avg_sentence_confidence += prediction.avg_char_probability results["/".join(model)] = avg_sentence_confidence / n_predictions else: for model in models: if isinstance(model, str): model = [model] predictor = MultiPredictor(checkpoint=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1, with_gt=True) out_gen = predictor.predict_dataset(dset, progress_bar=True, apply_preproc=False) result = Evaluator.evaluate_single_list(map(Evaluator.evaluate_single_args, map(lambda d: tuple([''.join(d[0].ground_truth), ''.join(d[0].chars)]), out_gen))) results["/".join(model)] = 1 - result["avg_ler"] return results
def main(): parser = argparse.ArgumentParser() parser.add_argument("--eval_imgs", type=str, nargs="+", required=True, help="The evaluation files") parser.add_argument("--checkpoint", type=str, nargs="+", default=[], help="Path to the checkpoint without file extension") parser.add_argument("-j", "--processes", type=int, default=1, help="Number of processes to use") parser.add_argument("--verbose", action="store_true", help="Print additional information") parser.add_argument( "--voter", type=str, nargs="+", default=[ "sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc" ], help= "The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), " "confidence_voter_fuzzy_ctc, sequence_voter") parser.add_argument("--batch_size", type=int, default=10, help="The batch size for prediction") parser.add_argument("--dump", type=str, help="Dump the output as serialized pickle object") parser.add_argument( "--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() # allow user to specify json file for model definition, but remove the file extension # for further processing args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint] # load files gt_images = sorted(glob_all(args.eval_imgs)) gt_txts = [ split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs)) ] dataset = FileDataSet(images=gt_images, texts=gt_txts, skip_invalid=not args.no_skip_invalid_gt) print("Found {} files in the dataset".format(len(dataset))) if len(dataset) == 0: raise Exception( "Empty dataset provided. Check your files argument (got {})!". format(args.files)) # predict for all models n_models = len(args.checkpoint) predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes) do_prediction = predictor.predict_dataset(dataset, progress_bar=True) voters = [] all_voter_sentences = [] all_prediction_sentences = [[] for _ in range(n_models)] for voter in args.voter: # create voter voter_params = VoterParams() voter_params.type = VoterParams.Type.Value(voter.upper()) voters.append(voter_from_proto(voter_params)) all_voter_sentences.append([]) for prediction, sample in do_prediction: for sent, p in zip(all_prediction_sentences, prediction): sent.append(p.sentence) # vote results for voter, voter_sentences in zip(voters, all_voter_sentences): voter_sentences.append( voter.vote_prediction_result(prediction).sentence) # evaluation text_preproc = text_processor_from_proto( predictor.predictors[0].model_params.text_preprocessor) evaluator = Evaluator(text_preprocessor=text_preproc) evaluator.preload_gt(gt_dataset=dataset, progress_bar=True) def single_evaluation(predicted_sentences): if len(predicted_sentences) != len(dataset): raise Exception( "Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did " "not succeed".format(len(dataset), len(predicted_sentences))) pred_data_set = RawDataSet(texts=predicted_sentences) r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes) return r full_evaluation = {} for id, data in [ (str(i), sent) for i, sent in enumerate(all_prediction_sentences) ] + list(zip(args.voter, all_voter_sentences)): full_evaluation[id] = {"eval": single_evaluation(data), "data": data} if args.verbose: print(full_evaluation) if args.dump: import pickle with open(args.dump, 'wb') as f: pickle.dump( { "full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples() }, f)
def train(self, progress_bar=False): checkpoint_params = self.checkpoint_params train_start_time = time.time() + self.checkpoint_params.total_time self.dataset.load_samples(processes=1, progress_bar=progress_bar) datas, txts = self.dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt) if len(datas) == 0: raise Exception("Empty dataset is not allowed. Check if the data is at the correct location") if self.validation_dataset: self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar) validation_datas, validation_txts = self.validation_dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt) if len(validation_datas) == 0: raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.") else: validation_datas, validation_txts = [], [] # preprocessing steps texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar) datas = self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar) validation_txts = self.txt_preproc.apply(validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar) validation_datas = self.data_preproc.apply(validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar) # compute the codec codec = self.codec if self.codec else Codec.from_texts(texts, whitelist=self.codec_whitelist) # data augmentation on preprocessed data if self.data_augmenter: datas, texts = self.data_augmenter.augment_datas(datas, texts, n_augmentations=self.n_augmentations, processes=checkpoint_params.processes, progress_bar=progress_bar) # TODO: validation data augmentation # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0, # processes=checkpoint_params.processes, progress_bar=progress_bar) # create backend network_params = checkpoint_params.model.network network_params.features = checkpoint_params.model.line_height network_params.classes = len(codec) if self.weights: # if we load the weights, take care of codec changes as-well with open(self.weights + '.json', 'r') as f: restore_checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) restore_model_params = restore_checkpoint_params.model # checks if checkpoint_params.model.line_height != network_params.features: raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format( network_params.features, checkpoint_params.model.line_height )) # create codec of the same type restore_codec = codec.__class__(restore_model_params.codec.charset) # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one codec_changes = restore_codec.align(codec) codec = restore_codec print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1]))) # The actual weight/bias matrix will be changed after loading the old weights else: codec_changes = None # store the new codec checkpoint_params.model.codec.charset[:] = codec.charset print("CODEC: {}".format(codec.charset)) # compute the labels with (new/current) codec labels = [codec.encode(txt) for txt in texts] backend = create_backend_from_proto(network_params, weights=self.weights, ) backend.set_train_data(datas, labels) backend.set_prediction_data(validation_datas) if codec_changes: backend.realign_model_labels(*codec_changes) backend.prepare(train=True) loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) early_stopping_enabled = self.validation_dataset is not None \ and checkpoint_params.early_stopping_frequency > 0 \ and checkpoint_params.early_stopping_nbest > 1 early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc, backend=backend) # Start the actual training # ==================================================================================== iter = checkpoint_params.iter # helper function to write a checkpoint def make_checkpoint(base_dir, prefix, version=None): if version: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version))) else: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1))) print("Storing checkpoint to '{}'".format(checkpoint_path)) backend.save_checkpoint(checkpoint_path) checkpoint_params.iter = iter checkpoint_params.loss_stats[:] = loss_stats.values checkpoint_params.ler_stats[:] = ler_stats.values checkpoint_params.dt_stats[:] = dt_stats.values checkpoint_params.total_time = time.time() - train_start_time checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter with open(checkpoint_path + ".json", 'w') as f: f.write(json_format.MessageToJson(checkpoint_params)) return checkpoint_path try: last_checkpoint = None # Training loop, can be interrupted by early stopping for iter in range(iter, checkpoint_params.max_iters): checkpoint_params.iter = iter iter_start_time = time.time() result = backend.train_step(checkpoint_params.batch_size) if not np.isfinite(result['loss']): print("Error: Loss is not finite! Trying to restart from last checkpoint.") if not last_checkpoint: raise Exception("No checkpoint written yet. Training must be stopped.") else: # reload also non trainable weights, such as solver-specific variables backend.load_checkpoint_weights(last_checkpoint, restore_only_trainable=False) continue loss_stats.push(result['loss']) ler_stats.push(result['ler']) dt_stats.push(time.time() - iter_start_time) if iter % checkpoint_params.display == 0: pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0]))) gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0]))) print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) print(" PRED: '{}'".format(pred_sentence)) print(" TRUE: '{}'".format(gt_sentence)) if (iter + 1) % checkpoint_params.checkpoint_frequency == 0: last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix) if early_stopping_enabled and (iter + 1) % checkpoint_params.early_stopping_frequency == 0: print("Checking early stopping model") out = early_stopping_predictor.predict_raw(validation_datas, batch_size=checkpoint_params.batch_size, progress_bar=progress_bar, apply_preproc=False) pred_texts = [d.sentence for d in out] result = Evaluator.evaluate(gt_data=validation_txts, pred_data=pred_texts, progress_bar=progress_bar) accuracy = 1 - result["avg_ler"] if accuracy > early_stopping_best_accuracy: early_stopping_best_accuracy = accuracy early_stopping_best_cur_nbest = 1 early_stopping_best_at_iter = iter + 1 # overwrite as best model last_checkpoint = make_checkpoint( checkpoint_params.early_stopping_best_model_output_dir, prefix="", version=checkpoint_params.early_stopping_best_model_prefix, ) print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy)) else: early_stopping_best_cur_nbest += 1 print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})". format(early_stopping_best_accuracy, early_stopping_best_at_iter, checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest)) if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest: print("Early stopping now.") break except KeyboardInterrupt as e: print("Storing interrupted checkpoint") make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix, "interrupted") raise e print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
def main(): parser = ArgumentParser() parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument( "--gt", nargs="+", required=True, help="Ground truth files (.gt.txt extension). " "Optionally, you can pass a single json file defining all parameters.") parser.add_argument( "--pred", nargs="+", default=None, help= "Prediction files if provided. Else files with .pred.txt are expected at the same " "location as the gt.") parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--pred_ext", type=str, default=".pred.txt", help="Extension of the predicted text files") parser.add_argument( "--n_confusions", type=int, default=10, help= "Only print n most common confusions. Defaults to 10, use -1 for all.") parser.add_argument( "--n_worst_lines", type=int, default=0, help="Print the n worst recognized text lines with its error") parser.add_argument( "--xlsx_output", type=str, help="Optionally write a xlsx file with the evaluation results") parser.add_argument("--num_threads", type=int, default=1, help="Number of threads to use for evaluation") parser.add_argument( "--non_existing_file_handling_mode", type=str, default="error", help= "How to handle non existing .pred.txt files. Possible modes: skip, empty, error. " "'Skip' will simply skip the evaluation of that file (not counting it to errors). " "'Empty' will handle this file as would it be empty (fully checking for errors)." "'Error' will throw an exception if a file is not existing. This is the default behaviour." ) parser.add_argument("--skip_empty_gt", action="store_true", default=False, help="Ignore lines of the gt that are empty.") parser.add_argument("--no_progress_bars", action="store_true", help="Do not show any progress bars") parser.add_argument( "--checkpoint", type=str, default=None, help= "Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)" ) # page xml specific args parser.add_argument("--pagexml_gt_text_index", default=0) parser.add_argument("--pagexml_pred_text_index", default=1) args = parser.parse_args() # check if loading a json file if len(args.gt) == 1 and args.gt[0].endswith("json"): with open(args.gt[0], 'r') as f: json_args = json.load(f) for key, value in json_args.items(): setattr(args, key, value) print("Resolving files") gt_files = sorted(glob_all(args.gt)) if args.pred: pred_files = sorted(glob_all(args.pred)) else: pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files] args.pred_dataset = args.dataset if args.non_existing_file_handling_mode.lower() == "skip": non_existing_pred = [p for p in pred_files if not os.path.exists(p)] for f in non_existing_pred: idx = pred_files.index(f) del pred_files[idx] del gt_files[idx] text_preproc = None if args.checkpoint: with open( args.checkpoint if args.checkpoint.endswith(".json") else args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) text_preproc = text_processor_from_proto( checkpoint_params.model.text_preprocessor) non_existing_as_empty = args.non_existing_file_handling_mode.lower( ) != "error " gt_data_set = create_dataset( args.dataset, DataSetMode.EVAL, texts=gt_files, non_existing_as_empty=non_existing_as_empty, args={'text_index': args.pagexml_gt_text_index}, ) pred_data_set = create_dataset( args.pred_dataset, DataSetMode.EVAL, texts=pred_files, non_existing_as_empty=non_existing_as_empty, args={'text_index': args.pagexml_pred_text_index}, ) evaluator = Evaluator(text_preprocessor=text_preproc, skip_empty_gt=args.skip_empty_gt) r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads, progress_bar=not args.no_progress_bars) # TODO: More output print("Evaluation result") print("=================") print("") print( "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)" .format(r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"])) # sort descending print_confusions(r, args.n_confusions) print_worst_lines(r, gt_data_set.samples(), args.n_worst_lines) if args.xlsx_output: write_xlsx(args.xlsx_output, [{ "prefix": "evaluation", "results": r, "gt_files": gt_files, }])
def evaluate_books(self, books, models, rtl=False, mode="auto", sample=-1): if type(books) == str: books = [books] if type(models) == str: models = [models] results = {} if mode == "auto": with h5py.File(self.cachefile, 'r', libver='latest', swmr=True) as cache: for b in books: for p in cache[b]: for s in cache[b][p]: if "text" in cache[b][p][s].attrs: mode = "eval" break if mode != "auto": break if mode != "auto": break if mode == "auto": mode = "conf" if mode == "conf": dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books) else: dset = Nash5DataSet(DataSetMode.TRAIN, self.cachefile, books) dset.mode = DataSetMode.PREDICT # otherwise results are randomised if 0 < sample < len(dset): delsamples = random.sample(dset._samples, len(dset) - sample) for s in delsamples: dset._samples.remove(s) if mode == "conf": #dset = dset.to_raw_input_dataset(processes=1, progress_bar=True) for model in models: if isinstance(model, str): model = [model] predictor = MultiPredictor(checkpoints=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1) voter_params = VoterParams() voter_params.type = VoterParams.Type.Value( "confidence_voter_default_ctc".upper()) voter = voter_from_proto(voter_params) do_prediction = predictor.predict_dataset(dset, progress_bar=True) avg_sentence_confidence = 0 n_predictions = 0 for result, sample in do_prediction: n_predictions += 1 prediction = voter.vote_prediction_result(result) avg_sentence_confidence += prediction.avg_char_probability results["/".join( model)] = avg_sentence_confidence / n_predictions else: for model in models: if isinstance(model, str): model = [model] predictor = MultiPredictor(checkpoints=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1) voter_params = VoterParams() voter_params.type = VoterParams.Type.Value( "confidence_voter_default_ctc".upper()) voter = voter_from_proto(voter_params) out_gen = predictor.predict_dataset(dset, progress_bar=True) preproc = self.bidi_preproc if rtl else self.txt_preproc pred_dset = RawDataSet(DataSetMode.EVAL, texts=preproc.apply([ voter.vote_prediction_result( d[0]).sentence for d in out_gen ])) evaluator = Evaluator(text_preprocessor=NoopTextProcessor(), skip_empty_gt=False) r = evaluator.run(gt_dataset=dset, pred_dataset=pred_dset, processes=1, progress_bar=True) results["/".join(model)] = 1 - r["avg_ler"] return results
def _run_train(self, train_net, test_net, codec, train_start_time, progress_bar): checkpoint_params = self.checkpoint_params validation_dataset = test_net.input_dataset iters_per_epoch = max( 1, int(train_net.input_dataset.epoch_size() / checkpoint_params.batch_size)) loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) display = checkpoint_params.display display_epochs = display <= 1 if display <= 0: display = 0 # to not display anything elif display_epochs: display = max(1, int(display * iters_per_epoch)) # relative to epochs else: display = max(1, int(display)) # iterations checkpoint_frequency = checkpoint_params.checkpoint_frequency early_stopping_frequency = checkpoint_params.early_stopping_frequency if early_stopping_frequency < 0: # set early stopping frequency to half epoch early_stopping_frequency = int(0.5 * iters_per_epoch) elif 0 < early_stopping_frequency <= 1: early_stopping_frequency = int( early_stopping_frequency * iters_per_epoch) # relative to epochs else: early_stopping_frequency = int(early_stopping_frequency) early_stopping_frequency = max(1, early_stopping_frequency) if checkpoint_frequency < 0: checkpoint_frequency = early_stopping_frequency elif 0 < checkpoint_frequency <= 1: checkpoint_frequency = int(checkpoint_frequency * iters_per_epoch) # relative to epochs else: checkpoint_frequency = int(checkpoint_frequency) early_stopping_enabled = self.validation_dataset is not None \ and checkpoint_params.early_stopping_frequency > 0 \ and checkpoint_params.early_stopping_nbest > 1 early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc, network=test_net) # Start the actual training # ==================================================================================== iter = checkpoint_params.iter # helper function to write a checkpoint def make_checkpoint(base_dir, prefix, version=None): if version: checkpoint_path = os.path.abspath( os.path.join(base_dir, "{}{}.ckpt".format(prefix, version))) else: checkpoint_path = os.path.abspath( os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1))) print("Storing checkpoint to '{}'".format(checkpoint_path)) train_net.save_checkpoint(checkpoint_path) checkpoint_params.version = Checkpoint.VERSION checkpoint_params.iter = iter checkpoint_params.loss_stats[:] = loss_stats.values checkpoint_params.ler_stats[:] = ler_stats.values checkpoint_params.dt_stats[:] = dt_stats.values checkpoint_params.total_time = time.time() - train_start_time checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter with open(checkpoint_path + ".json", 'w') as f: f.write(json_format.MessageToJson(checkpoint_params)) return checkpoint_path try: last_checkpoint = None n_infinite_losses = 0 n_max_infinite_losses = 5 # Training loop, can be interrupted by early stopping for iter in range(iter, checkpoint_params.max_iters): checkpoint_params.iter = iter iter_start_time = time.time() result = train_net.train_step() if not np.isfinite(result['loss']): n_infinite_losses += 1 if n_max_infinite_losses == n_infinite_losses: print( "Error: Loss is not finite! Trying to restart from last checkpoint." ) if not last_checkpoint: raise Exception( "No checkpoint written yet. Training must be stopped." ) else: # reload also non trainable weights, such as solver-specific variables train_net.load_weights( last_checkpoint, restore_only_trainable=False) continue else: continue n_infinite_losses = 0 loss_stats.push(result['loss']) ler_stats.push(result['ler']) dt_stats.push(time.time() - iter_start_time) if display > 0 and iter % display == 0: # apply postprocessing to display the true output pred_sentence = self.txt_postproc.apply("".join( codec.decode(result["decoded"][0]))) gt_sentence = self.txt_postproc.apply("".join( codec.decode(result["gt"][0]))) if display_epochs: print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s". format(iter / iters_per_epoch, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) else: print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s". format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) # Insert utf-8 ltr/rtl direction marks for bidi support lr = "\u202A\u202B" print(" PRED: '{}{}{}'".format( lr[bidi.get_base_level(pred_sentence)], pred_sentence, "\u202C")) print(" TRUE: '{}{}{}'".format( lr[bidi.get_base_level(gt_sentence)], gt_sentence, "\u202C")) if checkpoint_frequency > 0 and ( iter + 1) % checkpoint_frequency == 0: last_checkpoint = make_checkpoint( checkpoint_params.output_dir, checkpoint_params.output_model_prefix) if early_stopping_enabled and ( iter + 1) % early_stopping_frequency == 0: print("Checking early stopping model") out_gen = early_stopping_predictor.predict_input_dataset( validation_dataset, progress_bar=progress_bar) result = Evaluator.evaluate_single_list( map( Evaluator.evaluate_single_args, map( lambda d: tuple( self.txt_preproc.apply([ ''.join(d.ground_truth), d.sentence ])), out_gen))) accuracy = 1 - result["avg_ler"] if accuracy > early_stopping_best_accuracy: early_stopping_best_accuracy = accuracy early_stopping_best_cur_nbest = 1 early_stopping_best_at_iter = iter + 1 # overwrite as best model last_checkpoint = make_checkpoint( checkpoint_params. early_stopping_best_model_output_dir, prefix="", version=checkpoint_params. early_stopping_best_model_prefix, ) print( "Found better model with accuracy of {:%}".format( early_stopping_best_accuracy)) else: early_stopping_best_cur_nbest += 1 print( "No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})" .format( early_stopping_best_accuracy, early_stopping_best_at_iter, checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest)) if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest: print("Early stopping now.") break if accuracy >= 1: print( "Reached perfect score on validation set. Early stopping now." ) break except KeyboardInterrupt as e: print("Storing interrupted checkpoint") make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix, "interrupted") raise e print("Total time {}s for {} iterations.".format( time.time() - train_start_time, iter))
def main(): parser = ArgumentParser() parser.add_argument("--gt", nargs="+", required=True, help="Ground truth files (.gt.txt extension)") parser.add_argument( "--pred", nargs="+", default=None, help= "Prediction files if provided. Else files with .pred.txt are expected at the same " "location as the gt.") parser.add_argument("--pred_ext", type=str, default=".pred.txt", help="Extension of the predicted text files") parser.add_argument("--n_confusions", type=int, default=-1, help="Only print n most common confusions") parser.add_argument("--num_threads", type=int, default=1, help="Number of threads to use for evaluation") args = parser.parse_args() gt_files = sorted(glob_all(args.gt)) if args.pred: pred_files = sorted(glob_all(args.pred)) if len(pred_files) != len(gt_files): raise Exception( "Mismatch in the number of gt and pred files: {} vs {}".format( len(gt_files), len(pred_files))) else: pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files] gt_data_set = FileDataSet(texts=gt_files) pred_data_set = FileDataSet(texts=pred_files) evaluator = Evaluator() r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads, progress_bar=True) # TODO: More output print("Evaluation result") print("=================") print("") print( "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)" .format(r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"])) # sort descending if args.n_confusions != 0 and r["total_sync_errs"] > 0: total_percent = 0 keys = sorted(r['confusion'].items(), key=lambda item: -item[1]) print("{:8s} {:8s} {:8s} {:10s}".format("GT", "PRED", "COUNT", "PERCENT")) for i, ((gt, pred), count) in enumerate(keys): gt_fmt = "{" + gt + "}" pred_fmt = "{" + pred + "}" if i == args.n_confusions: break percent = count * max(len(gt), len(pred)) / r["total_sync_errs"] print("{:8s} {:8s} {:8d} {:10.2%}".format(gt_fmt, pred_fmt, count, percent)) total_percent += percent print("The remaining but hidden errors make up {:.2%}".format( 1.0 - total_percent))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--eval_imgs", type=str, nargs="+", required=True, help="The evaluation files") parser.add_argument("--eval_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--checkpoint", type=str, nargs="+", default=[], help="Path to the checkpoint without file extension") parser.add_argument("-j", "--processes", type=int, default=1, help="Number of processes to use") parser.add_argument("--verbose", action="store_true", help="Print additional information") parser.add_argument("--voter", type=str, nargs="+", default=["sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc"], help="The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), " "confidence_voter_fuzzy_ctc, sequence_voter") parser.add_argument("--batch_size", type=int, default=10, help="The batch size for prediction") parser.add_argument("--dump", type=str, help="Dump the output as serialized pickle object") parser.add_argument("--no_skip_invalid_gt", action="store_true", help="Do no skip invalid gt, instead raise an exception.") args = parser.parse_args() # allow user to specify json file for model definition, but remove the file extension # for further processing args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint] # load files gt_images = sorted(glob_all(args.eval_imgs)) gt_txts = [split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs))] dataset = create_dataset( args.eval_dataset, DataSetMode.TRAIN, images=gt_images, texts=gt_txts, skip_invalid=not args.no_skip_invalid_gt ) print("Found {} files in the dataset".format(len(dataset))) if len(dataset) == 0: raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files)) # predict for all models n_models = len(args.checkpoint) predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes) do_prediction = predictor.predict_dataset(dataset, progress_bar=True) voters = [] all_voter_sentences = [] all_prediction_sentences = [[] for _ in range(n_models)] for voter in args.voter: # create voter voter_params = VoterParams() voter_params.type = VoterParams.Type.Value(voter.upper()) voters.append(voter_from_proto(voter_params)) all_voter_sentences.append([]) for prediction, sample in do_prediction: for sent, p in zip(all_prediction_sentences, prediction): sent.append(p.sentence) # vote results for voter, voter_sentences in zip(voters, all_voter_sentences): voter_sentences.append(voter.vote_prediction_result(prediction).sentence) # evaluation text_preproc = text_processor_from_proto(predictor.predictors[0].model_params.text_preprocessor) evaluator = Evaluator(text_preprocessor=text_preproc) evaluator.preload_gt(gt_dataset=dataset, progress_bar=True) def single_evaluation(predicted_sentences): if len(predicted_sentences) != len(dataset): raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did " "not succeed".format(len(dataset), len(predicted_sentences))) pred_data_set = create_dataset( DataSetType.RAW, DataSetMode.EVAL, texts=predicted_sentences) r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes) return r full_evaluation = {} for id, data in [(str(i), sent) for i, sent in enumerate(all_prediction_sentences)] + list(zip(args.voter, all_voter_sentences)): full_evaluation[id] = {"eval": single_evaluation(data), "data": data} if args.verbose: print(full_evaluation) if args.dump: import pickle with open(args.dump, 'wb') as f: pickle.dump({"full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples()}, f)
def main(): parser = ArgumentParser() parser.add_argument("--gt", nargs="+", required=True, help="Ground truth files (.gt.txt extension)") parser.add_argument( "--pred", nargs="+", default=None, help= "Prediction files if provided. Else files with .pred.txt are expected at the same " "location as the gt.") parser.add_argument("--pred_ext", type=str, default=".pred.txt", help="Extension of the predicted text files") parser.add_argument( "--n_confusions", type=int, default=10, help= "Only print n most common confusions. Defaults to 10, use -1 for all.") parser.add_argument( "--n_worst_lines", type=int, default=0, help="Print the n worst recognized text lines with its error") parser.add_argument( "--xlsx_output", type=str, help="Optionally write a xlsx file with the evaluation results") parser.add_argument("--num_threads", type=int, default=1, help="Number of threads to use for evaluation") parser.add_argument( "--non_existing_file_handling_mode", type=str, default="error", help= "How to handle non existing .pred.txt files. Possible modes: skip, empty, error. " "'Skip' will simply skip the evaluation of that file (not counting it to errors). " "'Empty' will handle this file as would it be empty (fully checking for errors)." "'Error' will throw an exception if a file is not existing. This is the default behaviour." ) parser.add_argument("--no_progress_bars", action="store_true", help="Do not show any progress bars") parser.add_argument( "--checkpoint", type=str, default=None, help= "Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)" ) args = parser.parse_args() print("Resolving files") gt_files = sorted(glob_all(args.gt)) if args.pred: pred_files = sorted(glob_all(args.pred)) if len(pred_files) != len(gt_files): raise Exception( "Mismatch in the number of gt and pred files: {} vs {}".format( len(gt_files), len(pred_files))) else: pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files] if args.non_existing_file_handling_mode.lower() == "skip": non_existing_pred = [p for p in pred_files if not os.path.exists(p)] for f in non_existing_pred: idx = pred_files.index(f) del pred_files[idx] del gt_files[idx] text_preproc = None if args.checkpoint: with open( args.checkpoint if args.checkpoint.endswith(".json") else args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) text_preproc = text_processor_from_proto( checkpoint_params.model.text_preprocessor) non_existing_as_empty = args.non_existing_file_handling_mode.lower( ) == "empty" gt_data_set = FileDataSet(texts=gt_files, non_existing_as_empty=non_existing_as_empty) pred_data_set = FileDataSet(texts=pred_files, non_existing_as_empty=non_existing_as_empty) evaluator = Evaluator(text_preprocessor=text_preproc) r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads, progress_bar=not args.no_progress_bars) # TODO: More output print("Evaluation result") print("=================") print("") print( "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)" .format(r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"])) # sort descending print_confusions(r, args.n_confusions) print_worst_lines(r, gt_files, gt_data_set.text_samples(), pred_data_set.text_samples(), args.n_worst_lines) if args.xlsx_output: write_xlsx(args.xlsx_output, [{ "prefix": "evaluation", "results": r, "gt_files": gt_files, "gts": gt_data_set.text_samples(), "preds": pred_data_set.text_samples() }])
def main(): parser = ArgumentParser() parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--gt", nargs="+", required=True, help="Ground truth files (.gt.txt extension)") parser.add_argument("--pred", nargs="+", default=None, help="Prediction files if provided. Else files with .pred.txt are expected at the same " "location as the gt.") parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE) parser.add_argument("--pred_ext", type=str, default=".pred.txt", help="Extension of the predicted text files") parser.add_argument("--n_confusions", type=int, default=10, help="Only print n most common confusions. Defaults to 10, use -1 for all.") parser.add_argument("--n_worst_lines", type=int, default=0, help="Print the n worst recognized text lines with its error") parser.add_argument("--xlsx_output", type=str, help="Optionally write a xlsx file with the evaluation results") parser.add_argument("--num_threads", type=int, default=1, help="Number of threads to use for evaluation") parser.add_argument("--non_existing_file_handling_mode", type=str, default="error", help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. " "'Skip' will simply skip the evaluation of that file (not counting it to errors). " "'Empty' will handle this file as would it be empty (fully checking for errors)." "'Error' will throw an exception if a file is not existing. This is the default behaviour.") parser.add_argument("--no_progress_bars", action="store_true", help="Do not show any progress bars") parser.add_argument("--checkpoint", type=str, default=None, help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)") # page xml specific args parser.add_argument("--pagexml_gt_text_index", default=0) parser.add_argument("--pagexml_pred_text_index", default=1) args = parser.parse_args() print("Resolving files") gt_files = sorted(glob_all(args.gt)) if args.pred: pred_files = sorted(glob_all(args.pred)) else: pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files] args.pred_dataset = args.dataset if args.non_existing_file_handling_mode.lower() == "skip": non_existing_pred = [p for p in pred_files if not os.path.exists(p)] for f in non_existing_pred: idx = pred_files.index(f) del pred_files[idx] del gt_files[idx] text_preproc = None if args.checkpoint: with open(args.checkpoint if args.checkpoint.endswith(".json") else args.checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) text_preproc = text_processor_from_proto(checkpoint_params.model.text_preprocessor) non_existing_as_empty = args.non_existing_file_handling_mode.lower() != "error " gt_data_set = create_dataset( args.dataset, DataSetMode.EVAL, texts=gt_files, non_existing_as_empty=non_existing_as_empty, args={'text_index': args.pagexml_gt_text_index}, ) pred_data_set = create_dataset( args.pred_dataset, DataSetMode.EVAL, texts=pred_files, non_existing_as_empty=non_existing_as_empty, args={'text_index': args.pagexml_pred_text_index}, ) evaluator = Evaluator(text_preprocessor=text_preproc) r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads, progress_bar=not args.no_progress_bars) # TODO: More output print("Evaluation result") print("=================") print("") print("Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format( r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"])) # sort descending print_confusions(r, args.n_confusions) print_worst_lines(r, gt_data_set.samples(), pred_data_set.text_samples(), args.n_worst_lines) if args.xlsx_output: write_xlsx(args.xlsx_output, [{ "prefix": "evaluation", "results": r, "gt_files": gt_files, "gts": gt_data_set.text_samples(), "preds": pred_data_set.text_samples() }])
def _run_train(self, train_net, test_net, codec, train_start_time, progress_bar): checkpoint_params = self.checkpoint_params validation_dataset = test_net.input_dataset iters_per_epoch = max(1, int(len(train_net.input_dataset) / checkpoint_params.batch_size)) loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) display = checkpoint_params.display display_epochs = display <= 1 if display <= 0: display = 0 # to not display anything elif display_epochs: display = max(1, int(display * iters_per_epoch)) # relative to epochs else: display = max(1, int(display)) # iterations checkpoint_frequency = checkpoint_params.checkpoint_frequency early_stopping_frequency = checkpoint_params.early_stopping_frequency if early_stopping_frequency < 0: # set early stopping frequency to half epoch early_stopping_frequency = int(0.5 * iters_per_epoch) elif 0 < early_stopping_frequency <= 1: early_stopping_frequency = int(early_stopping_frequency * iters_per_epoch) # relative to epochs else: early_stopping_frequency = int(early_stopping_frequency) if checkpoint_frequency < 0: checkpoint_frequency = early_stopping_frequency elif 0 < checkpoint_frequency <= 1: checkpoint_frequency = int(checkpoint_frequency * iters_per_epoch) # relative to epochs else: checkpoint_frequency = int(checkpoint_frequency) early_stopping_enabled = self.validation_dataset is not None \ and checkpoint_params.early_stopping_frequency > 0 \ and checkpoint_params.early_stopping_nbest > 1 early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc, network=test_net) # Start the actual training # ==================================================================================== iter = checkpoint_params.iter # helper function to write a checkpoint def make_checkpoint(base_dir, prefix, version=None): if version: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version))) else: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1))) print("Storing checkpoint to '{}'".format(checkpoint_path)) train_net.save_checkpoint(checkpoint_path) checkpoint_params.version = Checkpoint.VERSION checkpoint_params.iter = iter checkpoint_params.loss_stats[:] = loss_stats.values checkpoint_params.ler_stats[:] = ler_stats.values checkpoint_params.dt_stats[:] = dt_stats.values checkpoint_params.total_time = time.time() - train_start_time checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter with open(checkpoint_path + ".json", 'w') as f: f.write(json_format.MessageToJson(checkpoint_params)) return checkpoint_path try: last_checkpoint = None n_infinite_losses = 0 n_max_infinite_losses = 5 # Training loop, can be interrupted by early stopping for iter in range(iter, checkpoint_params.max_iters): checkpoint_params.iter = iter iter_start_time = time.time() result = train_net.train_step() if not np.isfinite(result['loss']): n_infinite_losses += 1 if n_max_infinite_losses == n_infinite_losses: print("Error: Loss is not finite! Trying to restart from last checkpoint.") if not last_checkpoint: raise Exception("No checkpoint written yet. Training must be stopped.") else: # reload also non trainable weights, such as solver-specific variables train_net.load_weights(last_checkpoint, restore_only_trainable=False) continue else: continue n_infinite_losses = 0 loss_stats.push(result['loss']) ler_stats.push(result['ler']) dt_stats.push(time.time() - iter_start_time) if display > 0 and iter % display == 0: # apply postprocessing to display the true output pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0]))) gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0]))) if display_epochs: print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s".format( iter / iters_per_epoch, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) else: print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format( iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) # Insert utf-8 ltr/rtl direction marks for bidi support lr = "\u202A\u202B" print(" PRED: '{}{}{}'".format(lr[bidi.get_base_level(pred_sentence)], pred_sentence, "\u202C")) print(" TRUE: '{}{}{}'".format(lr[bidi.get_base_level(gt_sentence)], gt_sentence, "\u202C")) if checkpoint_frequency > 0 and (iter + 1) % checkpoint_frequency == 0: last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix) if early_stopping_enabled and (iter + 1) % early_stopping_frequency == 0: print("Checking early stopping model") out_gen = early_stopping_predictor.predict_input_dataset(validation_dataset, progress_bar=progress_bar) result = Evaluator.evaluate_single_list(map( Evaluator.evaluate_single_args, map(lambda d: tuple(self.txt_preproc.apply([''.join(d.ground_truth), d.sentence])), out_gen))) accuracy = 1 - result["avg_ler"] if accuracy > early_stopping_best_accuracy: early_stopping_best_accuracy = accuracy early_stopping_best_cur_nbest = 1 early_stopping_best_at_iter = iter + 1 # overwrite as best model last_checkpoint = make_checkpoint( checkpoint_params.early_stopping_best_model_output_dir, prefix="", version=checkpoint_params.early_stopping_best_model_prefix, ) print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy)) else: early_stopping_best_cur_nbest += 1 print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})". format(early_stopping_best_accuracy, early_stopping_best_at_iter, checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest)) if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest: print("Early stopping now.") break if accuracy >= 1: print("Reached perfect score on validation set. Early stopping now.") break except KeyboardInterrupt as e: print("Storing interrupted checkpoint") make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix, "interrupted") raise e print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))