Example #1
0
    def evaluate_books(self, books, models, mode="auto", sample=-1):
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        results = {}
        if mode == "auto":
            with h5py.File(self.cachefile, 'r', libver='latest', swmr=True) as cache:
                for b in books:
                    for p in cache[b]:
                        for s in cache[b][p]:
                            if "text" in cache[b][p][s].attrs:
                                mode = "eval"
                                break
                        if mode != "auto":
                            break
                    if mode != "auto":
                        break
            if mode == "auto":
                mode = "conf"

        if mode == "conf":
            dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)
        else:
            dset = Nash5DataSet(DataSetMode.EVAL, self.cachefile, books)

        if 0 < sample < len(dset):
            delsamples = random.sample(dset._samples, len(dset) - sample)
            for s in delsamples:
                dset._samples.remove(s)

        if mode == "conf":
            for model in models:
                if isinstance(model, str):
                    model = [model]
                predictor = MultiPredictor(checkpoints=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1)
                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value("confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)
                do_prediction = predictor.predict_dataset(dset, progress_bar=True)
                avg_sentence_confidence = 0
                n_predictions = 0
                for result, sample in do_prediction:
                    n_predictions += 1
                    prediction = voter.vote_prediction_result(result)
                    avg_sentence_confidence += prediction.avg_char_probability
                results["/".join(model)] = avg_sentence_confidence / n_predictions

        else:
            for model in models:
                if isinstance(model, str):
                    model = [model]
                predictor = MultiPredictor(checkpoint=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1, with_gt=True)
                out_gen = predictor.predict_dataset(dset, progress_bar=True, apply_preproc=False)
                result = Evaluator.evaluate_single_list(map(Evaluator.evaluate_single_args,
                            map(lambda d: tuple([''.join(d[0].ground_truth), ''.join(d[0].chars)]), out_gen)))
                results["/".join(model)] = 1 - result["avg_ler"]
        return results
Example #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs",
                        type=str,
                        nargs="+",
                        required=True,
                        help="The evaluation files")
    parser.add_argument("--checkpoint",
                        type=str,
                        nargs="+",
                        default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j",
                        "--processes",
                        type=int,
                        default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose",
                        action="store_true",
                        help="Print additional information")
    parser.add_argument(
        "--voter",
        type=str,
        nargs="+",
        default=[
            "sequence_voter", "confidence_voter_default_ctc",
            "confidence_voter_fuzzy_ctc"
        ],
        help=
        "The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
        "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size",
                        type=int,
                        default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump",
                        type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument(
        "--no_skip_invalid_gt",
        action="store_true",
        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp)
                       for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [
        split_all_ext(path)[0] + ".gt.txt"
        for path in sorted(glob_all(args.eval_imgs))
    ]

    dataset = FileDataSet(images=gt_images,
                          texts=gt_txts,
                          skip_invalid=not args.no_skip_invalid_gt)

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception(
            "Empty dataset provided. Check your files argument (got {})!".
            format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint,
                               batch_size=args.batch_size,
                               processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(
                voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(
        predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception(
                "Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = RawDataSet(texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set,
                          progress_bar=True,
                          processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [
        (str(i), sent) for i, sent in enumerate(all_prediction_sentences)
    ] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump(
                {
                    "full": full_evaluation,
                    "gt_txts": gt_txts,
                    "gt": dataset.text_samples()
                }, f)
Example #3
0
    def train(self, progress_bar=False):
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        self.dataset.load_samples(processes=1, progress_bar=progress_bar)
        datas, txts = self.dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
        if len(datas) == 0:
            raise Exception("Empty dataset is not allowed. Check if the data is at the correct location")

        if self.validation_dataset:
            self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar)
            validation_datas, validation_txts = self.validation_dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
            if len(validation_datas) == 0:
                raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.")
        else:
            validation_datas, validation_txts = [], []


        # preprocessing steps
        texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        datas = self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_txts = self.txt_preproc.apply(validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_datas = self.data_preproc.apply(validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar)

        # compute the codec
        codec = self.codec if self.codec else Codec.from_texts(texts, whitelist=self.codec_whitelist)

        # data augmentation on preprocessed data
        if self.data_augmenter:
            datas, texts = self.data_augmenter.augment_datas(datas, texts, n_augmentations=self.n_augmentations,
                                                             processes=checkpoint_params.processes, progress_bar=progress_bar)

            # TODO: validation data augmentation
            # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0,
            #                                                  processes=checkpoint_params.processes, progress_bar=progress_bar)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            with open(self.weights + '.json', 'r') as f:
                restore_checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
                restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                    network_params.features, checkpoint_params.model.line_height
                ))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        # compute the labels with (new/current) codec
        labels = [codec.encode(txt) for txt in texts]

        backend = create_backend_from_proto(network_params,
                                            weights=self.weights,
                                            )
        backend.set_train_data(datas, labels)
        backend.set_prediction_data(validation_datas)
        if codec_changes:
            backend.realign_model_labels(*codec_changes)
        backend.prepare(train=True)

        loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats)
        ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats)
        dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats)

        early_stopping_enabled = self.validation_dataset is not None \
                                 and checkpoint_params.early_stopping_frequency > 0 \
                                 and checkpoint_params.early_stopping_nbest > 1
        early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy
        early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest
        early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter

        early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc,
                                             backend=backend)

        # Start the actual training
        # ====================================================================================

        iter = checkpoint_params.iter

        # helper function to write a checkpoint
        def make_checkpoint(base_dir, prefix, version=None):
            if version:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version)))
            else:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1)))
            print("Storing checkpoint to '{}'".format(checkpoint_path))
            backend.save_checkpoint(checkpoint_path)
            checkpoint_params.iter = iter
            checkpoint_params.loss_stats[:] = loss_stats.values
            checkpoint_params.ler_stats[:] = ler_stats.values
            checkpoint_params.dt_stats[:] = dt_stats.values
            checkpoint_params.total_time = time.time() - train_start_time
            checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy
            checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest
            checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter

            with open(checkpoint_path + ".json", 'w') as f:
                f.write(json_format.MessageToJson(checkpoint_params))

            return checkpoint_path

        try:
            last_checkpoint = None

            # Training loop, can be interrupted by early stopping
            for iter in range(iter, checkpoint_params.max_iters):
                checkpoint_params.iter = iter

                iter_start_time = time.time()
                result = backend.train_step(checkpoint_params.batch_size)

                if not np.isfinite(result['loss']):
                    print("Error: Loss is not finite! Trying to restart from last checkpoint.")
                    if not last_checkpoint:
                        raise Exception("No checkpoint written yet. Training must be stopped.")
                    else:
                        # reload also non trainable weights, such as solver-specific variables
                        backend.load_checkpoint_weights(last_checkpoint, restore_only_trainable=False)
                        continue

                loss_stats.push(result['loss'])
                ler_stats.push(result['ler'])

                dt_stats.push(time.time() - iter_start_time)

                if iter % checkpoint_params.display == 0:
                    pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0])))
                    gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0])))
                    print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean()))
                    print(" PRED: '{}'".format(pred_sentence))
                    print(" TRUE: '{}'".format(gt_sentence))

                if (iter + 1) % checkpoint_params.checkpoint_frequency == 0:
                    last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix)

                if early_stopping_enabled and (iter + 1) % checkpoint_params.early_stopping_frequency == 0:
                    print("Checking early stopping model")

                    out = early_stopping_predictor.predict_raw(validation_datas, batch_size=checkpoint_params.batch_size,
                                                               progress_bar=progress_bar, apply_preproc=False)
                    pred_texts = [d.sentence for d in out]
                    result = Evaluator.evaluate(gt_data=validation_txts, pred_data=pred_texts, progress_bar=progress_bar)
                    accuracy = 1 - result["avg_ler"]

                    if accuracy > early_stopping_best_accuracy:
                        early_stopping_best_accuracy = accuracy
                        early_stopping_best_cur_nbest = 1
                        early_stopping_best_at_iter = iter + 1
                        # overwrite as best model
                        last_checkpoint = make_checkpoint(
                            checkpoint_params.early_stopping_best_model_output_dir,
                            prefix="",
                            version=checkpoint_params.early_stopping_best_model_prefix,
                        )
                        print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy))
                    else:
                        early_stopping_best_cur_nbest += 1
                        print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})".
                              format(early_stopping_best_accuracy, early_stopping_best_at_iter,
                                     checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest))

                    if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest:
                        print("Early stopping now.")
                        break

        except KeyboardInterrupt as e:
            print("Storing interrupted checkpoint")
            make_checkpoint(checkpoint_params.output_dir,
                            checkpoint_params.output_model_prefix,
                            "interrupted")
            raise e

        print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
