def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.horovod:
        hvd.init()
    if FLAGS.use_fp16:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "xnli": XnliProcessor,
    }

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True."
        )

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    master_process = True
    training_hooks = []
    global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
    hvd_rank = 0

    config = tf.ConfigProto()
    if FLAGS.horovod:

        tf.logging.info("Multi-GPU training with TF Horovod")
        tf.logging.info("hvd.size() = %d hvd.rank() = %d", hvd.size(),
                        hvd.rank())
        global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size(
        )
        master_process = (hvd.rank() == 0)
        hvd_rank = hvd.rank()
        config.gpu_options.allow_growth = True
        config.gpu_options.visible_device_list = str(hvd.local_rank())
        if hvd.size() > 1:
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
    if FLAGS.use_xla:
        config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.output_dir if master_process else None,
        session_config=config,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps
        if master_process else None,
        keep_checkpoint_max=1)

    if master_process:
        tf.logging.info("***** Configuaration *****")
        for key in FLAGS.__flags.keys():
            tf.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
        tf.logging.info("**************************")

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank))

    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

        start_index = 0
        end_index = len(train_examples)
        tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]

        if FLAGS.horovod:
            tmp_filenames = [
                os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i))
                for i in range(hvd.size())
            ]
            num_examples_per_rank = len(train_examples) // hvd.size()
            remainder = len(train_examples) % hvd.size()
            if hvd.rank() < remainder:
                start_index = hvd.rank() * (num_examples_per_rank + 1)
                end_index = start_index + num_examples_per_rank + 1
            else:
                start_index = hvd.rank() * num_examples_per_rank + remainder
                end_index = start_index + (num_examples_per_rank)

    model_fn = model_fn_builder(task_name=task_name,
                                bert_config=bert_config,
                                num_labels=len(label_list),
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate
                                if not FLAGS.horovod else FLAGS.learning_rate *
                                hvd.size(),
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                use_one_hot_embeddings=False,
                                hvd=None if not FLAGS.horovod else hvd)

    estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

    if FLAGS.do_train:

        file_based_convert_examples_to_features(
            train_examples[start_index:end_index], label_list,
            FLAGS.max_seq_length, tokenizer, tmp_filenames[hvd_rank])

        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=tmp_filenames,
            batch_size=FLAGS.train_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True,
            hvd=None if not FLAGS.horovod else hvd)

        train_start_time = time.time()
        estimator.train(input_fn=train_input_fn,
                        max_steps=num_train_steps,
                        hooks=training_hooks)
        train_time_elapsed = time.time() - train_start_time
        train_time_wo_overhead = training_hooks[-1].total_time
        avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
        ss_sentences_per_second = (
            num_train_steps - training_hooks[-1].skipped
        ) * global_batch_size * 1.0 / train_time_wo_overhead

        if master_process:
            tf.logging.info("-----------------------------")
            tf.logging.info("Total Training Time = %0.2f for Sentences = %d",
                            train_time_elapsed,
                            num_train_steps * global_batch_size)
            tf.logging.info(
                "Total Training Time W/O Overhead = %0.2f for Sentences = %d",
                train_time_wo_overhead,
                (num_train_steps - training_hooks[-1].skipped) *
                global_batch_size)
            tf.logging.info(
                "Throughput Average (sentences/sec) with overhead = %0.2f",
                avg_sentences_per_second)
            tf.logging.info("Throughput Average (sentences/sec) = %0.2f",
                            ss_sentences_per_second)
            tf.logging.info("-----------------------------")

    if FLAGS.do_eval and master_process:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        file_based_convert_examples_to_features(eval_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, eval_file)

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d", len(eval_examples))
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        eval_drop_remainder = False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            batch_size=FLAGS.eval_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)

        eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
        eval_start_time = time.time()
        result = estimator.evaluate(input_fn=eval_input_fn, hooks=eval_hooks)

        eval_time_elapsed = time.time() - eval_start_time
        eval_time_wo_overhead = eval_hooks[-1].total_time

        time_list = eval_hooks[-1].time_list
        time_list.sort()
        num_sentences = (eval_hooks[-1].count -
                         eval_hooks[-1].skipped) * FLAGS.eval_batch_size

        avg = np.mean(time_list)
        cf_50 = max(time_list[:int(len(time_list) * 0.50)])
        cf_90 = max(time_list[:int(len(time_list) * 0.90)])
        cf_95 = max(time_list[:int(len(time_list) * 0.95)])
        cf_99 = max(time_list[:int(len(time_list) * 0.99)])
        cf_100 = max(time_list[:int(len(time_list) * 1)])
        ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

        tf.logging.info("-----------------------------")
        tf.logging.info("Total Inference Time = %0.2f for Sentences = %d",
                        eval_time_elapsed,
                        eval_hooks[-1].count * FLAGS.eval_batch_size)
        tf.logging.info(
            "Total Inference Time W/O Overhead = %0.2f for Sentences = %d",
            eval_time_wo_overhead,
            (eval_hooks[-1].count - eval_hooks[-1].skipped) *
            FLAGS.eval_batch_size)
        tf.logging.info("Summary Inference Statistics on EVAL set")
        tf.logging.info("Batch size = %d", FLAGS.eval_batch_size)
        tf.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
        tf.logging.info("Precision = %s", "fp16" if FLAGS.use_fp16 else "fp32")
        tf.logging.info("Latency Confidence Level 50 (ms) = %0.2f",
                        cf_50 * 1000)
        tf.logging.info("Latency Confidence Level 90 (ms) = %0.2f",
                        cf_90 * 1000)
        tf.logging.info("Latency Confidence Level 95 (ms) = %0.2f",
                        cf_95 * 1000)
        tf.logging.info("Latency Confidence Level 99 (ms) = %0.2f",
                        cf_99 * 1000)
        tf.logging.info("Latency Confidence Level 100 (ms) = %0.2f",
                        cf_100 * 1000)
        tf.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
        tf.logging.info("Throughput Average (sentences/sec) = %0.2f",
                        ss_sentences_per_second)
        tf.logging.info("-----------------------------")

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if FLAGS.do_predict and master_process:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, predict_file)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d", len(predict_examples))
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            batch_size=FLAGS.predict_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        predict_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
        predict_start_time = time.time()

        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "test_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            tf.logging.info("***** Predict results *****")
            for prediction in estimator.predict(input_fn=predict_input_fn,
                                                hooks=predict_hooks,
                                                yield_single_examples=False):
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in prediction) + "\n"
                writer.write(output_line)

        predict_time_elapsed = time.time() - predict_start_time
        predict_time_wo_overhead = predict_hooks[-1].total_time

        time_list = predict_hooks[-1].time_list
        time_list.sort()
        num_sentences = (predict_hooks[-1].count -
                         predict_hooks[-1].skipped) * FLAGS.predict_batch_size

        avg = np.mean(time_list)
        cf_50 = max(time_list[:int(len(time_list) * 0.50)])
        cf_90 = max(time_list[:int(len(time_list) * 0.90)])
        cf_95 = max(time_list[:int(len(time_list) * 0.95)])
        cf_99 = max(time_list[:int(len(time_list) * 0.99)])
        cf_100 = max(time_list[:int(len(time_list) * 1)])
        ss_sentences_per_second = num_sentences * 1.0 / predict_time_wo_overhead

        tf.logging.info("-----------------------------")
        tf.logging.info("Total Inference Time = %0.2f for Sentences = %d",
                        predict_time_elapsed,
                        predict_hooks[-1].count * FLAGS.predict_batch_size)
        tf.logging.info(
            "Total Inference Time W/O Overhead = %0.2f for Sentences = %d",
            predict_time_wo_overhead,
            (predict_hooks[-1].count - predict_hooks[-1].skipped) *
            FLAGS.predict_batch_size)

        tf.logging.info("Summary Inference Statistics on TEST SET")
        tf.logging.info("Batch size = %d", FLAGS.predict_batch_size)
        tf.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
        tf.logging.info("Precision = %s", "fp16" if FLAGS.use_fp16 else "fp32")
        tf.logging.info("Latency Confidence Level 50 (ms) = %0.2f",
                        cf_50 * 1000)
        tf.logging.info("Latency Confidence Level 90 (ms) = %0.2f",
                        cf_90 * 1000)
        tf.logging.info("Latency Confidence Level 95 (ms) = %0.2f",
                        cf_95 * 1000)
        tf.logging.info("Latency Confidence Level 99 (ms) = %0.2f",
                        cf_99 * 1000)
        tf.logging.info("Latency Confidence Level 100 (ms) = %0.2f",
                        cf_100 * 1000)
        tf.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
        tf.logging.info("Throughput Average (sentences/sec) = %0.2f",
                        ss_sentences_per_second)
        tf.logging.info("-----------------------------")
