def main(_):
    setup_xla_flags()

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

    if FLAGS.horovod:
      hvd.init()

    processors = {
        "bc5cdr": BC5CDRProcessor,
        "clefe": CLEFEProcessor,
        'i2b2': I2b22012Processor
    }
    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
       raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    task_name = FLAGS.task_name.lower()
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2

    master_process = True
    training_hooks = []
    global_batch_size = FLAGS.train_batch_size
    hvd_rank = 0

    config = tf.compat.v1.ConfigProto()
    if FLAGS.horovod:
      global_batch_size = FLAGS.train_batch_size * hvd.size()
      master_process = (hvd.rank() == 0)
      hvd_rank = hvd.rank()
      config.gpu_options.visible_device_list = str(hvd.local_rank())
      if hvd.size() > 1:
        training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

    if FLAGS.use_xla:
        config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
        if FLAGS.amp:
            tf.enable_resource_variables()
    run_config = tf.estimator.RunConfig(
      model_dir=FLAGS.output_dir if master_process else None,
      session_config=config,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None,
      keep_checkpoint_max=1)

    if master_process:
      tf.compat.v1.logging.info("***** Configuaration *****")
      for key in FLAGS.__flags.keys():
          tf.compat.v1.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
      tf.compat.v1.logging.info("**************************")

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank))

    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

        start_index = 0
        end_index = len(train_examples)
        tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]

        if FLAGS.horovod:
          tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
          num_examples_per_rank = len(train_examples) // hvd.size()
          remainder = len(train_examples) % hvd.size()
          if hvd.rank() < remainder:
            start_index = hvd.rank() * (num_examples_per_rank+1)
            end_index = start_index + num_examples_per_rank + 1
          else:
            start_index = hvd.rank() * num_examples_per_rank + remainder
            end_index = start_index + (num_examples_per_rank)

    model_fn = model_fn_builder(
        bert_config=bert_config,
        num_labels=len(label_list) + 1,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate if not FLAGS.horovod else FLAGS.learning_rate * hvd.size(),
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_one_hot_embeddings=False,
        hvd=None if not FLAGS.horovod else hvd,
        amp=FLAGS.amp)

    estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      config=run_config)

    if FLAGS.do_train:
        #train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        #filed_based_convert_examples_to_features(
        #    train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
        filed_based_convert_examples_to_features(
          train_examples[start_index:end_index], label_list, FLAGS.max_seq_length, tokenizer, tmp_filenames[hvd_rank])
        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Num examples = %d", len(train_examples))
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=tmp_filenames, #train_file,
            batch_size=FLAGS.train_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True,
            hvd=None if not FLAGS.horovod else hvd)
        
        #estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
        train_start_time = time.time()
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=training_hooks)
        train_time_elapsed = time.time() - train_start_time
        train_time_wo_overhead = training_hooks[-1].total_time
        avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
        ss_sentences_per_second = (num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead

        if master_process:
          tf.compat.v1.logging.info("-----------------------------")
          tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
                        num_train_steps * global_batch_size)
          tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
                        (num_train_steps - training_hooks[-1].skipped) * global_batch_size)
          tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
          tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
          dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
          tf.compat.v1.logging.info("-----------------------------")

    if FLAGS.do_eval and master_process:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        filed_based_convert_examples_to_features(
            eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)

        tf.compat.v1.logging.info("***** Running evaluation *****")
        tf.compat.v1.logging.info("  Num examples = %d", len(eval_examples))
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)
        eval_steps = None
        eval_drop_remainder = False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            batch_size=FLAGS.eval_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)
        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            tf.compat.v1.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
                dllogging.logger.log(step=(), data={key: float(str(result[key]))}, verbosity=Verbosity.DEFAULT)
                writer.write("%s = %s\n" % (key, str(result[key])))
    if FLAGS.do_predict and master_process:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        filed_based_convert_examples_to_features(predict_examples, label_list,
                                                 FLAGS.max_seq_length, tokenizer,
                                                 predict_file, mode="test")

        with tf.io.gfile.GFile(os.path.join(FLAGS.output_dir, 'label2id.pkl'), 'rb') as rf:
            label2id = pickle.load(rf)
            id2label = {value: key for key, value in label2id.items()}
        token_path = os.path.join(FLAGS.output_dir, "token_test.txt")
        if tf.io.gfile.exists(token_path):
            tf.io.gfile.remove(token_path)

        tf.compat.v1.logging.info("***** Running prediction*****")
        tf.compat.v1.logging.info("  Num examples = %d", len(predict_examples))
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            batch_size=FLAGS.predict_batch_size,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        eval_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
        eval_start_time = time.time()

        output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt")
        test_labels_file = os.path.join(FLAGS.output_dir, "test_labels.txt")
        test_labels_err_file = os.path.join(FLAGS.output_dir, "test_labels_errs.txt")
        with tf.io.gfile.GFile(output_predict_file, 'w') as writer, \
                tf.io.gfile.GFile(test_labels_file, 'w') as tl, \
                tf.io.gfile.GFile(test_labels_err_file, 'w') as tle:
            print(id2label)
            i=0
            for prediction in estimator.predict(input_fn=predict_input_fn, hooks=eval_hooks,
                                                yield_single_examples=True):
                output_line = "\n".join(id2label[id] for id in prediction if id != 0) + "\n"
                writer.write(output_line)
                result_to_pair(predict_examples[i], prediction, id2label, tl, tle)
                i = i + 1

        eval_time_elapsed = time.time() - eval_start_time

        time_list = eval_hooks[-1].time_list
        time_list.sort()
        # Removing outliers (init/warmup) in throughput computation.
        eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
        num_sentences = (int(len(time_list) * 0.99)) * FLAGS.predict_batch_size

        avg = np.mean(time_list)
        cf_50 = max(time_list[:int(len(time_list) * 0.50)])
        cf_90 = max(time_list[:int(len(time_list) * 0.90)])
        cf_95 = max(time_list[:int(len(time_list) * 0.95)])
        cf_99 = max(time_list[:int(len(time_list) * 0.99)])
        cf_100 = max(time_list[:int(len(time_list) * 1)])
        ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
                        eval_hooks[-1].count * FLAGS.predict_batch_size)
        tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
                        num_sentences)
        tf.compat.v1.logging.info("Summary Inference Statistics")
        tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
        tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
        tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
        tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
        tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
        tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
        dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")

        tf.compat.v1.logging.info('Reading: %s', test_labels_file)
        with tf.io.gfile.GFile(test_labels_file, "r") as f:
            counts = evaluate(f)
        eval_result = report_notprint(counts)
        print(''.join(eval_result))
        with tf.io.gfile.GFile(os.path.join(FLAGS.output_dir, 'test_results_conlleval.txt'), 'w') as fd:
            fd.write(''.join(eval_result))