Example #4
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument(
        "--gt",
        nargs="+",
        required=True,
        help="Ground truth files (.gt.txt extension). "
        "Optionally, you can pass a single json file defining all parameters.")
    parser.add_argument(
        "--pred",
        nargs="+",
        default=None,
        help=
        "Prediction files if provided. Else files with .pred.txt are expected at the same "
        "location as the gt.")
    parser.add_argument("--pred_dataset",
                        type=DataSetType.from_string,
                        choices=list(DataSetType),
                        default=DataSetType.FILE)
    parser.add_argument("--pred_ext",
                        type=str,
                        default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument(
        "--n_worst_lines",
        type=int,
        default=0,
        help="Print the n worst recognized text lines with its error")
    parser.add_argument(
        "--xlsx_output",
        type=str,
        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument(
        "--non_existing_file_handling_mode",
        type=str,
        default="error",
        help=
        "How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
        "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
        "'Empty' will handle this file as would it be empty (fully checking for errors)."
        "'Error' will throw an exception if a file is not existing. This is the default behaviour."
    )
    parser.add_argument("--skip_empty_gt",
                        action="store_true",
                        default=False,
                        help="Ignore lines of the gt that are empty.")
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help=
        "Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)"
    )

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)

    args = parser.parse_args()

    # check if loading a json file
    if len(args.gt) == 1 and args.gt[0].endswith("json"):
        with open(args.gt[0], 'r') as f:
            json_args = json.load(f)
            for key, value in json_args.items():
                setattr(args, key, value)

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(
                args.checkpoint if args.checkpoint.endswith(".json") else
                args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(
                checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower(
    ) != "error "
    gt_data_set = create_dataset(
        args.dataset,
        DataSetMode.EVAL,
        texts=gt_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_gt_text_index},
    )
    pred_data_set = create_dataset(
        args.pred_dataset,
        DataSetMode.EVAL,
        texts=pred_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_pred_text_index},
    )

    evaluator = Evaluator(text_preprocessor=text_preproc,
                          skip_empty_gt=args.skip_empty_gt)
    r = evaluator.run(gt_dataset=gt_data_set,
                      pred_dataset=pred_data_set,
                      processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
        .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_data_set.samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output, [{
            "prefix": "evaluation",
            "results": r,
            "gt_files": gt_files,
        }])
