class VisCallback(keras.callbacks.Callback):
    def __init__(self, training_callback, codec, data_gen, predict_func,
                 checkpoint_params, steps_per_epoch, text_post_proc):
        self.training_callback = training_callback
        self.codec = codec
        self.data_gen = data_gen
        self.predict_func = predict_func
        self.checkpoint_params = checkpoint_params
        self.steps_per_epoch = steps_per_epoch
        self.text_post_proc = text_post_proc

        self.loss_stats = RunningStatistics(checkpoint_params.stats_size,
                                            checkpoint_params.loss_stats)
        self.ler_stats = RunningStatistics(checkpoint_params.stats_size,
                                           checkpoint_params.ler_stats)
        self.dt_stats = RunningStatistics(checkpoint_params.stats_size,
                                          checkpoint_params.dt_stats)

        display = checkpoint_params.display
        self.display_epochs = display <= 1
        if display <= 0:
            display = 0  # do not display anything
        elif self.display_epochs:
            display = max(1,
                          int(display * steps_per_epoch))  # relative to epochs
        else:
            display = max(1, int(display))  # iterations

        self.display = display
        self.iter_start_time = time.time()
        self.train_start_time = time.time()

    def on_train_begin(self, logs):
        self.iter_start_time = time.time()
        self.train_start_time = time.time()

    def on_train_end(self, logs):
        self.training_callback.training_finished(
            time.time() - self.train_start_time, self.checkpoint_params.iter)

    def on_batch_end(self, batch, logs):
        dt = time.time() - self.iter_start_time
        self.iter_start_time = time.time()
        self.dt_stats.push(dt)
        self.loss_stats.push(logs['loss'])
        self.checkpoint_params.iter += 1

        if self.display > 0 and self.checkpoint_params.iter % self.display == 0:
            # apply postprocessing to display the true output
            cer, target, decoded = self._generate(1)
            self.ler_stats.push(cer)
            pred_sentence = self.text_post_proc.apply("".join(
                self.codec.decode(decoded[0])))
            gt_sentence = self.text_post_proc.apply("".join(
                self.codec.decode(target[0])))

            self.training_callback.display(self.ler_stats.mean(),
                                           self.loss_stats.mean(),
                                           self.dt_stats.mean(),
                                           self.checkpoint_params.iter,
                                           self.steps_per_epoch,
                                           self.display_epochs, pred_sentence,
                                           gt_sentence)

    def on_epoch_end(self, epoch, logs):
        pass

    def _generate(self, count):
        it = iter(self.data_gen)
        cer, target, decoded = zip(
            *[self.predict_func(next(it)) for _ in range(count)])
        return np.mean(cer), sum(map(sparse_to_lists, target),
                                 []), sum(map(sparse_to_lists, decoded), [])
