class VisCallback(keras.callbacks.Callback): def __init__(self, training_callback, codec, data_gen, predict_func, checkpoint_params, steps_per_epoch, text_post_proc): self.training_callback = training_callback self.codec = codec self.data_gen = data_gen self.predict_func = predict_func self.checkpoint_params = checkpoint_params self.steps_per_epoch = steps_per_epoch self.text_post_proc = text_post_proc self.loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) self.ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) self.dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) display = checkpoint_params.display self.display_epochs = display <= 1 if display <= 0: display = 0 # do not display anything elif self.display_epochs: display = max(1, int(display * steps_per_epoch)) # relative to epochs else: display = max(1, int(display)) # iterations self.display = display self.iter_start_time = time.time() self.train_start_time = time.time() def on_train_begin(self, logs): self.iter_start_time = time.time() self.train_start_time = time.time() def on_train_end(self, logs): self.training_callback.training_finished( time.time() - self.train_start_time, self.checkpoint_params.iter) def on_batch_end(self, batch, logs): dt = time.time() - self.iter_start_time self.iter_start_time = time.time() self.dt_stats.push(dt) self.loss_stats.push(logs['loss']) self.checkpoint_params.iter += 1 if self.display > 0 and self.checkpoint_params.iter % self.display == 0: # apply postprocessing to display the true output cer, target, decoded = self._generate(1) self.ler_stats.push(cer) pred_sentence = self.text_post_proc.apply("".join( self.codec.decode(decoded[0]))) gt_sentence = self.text_post_proc.apply("".join( self.codec.decode(target[0]))) self.training_callback.display(self.ler_stats.mean(), self.loss_stats.mean(), self.dt_stats.mean(), self.checkpoint_params.iter, self.steps_per_epoch, self.display_epochs, pred_sentence, gt_sentence) def on_epoch_end(self, epoch, logs): pass def _generate(self, count): it = iter(self.data_gen) cer, target, decoded = zip( *[self.predict_func(next(it)) for _ in range(count)]) return np.mean(cer), sum(map(sparse_to_lists, target), []), sum(map(sparse_to_lists, decoded), [])
def train(self, progress_bar=False): checkpoint_params = self.checkpoint_params train_start_time = time.time() + self.checkpoint_params.total_time self.dataset.load_samples(processes=1, progress_bar=progress_bar) datas, txts = self.dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt) if len(datas) == 0: raise Exception("Empty dataset is not allowed. Check if the data is at the correct location") if self.validation_dataset: self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar) validation_datas, validation_txts = self.validation_dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt) if len(validation_datas) == 0: raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.") else: validation_datas, validation_txts = [], [] # preprocessing steps texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar) datas = self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar) validation_txts = self.txt_preproc.apply(validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar) validation_datas = self.data_preproc.apply(validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar) # compute the codec codec = self.codec if self.codec else Codec.from_texts(texts, whitelist=self.codec_whitelist) # data augmentation on preprocessed data if self.data_augmenter: datas, texts = self.data_augmenter.augment_datas(datas, texts, n_augmentations=self.n_augmentations, processes=checkpoint_params.processes, progress_bar=progress_bar) # TODO: validation data augmentation # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0, # processes=checkpoint_params.processes, progress_bar=progress_bar) # create backend network_params = checkpoint_params.model.network network_params.features = checkpoint_params.model.line_height network_params.classes = len(codec) if self.weights: # if we load the weights, take care of codec changes as-well with open(self.weights + '.json', 'r') as f: restore_checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) restore_model_params = restore_checkpoint_params.model # checks if checkpoint_params.model.line_height != network_params.features: raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format( network_params.features, checkpoint_params.model.line_height )) # create codec of the same type restore_codec = codec.__class__(restore_model_params.codec.charset) # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one codec_changes = restore_codec.align(codec) codec = restore_codec print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1]))) # The actual weight/bias matrix will be changed after loading the old weights else: codec_changes = None # store the new codec checkpoint_params.model.codec.charset[:] = codec.charset print("CODEC: {}".format(codec.charset)) # compute the labels with (new/current) codec labels = [codec.encode(txt) for txt in texts] backend = create_backend_from_proto(network_params, weights=self.weights, ) backend.set_train_data(datas, labels) backend.set_prediction_data(validation_datas) if codec_changes: backend.realign_model_labels(*codec_changes) backend.prepare(train=True) loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) early_stopping_enabled = self.validation_dataset is not None \ and checkpoint_params.early_stopping_frequency > 0 \ and checkpoint_params.early_stopping_nbest > 1 early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc, backend=backend) # Start the actual training # ==================================================================================== iter = checkpoint_params.iter # helper function to write a checkpoint def make_checkpoint(base_dir, prefix, version=None): if version: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version))) else: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1))) print("Storing checkpoint to '{}'".format(checkpoint_path)) backend.save_checkpoint(checkpoint_path) checkpoint_params.iter = iter checkpoint_params.loss_stats[:] = loss_stats.values checkpoint_params.ler_stats[:] = ler_stats.values checkpoint_params.dt_stats[:] = dt_stats.values checkpoint_params.total_time = time.time() - train_start_time checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter with open(checkpoint_path + ".json", 'w') as f: f.write(json_format.MessageToJson(checkpoint_params)) return checkpoint_path try: last_checkpoint = None # Training loop, can be interrupted by early stopping for iter in range(iter, checkpoint_params.max_iters): checkpoint_params.iter = iter iter_start_time = time.time() result = backend.train_step(checkpoint_params.batch_size) if not np.isfinite(result['loss']): print("Error: Loss is not finite! Trying to restart from last checkpoint.") if not last_checkpoint: raise Exception("No checkpoint written yet. Training must be stopped.") else: # reload also non trainable weights, such as solver-specific variables backend.load_checkpoint_weights(last_checkpoint, restore_only_trainable=False) continue loss_stats.push(result['loss']) ler_stats.push(result['ler']) dt_stats.push(time.time() - iter_start_time) if iter % checkpoint_params.display == 0: pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0]))) gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0]))) print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) print(" PRED: '{}'".format(pred_sentence)) print(" TRUE: '{}'".format(gt_sentence)) if (iter + 1) % checkpoint_params.checkpoint_frequency == 0: last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix) if early_stopping_enabled and (iter + 1) % checkpoint_params.early_stopping_frequency == 0: print("Checking early stopping model") out = early_stopping_predictor.predict_raw(validation_datas, batch_size=checkpoint_params.batch_size, progress_bar=progress_bar, apply_preproc=False) pred_texts = [d.sentence for d in out] result = Evaluator.evaluate(gt_data=validation_txts, pred_data=pred_texts, progress_bar=progress_bar) accuracy = 1 - result["avg_ler"] if accuracy > early_stopping_best_accuracy: early_stopping_best_accuracy = accuracy early_stopping_best_cur_nbest = 1 early_stopping_best_at_iter = iter + 1 # overwrite as best model last_checkpoint = make_checkpoint( checkpoint_params.early_stopping_best_model_output_dir, prefix="", version=checkpoint_params.early_stopping_best_model_prefix, ) print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy)) else: early_stopping_best_cur_nbest += 1 print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})". format(early_stopping_best_accuracy, early_stopping_best_at_iter, checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest)) if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest: print("Early stopping now.") break except KeyboardInterrupt as e: print("Storing interrupted checkpoint") make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix, "interrupted") raise e print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
def _run_train(self, train_net, test_net, codec, train_start_time, progress_bar): checkpoint_params = self.checkpoint_params validation_dataset = test_net.input_dataset iters_per_epoch = max( 1, int(train_net.input_dataset.epoch_size() / checkpoint_params.batch_size)) loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) display = checkpoint_params.display display_epochs = display <= 1 if display <= 0: display = 0 # to not display anything elif display_epochs: display = max(1, int(display * iters_per_epoch)) # relative to epochs else: display = max(1, int(display)) # iterations checkpoint_frequency = checkpoint_params.checkpoint_frequency early_stopping_frequency = checkpoint_params.early_stopping_frequency if early_stopping_frequency < 0: # set early stopping frequency to half epoch early_stopping_frequency = int(0.5 * iters_per_epoch) elif 0 < early_stopping_frequency <= 1: early_stopping_frequency = int( early_stopping_frequency * iters_per_epoch) # relative to epochs else: early_stopping_frequency = int(early_stopping_frequency) early_stopping_frequency = max(1, early_stopping_frequency) if checkpoint_frequency < 0: checkpoint_frequency = early_stopping_frequency elif 0 < checkpoint_frequency <= 1: checkpoint_frequency = int(checkpoint_frequency * iters_per_epoch) # relative to epochs else: checkpoint_frequency = int(checkpoint_frequency) early_stopping_enabled = self.validation_dataset is not None \ and checkpoint_params.early_stopping_frequency > 0 \ and checkpoint_params.early_stopping_nbest > 1 early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc, network=test_net) # Start the actual training # ==================================================================================== iter = checkpoint_params.iter # helper function to write a checkpoint def make_checkpoint(base_dir, prefix, version=None): if version: checkpoint_path = os.path.abspath( os.path.join(base_dir, "{}{}.ckpt".format(prefix, version))) else: checkpoint_path = os.path.abspath( os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1))) print("Storing checkpoint to '{}'".format(checkpoint_path)) train_net.save_checkpoint(checkpoint_path) checkpoint_params.version = Checkpoint.VERSION checkpoint_params.iter = iter checkpoint_params.loss_stats[:] = loss_stats.values checkpoint_params.ler_stats[:] = ler_stats.values checkpoint_params.dt_stats[:] = dt_stats.values checkpoint_params.total_time = time.time() - train_start_time checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter with open(checkpoint_path + ".json", 'w') as f: f.write(json_format.MessageToJson(checkpoint_params)) return checkpoint_path try: last_checkpoint = None n_infinite_losses = 0 n_max_infinite_losses = 5 # Training loop, can be interrupted by early stopping for iter in range(iter, checkpoint_params.max_iters): checkpoint_params.iter = iter iter_start_time = time.time() result = train_net.train_step() if not np.isfinite(result['loss']): n_infinite_losses += 1 if n_max_infinite_losses == n_infinite_losses: print( "Error: Loss is not finite! Trying to restart from last checkpoint." ) if not last_checkpoint: raise Exception( "No checkpoint written yet. Training must be stopped." ) else: # reload also non trainable weights, such as solver-specific variables train_net.load_weights( last_checkpoint, restore_only_trainable=False) continue else: continue n_infinite_losses = 0 loss_stats.push(result['loss']) ler_stats.push(result['ler']) dt_stats.push(time.time() - iter_start_time) if display > 0 and iter % display == 0: # apply postprocessing to display the true output pred_sentence = self.txt_postproc.apply("".join( codec.decode(result["decoded"][0]))) gt_sentence = self.txt_postproc.apply("".join( codec.decode(result["gt"][0]))) if display_epochs: print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s". format(iter / iters_per_epoch, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) else: print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s". format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) # Insert utf-8 ltr/rtl direction marks for bidi support lr = "\u202A\u202B" print(" PRED: '{}{}{}'".format( lr[bidi.get_base_level(pred_sentence)], pred_sentence, "\u202C")) print(" TRUE: '{}{}{}'".format( lr[bidi.get_base_level(gt_sentence)], gt_sentence, "\u202C")) if checkpoint_frequency > 0 and ( iter + 1) % checkpoint_frequency == 0: last_checkpoint = make_checkpoint( checkpoint_params.output_dir, checkpoint_params.output_model_prefix) if early_stopping_enabled and ( iter + 1) % early_stopping_frequency == 0: print("Checking early stopping model") out_gen = early_stopping_predictor.predict_input_dataset( validation_dataset, progress_bar=progress_bar) result = Evaluator.evaluate_single_list( map( Evaluator.evaluate_single_args, map( lambda d: tuple( self.txt_preproc.apply([ ''.join(d.ground_truth), d.sentence ])), out_gen))) accuracy = 1 - result["avg_ler"] if accuracy > early_stopping_best_accuracy: early_stopping_best_accuracy = accuracy early_stopping_best_cur_nbest = 1 early_stopping_best_at_iter = iter + 1 # overwrite as best model last_checkpoint = make_checkpoint( checkpoint_params. early_stopping_best_model_output_dir, prefix="", version=checkpoint_params. early_stopping_best_model_prefix, ) print( "Found better model with accuracy of {:%}".format( early_stopping_best_accuracy)) else: early_stopping_best_cur_nbest += 1 print( "No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})" .format( early_stopping_best_accuracy, early_stopping_best_at_iter, checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest)) if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest: print("Early stopping now.") break if accuracy >= 1: print( "Reached perfect score on validation set. Early stopping now." ) break except KeyboardInterrupt as e: print("Storing interrupted checkpoint") make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix, "interrupted") raise e print("Total time {}s for {} iterations.".format( time.time() - train_start_time, iter))
class CustomTensorBoard(TensorBoard): """ Custom TensorBoard Logging Class Per display freq: - training cer on 20 batches of population-wise data - validation cer on 20 batches every subpopulation data Per epoch: - model weights """ def __init__(self, training_callback, codec, train_data_gen, validation_data_gen: Union[tuple, None], predict_func, checkpoint_params, steps_per_epoch, text_post_proc, log_dir='logs', histogram_freq=0, write_graph=True, write_images=False, update_freq='batch', embeddings_freq=0, embeddings_metadata=None, **kwargs): super().__init__(log_dir=log_dir, histogram_freq=histogram_freq, write_graph=write_graph, write_images=write_images, update_freq=update_freq, embeddings_freq=embeddings_freq, embeddings_metadata=embeddings_metadata, **kwargs) # override default folder structure self._train_run_name = '' self._validation_run_name = '' self.training_callback = training_callback self.codec = codec self.train_data_gen = train_data_gen self.validation_data_gen = validation_data_gen self.predict_func = predict_func self.checkpoint_params = checkpoint_params self.steps_per_epoch = steps_per_epoch self.text_post_proc = text_post_proc self.loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) self.ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) self.dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) self.val_ler_stats = [ RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) for _ in range(len(self.validation_data_gen)) ] display = checkpoint_params.display self.display_epochs = display <= 1 if display <= 0: display = 0 # do not display anything elif self.display_epochs: display = max(1, int(display * steps_per_epoch)) # relative to epochs else: display = max(1, int(display)) # iterations self.display = display self.iter_start_time = time.time() self.train_start_time = time.time() def on_train_begin(self, logs): super().on_train_begin(logs) if self.histogram_freq: self._log_weights(0) if self.embeddings_freq: self._log_embeddings(0) self.iter_start_time = time.time() self.train_start_time = time.time() def on_train_end(self, logs): super().on_train_end(logs) self.training_callback.training_finished( time.time() - self.train_start_time, self.checkpoint_params.iter) def on_train_batch_end(self, batch, logs=None): assert self._total_batches_seen == self.checkpoint_params.iter self.checkpoint_params.iter += 1 if self.update_freq == 'epoch' and self._profile_batch is None: return dt = time.time() - self.iter_start_time self.iter_start_time = time.time() self.dt_stats.push(dt) self.loss_stats.push(logs['loss']) logs = logs or {} if (self.update_freq != 'epoch' and self.display > 0 and self.checkpoint_params.iter % self.display == 0): cer, target, decoded = self._generate( self.train_data_gen, 20) # 20 batches for generating training metrics self.ler_stats.push(cer) pred_sentence = self.text_post_proc.apply("".join( self.codec.decode(decoded[0]))) gt_sentence = self.text_post_proc.apply("".join( self.codec.decode(target[0]))) self._log_metrics({"loss": self.loss_stats.mean()}, prefix='training/batch_', step=self.checkpoint_params.iter) self._log_metrics({"cer": self.ler_stats.mean()}, prefix='training/batch_', step=self.checkpoint_params.iter) self._log_metrics({"lr": logs['lr']}, prefix='', step=self.checkpoint_params.iter) self.training_callback.display(self.ler_stats.mean(), self.loss_stats.mean(), self.dt_stats.mean(), self.checkpoint_params.iter, self.steps_per_epoch, self.display_epochs, pred_sentence, gt_sentence) if self.validation_data_gen is not None: for i, val_data in enumerate(self.validation_data_gen): val_data_name = dataregistry.get_name(i) val_cer, _, _ = self._generate( val_data, 20) # 20 batches for generating training metrics self.val_ler_stats[i].push(val_cer) self._log_metrics( {"cer": self.val_ler_stats[i].mean()}, prefix=f'{val_data_name}/validation_batch_', step=self.checkpoint_params.iter) self._total_batches_seen += 1 if context.executing_eagerly(): if self._is_tracing: self._log_trace() elif (not self._is_tracing and math_ops.equal( self.checkpoint_params.iter, self._profile_batch - 1)): self._enable_trace() def on_epoch_end(self, epoch, logs=None): self._log_metrics(logs, prefix='epoch_', step=epoch) if self.histogram_freq and epoch % self.histogram_freq == 0: self._log_weights(epoch) if self.embeddings_freq and epoch % self.embeddings_freq == 0: self._log_embeddings(epoch) if self.update_freq == 'epoch': train_cer, _, _ = self._generate( self.train_data_gen, 20) # 20 batches for generating training metrics self.ler_stats.push(train_cer) self._log_metrics({"cer": self.ler_stats.mean()}, prefix='training/batch_', step=epoch) if self.validation_data_gen is not None: for i, val_data in enumerate(self.validation_data_gen): val_data_name = dataregistry.get_name(i) val_cer, _, _ = self._generate( val_data, 20) # 20 batches for generating training metrics self.val_ler_stats[i].push(val_cer) self._log_metrics( {"cer": self.val_ler_stats[i].mean()}, prefix=f'{val_data_name}/validation_batch_', step=self.checkpoint_params.iter) def _generate(self, data_gen, count): if data_gen is None: pass else: it = iter(data_gen) cer, target, decoded = zip( *[self.predict_func(next(it)) for _ in range(count)]) return np.mean(cer), sum(map(sparse_to_lists, target), []), sum(map(sparse_to_lists, decoded), [])
def _run_train(self, train_net, test_net, codec, train_start_time, progress_bar): checkpoint_params = self.checkpoint_params validation_dataset = test_net.input_dataset iters_per_epoch = max(1, int(len(train_net.input_dataset) / checkpoint_params.batch_size)) loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) display = checkpoint_params.display display_epochs = display <= 1 if display <= 0: display = 0 # to not display anything elif display_epochs: display = max(1, int(display * iters_per_epoch)) # relative to epochs else: display = max(1, int(display)) # iterations checkpoint_frequency = checkpoint_params.checkpoint_frequency early_stopping_frequency = checkpoint_params.early_stopping_frequency if early_stopping_frequency < 0: # set early stopping frequency to half epoch early_stopping_frequency = int(0.5 * iters_per_epoch) elif 0 < early_stopping_frequency <= 1: early_stopping_frequency = int(early_stopping_frequency * iters_per_epoch) # relative to epochs else: early_stopping_frequency = int(early_stopping_frequency) if checkpoint_frequency < 0: checkpoint_frequency = early_stopping_frequency elif 0 < checkpoint_frequency <= 1: checkpoint_frequency = int(checkpoint_frequency * iters_per_epoch) # relative to epochs else: checkpoint_frequency = int(checkpoint_frequency) early_stopping_enabled = self.validation_dataset is not None \ and checkpoint_params.early_stopping_frequency > 0 \ and checkpoint_params.early_stopping_nbest > 1 early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc, network=test_net) # Start the actual training # ==================================================================================== iter = checkpoint_params.iter # helper function to write a checkpoint def make_checkpoint(base_dir, prefix, version=None): if version: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version))) else: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1))) print("Storing checkpoint to '{}'".format(checkpoint_path)) train_net.save_checkpoint(checkpoint_path) checkpoint_params.version = Checkpoint.VERSION checkpoint_params.iter = iter checkpoint_params.loss_stats[:] = loss_stats.values checkpoint_params.ler_stats[:] = ler_stats.values checkpoint_params.dt_stats[:] = dt_stats.values checkpoint_params.total_time = time.time() - train_start_time checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter with open(checkpoint_path + ".json", 'w') as f: f.write(json_format.MessageToJson(checkpoint_params)) return checkpoint_path try: last_checkpoint = None n_infinite_losses = 0 n_max_infinite_losses = 5 # Training loop, can be interrupted by early stopping for iter in range(iter, checkpoint_params.max_iters): checkpoint_params.iter = iter iter_start_time = time.time() result = train_net.train_step() if not np.isfinite(result['loss']): n_infinite_losses += 1 if n_max_infinite_losses == n_infinite_losses: print("Error: Loss is not finite! Trying to restart from last checkpoint.") if not last_checkpoint: raise Exception("No checkpoint written yet. Training must be stopped.") else: # reload also non trainable weights, such as solver-specific variables train_net.load_weights(last_checkpoint, restore_only_trainable=False) continue else: continue n_infinite_losses = 0 loss_stats.push(result['loss']) ler_stats.push(result['ler']) dt_stats.push(time.time() - iter_start_time) if display > 0 and iter % display == 0: # apply postprocessing to display the true output pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0]))) gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0]))) if display_epochs: print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s".format( iter / iters_per_epoch, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) else: print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format( iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) # Insert utf-8 ltr/rtl direction marks for bidi support lr = "\u202A\u202B" print(" PRED: '{}{}{}'".format(lr[bidi.get_base_level(pred_sentence)], pred_sentence, "\u202C")) print(" TRUE: '{}{}{}'".format(lr[bidi.get_base_level(gt_sentence)], gt_sentence, "\u202C")) if checkpoint_frequency > 0 and (iter + 1) % checkpoint_frequency == 0: last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix) if early_stopping_enabled and (iter + 1) % early_stopping_frequency == 0: print("Checking early stopping model") out_gen = early_stopping_predictor.predict_input_dataset(validation_dataset, progress_bar=progress_bar) result = Evaluator.evaluate_single_list(map( Evaluator.evaluate_single_args, map(lambda d: tuple(self.txt_preproc.apply([''.join(d.ground_truth), d.sentence])), out_gen))) accuracy = 1 - result["avg_ler"] if accuracy > early_stopping_best_accuracy: early_stopping_best_accuracy = accuracy early_stopping_best_cur_nbest = 1 early_stopping_best_at_iter = iter + 1 # overwrite as best model last_checkpoint = make_checkpoint( checkpoint_params.early_stopping_best_model_output_dir, prefix="", version=checkpoint_params.early_stopping_best_model_prefix, ) print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy)) else: early_stopping_best_cur_nbest += 1 print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})". format(early_stopping_best_accuracy, early_stopping_best_at_iter, checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest)) if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest: print("Early stopping now.") break if accuracy >= 1: print("Reached perfect score on validation set. Early stopping now.") break except KeyboardInterrupt as e: print("Storing interrupted checkpoint") make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix, "interrupted") raise e print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
class VisCallback(keras.callbacks.Callback): def __init__(self, codec, data_gen, predict_func, checkpoint_params, steps_per_epoch, text_post_proc): self.codec = codec self.data_gen = data_gen self.predict_func = predict_func self.checkpoint_params = checkpoint_params self.steps_per_epoch = steps_per_epoch self.text_post_proc = text_post_proc self.loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) self.ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) self.dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) display = checkpoint_params.display self.display_epochs = display <= 1 if display <= 0: display = 0 # do not display anything elif self.display_epochs: display = max(1, int(display * steps_per_epoch)) # relative to epochs else: display = max(1, int(display)) # iterations self.display = display self.iter_start_time = time.time() self.train_start_time = time.time() def on_train_begin(self, logs): self.iter_start_time = time.time() self.train_start_time = time.time() def on_train_end(self, logs): print("Total training time {}s for {} iterations.".format( time.time() - self.train_start_time, self.checkpoint_params.iter)) def on_batch_end(self, batch, logs): dt = time.time() - self.iter_start_time self.iter_start_time = time.time() self.dt_stats.push(dt) self.loss_stats.push(logs['loss']) self.checkpoint_params.iter += 1 if self.display > 0 and self.checkpoint_params.iter % self.display == 0: # apply postprocessing to display the true output cer, target, decoded = self._generate(1) self.ler_stats.push(cer) pred_sentence = self.text_post_proc.apply("".join( self.codec.decode(decoded[0]))) gt_sentence = self.text_post_proc.apply("".join( self.codec.decode(target[0]))) if self.display_epochs: print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s".format( self.checkpoint_params.iter / self.steps_per_epoch, self.loss_stats.mean(), self.ler_stats.mean(), self.dt_stats.mean())) else: print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format( self.checkpoint_params.iter, self.loss_stats.mean(), self.ler_stats.mean(), self.dt_stats.mean())) # Insert utf-8 ltr/rtl direction marks for bidi support lr = "\u202A\u202B" print(" PRED: '{}{}{}'".format( lr[bidi.get_base_level(pred_sentence)], pred_sentence, "\u202C")) print(" TRUE: '{}{}{}'".format( lr[bidi.get_base_level(gt_sentence)], gt_sentence, "\u202C")) def on_epoch_end(self, epoch, logs): pass def _generate(self, count): it = iter(self.data_gen) cer, target, decoded = zip( *[self.predict_func(next(it)) for _ in range(count)]) return np.mean(cer), sum(map(sparse_to_lists, target), []), sum(map(sparse_to_lists, decoded), [])