def main(_):
  # causes memory fragmentation for bert leading to OOM
  if os.environ.get("TF_XLA_FLAGS", None) is not None:
    os.environ["TF_XLA_FLAGS"] += " --tf_xla_enable_lazy_compilation false"
  else:
    os.environ["TF_XLA_FLAGS"] = " --tf_xla_enable_lazy_compilation false"

  # Enable async_io to speed up multi-gpu training with XLA and Horovod.
  os.environ["TF_XLA_FLAGS"] += " --tf_xla_async_io_level 1"

  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

  if FLAGS.horovod:
    hvd.init()

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  validate_flags_or_throw(bert_config)

  tf.io.gfile.makedirs(FLAGS.output_dir)

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  master_process = True
  training_hooks = []
  global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
  hvd_rank = 0

  config = tf.compat.v1.ConfigProto()
  learning_rate = FLAGS.learning_rate
  if FLAGS.horovod:

      tf.compat.v1.logging.info("Multi-GPU training with TF Horovod")
      tf.compat.v1.logging.info("hvd.size() = %d hvd.rank() = %d", hvd.size(), hvd.rank())
      global_batch_size = FLAGS.train_batch_size * hvd.size() * FLAGS.num_accumulation_steps
      learning_rate = learning_rate * hvd.size()
      master_process = (hvd.rank() == 0)
      hvd_rank = hvd.rank()
      config.gpu_options.visible_device_list = str(hvd.local_rank())
      if hvd.size() > 1:
          training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
  if FLAGS.use_xla:
    config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
    if FLAGS.amp:
        tf.enable_resource_variables()

  run_config = tf.estimator.RunConfig(
      model_dir=FLAGS.output_dir if master_process else None,
      session_config=config,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None,
      save_summary_steps=FLAGS.save_checkpoints_steps if master_process else None,
      log_step_count_steps=FLAGS.display_loss_steps,
      keep_checkpoint_max=1)

  if master_process:
      tf.compat.v1.logging.info("***** Configuaration *****")
      for key in FLAGS.__flags.keys():
          tf.compat.v1.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
      tf.compat.v1.logging.info("**************************")

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None
  training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank, FLAGS.save_checkpoints_steps))

  # Prepare Training Data
  if FLAGS.do_train:
    train_examples = read_squad_examples(
        input_file=FLAGS.train_file, is_training=True,
        version_2_with_negative=FLAGS.version_2_with_negative)
    num_train_steps = int(
        len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Pre-shuffle the input to avoid having to make a very large shuffle
    # buffer in in the `input_fn`.
    rng = random.Random(12345)
    rng.shuffle(train_examples)

    start_index = 0 
    end_index = len(train_examples)
    tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]

    if FLAGS.horovod:
      tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
      num_examples_per_rank = len(train_examples) // hvd.size()
      remainder = len(train_examples) % hvd.size()
      if hvd.rank() < remainder:
        start_index = hvd.rank() * (num_examples_per_rank+1)
        end_index = start_index + num_examples_per_rank + 1
      else:
        start_index = hvd.rank() * num_examples_per_rank + remainder
        end_index = start_index + (num_examples_per_rank)


  model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      hvd=None if not FLAGS.horovod else hvd,
      amp=FLAGS.amp)

  estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      config=run_config)

  if FLAGS.do_train:

    # We write to a temporary file to avoid storing very large constant tensors
    # in memory.
    train_writer = FeatureWriter(
        filename=tmp_filenames[hvd_rank],
        is_training=True)
    convert_examples_to_features(
        examples=train_examples[start_index:end_index],
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=True,
        output_fn=train_writer.process_feature,
        verbose_logging=FLAGS.verbose_logging)
    train_writer.close()

    tf.compat.v1.logging.info("***** Running training *****")
    tf.compat.v1.logging.info("  Num orig examples = %d", end_index - start_index)
    tf.compat.v1.logging.info("  Num split examples = %d", train_writer.num_features)
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
    tf.compat.v1.logging.info("  LR = %f", learning_rate)
    del train_examples

    train_input_fn = input_fn_builder(
        input_file=tmp_filenames,
        batch_size=FLAGS.train_batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True,
        hvd=None if not FLAGS.horovod else hvd)

    train_start_time = time.time()
    estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=num_train_steps)
    train_time_elapsed = time.time() - train_start_time
    train_time_wo_overhead = training_hooks[-1].total_time
    avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
    ss_sentences_per_second = (num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead

    if master_process:
        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
                        num_train_steps * global_batch_size)
        tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
                        (num_train_steps - training_hooks[-1].skipped) * global_batch_size)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
        dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")


  if FLAGS.export_triton and master_process:
    export_model(estimator, FLAGS.output_dir, FLAGS.init_checkpoint)

  if FLAGS.do_predict and master_process:
    eval_examples = read_squad_examples(
        input_file=FLAGS.predict_file, is_training=False,
        version_2_with_negative=FLAGS.version_2_with_negative)

    # Perform evaluation on subset, useful for profiling
    if FLAGS.num_eval_iterations is not None:
        eval_examples = eval_examples[:FLAGS.num_eval_iterations*FLAGS.predict_batch_size]

    eval_writer = FeatureWriter(
        filename=os.path.join(FLAGS.output_dir, "eval.tf_record"),
        is_training=False)
    eval_features = []

    def append_feature(feature):
      eval_features.append(feature)
      eval_writer.process_feature(feature)

    convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=False,
        output_fn=append_feature,
        verbose_logging=FLAGS.verbose_logging)
    eval_writer.close()

    tf.compat.v1.logging.info("***** Running predictions *****")
    tf.compat.v1.logging.info("  Num orig examples = %d", len(eval_examples))
    tf.compat.v1.logging.info("  Num split examples = %d", len(eval_features))
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_input_fn = input_fn_builder(
        input_file=eval_writer.filename,
        batch_size=FLAGS.predict_batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False)

    all_results = []
    eval_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
    eval_start_time = time.time()
    for result in estimator.predict(
        predict_input_fn, yield_single_examples=True, hooks=eval_hooks):
      if len(all_results) % 1000 == 0:
        tf.compat.v1.logging.info("Processing example: %d" % (len(all_results)))
      unique_id = int(result["unique_ids"])
      start_logits = [float(x) for x in result["start_logits"].flat]
      end_logits = [float(x) for x in result["end_logits"].flat]
      all_results.append(
          RawResult(
              unique_id=unique_id,
              start_logits=start_logits,
              end_logits=end_logits))

    eval_time_elapsed = time.time() - eval_start_time

    time_list = eval_hooks[-1].time_list
    time_list.sort()
    # Removing outliers (init/warmup) in throughput computation.
    eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
    num_sentences = (int(len(time_list) * 0.99)) * FLAGS.predict_batch_size

    avg = np.mean(time_list)
    cf_50 = max(time_list[:int(len(time_list) * 0.50)])
    cf_90 = max(time_list[:int(len(time_list) * 0.90)])
    cf_95 = max(time_list[:int(len(time_list) * 0.95)])
    cf_99 = max(time_list[:int(len(time_list) * 0.99)])
    cf_100 = max(time_list[:int(len(time_list) * 1)])
    ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

    tf.compat.v1.logging.info("-----------------------------")
    tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
                    eval_hooks[-1].count * FLAGS.predict_batch_size)
    tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
                    num_sentences)
    tf.compat.v1.logging.info("Summary Inference Statistics")
    tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
    tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
    tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
    tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
    tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
    tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
    dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
    tf.compat.v1.logging.info("-----------------------------")

    output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
    output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
    output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")

    write_predictions(eval_examples, eval_features, all_results,
                      FLAGS.n_best_size, FLAGS.max_answer_length,
                      FLAGS.do_lower_case, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file,
                      FLAGS.version_2_with_negative, FLAGS.verbose_logging)

    if FLAGS.eval_script:
        import sys
        import subprocess
        eval_out = subprocess.check_output([sys.executable, FLAGS.eval_script,
                                          FLAGS.predict_file, output_prediction_file])
        scores = str(eval_out).strip()
        exact_match = float(scores.split(":")[1].split(",")[0])
        f1 = float(scores.split(":")[2].split("}")[0])
        dllogging.logger.log(step=(), data={"f1": f1}, verbosity=Verbosity.DEFAULT)
        dllogging.logger.log(step=(), data={"exact_match": exact_match}, verbosity=Verbosity.DEFAULT)
        print(str(eval_out))