Exemple #2
0
    def train(self, progress_bar=False):
        checkpoint_params = self.checkpoint_params

        train_start_time = time.time() + self.checkpoint_params.total_time

        self.dataset.load_samples(processes=1, progress_bar=progress_bar)
        datas, txts = self.dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
        if len(datas) == 0:
            raise Exception("Empty dataset is not allowed. Check if the data is at the correct location")

        if self.validation_dataset:
            self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar)
            validation_datas, validation_txts = self.validation_dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt)
            if len(validation_datas) == 0:
                raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.")
        else:
            validation_datas, validation_txts = [], []


        # preprocessing steps
        texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        datas = self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_txts = self.txt_preproc.apply(validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar)
        validation_datas = self.data_preproc.apply(validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar)

        # compute the codec
        codec = self.codec if self.codec else Codec.from_texts(texts, whitelist=self.codec_whitelist)

        # data augmentation on preprocessed data
        if self.data_augmenter:
            datas, texts = self.data_augmenter.augment_datas(datas, texts, n_augmentations=self.n_augmentations,
                                                             processes=checkpoint_params.processes, progress_bar=progress_bar)

            # TODO: validation data augmentation
            # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0,
            #                                                  processes=checkpoint_params.processes, progress_bar=progress_bar)

        # create backend
        network_params = checkpoint_params.model.network
        network_params.features = checkpoint_params.model.line_height
        network_params.classes = len(codec)
        if self.weights:
            # if we load the weights, take care of codec changes as-well
            with open(self.weights + '.json', 'r') as f:
                restore_checkpoint_params = json_format.Parse(f.read(), CheckpointParams())
                restore_model_params = restore_checkpoint_params.model

            # checks
            if checkpoint_params.model.line_height != network_params.features:
                raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format(
                    network_params.features, checkpoint_params.model.line_height
                ))

            # create codec of the same type
            restore_codec = codec.__class__(restore_model_params.codec.charset)
            # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one
            codec_changes = restore_codec.align(codec)
            codec = restore_codec
            print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1])))
            # The actual weight/bias matrix will be changed after loading the old weights
        else:
            codec_changes = None

        # store the new codec
        checkpoint_params.model.codec.charset[:] = codec.charset
        print("CODEC: {}".format(codec.charset))

        # compute the labels with (new/current) codec
        labels = [codec.encode(txt) for txt in texts]

        backend = create_backend_from_proto(network_params,
                                            weights=self.weights,
                                            )
        backend.set_train_data(datas, labels)
        backend.set_prediction_data(validation_datas)
        if codec_changes:
            backend.realign_model_labels(*codec_changes)
        backend.prepare(train=True)

        loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats)
        ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats)
        dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats)

        early_stopping_enabled = self.validation_dataset is not None \
                                 and checkpoint_params.early_stopping_frequency > 0 \
                                 and checkpoint_params.early_stopping_nbest > 1
        early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy
        early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest
        early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter

        early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc,
                                             backend=backend)

        # Start the actual training
        # ====================================================================================

        iter = checkpoint_params.iter

        # helper function to write a checkpoint
        def make_checkpoint(base_dir, prefix, version=None):
            if version:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version)))
            else:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1)))
            print("Storing checkpoint to '{}'".format(checkpoint_path))
            backend.save_checkpoint(checkpoint_path)
            checkpoint_params.iter = iter
            checkpoint_params.loss_stats[:] = loss_stats.values
            checkpoint_params.ler_stats[:] = ler_stats.values
            checkpoint_params.dt_stats[:] = dt_stats.values
            checkpoint_params.total_time = time.time() - train_start_time
            checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy
            checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest
            checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter

            with open(checkpoint_path + ".json", 'w') as f:
                f.write(json_format.MessageToJson(checkpoint_params))

            return checkpoint_path

        try:
            last_checkpoint = None

            # Training loop, can be interrupted by early stopping
            for iter in range(iter, checkpoint_params.max_iters):
                checkpoint_params.iter = iter

                iter_start_time = time.time()
                result = backend.train_step(checkpoint_params.batch_size)

                if not np.isfinite(result['loss']):
                    print("Error: Loss is not finite! Trying to restart from last checkpoint.")
                    if not last_checkpoint:
                        raise Exception("No checkpoint written yet. Training must be stopped.")
                    else:
                        # reload also non trainable weights, such as solver-specific variables
                        backend.load_checkpoint_weights(last_checkpoint, restore_only_trainable=False)
                        continue

                loss_stats.push(result['loss'])
                ler_stats.push(result['ler'])

                dt_stats.push(time.time() - iter_start_time)

                if iter % checkpoint_params.display == 0:
                    pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0])))
                    gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0])))
                    print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean()))
                    print(" PRED: '{}'".format(pred_sentence))
                    print(" TRUE: '{}'".format(gt_sentence))

                if (iter + 1) % checkpoint_params.checkpoint_frequency == 0:
                    last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix)

                if early_stopping_enabled and (iter + 1) % checkpoint_params.early_stopping_frequency == 0:
                    print("Checking early stopping model")

                    out = early_stopping_predictor.predict_raw(validation_datas, batch_size=checkpoint_params.batch_size,
                                                               progress_bar=progress_bar, apply_preproc=False)
                    pred_texts = [d.sentence for d in out]
                    result = Evaluator.evaluate(gt_data=validation_txts, pred_data=pred_texts, progress_bar=progress_bar)
                    accuracy = 1 - result["avg_ler"]

                    if accuracy > early_stopping_best_accuracy:
                        early_stopping_best_accuracy = accuracy
                        early_stopping_best_cur_nbest = 1
                        early_stopping_best_at_iter = iter + 1
                        # overwrite as best model
                        last_checkpoint = make_checkpoint(
                            checkpoint_params.early_stopping_best_model_output_dir,
                            prefix="",
                            version=checkpoint_params.early_stopping_best_model_prefix,
                        )
                        print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy))
                    else:
                        early_stopping_best_cur_nbest += 1
                        print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})".
                              format(early_stopping_best_accuracy, early_stopping_best_at_iter,
                                     checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest))

                    if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest:
                        print("Early stopping now.")
                        break

        except KeyboardInterrupt as e:
            print("Storing interrupted checkpoint")
            make_checkpoint(checkpoint_params.output_dir,
                            checkpoint_params.output_model_prefix,
                            "interrupted")
            raise e

        print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