def main(_):

  setup_xla_flags()

  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

  if FLAGS.horovod:
    hvd.init()

  processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
  }

  if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
    raise ValueError(
        "At least one of `do_train`, `do_eval` or `do_predict' must be True.")

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  tf.io.gfile.makedirs(FLAGS.output_dir)

  task_name = FLAGS.task_name.lower()

  if task_name not in processors:
    raise ValueError("Task not found: %s" % (task_name))

  processor = processors[task_name]()

  label_list = processor.get_labels()

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  master_process = True
  training_hooks = []
  global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
  hvd_rank = 0

  config = tf.compat.v1.ConfigProto()
  if FLAGS.horovod:

      tf.compat.v1.logging.info("Multi-GPU training with TF Horovod")
      tf.compat.v1.logging.info("hvd.size() = %d hvd.rank() = %d", hvd.size(), hvd.rank())
      global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size()
      master_process = (hvd.rank() == 0)
      hvd_rank = hvd.rank()
      config.gpu_options.visible_device_list = str(hvd.local_rank())
      set_affinity(hvd.local_rank())
      if hvd.size() > 1:
          training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
  if FLAGS.use_xla:
    config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
    if FLAGS.amp:
      tf.enable_resource_variables()

  run_config = tf.estimator.RunConfig(
      model_dir=FLAGS.output_dir if master_process else None,
      session_config=config,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None,
      save_summary_steps=FLAGS.save_checkpoints_steps if master_process else None,
      log_step_count_steps=FLAGS.display_loss_steps,
      keep_checkpoint_max=1)

  if master_process:
      tf.compat.v1.logging.info("***** Configuaration *****")
      for key in FLAGS.__flags.keys():
          tf.compat.v1.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
      tf.compat.v1.logging.info("**************************")

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None
  training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank, FLAGS.save_checkpoints_steps, num_steps_ignore_xla=25))

  if FLAGS.do_train:
    train_examples = processor.get_train_examples(FLAGS.data_dir)
    num_train_steps = int(
        len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    start_index = 0
    end_index = len(train_examples)
    tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]

    if FLAGS.horovod:
      tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
      num_examples_per_rank = len(train_examples) // hvd.size()
      remainder = len(train_examples) % hvd.size()
      if hvd.rank() < remainder:
        start_index = hvd.rank() * (num_examples_per_rank+1)
        end_index = start_index + num_examples_per_rank + 1
      else:
        start_index = hvd.rank() * num_examples_per_rank + remainder
        end_index = start_index + (num_examples_per_rank)

  model_fn = model_fn_builder(
      task_name=task_name,
      bert_config=bert_config,
      num_labels=len(label_list),
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate if not FLAGS.horovod else FLAGS.learning_rate * hvd.size(),
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_one_hot_embeddings=False,
      hvd=None if not FLAGS.horovod else hvd)

  estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      config=run_config)

  if FLAGS.do_train:

    file_based_convert_examples_to_features(
        train_examples[start_index:end_index], label_list, FLAGS.max_seq_length, tokenizer, tmp_filenames[hvd_rank])

    tf.compat.v1.logging.info("***** Running training *****")
    tf.compat.v1.logging.info("  Num examples = %d", len(train_examples))
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
    train_input_fn = file_based_input_fn_builder(
        input_file=tmp_filenames,
        batch_size=FLAGS.train_batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True,
        hvd=None if not FLAGS.horovod else hvd)

    train_start_time = time.time()
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=training_hooks)
    train_time_elapsed = time.time() - train_start_time
    train_time_wo_overhead = training_hooks[-1].total_time
    avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
    ss_sentences_per_second = (training_hooks[-1].count - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead

    if master_process:
        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
                        num_train_steps * global_batch_size)
        tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
                        (training_hooks[-1].count - training_hooks[-1].skipped) * global_batch_size)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
        tf.compat.v1.logging.info("-----------------------------")

  if FLAGS.do_eval and master_process:
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
    file_based_convert_examples_to_features(
        eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)

    tf.compat.v1.logging.info("***** Running evaluation *****")
    tf.compat.v1.logging.info("  Num examples = %d", len(eval_examples))
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

    eval_drop_remainder = False
    eval_input_fn = file_based_input_fn_builder(
        input_file=eval_file,
        batch_size=FLAGS.eval_batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=eval_drop_remainder)

    eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
    eval_start_time = time.time()
    result = estimator.evaluate(input_fn=eval_input_fn, hooks=eval_hooks)

    eval_time_elapsed = time.time() - eval_start_time

    time_list = eval_hooks[-1].time_list
    time_list.sort()
    # Removing outliers (init/warmup) in throughput computation.
    eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.8)])
    num_sentences = (int(len(time_list) * 0.8)) * FLAGS.eval_batch_size

    avg = np.mean(time_list)
    cf_50 = max(time_list[:int(len(time_list) * 0.50)])
    cf_90 = max(time_list[:int(len(time_list) * 0.90)])
    cf_95 = max(time_list[:int(len(time_list) * 0.95)])
    cf_99 = max(time_list[:int(len(time_list) * 0.99)])
    cf_100 = max(time_list[:int(len(time_list) * 1)])
    ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

    tf.compat.v1.logging.info("-----------------------------")
    tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
                    eval_hooks[-1].count * FLAGS.eval_batch_size)
    tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
                    num_sentences)
    tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set")
    tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size)
    tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
    tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
    tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
    tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
    tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
    dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
    tf.compat.v1.logging.info("-----------------------------")


    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
    with tf.io.gfile.GFile(output_eval_file, "w") as writer:
      tf.compat.v1.logging.info("***** Eval results *****")
      for key in sorted(result.keys()):
        dllogging.logger.log(step=(), data={key: float(result[key])}, verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))

  if FLAGS.do_predict and master_process:
    predict_examples = processor.get_test_examples(FLAGS.data_dir)
    predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
    file_based_convert_examples_to_features(predict_examples, label_list,
                                            FLAGS.max_seq_length, tokenizer,
                                            predict_file)

    tf.compat.v1.logging.info("***** Running prediction*****")
    tf.compat.v1.logging.info("  Num examples = %d", len(predict_examples))
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_drop_remainder = False
    predict_input_fn = file_based_input_fn_builder(
        input_file=predict_file,
        batch_size=FLAGS.predict_batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=predict_drop_remainder)

    predict_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
    predict_start_time = time.time()

    output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
    with tf.io.gfile.GFile(output_predict_file, "w") as writer:
        tf.compat.v1.logging.info("***** Predict results *****")
        for prediction in estimator.predict(input_fn=predict_input_fn, hooks=predict_hooks,
                                            yield_single_examples=False):
            output_line = "\t".join(
                str(class_probability) for class_probability in prediction) + "\n"
            writer.write(output_line)


    predict_time_elapsed = time.time() - predict_start_time

    time_list = predict_hooks[-1].time_list
    time_list.sort()
    # Removing outliers (init/warmup) in throughput computation.
    predict_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.8)])
    num_sentences = (int(len(time_list) * 0.8)) * FLAGS.predict_batch_size

    avg = np.mean(time_list)
    cf_50 = max(time_list[:int(len(time_list) * 0.50)])
    cf_90 = max(time_list[:int(len(time_list) * 0.90)])
    cf_95 = max(time_list[:int(len(time_list) * 0.95)])
    cf_99 = max(time_list[:int(len(time_list) * 0.99)])
    cf_100 = max(time_list[:int(len(time_list) * 1)])
    ss_sentences_per_second = num_sentences * 1.0 / predict_time_wo_overhead

    tf.compat.v1.logging.info("-----------------------------")
    tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", predict_time_elapsed,
                    predict_hooks[-1].count * FLAGS.predict_batch_size)
    tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", predict_time_wo_overhead,
                              num_sentences)
    tf.compat.v1.logging.info("Summary Inference Statistics on TEST SET")
    tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
    tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
    tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
    tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
    tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
    tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
    dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
    tf.compat.v1.logging.info("-----------------------------")