Example #5
0
    def evaluate_books(self, books, models, rtl=False, mode="auto", sample=-1):
        if type(books) == str:
            books = [books]
        if type(models) == str:
            models = [models]
        results = {}
        if mode == "auto":
            with h5py.File(self.cachefile, 'r', libver='latest',
                           swmr=True) as cache:
                for b in books:
                    for p in cache[b]:
                        for s in cache[b][p]:
                            if "text" in cache[b][p][s].attrs:
                                mode = "eval"
                                break
                        if mode != "auto":
                            break
                    if mode != "auto":
                        break
            if mode == "auto":
                mode = "conf"

        if mode == "conf":
            dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books)
        else:
            dset = Nash5DataSet(DataSetMode.TRAIN, self.cachefile, books)
            dset.mode = DataSetMode.PREDICT  # otherwise results are randomised

        if 0 < sample < len(dset):
            delsamples = random.sample(dset._samples, len(dset) - sample)
            for s in delsamples:
                dset._samples.remove(s)

        if mode == "conf":
            #dset = dset.to_raw_input_dataset(processes=1, progress_bar=True)
            for model in models:
                if isinstance(model, str):
                    model = [model]
                predictor = MultiPredictor(checkpoints=model,
                                           data_preproc=NoopDataPreprocessor(),
                                           batch_size=1,
                                           processes=1)
                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value(
                    "confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)
                do_prediction = predictor.predict_dataset(dset,
                                                          progress_bar=True)
                avg_sentence_confidence = 0
                n_predictions = 0
                for result, sample in do_prediction:
                    n_predictions += 1
                    prediction = voter.vote_prediction_result(result)
                    avg_sentence_confidence += prediction.avg_char_probability

                results["/".join(
                    model)] = avg_sentence_confidence / n_predictions

        else:
            for model in models:
                if isinstance(model, str):
                    model = [model]

                predictor = MultiPredictor(checkpoints=model,
                                           data_preproc=NoopDataPreprocessor(),
                                           batch_size=1,
                                           processes=1)

                voter_params = VoterParams()
                voter_params.type = VoterParams.Type.Value(
                    "confidence_voter_default_ctc".upper())
                voter = voter_from_proto(voter_params)

                out_gen = predictor.predict_dataset(dset, progress_bar=True)

                preproc = self.bidi_preproc if rtl else self.txt_preproc

                pred_dset = RawDataSet(DataSetMode.EVAL,
                                       texts=preproc.apply([
                                           voter.vote_prediction_result(
                                               d[0]).sentence for d in out_gen
                                       ]))

                evaluator = Evaluator(text_preprocessor=NoopTextProcessor(),
                                      skip_empty_gt=False)
                r = evaluator.run(gt_dataset=dset,
                                  pred_dataset=pred_dset,
                                  processes=1,
                                  progress_bar=True)

                results["/".join(model)] = 1 - r["avg_ler"]
        return results
Example #6
0
    def _run_train(self, train_net, test_net, codec, train_start_time,
                   progress_bar):
        checkpoint_params = self.checkpoint_params
        validation_dataset = test_net.input_dataset
        iters_per_epoch = max(
            1,
            int(train_net.input_dataset.epoch_size() /
                checkpoint_params.batch_size))

        loss_stats = RunningStatistics(checkpoint_params.stats_size,
                                       checkpoint_params.loss_stats)
        ler_stats = RunningStatistics(checkpoint_params.stats_size,
                                      checkpoint_params.ler_stats)
        dt_stats = RunningStatistics(checkpoint_params.stats_size,
                                     checkpoint_params.dt_stats)

        display = checkpoint_params.display
        display_epochs = display <= 1
        if display <= 0:
            display = 0  # to not display anything
        elif display_epochs:
            display = max(1,
                          int(display * iters_per_epoch))  # relative to epochs
        else:
            display = max(1, int(display))  # iterations

        checkpoint_frequency = checkpoint_params.checkpoint_frequency
        early_stopping_frequency = checkpoint_params.early_stopping_frequency
        if early_stopping_frequency < 0:
            # set early stopping frequency to half epoch
            early_stopping_frequency = int(0.5 * iters_per_epoch)
        elif 0 < early_stopping_frequency <= 1:
            early_stopping_frequency = int(
                early_stopping_frequency *
                iters_per_epoch)  # relative to epochs
        else:
            early_stopping_frequency = int(early_stopping_frequency)
        early_stopping_frequency = max(1, early_stopping_frequency)

        if checkpoint_frequency < 0:
            checkpoint_frequency = early_stopping_frequency
        elif 0 < checkpoint_frequency <= 1:
            checkpoint_frequency = int(checkpoint_frequency *
                                       iters_per_epoch)  # relative to epochs
        else:
            checkpoint_frequency = int(checkpoint_frequency)

        early_stopping_enabled = self.validation_dataset is not None \
                                 and checkpoint_params.early_stopping_frequency > 0 \
                                 and checkpoint_params.early_stopping_nbest > 1
        early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy
        early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest
        early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter

        early_stopping_predictor = Predictor(codec=codec,
                                             text_postproc=self.txt_postproc,
                                             network=test_net)

        # Start the actual training
        # ====================================================================================

        iter = checkpoint_params.iter

        # helper function to write a checkpoint
        def make_checkpoint(base_dir, prefix, version=None):
            if version:
                checkpoint_path = os.path.abspath(
                    os.path.join(base_dir, "{}{}.ckpt".format(prefix,
                                                              version)))
            else:
                checkpoint_path = os.path.abspath(
                    os.path.join(base_dir,
                                 "{}{:08d}.ckpt".format(prefix, iter + 1)))
            print("Storing checkpoint to '{}'".format(checkpoint_path))
            train_net.save_checkpoint(checkpoint_path)
            checkpoint_params.version = Checkpoint.VERSION
            checkpoint_params.iter = iter
            checkpoint_params.loss_stats[:] = loss_stats.values
            checkpoint_params.ler_stats[:] = ler_stats.values
            checkpoint_params.dt_stats[:] = dt_stats.values
            checkpoint_params.total_time = time.time() - train_start_time
            checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy
            checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest
            checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter

            with open(checkpoint_path + ".json", 'w') as f:
                f.write(json_format.MessageToJson(checkpoint_params))

            return checkpoint_path

        try:
            last_checkpoint = None
            n_infinite_losses = 0
            n_max_infinite_losses = 5

            # Training loop, can be interrupted by early stopping
            for iter in range(iter, checkpoint_params.max_iters):
                checkpoint_params.iter = iter

                iter_start_time = time.time()
                result = train_net.train_step()

                if not np.isfinite(result['loss']):
                    n_infinite_losses += 1

                    if n_max_infinite_losses == n_infinite_losses:
                        print(
                            "Error: Loss is not finite! Trying to restart from last checkpoint."
                        )
                        if not last_checkpoint:
                            raise Exception(
                                "No checkpoint written yet. Training must be stopped."
                            )
                        else:
                            # reload also non trainable weights, such as solver-specific variables
                            train_net.load_weights(
                                last_checkpoint, restore_only_trainable=False)
                            continue
                    else:
                        continue

                n_infinite_losses = 0

                loss_stats.push(result['loss'])
                ler_stats.push(result['ler'])

                dt_stats.push(time.time() - iter_start_time)

                if display > 0 and iter % display == 0:
                    # apply postprocessing to display the true output
                    pred_sentence = self.txt_postproc.apply("".join(
                        codec.decode(result["decoded"][0])))
                    gt_sentence = self.txt_postproc.apply("".join(
                        codec.decode(result["gt"][0])))

                    if display_epochs:
                        print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s".
                              format(iter / iters_per_epoch, loss_stats.mean(),
                                     ler_stats.mean(), dt_stats.mean()))
                    else:
                        print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".
                              format(iter, loss_stats.mean(), ler_stats.mean(),
                                     dt_stats.mean()))

                    # Insert utf-8 ltr/rtl direction marks for bidi support
                    lr = "\u202A\u202B"
                    print(" PRED: '{}{}{}'".format(
                        lr[bidi.get_base_level(pred_sentence)], pred_sentence,
                        "\u202C"))
                    print(" TRUE: '{}{}{}'".format(
                        lr[bidi.get_base_level(gt_sentence)], gt_sentence,
                        "\u202C"))

                if checkpoint_frequency > 0 and (
                        iter + 1) % checkpoint_frequency == 0:
                    last_checkpoint = make_checkpoint(
                        checkpoint_params.output_dir,
                        checkpoint_params.output_model_prefix)

                if early_stopping_enabled and (
                        iter + 1) % early_stopping_frequency == 0:
                    print("Checking early stopping model")

                    out_gen = early_stopping_predictor.predict_input_dataset(
                        validation_dataset, progress_bar=progress_bar)
                    result = Evaluator.evaluate_single_list(
                        map(
                            Evaluator.evaluate_single_args,
                            map(
                                lambda d: tuple(
                                    self.txt_preproc.apply([
                                        ''.join(d.ground_truth), d.sentence
                                    ])), out_gen)))
                    accuracy = 1 - result["avg_ler"]

                    if accuracy > early_stopping_best_accuracy:
                        early_stopping_best_accuracy = accuracy
                        early_stopping_best_cur_nbest = 1
                        early_stopping_best_at_iter = iter + 1
                        # overwrite as best model
                        last_checkpoint = make_checkpoint(
                            checkpoint_params.
                            early_stopping_best_model_output_dir,
                            prefix="",
                            version=checkpoint_params.
                            early_stopping_best_model_prefix,
                        )
                        print(
                            "Found better model with accuracy of {:%}".format(
                                early_stopping_best_accuracy))
                    else:
                        early_stopping_best_cur_nbest += 1
                        print(
                            "No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})"
                            .format(
                                early_stopping_best_accuracy,
                                early_stopping_best_at_iter,
                                checkpoint_params.early_stopping_nbest -
                                early_stopping_best_cur_nbest))

                    if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest:
                        print("Early stopping now.")
                        break

                    if accuracy >= 1:
                        print(
                            "Reached perfect score on validation set. Early stopping now."
                        )
                        break

        except KeyboardInterrupt as e:
            print("Storing interrupted checkpoint")
            make_checkpoint(checkpoint_params.output_dir,
                            checkpoint_params.output_model_prefix,
                            "interrupted")
            raise e

        print("Total time {}s for {} iterations.".format(
            time.time() - train_start_time, iter))