Exemple #3
0
    def _run_train(self, train_net, test_net, codec, train_start_time,
                   progress_bar):
        checkpoint_params = self.checkpoint_params
        validation_dataset = test_net.input_dataset
        iters_per_epoch = max(
            1,
            int(train_net.input_dataset.epoch_size() /
                checkpoint_params.batch_size))

        loss_stats = RunningStatistics(checkpoint_params.stats_size,
                                       checkpoint_params.loss_stats)
        ler_stats = RunningStatistics(checkpoint_params.stats_size,
                                      checkpoint_params.ler_stats)
        dt_stats = RunningStatistics(checkpoint_params.stats_size,
                                     checkpoint_params.dt_stats)

        display = checkpoint_params.display
        display_epochs = display <= 1
        if display <= 0:
            display = 0  # to not display anything
        elif display_epochs:
            display = max(1,
                          int(display * iters_per_epoch))  # relative to epochs
        else:
            display = max(1, int(display))  # iterations

        checkpoint_frequency = checkpoint_params.checkpoint_frequency
        early_stopping_frequency = checkpoint_params.early_stopping_frequency
        if early_stopping_frequency < 0:
            # set early stopping frequency to half epoch
            early_stopping_frequency = int(0.5 * iters_per_epoch)
        elif 0 < early_stopping_frequency <= 1:
            early_stopping_frequency = int(
                early_stopping_frequency *
                iters_per_epoch)  # relative to epochs
        else:
            early_stopping_frequency = int(early_stopping_frequency)
        early_stopping_frequency = max(1, early_stopping_frequency)

        if checkpoint_frequency < 0:
            checkpoint_frequency = early_stopping_frequency
        elif 0 < checkpoint_frequency <= 1:
            checkpoint_frequency = int(checkpoint_frequency *
                                       iters_per_epoch)  # relative to epochs
        else:
            checkpoint_frequency = int(checkpoint_frequency)

        early_stopping_enabled = self.validation_dataset is not None \
                                 and checkpoint_params.early_stopping_frequency > 0 \
                                 and checkpoint_params.early_stopping_nbest > 1
        early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy
        early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest
        early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter

        early_stopping_predictor = Predictor(codec=codec,
                                             text_postproc=self.txt_postproc,
                                             network=test_net)

        # Start the actual training
        # ====================================================================================

        iter = checkpoint_params.iter

        # helper function to write a checkpoint
        def make_checkpoint(base_dir, prefix, version=None):
            if version:
                checkpoint_path = os.path.abspath(
                    os.path.join(base_dir, "{}{}.ckpt".format(prefix,
                                                              version)))
            else:
                checkpoint_path = os.path.abspath(
                    os.path.join(base_dir,
                                 "{}{:08d}.ckpt".format(prefix, iter + 1)))
            print("Storing checkpoint to '{}'".format(checkpoint_path))
            train_net.save_checkpoint(checkpoint_path)
            checkpoint_params.version = Checkpoint.VERSION
            checkpoint_params.iter = iter
            checkpoint_params.loss_stats[:] = loss_stats.values
            checkpoint_params.ler_stats[:] = ler_stats.values
            checkpoint_params.dt_stats[:] = dt_stats.values
            checkpoint_params.total_time = time.time() - train_start_time
            checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy
            checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest
            checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter

            with open(checkpoint_path + ".json", 'w') as f:
                f.write(json_format.MessageToJson(checkpoint_params))

            return checkpoint_path

        try:
            last_checkpoint = None
            n_infinite_losses = 0
            n_max_infinite_losses = 5

            # Training loop, can be interrupted by early stopping
            for iter in range(iter, checkpoint_params.max_iters):
                checkpoint_params.iter = iter

                iter_start_time = time.time()
                result = train_net.train_step()

                if not np.isfinite(result['loss']):
                    n_infinite_losses += 1

                    if n_max_infinite_losses == n_infinite_losses:
                        print(
                            "Error: Loss is not finite! Trying to restart from last checkpoint."
                        )
                        if not last_checkpoint:
                            raise Exception(
                                "No checkpoint written yet. Training must be stopped."
                            )
                        else:
                            # reload also non trainable weights, such as solver-specific variables
                            train_net.load_weights(
                                last_checkpoint, restore_only_trainable=False)
                            continue
                    else:
                        continue

                n_infinite_losses = 0

                loss_stats.push(result['loss'])
                ler_stats.push(result['ler'])

                dt_stats.push(time.time() - iter_start_time)

                if display > 0 and iter % display == 0:
                    # apply postprocessing to display the true output
                    pred_sentence = self.txt_postproc.apply("".join(
                        codec.decode(result["decoded"][0])))
                    gt_sentence = self.txt_postproc.apply("".join(
                        codec.decode(result["gt"][0])))

                    if display_epochs:
                        print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s".
                              format(iter / iters_per_epoch, loss_stats.mean(),
                                     ler_stats.mean(), dt_stats.mean()))
                    else:
                        print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".
                              format(iter, loss_stats.mean(), ler_stats.mean(),
                                     dt_stats.mean()))

                    # Insert utf-8 ltr/rtl direction marks for bidi support
                    lr = "\u202A\u202B"
                    print(" PRED: '{}{}{}'".format(
                        lr[bidi.get_base_level(pred_sentence)], pred_sentence,
                        "\u202C"))
                    print(" TRUE: '{}{}{}'".format(
                        lr[bidi.get_base_level(gt_sentence)], gt_sentence,
                        "\u202C"))

                if checkpoint_frequency > 0 and (
                        iter + 1) % checkpoint_frequency == 0:
                    last_checkpoint = make_checkpoint(
                        checkpoint_params.output_dir,
                        checkpoint_params.output_model_prefix)

                if early_stopping_enabled and (
                        iter + 1) % early_stopping_frequency == 0:
                    print("Checking early stopping model")

                    out_gen = early_stopping_predictor.predict_input_dataset(
                        validation_dataset, progress_bar=progress_bar)
                    result = Evaluator.evaluate_single_list(
                        map(
                            Evaluator.evaluate_single_args,
                            map(
                                lambda d: tuple(
                                    self.txt_preproc.apply([
                                        ''.join(d.ground_truth), d.sentence
                                    ])), out_gen)))
                    accuracy = 1 - result["avg_ler"]

                    if accuracy > early_stopping_best_accuracy:
                        early_stopping_best_accuracy = accuracy
                        early_stopping_best_cur_nbest = 1
                        early_stopping_best_at_iter = iter + 1
                        # overwrite as best model
                        last_checkpoint = make_checkpoint(
                            checkpoint_params.
                            early_stopping_best_model_output_dir,
                            prefix="",
                            version=checkpoint_params.
                            early_stopping_best_model_prefix,
                        )
                        print(
                            "Found better model with accuracy of {:%}".format(
                                early_stopping_best_accuracy))
                    else:
                        early_stopping_best_cur_nbest += 1
                        print(
                            "No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})"
                            .format(
                                early_stopping_best_accuracy,
                                early_stopping_best_at_iter,
                                checkpoint_params.early_stopping_nbest -
                                early_stopping_best_cur_nbest))

                    if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest:
                        print("Early stopping now.")
                        break

                    if accuracy >= 1:
                        print(
                            "Reached perfect score on validation set. Early stopping now."
                        )
                        break

        except KeyboardInterrupt as e:
            print("Storing interrupted checkpoint")
            make_checkpoint(checkpoint_params.output_dir,
                            checkpoint_params.output_model_prefix,
                            "interrupted")
            raise e

        print("Total time {}s for {} iterations.".format(
            time.time() - train_start_time, iter))