Exemple #3
0
def main(_):
  setup_xla_flags()

  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

  if not FLAGS.do_train and not FLAGS.do_eval:
    raise ValueError("At least one of `do_train` or `do_eval` must be True.")

  if FLAGS.horovod:
    import horovod.tensorflow as hvd
    hvd.init()

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  tf.io.gfile.makedirs(FLAGS.output_dir)

  input_files = []
  for input_file_dir in FLAGS.input_files_dir.split(","):
    input_files.extend(tf.io.gfile.glob(os.path.join(input_file_dir, "*")))

  if FLAGS.horovod and len(input_files) < hvd.size():
      raise ValueError("Input Files must be sharded")
  if FLAGS.amp and FLAGS.manual_fp16:
      raise ValueError("AMP and Manual Mixed Precision Training are both activated! Error")

  is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  config = tf.compat.v1.ConfigProto()
  if FLAGS.horovod:
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    set_affinity(hvd.local_rank())
    if hvd.rank() == 0:
      tf.compat.v1.logging.info("***** Configuaration *****")
      for key in FLAGS.__flags.keys():
          tf.compat.v1.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
      tf.compat.v1.logging.info("**************************")

#    config.gpu_options.per_process_gpu_memory_fraction = 0.7
  if FLAGS.use_xla: 
      config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
      config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
      if FLAGS.amp:
        tf.enable_resource_variables()

  run_config = tf.estimator.RunConfig(
      model_dir=FLAGS.output_dir,
      session_config=config,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps if not FLAGS.horovod or hvd.rank() == 0 else None,
      save_summary_steps=FLAGS.save_checkpoints_steps if not FLAGS.horovod or hvd.rank() == 0 else None,
      # This variable controls how often estimator reports examples/sec.
      # Default value is every 100 steps.
      # When --report_loss is True, we set to very large value to prevent
      # default info reporting from estimator.
      # Ideally we should set it to None, but that does not work.
      log_step_count_steps=10000 if FLAGS.report_loss else 100)

  model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate if not FLAGS.horovod else FLAGS.learning_rate*hvd.size(),
      num_train_steps=FLAGS.num_train_steps,
      num_warmup_steps=FLAGS.num_warmup_steps,
      use_one_hot_embeddings=False,
      hvd=None if not FLAGS.horovod else hvd)

  estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      config=run_config)

  if FLAGS.do_train:

    training_hooks = []
    if FLAGS.horovod and hvd.size() > 1:
      training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
    if (not FLAGS.horovod or hvd.rank() == 0):
      global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size()
      training_hooks.append(_LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss))

    tf.compat.v1.logging.info("***** Running training *****")
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    train_input_fn = input_fn_builder(
        input_files=input_files,
        batch_size=FLAGS.train_batch_size,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        is_training=True,
        hvd=None if not FLAGS.horovod else hvd)

    train_start_time = time.time()
    estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=FLAGS.num_train_steps)
    train_time_elapsed = time.time() - train_start_time

    if (not FLAGS.horovod or hvd.rank() == 0):
        train_time_wo_overhead = training_hooks[-1].total_time
        avg_sentences_per_second = FLAGS.num_train_steps * global_batch_size * 1.0 / train_time_elapsed
        ss_sentences_per_second = (FLAGS.num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
                        FLAGS.num_train_steps * global_batch_size)
        tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
                        (FLAGS.num_train_steps - training_hooks[-1].skipped) * global_batch_size)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
        dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")

  if FLAGS.do_eval and (not FLAGS.horovod or hvd.rank() == 0):
    tf.compat.v1.logging.info("***** Running evaluation *****")
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

    eval_files = []
    for eval_file_dir in FLAGS.eval_files_dir.split(","):
        eval_files.extend(tf.io.gfile.glob(os.path.join(eval_file_dir, "*")))

    eval_input_fn = input_fn_builder(
        input_files=eval_files,
        batch_size=FLAGS.eval_batch_size,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        is_training=False,
        hvd=None if not FLAGS.horovod else hvd)

    eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
    eval_start_time = time.time()
    result = estimator.evaluate(
        input_fn=eval_input_fn, steps=FLAGS.max_eval_steps, hooks=eval_hooks)

    eval_time_elapsed = time.time() - eval_start_time
    time_list = eval_hooks[-1].time_list
    time_list.sort()
    # Removing outliers (init/warmup) in throughput computation.
    eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
    num_sentences = (int(len(time_list) * 0.99)) * FLAGS.eval_batch_size

    ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

    tf.compat.v1.logging.info("-----------------------------")
    tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
                    eval_hooks[-1].count * FLAGS.eval_batch_size)
    tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
                    num_sentences)
    tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set")
    tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size)
    tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
    tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
    tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
    dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
    tf.compat.v1.logging.info("-----------------------------")

    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
    with tf.io.gfile.GFile(output_eval_file, "w") as writer:
      tf.compat.v1.logging.info("***** Eval results *****")
      for key in sorted(result.keys()):
        tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))