Example #1
0
def train(params, strategy, dataset=None):
  """Runs training."""

  if not dataset:
    dataset = input_pipeline.get_input_dataset(
        FLAGS.train_file_pattern,
        FLAGS.train_batch_size,
        params,
        is_training=True,
        strategy=strategy)

  with strategy.scope():
    model = models.create_model(
        FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint)
    opt = optimizer.create_optimizer(params)
    trainer = Trainer(model, params)

    trainer.compile(
        optimizer=opt,
        steps_per_execution=FLAGS.steps_per_loop)
    summary_dir = os.path.join(FLAGS.model_dir, "summaries")
    summary_callback = tf.keras.callbacks.TensorBoard(
        summary_dir, update_freq=max(100, FLAGS.steps_per_loop))
    checkpoint = tf.train.Checkpoint(
        model=model, optimizer=opt, global_step=opt.iterations)
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=FLAGS.model_dir,
        max_to_keep=10,
        step_counter=opt.iterations,
        checkpoint_interval=FLAGS.checkpoint_interval)
    if checkpoint_manager.restore_or_initialize():
      logging.info("Training restored from the checkpoints in: %s",
                   FLAGS.model_dir)
    checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)

  # Trains the model.
  steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval)
  epochs = FLAGS.train_steps // steps_per_epoch
  history = trainer.fit(
      x=dataset,
      steps_per_epoch=steps_per_epoch,
      epochs=epochs,
      callbacks=[summary_callback, checkpoint_callback],
      verbose=2)
  train_hist = history.history
  # Gets final loss from training.
  stats = dict(training_loss=float(train_hist["training_loss"][-1]))
  return stats
Example #2
0
def continuous_eval(strategy,
                    params,
                    model_type,
                    eval_file_pattern=None,
                    batch_size=4,
                    eval_steps=None,
                    model_dir=None,
                    timeout=3000):
  """Continuously evaluate checkpoints on testing data."""
  test_dataset = input_pipeline.get_input_dataset(
      eval_file_pattern,
      batch_size=batch_size,
      params=params,
      is_training=False,
      strategy=strategy)

  with strategy.scope():
    model = models.create_model(model_type, params)
    metric_layer = metrics_v2.MetricLayer(params.vocab_size)
    eval_summary_writer = tf.summary.create_file_writer(
        os.path.join(model_dir, "summaries/eval"))
    global_step = tf.Variable(
        0,
        trainable=False,
        dtype=tf.int64,
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
        shape=[])

  @tf.function
  def test_step(inputs):
    """Calculates evaluation metrics on distributed devices."""

    def _test_step_fn(inputs):
      """Replicated accuracy calculation."""
      targets = models.remove_sos_from_seq(inputs["target_ids"],
                                           params.pad_token_id)

      # Using ground truth sequences as targets to calculate logits for accuracy
      # and perplexity metrics.
      logits, _, _ = model(inputs, training=False, mode="train")
      metric_layer([logits, targets])

      # Get logits from top beam search results for bleu and rouge metrics.
      logits = model(inputs, training=False, mode="eval")

      return targets, logits

    outputs = strategy.run(_test_step_fn, args=(inputs,))

    return tf.nest.map_structure(strategy.experimental_local_results, outputs)

  metrics_and_funcs = [
      (tf.keras.metrics.Mean("bleu", dtype=tf.float32), bleu_score),
      (tf.keras.metrics.Mean("rouge_2_fscore",
                             dtype=tf.float32), rouge_2_fscore),
      (tf.keras.metrics.Mean("rouge_l_fscore",
                             dtype=tf.float32), rouge_l_fscore),
  ]
  eval_results = {}
  for latest_checkpoint in tf.train.checkpoints_iterator(
      model_dir, timeout=timeout):
    checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
    checkpoint.restore(latest_checkpoint).expect_partial()
    logging.info("Loaded checkpoint %s", latest_checkpoint)

    for i, inputs in enumerate(test_dataset):
      if eval_steps and i >= eval_steps:
        break
      outputs = test_step(inputs)
      for metric, func in metrics_and_funcs:
        for targets, logits in zip(outputs[0], outputs[1]):
          metric.update_state(func(logits.numpy(), targets.numpy()))

    with eval_summary_writer.as_default():
      step = global_step.numpy()
      for metric, _ in metrics_and_funcs:
        eval_results[metric.name] = metric.result().numpy().astype(float)
        tf.summary.scalar(
            metric.name,
            eval_results[metric.name],
            step=step)
      for metric in metric_layer.metrics:
        eval_results[metric.name] = metric.result().numpy().astype(float)
        tf.summary.scalar(
            metric.name,
            eval_results[metric.name],
            step=step)
      logging.info("Step %d Metrics= %s", step, str(eval_results))
      eval_summary_writer.flush()

    # Resets metrics.
    for metric, _ in metrics_and_funcs:
      metric.reset_states()
    for metric in metric_layer.metrics:
      metric.reset_states()
  return eval_results