Exemple #4
0
class CustomTensorBoard(TensorBoard):
    """
    Custom TensorBoard Logging Class
    Per display freq:
        - training cer on 20 batches of population-wise data
        - validation cer on 20 batches every subpopulation data
    Per epoch:
        - model weights
    """
    def __init__(self,
                 training_callback,
                 codec,
                 train_data_gen,
                 validation_data_gen: Union[tuple, None],
                 predict_func,
                 checkpoint_params,
                 steps_per_epoch,
                 text_post_proc,
                 log_dir='logs',
                 histogram_freq=0,
                 write_graph=True,
                 write_images=False,
                 update_freq='batch',
                 embeddings_freq=0,
                 embeddings_metadata=None,
                 **kwargs):

        super().__init__(log_dir=log_dir,
                         histogram_freq=histogram_freq,
                         write_graph=write_graph,
                         write_images=write_images,
                         update_freq=update_freq,
                         embeddings_freq=embeddings_freq,
                         embeddings_metadata=embeddings_metadata,
                         **kwargs)

        # override default folder structure
        self._train_run_name = ''
        self._validation_run_name = ''

        self.training_callback = training_callback
        self.codec = codec
        self.train_data_gen = train_data_gen
        self.validation_data_gen = validation_data_gen
        self.predict_func = predict_func
        self.checkpoint_params = checkpoint_params
        self.steps_per_epoch = steps_per_epoch
        self.text_post_proc = text_post_proc

        self.loss_stats = RunningStatistics(checkpoint_params.stats_size,
                                            checkpoint_params.loss_stats)
        self.ler_stats = RunningStatistics(checkpoint_params.stats_size,
                                           checkpoint_params.ler_stats)
        self.dt_stats = RunningStatistics(checkpoint_params.stats_size,
                                          checkpoint_params.dt_stats)

        self.val_ler_stats = [
            RunningStatistics(checkpoint_params.stats_size,
                              checkpoint_params.ler_stats)
            for _ in range(len(self.validation_data_gen))
        ]

        display = checkpoint_params.display
        self.display_epochs = display <= 1
        if display <= 0:
            display = 0  # do not display anything
        elif self.display_epochs:
            display = max(1,
                          int(display * steps_per_epoch))  # relative to epochs
        else:
            display = max(1, int(display))  # iterations

        self.display = display
        self.iter_start_time = time.time()
        self.train_start_time = time.time()

    def on_train_begin(self, logs):
        super().on_train_begin(logs)

        if self.histogram_freq:
            self._log_weights(0)

        if self.embeddings_freq:
            self._log_embeddings(0)

        self.iter_start_time = time.time()
        self.train_start_time = time.time()

    def on_train_end(self, logs):
        super().on_train_end(logs)

        self.training_callback.training_finished(
            time.time() - self.train_start_time, self.checkpoint_params.iter)

    def on_train_batch_end(self, batch, logs=None):
        assert self._total_batches_seen == self.checkpoint_params.iter

        self.checkpoint_params.iter += 1

        if self.update_freq == 'epoch' and self._profile_batch is None:
            return

        dt = time.time() - self.iter_start_time
        self.iter_start_time = time.time()
        self.dt_stats.push(dt)
        self.loss_stats.push(logs['loss'])

        logs = logs or {}
        if (self.update_freq != 'epoch' and self.display > 0
                and self.checkpoint_params.iter % self.display == 0):
            cer, target, decoded = self._generate(
                self.train_data_gen,
                20)  # 20 batches for generating training metrics
            self.ler_stats.push(cer)
            pred_sentence = self.text_post_proc.apply("".join(
                self.codec.decode(decoded[0])))
            gt_sentence = self.text_post_proc.apply("".join(
                self.codec.decode(target[0])))
            self._log_metrics({"loss": self.loss_stats.mean()},
                              prefix='training/batch_',
                              step=self.checkpoint_params.iter)
            self._log_metrics({"cer": self.ler_stats.mean()},
                              prefix='training/batch_',
                              step=self.checkpoint_params.iter)
            self._log_metrics({"lr": logs['lr']},
                              prefix='',
                              step=self.checkpoint_params.iter)

            self.training_callback.display(self.ler_stats.mean(),
                                           self.loss_stats.mean(),
                                           self.dt_stats.mean(),
                                           self.checkpoint_params.iter,
                                           self.steps_per_epoch,
                                           self.display_epochs, pred_sentence,
                                           gt_sentence)

            if self.validation_data_gen is not None:
                for i, val_data in enumerate(self.validation_data_gen):
                    val_data_name = dataregistry.get_name(i)
                    val_cer, _, _ = self._generate(
                        val_data,
                        20)  # 20 batches for generating training metrics
                    self.val_ler_stats[i].push(val_cer)
                    self._log_metrics(
                        {"cer": self.val_ler_stats[i].mean()},
                        prefix=f'{val_data_name}/validation_batch_',
                        step=self.checkpoint_params.iter)

        self._total_batches_seen += 1

        if context.executing_eagerly():
            if self._is_tracing:
                self._log_trace()
            elif (not self._is_tracing and math_ops.equal(
                    self.checkpoint_params.iter, self._profile_batch - 1)):
                self._enable_trace()

    def on_epoch_end(self, epoch, logs=None):
        self._log_metrics(logs, prefix='epoch_', step=epoch)

        if self.histogram_freq and epoch % self.histogram_freq == 0:
            self._log_weights(epoch)

        if self.embeddings_freq and epoch % self.embeddings_freq == 0:
            self._log_embeddings(epoch)

        if self.update_freq == 'epoch':
            train_cer, _, _ = self._generate(
                self.train_data_gen,
                20)  # 20 batches for generating training metrics
            self.ler_stats.push(train_cer)
            self._log_metrics({"cer": self.ler_stats.mean()},
                              prefix='training/batch_',
                              step=epoch)

            if self.validation_data_gen is not None:
                for i, val_data in enumerate(self.validation_data_gen):
                    val_data_name = dataregistry.get_name(i)
                    val_cer, _, _ = self._generate(
                        val_data,
                        20)  # 20 batches for generating training metrics
                    self.val_ler_stats[i].push(val_cer)
                    self._log_metrics(
                        {"cer": self.val_ler_stats[i].mean()},
                        prefix=f'{val_data_name}/validation_batch_',
                        step=self.checkpoint_params.iter)

    def _generate(self, data_gen, count):
        if data_gen is None:
            pass
        else:
            it = iter(data_gen)
            cer, target, decoded = zip(
                *[self.predict_func(next(it)) for _ in range(count)])
            return np.mean(cer), sum(map(sparse_to_lists, target),
                                     []), sum(map(sparse_to_lists, decoded),
                                              [])
