Exemplo n.º 1
0
def input_fn_builder(input_files, is_training, num_cpu_threads=4):
    """Creates an `input_fn` closure to be passed to TPUEstimator."""
    num_blocks_per_example, block_length = input_utils.get_block_params_from_input_file(
        input_files[0])
    max_num_annotations = input_utils.get_num_annotations_from_input_file(
        input_files[0])
    logging.info("***** Building Input pipeline *****")
    logging.info("  Number of blocks per example = %d", num_blocks_per_example)
    logging.info("  Block length = %d", block_length)
    logging.info("  Number of anntotations per block = %d",
                 max_num_annotations)

    def input_fn(params):
        """The actual input function."""
        logging.info("*** Input: Params ***")
        for name in sorted(params.keys()):
            logging.info("  %s = %s", name, params[name])

        # For training, we want a lot of parallel reading and shuffling.
        # For eval, we want no shuffling and parallel reading doesn't matter.
        if is_training:
            d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))

            # From https://www.tensorflow.org/guide/data#randomly_shuffling_input_data
            # Dataset.shuffle doesn't signal the end of an epoch until the shuffle
            # buffer is empty. So a shuffle placed before a repeat will show every
            # element of one epoch before moving to the next.
            d = d.shuffle(buffer_size=len(input_files))
            d = d.repeat()

            # `cycle_length` is the number of parallel files that get read.
            cycle_length = min(num_cpu_threads, len(input_files))

            # `sloppy` mode means that the interleaving is not exact. This adds
            # even more randomness to the training pipeline.
            d = d.apply(
                tf.data.experimental.parallel_interleave(
                    tf.data.TFRecordDataset,
                    sloppy=is_training,
                    cycle_length=cycle_length))
            d = d.shuffle(buffer_size=100)
        else:
            d = tf.data.TFRecordDataset(input_files)

        d = d.map(input_utils.get_span_prediction_example_decode_fn(
            num_blocks_per_example,
            block_length,
            max_num_answer_annotations=max_num_annotations,
            max_num_entity_annotations=max_num_annotations),
                  num_parallel_calls=tf.data.experimental.AUTOTUNE)

        d = d.prefetch(tf.data.experimental.AUTOTUNE)
        return d

    return input_fn
