def train(params, strategy, dataset=None): """Runs training.""" if not dataset: dataset = input_pipeline.get_input_dataset( FLAGS.train_file_pattern, FLAGS.train_batch_size, params, is_training=True, strategy=strategy) with strategy.scope(): model = models.create_model( FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint) opt = optimizer.create_optimizer(params) trainer = Trainer(model, params) trainer.compile( optimizer=opt, steps_per_execution=FLAGS.steps_per_loop) summary_dir = os.path.join(FLAGS.model_dir, "summaries") summary_callback = tf.keras.callbacks.TensorBoard( summary_dir, update_freq=max(100, FLAGS.steps_per_loop)) checkpoint = tf.train.Checkpoint( model=model, optimizer=opt, global_step=opt.iterations) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=FLAGS.model_dir, max_to_keep=10, step_counter=opt.iterations, checkpoint_interval=FLAGS.checkpoint_interval) if checkpoint_manager.restore_or_initialize(): logging.info("Training restored from the checkpoints in: %s", FLAGS.model_dir) checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager) # Trains the model. steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval) epochs = FLAGS.train_steps // steps_per_epoch history = trainer.fit( x=dataset, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[summary_callback, checkpoint_callback], verbose=2) train_hist = history.history # Gets final loss from training. stats = dict(training_loss=float(train_hist["training_loss"][-1])) return stats
def continuous_eval(strategy, params, model_type, eval_file_pattern=None, batch_size=4, eval_steps=None, model_dir=None, timeout=3000): """Continuously evaluate checkpoints on testing data.""" test_dataset = input_pipeline.get_input_dataset( eval_file_pattern, batch_size=batch_size, params=params, is_training=False, strategy=strategy) with strategy.scope(): model = models.create_model(model_type, params) metric_layer = metrics_v2.MetricLayer(params.vocab_size) eval_summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, "summaries/eval")) global_step = tf.Variable( 0, trainable=False, dtype=tf.int64, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, shape=[]) @tf.function def test_step(inputs): """Calculates evaluation metrics on distributed devices.""" def _test_step_fn(inputs): """Replicated accuracy calculation.""" targets = models.remove_sos_from_seq(inputs["target_ids"], params.pad_token_id) # Using ground truth sequences as targets to calculate logits for accuracy # and perplexity metrics. logits, _, _ = model(inputs, training=False, mode="train") metric_layer([logits, targets]) # Get logits from top beam search results for bleu and rouge metrics. logits = model(inputs, training=False, mode="eval") return targets, logits outputs = strategy.run(_test_step_fn, args=(inputs,)) return tf.nest.map_structure(strategy.experimental_local_results, outputs) metrics_and_funcs = [ (tf.keras.metrics.Mean("bleu", dtype=tf.float32), bleu_score), (tf.keras.metrics.Mean("rouge_2_fscore", dtype=tf.float32), rouge_2_fscore), (tf.keras.metrics.Mean("rouge_l_fscore", dtype=tf.float32), rouge_l_fscore), ] eval_results = {} for latest_checkpoint in tf.train.checkpoints_iterator( model_dir, timeout=timeout): checkpoint = tf.train.Checkpoint(model=model, global_step=global_step) checkpoint.restore(latest_checkpoint).expect_partial() logging.info("Loaded checkpoint %s", latest_checkpoint) for i, inputs in enumerate(test_dataset): if eval_steps and i >= eval_steps: break outputs = test_step(inputs) for metric, func in metrics_and_funcs: for targets, logits in zip(outputs[0], outputs[1]): metric.update_state(func(logits.numpy(), targets.numpy())) with eval_summary_writer.as_default(): step = global_step.numpy() for metric, _ in metrics_and_funcs: eval_results[metric.name] = metric.result().numpy().astype(float) tf.summary.scalar( metric.name, eval_results[metric.name], step=step) for metric in metric_layer.metrics: eval_results[metric.name] = metric.result().numpy().astype(float) tf.summary.scalar( metric.name, eval_results[metric.name], step=step) logging.info("Step %d Metrics= %s", step, str(eval_results)) eval_summary_writer.flush() # Resets metrics. for metric, _ in metrics_and_funcs: metric.reset_states() for metric in metric_layer.metrics: metric.reset_states() return eval_results