Beispiel #1
0
def main(_):

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_export:
        raise ValueError("At least one of `do_train`, `do_eval` must be True.")

    bert_config = flags.as_dictionary()

    if FLAGS.max_encoder_length > bert_config["max_position_embeddings"]:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_encoder_length, bert_config["max_position_embeddings"]))

    tf.io.gfile.makedirs(FLAGS.output_dir)
    if FLAGS.do_train:
        flags.save(os.path.join(FLAGS.output_dir, "pretrain.config"))

    model_fn = model_fn_builder(bert_config)
    estimator = utils.get_estimator(bert_config, model_fn)
    tmp_data_dir = os.path.join(FLAGS.output_dir, "tfds")

    if FLAGS.do_train:
        logging.info("***** Running training *****")
        logging.info("  Batch size = %d", estimator.train_batch_size)
        logging.info("  Num steps = %d", FLAGS.num_train_steps)
        train_input_fn = input_fn_builder(
            data_dir=FLAGS.data_dir,
            vocab_model_file=FLAGS.vocab_model_file,
            masked_lm_prob=FLAGS.masked_lm_prob,
            max_encoder_length=FLAGS.max_encoder_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            preprocessed_data=FLAGS.preprocessed_data,
            substitute_newline=FLAGS.substitute_newline,
            tmp_dir=tmp_data_dir,
            is_training=True)
        estimator.train(input_fn=train_input_fn,
                        max_steps=FLAGS.num_train_steps)

    if FLAGS.do_eval:
        logging.info("***** Running evaluation *****")
        logging.info("  Batch size = %d", estimator.eval_batch_size)

        eval_input_fn = input_fn_builder(
            data_dir=FLAGS.data_dir,
            vocab_model_file=FLAGS.vocab_model_file,
            masked_lm_prob=FLAGS.masked_lm_prob,
            max_encoder_length=FLAGS.max_encoder_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            preprocessed_data=FLAGS.preprocessed_data,
            substitute_newline=FLAGS.substitute_newline,
            tmp_dir=tmp_data_dir,
            is_training=False)

        # Run continuous evaluation for latest checkpoint as training progresses.
        last_evaluated = None
        while True:
            latest = tf.train.latest_checkpoint(FLAGS.output_dir)
            if latest == last_evaluated:
                if not latest:
                    logging.info("No checkpoints found yet.")
                else:
                    logging.info("Latest checkpoint %s already evaluated.",
                                 latest)
                time.sleep(300)
                continue
            else:
                logging.info("Evaluating check point %s", latest)
                last_evaluated = latest

                current_step = int(os.path.basename(latest).split("-")[1])
                output_eval_file = os.path.join(
                    FLAGS.output_dir,
                    "eval_results_{}.txt".format(current_step))
                result = estimator.evaluate(input_fn=eval_input_fn,
                                            steps=FLAGS.max_eval_steps,
                                            checkpoint_path=latest)

                with tf.io.gfile.GFile(output_eval_file, "w") as writer:
                    logging.info("***** Eval results *****")
                    for key in sorted(result.keys()):
                        logging.info("  %s = %s", key, str(result[key]))
                        writer.write("%s = %s\n" % (key, str(result[key])))

    if FLAGS.do_export:
        logging.info("***** Running export *****")

        serving_input_fn = serving_input_fn_builder(
            batch_size=FLAGS.eval_batch_size,
            vocab_model_file=FLAGS.vocab_model_file,
            max_encoder_length=FLAGS.max_encoder_length,
            substitute_newline=FLAGS.substitute_newline)

        estimator.export_saved_model(os.path.join(FLAGS.output_dir, "export"),
                                     serving_input_fn)
Beispiel #2
0
def main(_):

  if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_export:
    raise ValueError(
        "At least one of `do_train`, `do_eval` must be True.")

  bert_config = flags.as_dictionary()

  if FLAGS.max_encoder_length > bert_config["max_position_embeddings"]:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_encoder_length, bert_config["max_position_embeddings"]))

  tf.io.gfile.makedirs(FLAGS.output_dir)
  if FLAGS.do_train:
    flags.save(os.path.join(FLAGS.output_dir, "classifier.config"))

  model_fn = model_fn_builder(bert_config)
  estimator = utils.get_estimator(bert_config, model_fn)

  if FLAGS.do_train:
    logging.info("***** Running training *****")
    logging.info("  Batch size = %d", estimator.train_batch_size)
    logging.info("  Num steps = %d", FLAGS.num_train_steps)
    train_input_fn = input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        tmp_dir=os.path.join(FLAGS.output_dir, "tfds"),
        is_training=True)
    estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)

  if FLAGS.do_eval:
    logging.info("***** Running evaluation *****")
    logging.info("  Batch size = %d", estimator.eval_batch_size)

    eval_input_fn = input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        tmp_dir=os.path.join(FLAGS.output_dir, "tfds"),
        is_training=False)

    if FLAGS.use_tpu:
      with tf.compat.v1.Session() as sess:
        eval_steps = eval_input_fn({
            "batch_size": estimator.eval_batch_size
        }).cardinality().eval(session=sess)
    else:
      eval_steps = None

    # Run evaluation for each new checkpoint.
    all_ckpts = [
        v.split(".meta")[0] for v in tf.io.gfile.glob(
            os.path.join(FLAGS.output_dir, "model.ckpt*.meta"))
    ]
    all_ckpts = natsorted(all_ckpts)
    for ckpt in all_ckpts:
      current_step = int(os.path.basename(ckpt).split("-")[1])
      output_eval_file = os.path.join(
          FLAGS.output_dir, "eval_results_{}.txt".format(current_step))
      result = estimator.evaluate(input_fn=eval_input_fn,
                                  checkpoint_path=ckpt,
                                  steps=eval_steps)

      with tf.io.gfile.GFile(output_eval_file, "w") as writer:
        logging.info("***** Eval results *****")
        for key in sorted(result.keys()):
          logging.info("  %s = %s", key, str(result[key]))
          writer.write("%s = %s\n" % (key, str(result[key])))

  if FLAGS.do_export:
    logging.info("***** Running export *****")

    serving_input_fn = serving_input_fn_builder(
        batch_size=FLAGS.eval_batch_size,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline)

    estimator.export_saved_model(
        os.path.join(FLAGS.output_dir, "export"), serving_input_fn)