Exemple #5
0
    def _run_train(self, train_net, test_net, codec, train_start_time, progress_bar):
        checkpoint_params = self.checkpoint_params
        validation_dataset = test_net.input_dataset
        iters_per_epoch = max(1, int(len(train_net.input_dataset) / checkpoint_params.batch_size))

        loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats)
        ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats)
        dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats)

        display = checkpoint_params.display
        display_epochs = display <= 1
        if display <= 0:
            display = 0                                       # to not display anything
        elif display_epochs:
            display = max(1, int(display * iters_per_epoch))  # relative to epochs
        else:
            display = max(1, int(display))                    # iterations

        checkpoint_frequency = checkpoint_params.checkpoint_frequency
        early_stopping_frequency = checkpoint_params.early_stopping_frequency
        if early_stopping_frequency < 0:
            # set early stopping frequency to half epoch
            early_stopping_frequency = int(0.5 * iters_per_epoch)
        elif 0 < early_stopping_frequency <= 1:
            early_stopping_frequency = int(early_stopping_frequency * iters_per_epoch)  # relative to epochs
        else:
            early_stopping_frequency = int(early_stopping_frequency)

        if checkpoint_frequency < 0:
            checkpoint_frequency = early_stopping_frequency
        elif 0 < checkpoint_frequency <= 1:
            checkpoint_frequency = int(checkpoint_frequency * iters_per_epoch)  # relative to epochs
        else:
            checkpoint_frequency = int(checkpoint_frequency)

        early_stopping_enabled = self.validation_dataset is not None \
                                 and checkpoint_params.early_stopping_frequency > 0 \
                                 and checkpoint_params.early_stopping_nbest > 1
        early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy
        early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest
        early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter

        early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc,
                                             network=test_net)

        # Start the actual training
        # ====================================================================================

        iter = checkpoint_params.iter

        # helper function to write a checkpoint
        def make_checkpoint(base_dir, prefix, version=None):
            if version:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version)))
            else:
                checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1)))
            print("Storing checkpoint to '{}'".format(checkpoint_path))
            train_net.save_checkpoint(checkpoint_path)
            checkpoint_params.version = Checkpoint.VERSION
            checkpoint_params.iter = iter
            checkpoint_params.loss_stats[:] = loss_stats.values
            checkpoint_params.ler_stats[:] = ler_stats.values
            checkpoint_params.dt_stats[:] = dt_stats.values
            checkpoint_params.total_time = time.time() - train_start_time
            checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy
            checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest
            checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter

            with open(checkpoint_path + ".json", 'w') as f:
                f.write(json_format.MessageToJson(checkpoint_params))

            return checkpoint_path

        try:
            last_checkpoint = None
            n_infinite_losses = 0
            n_max_infinite_losses = 5

            # Training loop, can be interrupted by early stopping
            for iter in range(iter, checkpoint_params.max_iters):
                checkpoint_params.iter = iter

                iter_start_time = time.time()
                result = train_net.train_step()

                if not np.isfinite(result['loss']):
                    n_infinite_losses += 1

                    if n_max_infinite_losses == n_infinite_losses:
                        print("Error: Loss is not finite! Trying to restart from last checkpoint.")
                        if not last_checkpoint:
                            raise Exception("No checkpoint written yet. Training must be stopped.")
                        else:
                            # reload also non trainable weights, such as solver-specific variables
                            train_net.load_weights(last_checkpoint, restore_only_trainable=False)
                            continue
                    else:
                        continue

                n_infinite_losses = 0

                loss_stats.push(result['loss'])
                ler_stats.push(result['ler'])

                dt_stats.push(time.time() - iter_start_time)

                if display > 0 and iter % display == 0:
                    # apply postprocessing to display the true output
                    pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0])))
                    gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0])))

                    if display_epochs:
                        print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(
                            iter / iters_per_epoch, loss_stats.mean(), ler_stats.mean(), dt_stats.mean()))
                    else:
                        print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(
                            iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean()))

                    # Insert utf-8 ltr/rtl direction marks for bidi support
                    lr = "\u202A\u202B"
                    print(" PRED: '{}{}{}'".format(lr[bidi.get_base_level(pred_sentence)], pred_sentence, "\u202C"))
                    print(" TRUE: '{}{}{}'".format(lr[bidi.get_base_level(gt_sentence)], gt_sentence, "\u202C"))

                if checkpoint_frequency > 0 and (iter + 1) % checkpoint_frequency == 0:
                    last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix)

                if early_stopping_enabled and (iter + 1) % early_stopping_frequency == 0:
                    print("Checking early stopping model")

                    out_gen = early_stopping_predictor.predict_input_dataset(validation_dataset,
                                                                             progress_bar=progress_bar)
                    result = Evaluator.evaluate_single_list(map(
                        Evaluator.evaluate_single_args,
                        map(lambda d: tuple(self.txt_preproc.apply([''.join(d.ground_truth), d.sentence])), out_gen)))
                    accuracy = 1 - result["avg_ler"]

                    if accuracy > early_stopping_best_accuracy:
                        early_stopping_best_accuracy = accuracy
                        early_stopping_best_cur_nbest = 1
                        early_stopping_best_at_iter = iter + 1
                        # overwrite as best model
                        last_checkpoint = make_checkpoint(
                            checkpoint_params.early_stopping_best_model_output_dir,
                            prefix="",
                            version=checkpoint_params.early_stopping_best_model_prefix,
                        )
                        print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy))
                    else:
                        early_stopping_best_cur_nbest += 1
                        print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})".
                              format(early_stopping_best_accuracy, early_stopping_best_at_iter,
                                     checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest))

                    if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest:
                        print("Early stopping now.")
                        break

                    if accuracy >= 1:
                        print("Reached perfect score on validation set. Early stopping now.")
                        break

        except KeyboardInterrupt as e:
            print("Storing interrupted checkpoint")
            make_checkpoint(checkpoint_params.output_dir,
                            checkpoint_params.output_model_prefix,
                            "interrupted")
            raise e

        print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