Exemple #3
0
def main(_):
    os.environ[
        "TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"  #causes memory fragmentation for bert leading to OOM

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

    if FLAGS.horovod:
        hvd.init()

    processors = {
        "bc5cdr": BC5CDRProcessor,
        "clefe": CLEFEProcessor,
        'i2b2': I2b22012Processor
    }
    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    task_name = FLAGS.task_name.lower()
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2

    master_process = True
    training_hooks = []
    global_batch_size = FLAGS.train_batch_size
    hvd_rank = 0

    config = tf.compat.v1.ConfigProto()
    if FLAGS.horovod:
        global_batch_size = FLAGS.train_batch_size * hvd.size()
        master_process = (hvd.rank() == 0)
        hvd_rank = hvd.rank()
        config.gpu_options.visible_device_list = str(hvd.local_rank())
        if hvd.size() > 1:
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

    if FLAGS.use_xla:
        config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
        tf.enable_resource_variables()
    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.output_dir if master_process else None,
        session_config=config,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps
        if master_process else None,
        keep_checkpoint_max=1)

    if master_process:
        tf.compat.v1.logging.info("***** Configuaration *****")
        for key in FLAGS.__flags.keys():
            tf.compat.v1.logging.info('  {}: {}'.format(
                key, getattr(FLAGS, key)))
        tf.compat.v1.logging.info("**************************")

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank))

    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

        start_index = 0
        end_index = len(train_examples)
        tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]

        if FLAGS.horovod:
            tmp_filenames = [
                os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i))
                for i in range(hvd.size())
            ]
            num_examples_per_rank = len(train_examples) // hvd.size()
            remainder = len(train_examples) % hvd.size()
            if hvd.rank() < remainder:
                start_index = hvd.rank() * (num_examples_per_rank + 1)
                end_index = start_index + num_examples_per_rank + 1
            else:
                start_index = hvd.rank() * num_examples_per_rank + remainder
                end_index = start_index + (num_examples_per_rank)

    model_fn = model_fn_builder(bert_config=bert_config,
                                num_labels=len(label_list) + 1,
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate
                                if not FLAGS.horovod else FLAGS.learning_rate *
                                hvd.size(),
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                use_one_hot_embeddings=False,
                                hvd=None if not FLAGS.horovod else hvd,
                                use_fp16=FLAGS.use_fp16)

    estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

    if FLAGS.do_train:
        #train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        #filed_based_convert_examples_to_features(
        #    train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
        filed_based_convert_examples_to_features(
            train_examples[start_index:end_index], label_list,
            FLAGS.max_seq_length, tokenizer, tmp_filenames[hvd_rank])
        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Num examples = %d", len(train_examples))
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=tmp_filenames,  #train_file,
            batch_size=FLAGS.train_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True,
            hvd=None if not FLAGS.horovod else hvd)

        #estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
        train_start_time = time.time()
        estimator.train(input_fn=train_input_fn,
                        max_steps=num_train_steps,
                        hooks=training_hooks)
        train_time_elapsed = time.time() - train_start_time
        train_time_wo_overhead = training_hooks[-1].total_time
        avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
        ss_sentences_per_second = (
            num_train_steps - training_hooks[-1].skipped
        ) * global_batch_size * 1.0 / train_time_wo_overhead

        if master_process:
            tf.compat.v1.logging.info("-----------------------------")
            tf.compat.v1.logging.info(
                "Total Training Time = %0.2f for Sentences = %d",
                train_time_elapsed, num_train_steps * global_batch_size)
            tf.compat.v1.logging.info(
                "Total Training Time W/O Overhead = %0.2f for Sentences = %d",
                train_time_wo_overhead,
                (num_train_steps - training_hooks[-1].skipped) *
                global_batch_size)
            tf.compat.v1.logging.info(
                "Throughput Average (sentences/sec) with overhead = %0.2f",
                avg_sentences_per_second)
            tf.compat.v1.logging.info(
                "Throughput Average (sentences/sec) = %0.2f",
                ss_sentences_per_second)
            dllogging.logger.log(
                step=(),
                data={"throughput_train": ss_sentences_per_second},
                verbosity=Verbosity.DEFAULT)
            tf.compat.v1.logging.info("-----------------------------")

    if FLAGS.do_eval and master_process:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        filed_based_convert_examples_to_features(eval_examples, label_list,
                                                 FLAGS.max_seq_length,
                                                 tokenizer, eval_file)

        tf.compat.v1.logging.info("***** Running evaluation *****")
        tf.compat.v1.logging.info("  Num examples = %d", len(eval_examples))
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)
        eval_steps = None
        eval_drop_remainder = False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            batch_size=FLAGS.eval_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)
        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.io.gfile.Open(output_eval_file, "w") as writer:
            tf.compat.v1.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
                dllogging.logger.log(step=(),
                                     data={key: float(strresult[key])},
                                     verbosity=Verbosity.DEFAULT)
                writer.write("%s = %s\n" % (key, str(result[key])))
    if FLAGS.do_predict and master_process:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        filed_based_convert_examples_to_features(predict_examples,
                                                 label_list,
                                                 FLAGS.max_seq_length,
                                                 tokenizer,
                                                 predict_file,
                                                 mode="test")

        with tf.io.gfile.Open(os.path.join(FLAGS.output_dir, 'label2id.pkl'),
                              'rb') as rf:
            label2id = pickle.load(rf)
            id2label = {value: key for key, value in label2id.items()}
        token_path = os.path.join(FLAGS.output_dir, "token_test.txt")
        if tf.io.gfile.Exists(token_path):
            tf.io.gfile.Remove(token_path)

        tf.compat.v1.logging.info("***** Running prediction*****")
        tf.compat.v1.logging.info("  Num examples = %d", len(predict_examples))
        tf.compat.v1.logging.info("  Batch size = %d",
                                  FLAGS.predict_batch_size)

        predict_drop_remainder = False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            batch_size=FLAGS.predict_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        eval_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
        eval_start_time = time.time()

        output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt")
        test_labels_file = os.path.join(FLAGS.output_dir, "test_labels.txt")
        test_labels_err_file = os.path.join(FLAGS.output_dir,
                                            "test_labels_errs.txt")
        with tf.io.gfile.Open(output_predict_file, 'w') as writer, \
                tf.io.gfile.Open(test_labels_file, 'w') as tl, \
                tf.io.gfile.Open(test_labels_err_file, 'w') as tle:
            print(id2label)
            i = 0
            for prediction in estimator.predict(input_fn=predict_input_fn,
                                                hooks=eval_hooks,
                                                yield_single_examples=True):
                output_line = "\n".join(id2label[id]
                                        for id in prediction if id != 0) + "\n"
                writer.write(output_line)
                result_to_pair(predict_examples[i], prediction, id2label, tl,
                               tle)
                i = i + 1

        eval_time_elapsed = time.time() - eval_start_time

        time_list = eval_hooks[-1].time_list
        time_list.sort()
        # Removing outliers (init/warmup) in throughput computation.
        eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
        num_sentences = (int(len(time_list) * 0.99)) * FLAGS.predict_batch_size

        avg = np.mean(time_list)
        cf_50 = max(time_list[:int(len(time_list) * 0.50)])
        cf_90 = max(time_list[:int(len(time_list) * 0.90)])
        cf_95 = max(time_list[:int(len(time_list) * 0.95)])
        cf_99 = max(time_list[:int(len(time_list) * 0.99)])
        cf_100 = max(time_list[:int(len(time_list) * 1)])
        ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info(
            "Total Inference Time = %0.2f for Sentences = %d",
            eval_time_elapsed, eval_hooks[-1].count * FLAGS.predict_batch_size)
        tf.compat.v1.logging.info(
            "Total Inference Time W/O Overhead = %0.2f for Sentences = %d",
            eval_time_wo_overhead, num_sentences)
        tf.compat.v1.logging.info("Summary Inference Statistics")
        tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
        tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
        tf.compat.v1.logging.info("Precision = %s",
                                  "fp16" if FLAGS.use_fp16 else "fp32")
        tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f",
                                  cf_50 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f",
                                  cf_90 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f",
                                  cf_95 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f",
                                  cf_99 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f",
                                  cf_100 * 1000)
        tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f",
                                  ss_sentences_per_second)
        dllogging.logger.log(step=(),
                             data={"throughput_val": ss_sentences_per_second},
                             verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")

        tf.compat.v1.logging.info('Reading: %s', test_labels_file)
        with tf.io.gfile.Open(test_labels_file, "r") as f:
            counts = evaluate(f)
        eval_result = report_notprint(counts)
        print(''.join(eval_result))
        with tf.io.gfile.Open(
                os.path.join(FLAGS.output_dir, 'test_results_conlleval.txt'),
                'w') as fd:
            fd.write(''.join(eval_result))
def main(_):
    # causes memory fragmentation for bert leading to OOM
    if os.environ.get("TF_XLA_FLAGS", None) is not None:
        os.environ["TF_XLA_FLAGS"] += "--tf_xla_enable_lazy_compilation=false"
    else:
        os.environ["TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

    if FLAGS.horovod:
        hvd.init()

    processors = {
        "chemprot": BioBERTChemprotProcessor,
        'mednli': MedNLIProcessor,
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True."
        )

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2

    master_process = True
    training_hooks = []
    global_batch_size = FLAGS.train_batch_size
    hvd_rank = 0

    config = tf.compat.v1.ConfigProto()
    if FLAGS.horovod:
        global_batch_size = FLAGS.train_batch_size * hvd.size()
        master_process = (hvd.rank() == 0)
        hvd_rank = hvd.rank()
        config.gpu_options.visible_device_list = str(hvd.local_rank())
        if hvd.size() > 1:
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

    if FLAGS.use_xla:
        config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
        tf.enable_resource_variables()
    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.output_dir if master_process else None,
        session_config=config,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps
        if master_process else None,
        keep_checkpoint_max=1)

    if master_process:
        tf.compat.v1.logging.info("***** Configuaration *****")
        for key in FLAGS.__flags.keys():
            tf.compat.v1.logging.info('  {}: {}'.format(
                key, getattr(FLAGS, key)))
        tf.compat.v1.logging.info("**************************")

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None

    training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank))

    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

        start_index = 0
        end_index = len(train_examples)
        tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]

        if FLAGS.horovod:
            tmp_filenames = [
                os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i))
                for i in range(hvd.size())
            ]
            num_examples_per_rank = len(train_examples) // hvd.size()
            remainder = len(train_examples) % hvd.size()
            if hvd.rank() < remainder:
                start_index = hvd.rank() * (num_examples_per_rank + 1)
                end_index = start_index + num_examples_per_rank + 1
            else:
                start_index = hvd.rank() * num_examples_per_rank + remainder
                end_index = start_index + (num_examples_per_rank)

    model_fn = model_fn_builder(bert_config=bert_config,
                                num_labels=len(label_list),
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate
                                if not FLAGS.horovod else FLAGS.learning_rate *
                                hvd.size(),
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                use_one_hot_embeddings=False,
                                hvd=None if not FLAGS.horovod else hvd,
                                amp=FLAGS.amp)

    estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

    if FLAGS.do_train:
        file_based_convert_examples_to_features(
            train_examples[start_index:end_index], label_list,
            FLAGS.max_seq_length, tokenizer, tmp_filenames[hvd_rank])
        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Num examples = %d", len(train_examples))
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=tmp_filenames,
            batch_size=FLAGS.train_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True,
            hvd=None if not FLAGS.horovod else hvd)

        train_start_time = time.time()
        estimator.train(input_fn=train_input_fn,
                        max_steps=num_train_steps,
                        hooks=training_hooks)
        train_time_elapsed = time.time() - train_start_time
        train_time_wo_overhead = training_hooks[-1].total_time
        avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
        ss_sentences_per_second = (
            num_train_steps - training_hooks[-1].skipped
        ) * global_batch_size * 1.0 / train_time_wo_overhead

        if master_process:
            tf.compat.v1.logging.info("-----------------------------")
            tf.compat.v1.logging.info(
                "Total Training Time = %0.2f for Sentences = %d",
                train_time_elapsed, num_train_steps * global_batch_size)
            tf.compat.v1.logging.info(
                "Total Training Time W/O Overhead = %0.2f for Sentences = %d",
                train_time_wo_overhead,
                (num_train_steps - training_hooks[-1].skipped) *
                global_batch_size)
            tf.compat.v1.logging.info(
                "Throughput Average (sentences/sec) with overhead = %0.2f",
                avg_sentences_per_second)
            tf.compat.v1.logging.info(
                "Throughput Average (sentences/sec) = %0.2f",
                ss_sentences_per_second)
            dllogging.logger.log(
                step=(),
                data={"throughput_train": ss_sentences_per_second},
                verbosity=Verbosity.DEFAULT)
            tf.compat.v1.logging.info("-----------------------------")

    if FLAGS.do_eval and master_process:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        file_based_convert_examples_to_features(eval_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, eval_file)

        tf.compat.v1.logging.info("***** Running evaluation *****")
        tf.compat.v1.logging.info(
            "  Num examples = %d (%d actual, %d padding)", len(eval_examples),
            num_actual_eval_examples,
            len(eval_examples) - num_actual_eval_examples)
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None

        eval_drop_remainder = False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            batch_size=FLAGS.eval_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)

        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            tf.compat.v1.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if FLAGS.do_predict and master_process:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, predict_file)

        tf.compat.v1.logging.info("***** Running prediction*****")
        tf.compat.v1.logging.info(
            "  Num examples = %d (%d actual, %d padding)",
            len(predict_examples), num_actual_predict_examples,
            len(predict_examples) - num_actual_predict_examples)
        tf.compat.v1.logging.info("  Batch size = %d",
                                  FLAGS.predict_batch_size)

        predict_drop_remainder = False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            batch_size=FLAGS.predict_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        eval_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
        eval_start_time = time.time()

        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "test_results.tsv")
        with tf.io.gfile.GFile(output_predict_file, "w") as writer:
            num_written_lines = 0
            tf.compat.v1.logging.info("***** Predict results *****")
            for prediction in estimator.predict(input_fn=predict_input_fn,
                                                hooks=eval_hooks,
                                                yield_single_examples=True):
                probabilities = prediction["probabilities"]
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                writer.write(output_line)
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples

        eval_time_elapsed = time.time() - eval_start_time

        time_list = eval_hooks[-1].time_list
        time_list.sort()
        # Removing outliers (init/warmup) in throughput computation.
        eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
        num_sentences = (int(len(time_list) * 0.99)) * FLAGS.predict_batch_size

        avg = np.mean(time_list)
        cf_50 = max(time_list[:int(len(time_list) * 0.50)])
        cf_90 = max(time_list[:int(len(time_list) * 0.90)])
        cf_95 = max(time_list[:int(len(time_list) * 0.95)])
        cf_99 = max(time_list[:int(len(time_list) * 0.99)])
        cf_100 = max(time_list[:int(len(time_list) * 1)])
        ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info(
            "Total Inference Time = %0.2f for Sentences = %d",
            eval_time_elapsed, eval_hooks[-1].count * FLAGS.predict_batch_size)
        tf.compat.v1.logging.info(
            "Total Inference Time W/O Overhead = %0.2f for Sentences = %d",
            eval_time_wo_overhead, num_sentences)
        tf.compat.v1.logging.info("Summary Inference Statistics")
        tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
        tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
        tf.compat.v1.logging.info("Precision = %s",
                                  "fp16" if FLAGS.amp else "fp32")
        tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f",
                                  cf_50 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f",
                                  cf_90 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f",
                                  cf_95 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f",
                                  cf_99 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f",
                                  cf_100 * 1000)
        tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f",
                                  ss_sentences_per_second)
        dllogging.logger.log(step=(),
                             data={"throughput_val": ss_sentences_per_second},
                             verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")
def main(_):
    os.environ["TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false" #causes memory fragmentation for bert leading to OOM
    os.environ["XLA_FLAGS"]="--xla_gpu_cuda_data_dir=/appl/spack/install-tree/gcc-8.3.0/cuda-10.1.168-mrdepn/"
    #os.environ["TF_CUDA_HOST_MEM_LIMIT_IN_MB"] = "30000"
    #os.environ["TF_XLA_FLAGS"]="--tf_xla_cpu_global_jit"
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    #dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path) 
    if FLAGS.horovod:
      hvd.init()
    #if FLAGS.use_fp16:
    #    os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

    processors = {'consensus':ConsensusProcessor}
    
    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
       raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    task_name = FLAGS.task_name.lower()
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))
    
    tf.io.gfile.makedirs(FLAGS.output_dir)

    processor = processors[task_name]()

    label_list,label_map = processor.get_labels(FLAGS.labels_dir)
    inv_label_map = { v: k for k, v in label_map.items() }
    #label_list = processor.get_labels()
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2

    master_process = True
    training_hooks = []
    #training_hooks.append(OomReportingHook())
    global_batch_size = FLAGS.train_batch_size
    hvd_rank = 0

    config = tf.compat.v1.ConfigProto()
    #didn't work
    #config.gpu_options.allow_growth = True
    if FLAGS.horovod:
      global_batch_size = FLAGS.train_batch_size * hvd.size()
      master_process = (hvd.rank() == 0)
      hvd_rank = hvd.rank()
      #os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank())
      config.gpu_options.visible_device_list = str(hvd.local_rank())
      if hvd.size() > 1:
        training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
      #config.gpu_options.per_process_gpu_memory_fraction = 0.4
   # config.gpu_options.per_process_gpu_memory_fraction = 0.7
    if FLAGS.use_xla:
        config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
        #config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
    run_config = tf.estimator.RunConfig(
      model_dir=FLAGS.output_dir if master_process else None,
      session_config=config,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None,
      keep_checkpoint_max=1)

    if master_process:
      tf.compat.v1.logging.info("***** Configuration *****")
      for key in FLAGS.__flags.keys():
          tf.compat.v1.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
      tf.compat.v1.logging.info("**************************")

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank))

    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

        start_index = 0
        end_index = len(train_examples)
        tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]

        if FLAGS.horovod:
          tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
          num_examples_per_rank = len(train_examples) // hvd.size()
          remainder = len(train_examples) % hvd.size()
          if hvd.rank() < remainder:
            start_index = hvd.rank() * (num_examples_per_rank+1)
            end_index = start_index + num_examples_per_rank + 1
          else:
            start_index = hvd.rank() * num_examples_per_rank + remainder
            end_index = start_index + (num_examples_per_rank)

    model_fn = model_fn_builder(
        bert_config=bert_config,
        num_labels=len(label_list) + 1,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_one_hot_embeddings=False,
        hvd=None if not FLAGS.horovod else hvd,
        use_fp16=FLAGS.use_fp16)

    estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      config=run_config)

    if FLAGS.do_train: 
        filed_based_convert_examples_to_features(
          train_examples[start_index:end_index], label_list, label_map, FLAGS.max_seq_length, tokenizer, tmp_filenames[hvd_rank], FLAGS.replace_span)
        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Num examples = %d", len(train_examples))
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
        tf.compat.v1.logging.info("  Num of labels = %d", len(label_list))
        train_input_fn = file_based_input_fn_builder(
            input_file=tmp_filenames,
            batch_size=FLAGS.train_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True,
            hvd=None if not FLAGS.horovod else hvd)
        
        train_start_time = time.time()
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=training_hooks)
        train_time_elapsed = time.time() - train_start_time
        train_time_wo_overhead = training_hooks[-1].total_time
        avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
        ss_sentences_per_second = (num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead

        if master_process:
          tf.compat.v1.logging.info("-----------------------------")
          tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
                        num_train_steps * global_batch_size)
          tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
                        (num_train_steps - training_hooks[-1].skipped) * global_batch_size)
          tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
          tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
          tf.compat.v1.logging.info("-----------------------------")

    if FLAGS.do_eval and master_process:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)
        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        filed_based_convert_examples_to_features(
            eval_examples, label_list, label_map, FLAGS.max_seq_length, tokenizer, eval_file, FLAGS.replace_span)

        tf.compat.v1.logging.info("***** Running evaluation *****")
        tf.compat.v1.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(eval_examples), num_actual_eval_examples,
                        len(eval_examples) - num_actual_eval_examples)
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)
        # This tells the estimator to run through the entire set.
        eval_steps = None
        eval_drop_remainder = False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            batch_size=FLAGS.eval_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)
        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            tf.compat.v1.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
        
        eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
 
    if FLAGS.do_predict and master_process:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)
        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        filed_based_convert_examples_to_features(predict_examples, label_list, label_map,
                                                 FLAGS.max_seq_length, tokenizer,
                                                 predict_file, FLAGS.replace_span)
        tf.compat.v1.logging.info("***** Running prediction*****")
        tf.compat.v1.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)        
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            batch_size=FLAGS.predict_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=False,       
            drop_remainder=predict_drop_remainder)

        eval_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
        eval_start_time = time.time()

        output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
        output_class_file = os.path.join(FLAGS.output_dir, "test_output_labels.txt")
        with tf.io.gfile.GFile(output_predict_file, "w") as writer, tf.io.gfile.GFile(output_class_file, "w") as writer2:
            num_written_lines = 0
            tf.compat.v1.logging.info("***** Predict results *****")
            for prediction in estimator.predict(input_fn=predict_input_fn, hooks=eval_hooks,
                                                     yield_single_examples=True):
                probabilities = prediction["probabilities"]
                logits = prediction["logits"]
                pr_res = np.argmax(logits, axis=-1)
                output = str(inv_label_map[pr_res])+"\n"
                writer2.write(output)
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                writer.write(output_line)
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples
    
        eval_time_elapsed = time.time() - eval_start_time
        eval_time_wo_overhead = eval_hooks[-1].total_time

        time_list = eval_hooks[-1].time_list
        time_list.sort()
        num_sentences = (eval_hooks[-1].count - eval_hooks[-1].skipped) * FLAGS.predict_batch_size

        avg = np.mean(time_list)
        cf_50 = max(time_list[:int(len(time_list) * 0.50)])
        cf_90 = max(time_list[:int(len(time_list) * 0.90)])
        cf_95 = max(time_list[:int(len(time_list) * 0.95)])
        cf_99 = max(time_list[:int(len(time_list) * 0.99)])
        cf_100 = max(time_list[:int(len(time_list) * 1)])
        ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
                        eval_hooks[-1].count * FLAGS.predict_batch_size)
        tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
                        (eval_hooks[-1].count - eval_hooks[-1].skipped) * FLAGS.predict_batch_size)
        tf.compat.v1.logging.info("Summary Inference Statistics")
        tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
        tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
        tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.use_fp16 else "fp32")
        tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
        tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
        tf.compat.v1.logging.info("-----------------------------")