Exemplo n.º 2
0
def main(_):
  logging.set_verbosity(logging.INFO)

  validate_flags()
  tf.io.gfile.makedirs(FLAGS.output_dir)

  for flag in FLAGS.flags_by_module_dict()[sys.argv[0]]:
    logging.info("  %s = %s", flag.name, flag.value)

  model_config = config.get_model_config(
      model_dir=FLAGS.output_dir,
      source_file=FLAGS.read_it_twice_bert_config_file,
      source_base64=FLAGS.read_it_twice_bert_config_base64,
      write_from_source=FLAGS.do_train)


  if FLAGS.checkpoint is not None:
    assert not FLAGS.do_train
    assert FLAGS.do_eval

  if FLAGS.cross_attention_top_k is not None:
    model_config = dataclasses.replace(
        model_config, cross_attention_top_k=FLAGS.cross_attention_top_k)

  input_files = []
  for input_pattern in FLAGS.input_file.split(","):
    input_files.extend(tf.io.gfile.glob(input_pattern))

  logging.info("*** Input Files ***")
  for input_file in input_files:
    logging.info("  %s", input_file)

  num_blocks_per_example, block_length = input_utils.get_block_params_from_input_file(
      input_files[0])

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  # Basically, quoting the answer above
  # PER_HOST_V1: iterator.get_next() is called 1 time with per_worker_batch_size
  # PER_HOST_V2: iterator.get_next() is called 8 times with per_core_batch_size
  # pylint: enable=line-too-long
  is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      # Keep all checkpoints
      keep_checkpoint_max=None,
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          tpu_job_name=FLAGS.tpu_job_name,
          per_host_input_for_training=is_per_host,
          experimental_host_call_every_n_steps=FLAGS.steps_per_summary))

  # TODO(urikz): Is there a better way to compute the number of tasks?
  # the code below doesn't work because `tpu_cluster_resolver.cluster_spec()`
  # returns None. Therefore, I have to pass number of total tasks via CLI arg.
  # num_tpu_tasks = tpu_cluster_resolver.cluster_spec().num_tasks()
  batch_size = (FLAGS.num_tpu_tasks or 1) * num_blocks_per_example

  num_train_examples = input_utils.get_num_examples_in_tf_records(input_files)
  num_train_steps = int(num_train_examples * FLAGS.num_train_epochs)
  num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

  logging.info("***** Input configuration *****")
  logging.info("  Number of blocks per example = %d", num_blocks_per_example)
  logging.info("  Block length = %d", block_length)
  logging.info("  Number of TPU tasks = %d", FLAGS.num_tpu_tasks or 1)
  logging.info("  Batch size = %d", batch_size)
  logging.info("  Number of TPU cores = %d", FLAGS.num_tpu_cores or 0)
  logging.info("  Number training steps = %d", num_train_steps)
  logging.info("  Number warmup steps = %d", num_warmup_steps)

  model_fn = model_fn_builder(
      model_config=model_config,
      padding_token_id=FLAGS.padding_token_id,
      enable_side_inputs=FLAGS.enable_side_inputs,
      num_replicas_concat=FLAGS.num_tpu_cores,
      cross_block_attention_mode=FLAGS.cross_block_attention_mode,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
      optimizer=FLAGS.optimizer,
      poly_power=FLAGS.poly_power,
      start_warmup_step=FLAGS.start_warmup_step,
      learning_rate_schedule=FLAGS.learning_rate_schedule,
      nbest_logits_for_eval=FLAGS.decode_top_k)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.estimator.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      predict_batch_size=batch_size)

  training_done_path = os.path.join(FLAGS.output_dir, "training_done")

  if FLAGS.do_train:
    logging.info("***** Running training *****")
    train_input_fn = input_fn_builder(input_files=input_files, is_training=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
    # Write file to signal training is done.
    with tf.gfile.GFile(training_done_path, "w") as writer:
      writer.write("\n")

  if FLAGS.do_eval:
    logging.info("***** Running evaluation *****")
    eval_input_fn = input_fn_builder(input_files=input_files, is_training=False)
    question_ids, ground_truth = read_question_answer_json(FLAGS.input_json)
    tokenizer = tokenization.FullTokenizer(FLAGS.spm_model_path)
    logging.info("Loaded SentencePiece model from %s", FLAGS.spm_model_path)

    # Writer for TensorBoard.
    summary_writer = tf.summary.FileWriter(
        os.path.join(FLAGS.output_dir, "eval_metrics"))

    if not FLAGS.checkpoint:
      # for checkpoint_path in _get_all_checkpoints(FLAGS.output_dir):
      checkpoint_iter = tf.train.checkpoints_iterator(
          FLAGS.output_dir, min_interval_secs=5 * 60, timeout=8 * 60 * 60)
    else:
      checkpoint_iter = [FLAGS.checkpoint]

    for checkpoint_path in checkpoint_iter:
      start_time = time.time()
      global_step = _get_global_step_for_checkpoint(checkpoint_path)
      if global_step == 0:
        continue
      logging.info("Starting eval on step %d on checkpoint: %s", global_step,
                   checkpoint_path)
      try:
        nbest_predictions = collections.OrderedDict()
        yesno_logits, yesno_labels = {}, {}
        supporting_fact_logits, supporting_fact_labels = {}, {}

        for prediction in estimator.predict(
            eval_input_fn,
            checkpoint_path=checkpoint_path,
            yield_single_examples=True):
          block_id = prediction["block_ids"]
          if block_id == 0:
            # Padding document
            continue
          question_id = question_ids[block_id]
          if question_id not in nbest_predictions:
            nbest_predictions[question_id] = {}
            yesno_logits[question_id] = []
            yesno_labels[question_id] = []
            supporting_fact_logits[question_id] = []
            supporting_fact_labels[question_id] = []

          yesno_logits[question_id].append(prediction["yesno_logits"].tolist())
          yesno_labels[question_id].append(prediction["answer_type"].tolist())
          supporting_fact_logits[question_id].append(
              prediction["supporting_fact_logits"].tolist())
          supporting_fact_labels[question_id].append(
              prediction["is_supporting_fact"].tolist())

          token_ids = prediction["token_ids"]
          for begin_index, begin_logit in zip(
              prediction["begin_logits_indices"],
              prediction["begin_logits_values"]):
            for end_index, end_logit in zip(prediction["end_logits_indices"],
                                            prediction["end_logits_values"]):
              if begin_index > end_index or end_index - begin_index + 1 > FLAGS.decode_max_size:
                continue
              answer = "".join(
                  tokenizer.convert_ids_to_tokens([
                      int(token_id)
                      for token_id in token_ids[begin_index:end_index + 1]
                  ]))

              answer = answer.replace(tokenization.SPIECE_UNDERLINE,
                                      " ").strip()
              if not answer:
                continue
              normalized_answer = evaluation.normalize_answer(answer)
              if normalized_answer not in nbest_predictions[question_id]:
                nbest_predictions[question_id][normalized_answer] = []
              nbest_predictions[question_id][normalized_answer].append(
                  begin_logit + end_logit)
      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info("Checkpoint %s no longer exists, skipping checkpoint",
                        checkpoint_path)
        continue

      nbest_predictions_probs = _convert_prediction_logits_to_probs(
          nbest_predictions)

      best_predictions_max = _get_best_predictions(nbest_predictions_probs, max)
      for question_id in yesno_logits:
        if question_id in best_predictions_max:
          span_answer = best_predictions_max[question_id]
        else:
          span_answer = None
        best_predictions_max[question_id] = {
            "yesno_logits": yesno_logits[question_id],
            "yesno_labels": yesno_labels[question_id],
            "supporting_fact_logits": supporting_fact_logits[question_id],
            "supporting_fact_labels": supporting_fact_labels[question_id],
        }
        if span_answer is not None:
          best_predictions_max[question_id]["span_answer"] = span_answer

      with tf.gfile.GFile(checkpoint_path + ".best_predictions_max.json",
                          "w") as f:
        json.dump(best_predictions_max, f, indent=2)

      best_predictions_max_results = evaluation.make_predictions_and_eval(
          ground_truth, best_predictions_max)
      write_eval_results(global_step, best_predictions_max_results, "max",
                         summary_writer)

      if tf.io.gfile.exists(training_done_path):
        # Break if the checkpoint we just processed is the last one.
        last_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        if last_checkpoint is None:
          continue
        last_global_step = _get_global_step_for_checkpoint(last_checkpoint)
        if global_step == last_global_step:
          break

      global_step = _get_global_step_for_checkpoint(checkpoint_path)
      logging.info("Finished eval on step %d in %d seconds", global_step,
                   time.time() - start_time)
def main(_):
  logging.set_verbosity(logging.INFO)

  validate_flags()

  tf.io.gfile.makedirs(FLAGS.output_dir)

  model_config = config.get_model_config(
      model_dir=FLAGS.output_dir,
      source_file=FLAGS.source_model_config_file,
      source_base64=FLAGS.source_model_config_base64,
      write_from_source=FLAGS.do_train)

  input_files = []
  for input_pattern in FLAGS.input_file.split(","):
    input_files.extend(tf.io.gfile.glob(input_pattern))

  num_blocks_per_example, block_length = input_utils.get_block_params_from_input_file(
      input_files[0])
  max_num_annotations = None
  if FLAGS.mlm_entity_fraction_to_mask is not None:
    max_num_annotations = input_utils.get_num_annotations_from_input_file(
        input_files[0])

  # TODO(urikz): Define `pretrain_input_utils`.
  input_config = pretrain_input_utils.PretrainInputConfig(  # pylint: disable=undefined-variable
      num_blocks_per_example=num_blocks_per_example,
      block_length=block_length,
      mlm_fraction_to_mask=FLAGS.mlm_fraction_to_mask,
      mlm_max_consecutive_masks=FLAGS.mlm_max_consecutive_masks,
      mlm_use_whole_word=FLAGS.mlm_use_whole_word,
      mask_token_id=FLAGS.mask_token_id,
      padding_token_id=FLAGS.padding_token_id,
      max_num_annotations=max_num_annotations,
      mlm_entity_fraction_to_mask=FLAGS.mlm_entity_fraction_to_mask,
      mention_mask_mode=FLAGS.mention_mask_mode)

  logging.info("*** Input Files ***")
  for input_file in input_files:
    logging.info("  %s", input_file)

  tpu_cluster_resolver, num_tpu_tasks, num_tpu_cores = None, None, None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    tpu_system_metadata = tpu_cluster_resolver.get_tpu_system_metadata()
    num_tpu_tasks = tpu_cluster_resolver.get_tpu_system_metadata().num_hosts
    num_tpu_cores = tpu_system_metadata.num_cores

  # PER_HOST_V1: iterator.get_next() is called 1 time with per_worker_batch_size
  # PER_HOST_V2: iterator.get_next() is called 8 times with per_core_batch_size
  # pylint: enable=line-too-long
  is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          tpu_job_name=FLAGS.tpu_job_name,
          per_host_input_for_training=is_per_host,
          experimental_host_call_every_n_steps=FLAGS.iterations_per_loop))

  batch_size = (num_tpu_tasks or 1) * num_blocks_per_example
  logging.info("***** Input configuration *****")
  logging.info("  Number of blocks per example = %d", num_blocks_per_example)
  logging.info("  Block length = %d", block_length)
  logging.info("  Number of TPU tasks = %d", num_tpu_tasks or 0)
  logging.info("  Batch size = %d", batch_size)
  logging.info("  Number of TPU cores = %d", num_tpu_cores or 0)
  logging.info("  Number of annotations per example = %d",
               input_config.max_num_annotations or 0)

  model_fn = model_fn_builder(
      model_config=model_config,
      padding_token_id=FLAGS.padding_token_id,
      enable_side_inputs=FLAGS.enable_side_inputs,
      num_replicas_concat=num_tpu_cores,
      cross_block_attention_mode=FLAGS.cross_block_attention_mode,
      extra_loss=parse_extra_loss_flag(FLAGS.extra_loss),
      summary_num_layers=FLAGS.summary_num_layers,
      summary_num_cross_attention_heads=FLAGS.summary_num_cross_attention_heads,
      summary_enable_default_side_input=FLAGS.summary_enable_default_side_input,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=FLAGS.num_train_steps,
      num_warmup_steps=FLAGS.num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
      optimizer=FLAGS.optimizer,
      poly_power=FLAGS.poly_power,
      start_warmup_step=FLAGS.start_warmup_step,
      learning_rate_schedule=FLAGS.learning_rate_schedule,
      metrics_name=FLAGS.metrics_name)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.estimator.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=batch_size,
      eval_batch_size=batch_size)

  if FLAGS.do_train:
    logging.info("***** Running training *****")
    train_input_fn = input_fn_builder(
        input_files=input_files,
        input_config=input_config,
        model_config=model_config,
        is_training=True)
    estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)

  if FLAGS.do_eval:
    logging.info("***** Running evaluation *****")
    num_eval_examples = input_utils.get_num_examples_in_tf_records(input_files)
    eval_steps_per_epoch = num_eval_examples
    max_eval_steps = eval_steps_per_epoch * FLAGS.num_eval_epochs
    logging.info("  Number of eval examples = %d", num_eval_examples)
    logging.info("  Number of TPU tasks = %d", num_tpu_tasks or 1)
    logging.info("  Number of eval steps per epoch = %d", eval_steps_per_epoch)
    logging.info("  Eval steps = %d", max_eval_steps)

    eval_input_fn = input_fn_builder(
        input_files=input_files,
        input_config=input_config,
        model_config=model_config,
        is_training=False)

    # Run evaluation for each new checkpoint.
    for ckpt in tf.train.checkpoints_iterator(FLAGS.output_dir):
      logging.info("Starting eval on new checkpoint: %s", ckpt)
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = estimator.evaluate(
            input_fn=eval_input_fn, steps=max_eval_steps, checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        logging.info("Eval results: %s. Elapsed seconds: %d", eval_results,
                     elapsed_time)

        # Terminate eval job when final checkpoint is reached.
        current_step = int(os.path.basename(ckpt).split("-")[1])
        if current_step >= FLAGS.num_train_steps:
          logging.info("Evaluation finished after training step %d",
                       current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        logging.info("Checkpoint %s no longer exists, skipping checkpoint",
                     ckpt)
Exemplo n.º 4
0
def main(_):
    logging.set_verbosity(logging.INFO)

    for flag in FLAGS.flags_by_module_dict()[sys.argv[0]]:
        logging.info("  %s = %s", flag.name, flag.value)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    tf.io.gfile.makedirs(FLAGS.output_dir)

    model_config = config.get_model_config(
        model_dir=FLAGS.output_dir,
        source_file=FLAGS.read_it_twice_bert_config_file,
        source_base64=None,
        write_from_source=False)

    if FLAGS.cross_attention_top_k is not None:
        model_config = dataclasses.replace(
            model_config, cross_attention_top_k=FLAGS.cross_attention_top_k)

    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))

    logging.info("*** Input Files ***")
    for input_file in input_files:
        logging.info("  %s", input_file)

    num_blocks_per_example, block_length = input_utils.get_block_params_from_input_file(
        input_files[0])

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    # PER_HOST_V1: iterator.get_next() is called 1 time with per_worker_batch_size
    # PER_HOST_V2: iterator.get_next() is called 8 times with per_core_batch_size
    is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            tpu_job_name=FLAGS.tpu_job_name,
            per_host_input_for_training=is_per_host,
            experimental_host_call_every_n_steps=FLAGS.steps_per_summary))

    # TODO(urikz): Is there a better way to compute the number of tasks?
    # the code below doesn't work because `tpu_cluster_resolver.cluster_spec()`
    # returns None. Therefore, I have to pass number of total tasks via CLI arg.
    # num_tpu_tasks = tpu_cluster_resolver.cluster_spec().num_tasks()
    batch_size = (FLAGS.num_tpu_tasks or 1) * num_blocks_per_example

    num_train_examples = input_utils.get_num_examples_in_tf_records(
        input_files)
    num_train_steps = int(num_train_examples * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    logging.info("***** Input configuration *****")
    logging.info("  Number of blocks per example = %d", num_blocks_per_example)
    logging.info("  Block length = %d", block_length)
    logging.info("  Number of TPU tasks = %d", FLAGS.num_tpu_tasks or 1)
    logging.info("  Batch size = %d", batch_size)
    logging.info("  Number of TPU cores = %d", FLAGS.num_tpu_cores or 0)
    logging.info("  Number training steps = %d", num_train_steps)
    logging.info("  Number warmup steps = %d", num_warmup_steps)

    model_fn = model_fn_builder(
        model_config=model_config,
        padding_token_id=FLAGS.padding_token_id,
        enable_side_inputs=FLAGS.enable_side_inputs,
        num_replicas_concat=FLAGS.num_tpu_cores,
        cross_block_attention_mode=FLAGS.cross_block_attention_mode,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
        optimizer=FLAGS.optimizer,
        poly_power=FLAGS.poly_power,
        start_warmup_step=FLAGS.start_warmup_step,
        learning_rate_schedule=FLAGS.learning_rate_schedule,
        nbest_logits_for_eval=FLAGS.decode_top_k)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.estimator.tpu.TPUEstimator(use_tpu=FLAGS.use_tpu,
                                              model_fn=model_fn,
                                              config=run_config,
                                              train_batch_size=batch_size,
                                              eval_batch_size=batch_size,
                                              predict_batch_size=batch_size)

    training_done_path = os.path.join(FLAGS.output_dir, "training_done")

    if FLAGS.do_train:
        logging.info("***** Running training *****")
        train_input_fn = input_fn_builder(input_files=input_files,
                                          is_training=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
        # Write file to signal training is done.
        with tf.gfile.GFile(training_done_path, "w") as writer:
            writer.write("\n")

    if FLAGS.do_eval:
        logging.info("***** Running evaluation *****")
        eval_input_fn = input_fn_builder(input_files=input_files,
                                         is_training=False)

        if FLAGS.eval_json_path is None:
            raise ValueError("Must specify `eval_json_path` for eval.")

        if FLAGS.eval_data_split == "test":
            do_eval = False
        elif FLAGS.eval_data_split == "valid":
            do_eval = True
        else:
            raise Exception("Unknown mode: " + FLAGS.eval_data_split)

        with tf.io.gfile.GFile(FLAGS.eval_json_path) as f:
            data = json.load(f)["Data"]
        if do_eval:
            ground_truth = {
                datum["QuestionId"]: datum["Answer"]
                for datum in data
            }
        # We skip the first question ID as it corresponds on a padding document.
        question_ids = [None] + [datum["QuestionId"] for datum in data]

        if do_eval:
            logging.info("Loaded %d questions for evaluation from %s",
                         len(ground_truth), FLAGS.eval_json_path)

        tokenizer = tokenization.FullTokenizer(FLAGS.spm_model_path)
        logging.info("Loaded SentencePiece model from %s",
                     FLAGS.spm_model_path)

        # Writer for TensorBoard.
        summary_writer = tf.summary.FileWriter(
            os.path.join(FLAGS.output_dir, "eval_metrics"))

        if _is_checkpoint_in_the_model_dir():
            logging.info("FINAL EVALUATION: checkpoint = %s, data split = %s",
                         FLAGS.init_checkpoint, FLAGS.eval_data_split)
            checkpoints = [FLAGS.init_checkpoint]
        else:
            # for checkpoint_path in _get_all_checkpoints(FLAGS.output_dir):
            checkpoints = tf.train.checkpoints_iterator(FLAGS.output_dir,
                                                        min_interval_secs=5 *
                                                        60,
                                                        timeout=8 * 60 * 60)

        for checkpoint_path in checkpoints:
            start_time = time.time()
            global_step = _get_global_step_for_checkpoint(checkpoint_path)
            logging.info("Starting eval on step %d on checkpoint: %s",
                         global_step, checkpoint_path)

            nbest_predictions = collections.OrderedDict()

            try:
                for prediction_index, prediction in enumerate(
                        estimator.predict(eval_input_fn,
                                          checkpoint_path=checkpoint_path,
                                          yield_single_examples=True)):
                    if prediction_index % 100000 == 0:
                        logging.info("Processing example: %d",
                                     prediction_index)
                    block_id = prediction["block_ids"]
                    if block_id == 0:
                        # Padding document
                        continue
                    question_id = question_ids[block_id]
                    if question_id not in nbest_predictions:
                        nbest_predictions[question_id] = {}
                    token_ids = prediction["token_ids"]
                    for begin_index, begin_logit in zip(
                            prediction["begin_logits_indices"],
                            prediction["begin_logits_values"]):
                        for end_index, end_logit in zip(
                                prediction["end_logits_indices"],
                                prediction["end_logits_values"]):
                            if begin_index > end_index or end_index - begin_index + 1 > FLAGS.decode_max_size:
                                continue
                            answer = "".join(
                                tokenizer.convert_ids_to_tokens([
                                    int(token_id) for token_id in
                                    token_ids[begin_index:end_index + 1]
                                ]))
                            answer = evaluation.normalize_answer(
                                answer.replace(tokenization.SPIECE_UNDERLINE,
                                               " "))
                            if not answer:
                                continue
                            if answer not in nbest_predictions[question_id]:
                                nbest_predictions[question_id][answer] = []
                            nbest_predictions[question_id][answer].append(
                                begin_logit + end_logit)

                nbest_predictions_probs = _convert_prediction_logits_to_probs(
                    nbest_predictions)

                best_predictions_max = _get_best_predictions(
                    nbest_predictions_probs, max)
                with tf.gfile.GFile(
                        checkpoint_path + ".%s.best_predictions_max.json" %
                        FLAGS.eval_data_split, "w") as f:
                    json.dump(best_predictions_max, f, indent=2)

                if do_eval:
                    best_predictions_max_results = evaluation.evaluate_triviaqa(
                        ground_truth, best_predictions_max, mute=False)
                    write_eval_results(global_step,
                                       best_predictions_max_results, "max",
                                       summary_writer)

                    summary_writer.flush()
            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    "Checkpoint %s no longer exists, skipping checkpoint",
                    checkpoint_path)
                continue

            if tf.io.gfile.exists(training_done_path):
                # Break if the checkpoint we just processed is the last one.
                last_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
                if last_checkpoint is None:
                    continue
                last_global_step = _get_global_step_for_checkpoint(
                    last_checkpoint)
                if global_step == last_global_step:
                    break

            global_step = _get_global_step_for_checkpoint(checkpoint_path)
            logging.info("Finished eval on step %d in %d seconds", global_step,
                         time.time() - start_time)