Exemple #6
0
class VisCallback(keras.callbacks.Callback):
    def __init__(self, codec, data_gen, predict_func, checkpoint_params,
                 steps_per_epoch, text_post_proc):
        self.codec = codec
        self.data_gen = data_gen
        self.predict_func = predict_func
        self.checkpoint_params = checkpoint_params
        self.steps_per_epoch = steps_per_epoch
        self.text_post_proc = text_post_proc

        self.loss_stats = RunningStatistics(checkpoint_params.stats_size,
                                            checkpoint_params.loss_stats)
        self.ler_stats = RunningStatistics(checkpoint_params.stats_size,
                                           checkpoint_params.ler_stats)
        self.dt_stats = RunningStatistics(checkpoint_params.stats_size,
                                          checkpoint_params.dt_stats)

        display = checkpoint_params.display
        self.display_epochs = display <= 1
        if display <= 0:
            display = 0  # do not display anything
        elif self.display_epochs:
            display = max(1,
                          int(display * steps_per_epoch))  # relative to epochs
        else:
            display = max(1, int(display))  # iterations

        self.display = display
        self.iter_start_time = time.time()
        self.train_start_time = time.time()

    def on_train_begin(self, logs):
        self.iter_start_time = time.time()
        self.train_start_time = time.time()

    def on_train_end(self, logs):
        print("Total training time {}s for {} iterations.".format(
            time.time() - self.train_start_time, self.checkpoint_params.iter))

    def on_batch_end(self, batch, logs):
        dt = time.time() - self.iter_start_time
        self.iter_start_time = time.time()
        self.dt_stats.push(dt)
        self.loss_stats.push(logs['loss'])
        self.checkpoint_params.iter += 1

        if self.display > 0 and self.checkpoint_params.iter % self.display == 0:
            # apply postprocessing to display the true output
            cer, target, decoded = self._generate(1)
            self.ler_stats.push(cer)
            pred_sentence = self.text_post_proc.apply("".join(
                self.codec.decode(decoded[0])))
            gt_sentence = self.text_post_proc.apply("".join(
                self.codec.decode(target[0])))

            if self.display_epochs:
                print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(
                    self.checkpoint_params.iter / self.steps_per_epoch,
                    self.loss_stats.mean(), self.ler_stats.mean(),
                    self.dt_stats.mean()))
            else:
                print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(
                    self.checkpoint_params.iter, self.loss_stats.mean(),
                    self.ler_stats.mean(), self.dt_stats.mean()))

            # Insert utf-8 ltr/rtl direction marks for bidi support
            lr = "\u202A\u202B"
            print("  PRED: '{}{}{}'".format(
                lr[bidi.get_base_level(pred_sentence)], pred_sentence,
                "\u202C"))
            print("  TRUE: '{}{}{}'".format(
                lr[bidi.get_base_level(gt_sentence)], gt_sentence, "\u202C"))

    def on_epoch_end(self, epoch, logs):
        pass

    def _generate(self, count):
        it = iter(self.data_gen)
        cer, target, decoded = zip(
            *[self.predict_func(next(it)) for _ in range(count)])
        return np.mean(cer), sum(map(sparse_to_lists, target),
                                 []), sum(map(sparse_to_lists, decoded), [])