Example #7
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--gt",
                        nargs="+",
                        required=True,
                        help="Ground truth files (.gt.txt extension)")
    parser.add_argument(
        "--pred",
        nargs="+",
        default=None,
        help=
        "Prediction files if provided. Else files with .pred.txt are expected at the same "
        "location as the gt.")
    parser.add_argument("--pred_ext",
                        type=str,
                        default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument("--n_confusions",
                        type=int,
                        default=-1,
                        help="Only print n most common confusions")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="Number of threads to use for evaluation")

    args = parser.parse_args()

    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
        if len(pred_files) != len(gt_files):
            raise Exception(
                "Mismatch in the number of gt and pred files: {} vs {}".format(
                    len(gt_files), len(pred_files)))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]

    gt_data_set = FileDataSet(texts=gt_files)
    pred_data_set = FileDataSet(texts=pred_files)

    evaluator = Evaluator()
    r = evaluator.run(gt_dataset=gt_data_set,
                      pred_dataset=pred_data_set,
                      processes=args.num_threads,
                      progress_bar=True)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
        .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                r["total_sync_errs"]))

    # sort descending
    if args.n_confusions != 0 and r["total_sync_errs"] > 0:
        total_percent = 0
        keys = sorted(r['confusion'].items(), key=lambda item: -item[1])
        print("{:8s} {:8s} {:8s} {:10s}".format("GT", "PRED", "COUNT",
                                                "PERCENT"))

        for i, ((gt, pred), count) in enumerate(keys):
            gt_fmt = "{" + gt + "}"
            pred_fmt = "{" + pred + "}"
            if i == args.n_confusions:
                break

            percent = count * max(len(gt), len(pred)) / r["total_sync_errs"]
            print("{:8s} {:8s} {:8d} {:10.2%}".format(gt_fmt, pred_fmt, count,
                                                      percent))
            total_percent += percent

        print("The remaining but hidden errors make up {:.2%}".format(
            1.0 - total_percent))