Exemple #6
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.horovod:
        hvd.init()

    if FLAGS.use_fp16:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

    model_fn_builder, file_based_input_fn_builder, eval_file_based_input_fn_builder = get_func_by_task(
        FLAGS.task_name.lower())

    if not FLAGS.do_train:
        raise ValueError("`do_train` must be True.")

    query_bert_config = modeling_bison.BertConfig.from_json_file(
        FLAGS.query_bert_config_file)
    meta_bert_config = modeling_bison.BertConfig.from_json_file(
        FLAGS.meta_bert_config_file)

    # Sequence length check
    if FLAGS.max_seq_length_query > query_bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use query sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length_query,
             query_bert_config.max_position_embeddings))

    meta_seq_length = FLAGS.max_seq_length_url + FLAGS.max_seq_length_title
    if FLAGS.enable_body:
        meta_seq_length += FLAGS.max_seq_length_body
    if meta_seq_length > meta_bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use meta sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (meta_seq_length, meta_bert_config.max_position_embeddings))

    master_process = True
    training_hooks = []
    global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
    hvd_rank = 0
    config = tf.ConfigProto()
    if FLAGS.horovod:
        tf.logging.info("Multi-GPU training with TF Horovod")
        tf.logging.info("hvd.size() = %d hvd.rank() = %d", hvd.size(),
                        hvd.rank())
        global_batch_size = FLAGS.train_batch_size * \
            FLAGS.num_accumulation_steps * hvd.size()
        master_process = (hvd.rank() == 0)
        hvd_rank = hvd.rank()
        config.gpu_options.allow_growth = True
        config.gpu_options.visible_device_list = str(hvd.local_rank())
        if hvd.size() > 1:
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
    if FLAGS.use_xla:
        config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    tf.gfile.MakeDirs(FLAGS.output_dir)
    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.output_dir if master_process else None,
        session_config=config,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps
        if master_process else None,
        keep_checkpoint_max=10)

    if master_process:
        tf.logging.info("***** Configuaration *****")
        for key in FLAGS.__flags.keys():
            tf.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
        tf.logging.info("**************************")

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    train_examples_count = FLAGS.train_line_count
    log_train_run_hook = LogTrainRunHook(global_batch_size, hvd_rank)
    training_hooks.append(log_train_run_hook)

    if FLAGS.do_train:
        num_train_steps = int(train_examples_count / global_batch_size *
                              FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = model_fn_builder(
        query_bert_config=query_bert_config,
        meta_bert_config=meta_bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
        nce_temperature=FLAGS.nce_temperature,
        nce_weight=FLAGS.nce_weight,
        hvd=None if not FLAGS.horovod else hvd)

    estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

    if FLAGS.do_train:
        start_index = 0
        end_index = FLAGS.train_partition_count

        if FLAGS.horovod:
            tfrecord_per_GPU = int(FLAGS.train_partition_count / hvd.size())
            start_index = hvd.rank() * tfrecord_per_GPU
            end_index = start_index + tfrecord_per_GPU

            if hvd.rank() == hvd.size():
                end_index = FLAGS.train_partition_count

        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", train_examples_count)
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        tf.logging.info("  hvd rank = %d", hvd.rank())
        tf.logging.info("  Num start_index = %d", start_index)
        tf.logging.info("  Num end_index = %d", end_index)

        train_file_list = []
        for i in range(start_index, end_index):
            train_file_list.append(
                os.path.join(FLAGS.preprocess_train_dir, str(i),
                             FLAGS.preprocess_train_file_name))
        tf.logging.info("merge " + str(end_index - start_index) +
                        " preprocessed file from preprocess dir")
        tf.logging.info(train_file_list)

        train_input_fn = file_based_input_fn_builder(
            input_file=train_file_list,
            batch_size=FLAGS.train_batch_size,
            query_seq_length=FLAGS.max_seq_length_query,
            meta_seq_length=meta_seq_length,
            is_training=True,
            drop_remainder=True,
            is_fidelity_eval=False,
            hvd=None if not FLAGS.horovod else hvd)

        # initilize eval file
        # must set preprocess_eval_dir, all file in folder preprocess_eval_dir will thinked as tfrecord file
        if FLAGS.preprocess_eval_dir is None:
            raise ValueError('must set preprocess_eval_dir by hand.')

        all_eval_files = []
        eval_file_list = []
        find_all_file_in_folder(FLAGS.preprocess_eval_dir, all_eval_files)
        for i in range(len(all_eval_files)):
            if hvd.rank() == i % hvd.size():
                eval_file_list.append(all_eval_files[i])

        if 0 == len(eval_file_list):
            raise ValueError('  Rank: %d get eval file empty.' % (hvd.rank()))

        tf.logging.info(
            "**********Check how many eval example in current rank*************"
        )
        eval_examples_count = check_line_count_in_tfrecords(eval_file_list)

        eval_steps = int(math.ceil(eval_examples_count /
                                   FLAGS.eval_batch_size))

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Rank: %d will eval files:%s" %
                        (hvd.rank(), str(eval_file_list)))
        tf.logging.info("  Rank: %d eval example count:%d" %
                        (hvd.rank(), eval_examples_count))
        tf.logging.info("  Rank: %d eval batch size:%d" %
                        (hvd.rank(), FLAGS.eval_batch_size))
        tf.logging.info("  Rank: %d eval_steps:%d" % (hvd.rank(), eval_steps))

        eval_input_fn = eval_file_based_input_fn_builder(
            input_file=eval_file_list,
            query_seq_length=FLAGS.max_seq_length_query,
            meta_seq_length=meta_seq_length,
            drop_remainder=False,
            is_fidelity_eval=False)

        # create InMemoryEvaluatorHook
        in_memory_evaluator = tf.estimator.experimental.InMemoryEvaluatorHook(
            estimator=estimator,
            steps=
            eval_steps,  # steps must be set or will not print any log, do not know why
            input_fn=eval_input_fn,
            every_n_iter=FLAGS.save_checkpoints_steps,
            name="fidelity_eval")
        training_hooks.append(in_memory_evaluator)

        train_start_time = time.time()
        estimator.train(input_fn=train_input_fn,
                        max_steps=num_train_steps,
                        hooks=training_hooks)
        train_time_elapsed = time.time() - train_start_time
        train_time_wo_overhead = log_train_run_hook.total_time
        avg_sentences_per_second = num_train_steps * \
            global_batch_size * 1.0 / train_time_elapsed
        ss_sentences_per_second = (
            num_train_steps - log_train_run_hook.skipped
        ) * global_batch_size * 1.0 / train_time_wo_overhead

        if master_process:
            tf.logging.info("-----------------------------")
            tf.logging.info("Total Training Time = %0.2f for Sentences = %d",
                            train_time_elapsed,
                            num_train_steps * global_batch_size)
            tf.logging.info(
                "Total Training Time W/O Overhead = %0.2f for Sentences = %d",
                train_time_wo_overhead,
                (num_train_steps - log_train_run_hook.skipped) *
                global_batch_size)
            tf.logging.info(
                "Throughput Average (sentences/sec) with overhead = %0.2f",
                avg_sentences_per_second)
            tf.logging.info("Throughput Average (sentences/sec) = %0.2f",
                            ss_sentences_per_second)
            tf.logging.info("-----------------------------")
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.horovod:
        hvd.init()
    if FLAGS.use_fp16:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    validate_flags_or_throw(bert_config)

    tf.gfile.MakeDirs(FLAGS.output_dir)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    master_process = True
    training_hooks = []
    global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
    hvd_rank = 0

    config = tf.ConfigProto()

    learning_rate = 2e-5  #FLAGS.learning_rate
    # if FLAGS.horovod:
    #
    #     tf.logging.info("Multi-GPU training with TF Horovod")
    #     tf.logging.info("hvd.size() = %d hvd.rank() = %d", hvd.size(), hvd.rank())
    #     global_batch_size = FLAGS.train_batch_size * hvd.size() * FLAGS.num_accumulation_steps
    #     learning_rate = learning_rate * hvd.size()
    #     master_process = (hvd.rank() == 0)
    #     hvd_rank = hvd.rank()
    #     config.gpu_options.visible_device_list = str(hvd.local_rank())
    #     if hvd.size() > 1:
    #         training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
    # if FLAGS.use_xla:
    #   config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.output_dir if master_process else None,
        session_config=config,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps
        if master_process else None,
        keep_checkpoint_max=1)

    if master_process:
        tf.logging.info("***** Configuaration *****")
        for key in FLAGS.__flags.keys():
            tf.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
        tf.logging.info("**************************")

    train_examples = None
    num_train_steps = 1564  # CHECK: make it to none
    num_warmup_steps = 1000
    training_hooks.append(
        LogTrainRunHook(global_batch_size, hvd_rank,
                        FLAGS.save_checkpoints_steps))

    # Prepare Training Data
    # if FLAGS.do_train:
    #   train_examples = read_squad_examples(
    #       input_file=FLAGS.train_file, is_training=True,
    #       version_2_with_negative=FLAGS.version_2_with_negative)
    #   num_train_steps = int(
    #       len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
    #   num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
    #
    #   # Pre-shuffle the input to avoid having to make a very large shuffle
    #   # buffer in in the `input_fn`.
    #   rng = random.Random(12345)
    #   rng.shuffle(train_examples)
    #
    #   start_index = 0
    #   end_index = len(train_examples)
    #   tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]
    #
    #   if FLAGS.horovod:
    #     tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
    #     num_examples_per_rank = len(train_examples) // hvd.size()
    #     remainder = len(train_examples) % hvd.size()
    #     if hvd.rank() < remainder:
    #       start_index = hvd.rank() * (num_examples_per_rank+1)
    #       end_index = start_index + num_examples_per_rank + 1
    #     else:
    #       start_index = hvd.rank() * num_examples_per_rank + remainder
    #       end_index = start_index + (num_examples_per_rank)

    model_fn = model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=learning_rate,  #2e-5
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        hvd=None if not FLAGS.horovod else hvd,
        use_fp16=FLAGS.use_fp16)

    estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

    # if FLAGS.do_train:
    #
    #   # We write to a temporary file to avoid storing very large constant tensors
    #   # in memory.
    #   train_writer = FeatureWriter(
    #       filename=tmp_filenames[hvd_rank],
    #       is_training=True)
    #   convert_examples_to_features(
    #       examples=train_examples[start_index:end_index],
    #       tokenizer=tokenizer,
    #       max_seq_length=FLAGS.max_seq_length,
    #       doc_stride=FLAGS.doc_stride,
    #       max_query_length=FLAGS.max_query_length,
    #       is_training=True,
    #       output_fn=train_writer.process_feature,
    #       verbose_logging=FLAGS.verbose_logging)
    #   train_writer.close()
    #
    #   tf.logging.info("***** Running training *****")
    #   tf.logging.info("  Num orig examples = %d", end_index - start_index)
    #   tf.logging.info("  Num split examples = %d", train_writer.num_features)
    #   tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    #   tf.logging.info("  Num steps = %d", num_train_steps)
    #   tf.logging.info("  LR = %f", learning_rate)
    #   del train_examples
    #
    #   train_input_fn = input_fn_builder(
    #       input_file=tmp_filenames,
    #       batch_size=FLAGS.train_batch_size,
    #       seq_length=FLAGS.max_seq_length,
    #       is_training=True,
    #       drop_remainder=True,
    #       hvd=None if not FLAGS.horovod else hvd)
    #
    #   train_start_time = time.time()
    #   estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=num_train_steps)
    #   train_time_elapsed = time.time() - train_start_time
    #   train_time_wo_overhead = training_hooks[-1].total_time
    #   avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
    #   ss_sentences_per_second = (num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead
    #
    #   if master_process:
    #       tf.logging.info("-----------------------------")
    #       tf.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
    #                       num_train_steps * global_batch_size)
    #       tf.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
    #                       (num_train_steps - training_hooks[-1].skipped) * global_batch_size)
    #       tf.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
    #       tf.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
    #       tf.logging.info("-----------------------------")

    if FLAGS.export_trtis and master_process:
        export_model(estimator, FLAGS.output_dir, FLAGS.init_checkpoint)