Example #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_imgs", type=str, nargs="+", required=True,
                        help="The evaluation files")
    parser.add_argument("--eval_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--checkpoint", type=str, nargs="+", default=[],
                        help="Path to the checkpoint without file extension")
    parser.add_argument("-j", "--processes", type=int, default=1,
                        help="Number of processes to use")
    parser.add_argument("--verbose", action="store_true",
                        help="Print additional information")
    parser.add_argument("--voter", type=str, nargs="+", default=["sequence_voter", "confidence_voter_default_ctc", "confidence_voter_fuzzy_ctc"],
                        help="The voting algorithm to use. Possible values: confidence_voter_default_ctc (default), "
                             "confidence_voter_fuzzy_ctc, sequence_voter")
    parser.add_argument("--batch_size", type=int, default=10,
                        help="The batch size for prediction")
    parser.add_argument("--dump", type=str,
                        help="Dump the output as serialized pickle object")
    parser.add_argument("--no_skip_invalid_gt", action="store_true",
                        help="Do no skip invalid gt, instead raise an exception.")

    args = parser.parse_args()

    # allow user to specify json file for model definition, but remove the file extension
    # for further processing
    args.checkpoint = [(cp[:-5] if cp.endswith(".json") else cp) for cp in args.checkpoint]

    # load files
    gt_images = sorted(glob_all(args.eval_imgs))
    gt_txts = [split_all_ext(path)[0] + ".gt.txt" for path in sorted(glob_all(args.eval_imgs))]

    dataset = create_dataset(
        args.eval_dataset,
        DataSetMode.TRAIN,
        images=gt_images,
        texts=gt_txts,
        skip_invalid=not args.no_skip_invalid_gt
    )

    print("Found {} files in the dataset".format(len(dataset)))
    if len(dataset) == 0:
        raise Exception("Empty dataset provided. Check your files argument (got {})!".format(args.files))

    # predict for all models
    n_models = len(args.checkpoint)
    predictor = MultiPredictor(checkpoints=args.checkpoint, batch_size=args.batch_size, processes=args.processes)
    do_prediction = predictor.predict_dataset(dataset, progress_bar=True)

    voters = []
    all_voter_sentences = []
    all_prediction_sentences = [[] for _ in range(n_models)]

    for voter in args.voter:
        # create voter
        voter_params = VoterParams()
        voter_params.type = VoterParams.Type.Value(voter.upper())
        voters.append(voter_from_proto(voter_params))
        all_voter_sentences.append([])

    for prediction, sample in do_prediction:
        for sent, p in zip(all_prediction_sentences, prediction):
            sent.append(p.sentence)

        # vote results
        for voter, voter_sentences in zip(voters, all_voter_sentences):
            voter_sentences.append(voter.vote_prediction_result(prediction).sentence)

    # evaluation
    text_preproc = text_processor_from_proto(predictor.predictors[0].model_params.text_preprocessor)
    evaluator = Evaluator(text_preprocessor=text_preproc)
    evaluator.preload_gt(gt_dataset=dataset, progress_bar=True)

    def single_evaluation(predicted_sentences):
        if len(predicted_sentences) != len(dataset):
            raise Exception("Mismatch in number of gt and pred files: {} != {}. Probably, the prediction did "
                            "not succeed".format(len(dataset), len(predicted_sentences)))

        pred_data_set = create_dataset(
            DataSetType.RAW,
            DataSetMode.EVAL,
            texts=predicted_sentences)

        r = evaluator.run(pred_dataset=pred_data_set, progress_bar=True, processes=args.processes)

        return r

    full_evaluation = {}
    for id, data in [(str(i), sent) for i, sent in enumerate(all_prediction_sentences)] + list(zip(args.voter, all_voter_sentences)):
        full_evaluation[id] = {"eval": single_evaluation(data), "data": data}

    if args.verbose:
        print(full_evaluation)

    if args.dump:
        import pickle
        with open(args.dump, 'wb') as f:
            pickle.dump({"full": full_evaluation, "gt_txts": gt_txts, "gt": dataset.text_samples()}, f)
Example #9
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--gt",
                        nargs="+",
                        required=True,
                        help="Ground truth files (.gt.txt extension)")
    parser.add_argument(
        "--pred",
        nargs="+",
        default=None,
        help=
        "Prediction files if provided. Else files with .pred.txt are expected at the same "
        "location as the gt.")
    parser.add_argument("--pred_ext",
                        type=str,
                        default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument(
        "--n_confusions",
        type=int,
        default=10,
        help=
        "Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument(
        "--n_worst_lines",
        type=int,
        default=0,
        help="Print the n worst recognized text lines with its error")
    parser.add_argument(
        "--xlsx_output",
        type=str,
        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads",
                        type=int,
                        default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument(
        "--non_existing_file_handling_mode",
        type=str,
        default="error",
        help=
        "How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
        "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
        "'Empty' will handle this file as would it be empty (fully checking for errors)."
        "'Error' will throw an exception if a file is not existing. This is the default behaviour."
    )
    parser.add_argument("--no_progress_bars",
                        action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help=
        "Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)"
    )

    args = parser.parse_args()

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
        if len(pred_files) != len(gt_files):
            raise Exception(
                "Mismatch in the number of gt and pred files: {} vs {}".format(
                    len(gt_files), len(pred_files)))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(
                args.checkpoint if args.checkpoint.endswith(".json") else
                args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(
                checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower(
    ) == "empty"
    gt_data_set = FileDataSet(texts=gt_files,
                              non_existing_as_empty=non_existing_as_empty)
    pred_data_set = FileDataSet(texts=pred_files,
                                non_existing_as_empty=non_existing_as_empty)

    evaluator = Evaluator(text_preprocessor=text_preproc)
    r = evaluator.run(gt_dataset=gt_data_set,
                      pred_dataset=pred_data_set,
                      processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print(
        "Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)"
        .format(r["avg_ler"], r["total_char_errs"], r["total_chars"],
                r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_files, gt_data_set.text_samples(),
                      pred_data_set.text_samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output, [{
            "prefix": "evaluation",
            "results": r,
            "gt_files": gt_files,
            "gts": gt_data_set.text_samples(),
            "preds": pred_data_set.text_samples()
        }])
Example #10
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--gt", nargs="+", required=True,
                        help="Ground truth files (.gt.txt extension)")
    parser.add_argument("--pred", nargs="+", default=None,
                        help="Prediction files if provided. Else files with .pred.txt are expected at the same "
                             "location as the gt.")
    parser.add_argument("--pred_dataset", type=DataSetType.from_string, choices=list(DataSetType), default=DataSetType.FILE)
    parser.add_argument("--pred_ext", type=str, default=".pred.txt",
                        help="Extension of the predicted text files")
    parser.add_argument("--n_confusions", type=int, default=10,
                        help="Only print n most common confusions. Defaults to 10, use -1 for all.")
    parser.add_argument("--n_worst_lines", type=int, default=0,
                        help="Print the n worst recognized text lines with its error")
    parser.add_argument("--xlsx_output", type=str,
                        help="Optionally write a xlsx file with the evaluation results")
    parser.add_argument("--num_threads", type=int, default=1,
                        help="Number of threads to use for evaluation")
    parser.add_argument("--non_existing_file_handling_mode", type=str, default="error",
                        help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
                             "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
                             "'Empty' will handle this file as would it be empty (fully checking for errors)."
                             "'Error' will throw an exception if a file is not existing. This is the default behaviour.")
    parser.add_argument("--no_progress_bars", action="store_true",
                        help="Do not show any progress bars")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)")

    # page xml specific args
    parser.add_argument("--pagexml_gt_text_index", default=0)
    parser.add_argument("--pagexml_pred_text_index", default=1)


    args = parser.parse_args()

    print("Resolving files")
    gt_files = sorted(glob_all(args.gt))

    if args.pred:
        pred_files = sorted(glob_all(args.pred))
    else:
        pred_files = [split_all_ext(gt)[0] + args.pred_ext for gt in gt_files]
        args.pred_dataset = args.dataset

    if args.non_existing_file_handling_mode.lower() == "skip":
        non_existing_pred = [p for p in pred_files if not os.path.exists(p)]
        for f in non_existing_pred:
            idx = pred_files.index(f)
            del pred_files[idx]
            del gt_files[idx]

    text_preproc = None
    if args.checkpoint:
        with open(args.checkpoint if args.checkpoint.endswith(".json") else args.checkpoint + '.json', 'r') as f:
            checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
            text_preproc = text_processor_from_proto(checkpoint_params.model.text_preprocessor)

    non_existing_as_empty = args.non_existing_file_handling_mode.lower() != "error "
    gt_data_set = create_dataset(
        args.dataset,
        DataSetMode.EVAL,
        texts=gt_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_gt_text_index},
    )
    pred_data_set = create_dataset(
        args.pred_dataset,
        DataSetMode.EVAL,
        texts=pred_files,
        non_existing_as_empty=non_existing_as_empty,
        args={'text_index': args.pagexml_pred_text_index},
    )

    evaluator = Evaluator(text_preprocessor=text_preproc)
    r = evaluator.run(gt_dataset=gt_data_set, pred_dataset=pred_data_set, processes=args.num_threads,
                      progress_bar=not args.no_progress_bars)

    # TODO: More output
    print("Evaluation result")
    print("=================")
    print("")
    print("Got mean normalized label error rate of {:.2%} ({} errs, {} total chars, {} sync errs)".format(
        r["avg_ler"], r["total_char_errs"], r["total_chars"], r["total_sync_errs"]))

    # sort descending
    print_confusions(r, args.n_confusions)

    print_worst_lines(r, gt_data_set.samples(), pred_data_set.text_samples(), args.n_worst_lines)

    if args.xlsx_output:
        write_xlsx(args.xlsx_output,
                   [{
                       "prefix": "evaluation",
                       "results": r,
                       "gt_files": gt_files,
                       "gts": gt_data_set.text_samples(),
                       "preds": pred_data_set.text_samples()
                   }])
Example #11
0
    def _run_train(self, train_net, test_net, codec, train_start_time, progress_bar):
        checkpoint_params = self.checkpoint_params
        validation_dataset = test_net.input_dataset
        iters_per_epoch = max(1, int(len(train_net.input_dataset) / checkpoint_params.batch_size))

        loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats)
        ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats)
        dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats)

        display = checkpoint_params.display
        display_epochs = display <= 1
        if display <= 0:
            display = 0                                       # to not display anything
        elif display_epochs:
            display = max(1, int(display * iters_per_epoch))  # relative to epochs
        else:
            display = max(1, int(display))                    # iterations

        checkpoint_frequency = checkpoint_params.checkpoint_frequency
        early_stopping_frequency = checkpoint_params.early_stopping_frequency
        if early_stopping_frequency < 0:
            # set early stopping frequency to half epoch
            early_stopping_frequency = int(0.5 * iters_per_epoch)
        elif 0 < early_stopping_frequency <= 1:
            early_stopping_frequency = int(early_stopping_frequency * iters_per_epoch)  # relative to epochs
        else:
            early_stopping_frequency = int(early_stopping_frequency)

        if checkpoint_frequency < 0:
            checkpoint_frequency = early_stopping_frequency
        elif 0 < checkpoint_frequency <= 1:
            checkpoint_frequency = int(checkpoint_frequency * iters_per_epoch)  # relative to epochs
        else:
            checkpoint_frequency = int(checkpoint_frequency)

        early_stopping_enabled = self.validation_dataset is not None \
                                 and checkpoint_params.early_stopping_frequency > 0 \
                                 and checkpoint_params.early_stopping_nbest > 1
        early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy
        early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest
        early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter

        early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc,
                                             network=test_net)

        # Start the actual training
        # ====================================================================================

        iter = checkpoint_params.iter

        # helper function to write a checkpoint
        def make_checkpoint(base_dir, prefix, version=None):
            if version:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version)))
            else:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1)))
            print("Storing checkpoint to '{}'".format(checkpoint_path))
            train_net.save_checkpoint(checkpoint_path)
            checkpoint_params.version = Checkpoint.VERSION
            checkpoint_params.iter = iter
            checkpoint_params.loss_stats[:] = loss_stats.values
            checkpoint_params.ler_stats[:] = ler_stats.values
            checkpoint_params.dt_stats[:] = dt_stats.values
            checkpoint_params.total_time = time.time() - train_start_time
            checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy
            checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest
            checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter

            with open(checkpoint_path + ".json", 'w') as f:
                f.write(json_format.MessageToJson(checkpoint_params))

            return checkpoint_path

        try:
            last_checkpoint = None
            n_infinite_losses = 0
            n_max_infinite_losses = 5

            # Training loop, can be interrupted by early stopping
            for iter in range(iter, checkpoint_params.max_iters):
                checkpoint_params.iter = iter

                iter_start_time = time.time()
                result = train_net.train_step()

                if not np.isfinite(result['loss']):
                    n_infinite_losses += 1

                    if n_max_infinite_losses == n_infinite_losses:
                        print("Error: Loss is not finite! Trying to restart from last checkpoint.")
                        if not last_checkpoint:
                            raise Exception("No checkpoint written yet. Training must be stopped.")
                        else:
                            # reload also non trainable weights, such as solver-specific variables
                            train_net.load_weights(last_checkpoint, restore_only_trainable=False)
                            continue
                    else:
                        continue

                n_infinite_losses = 0

                loss_stats.push(result['loss'])
                ler_stats.push(result['ler'])

                dt_stats.push(time.time() - iter_start_time)

                if display > 0 and iter % display == 0:
                    # apply postprocessing to display the true output
                    pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0])))
                    gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0])))

                    if display_epochs:
                        print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(
                            iter / iters_per_epoch, loss_stats.mean(), ler_stats.mean(), dt_stats.mean()))
                    else:
                        print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(
                            iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean()))

                    # Insert utf-8 ltr/rtl direction marks for bidi support
                    lr = "\u202A\u202B"
                    print(" PRED: '{}{}{}'".format(lr[bidi.get_base_level(pred_sentence)], pred_sentence, "\u202C"))
                    print(" TRUE: '{}{}{}'".format(lr[bidi.get_base_level(gt_sentence)], gt_sentence, "\u202C"))

                if checkpoint_frequency > 0 and (iter + 1) % checkpoint_frequency == 0:
                    last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix)

                if early_stopping_enabled and (iter + 1) % early_stopping_frequency == 0:
                    print("Checking early stopping model")

                    out_gen = early_stopping_predictor.predict_input_dataset(validation_dataset,
                                                                             progress_bar=progress_bar)
                    result = Evaluator.evaluate_single_list(map(
                        Evaluator.evaluate_single_args,
                        map(lambda d: tuple(self.txt_preproc.apply([''.join(d.ground_truth), d.sentence])), out_gen)))
                    accuracy = 1 - result["avg_ler"]

                    if accuracy > early_stopping_best_accuracy:
                        early_stopping_best_accuracy = accuracy
                        early_stopping_best_cur_nbest = 1
                        early_stopping_best_at_iter = iter + 1
                        # overwrite as best model
                        last_checkpoint = make_checkpoint(
                            checkpoint_params.early_stopping_best_model_output_dir,
                            prefix="",
                            version=checkpoint_params.early_stopping_best_model_prefix,
                        )
                        print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy))
                    else:
                        early_stopping_best_cur_nbest += 1
                        print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})".
                              format(early_stopping_best_accuracy, early_stopping_best_at_iter,
                                     checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest))

                    if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest:
                        print("Early stopping now.")
                        break

                    if accuracy >= 1:
                        print("Reached perfect score on validation set. Early stopping now.")
                        break

        except KeyboardInterrupt as e:
            print("Storing interrupted checkpoint")
            make_checkpoint(checkpoint_params.output_dir,
                            checkpoint_params.output_model_prefix,
                            "interrupted")
            raise e

        print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))