def main(_): if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_export: raise ValueError("At least one of `do_train`, `do_eval` must be True.") bert_config = flags.as_dictionary() if FLAGS.max_encoder_length > bert_config["max_position_embeddings"]: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_encoder_length, bert_config["max_position_embeddings"])) tf.io.gfile.makedirs(FLAGS.output_dir) if FLAGS.do_train: flags.save(os.path.join(FLAGS.output_dir, "pretrain.config")) model_fn = model_fn_builder(bert_config) estimator = utils.get_estimator(bert_config, model_fn) tmp_data_dir = os.path.join(FLAGS.output_dir, "tfds") if FLAGS.do_train: logging.info("***** Running training *****") logging.info(" Batch size = %d", estimator.train_batch_size) logging.info(" Num steps = %d", FLAGS.num_train_steps) train_input_fn = input_fn_builder( data_dir=FLAGS.data_dir, vocab_model_file=FLAGS.vocab_model_file, masked_lm_prob=FLAGS.masked_lm_prob, max_encoder_length=FLAGS.max_encoder_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, preprocessed_data=FLAGS.preprocessed_data, substitute_newline=FLAGS.substitute_newline, tmp_dir=tmp_data_dir, is_training=True) estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) if FLAGS.do_eval: logging.info("***** Running evaluation *****") logging.info(" Batch size = %d", estimator.eval_batch_size) eval_input_fn = input_fn_builder( data_dir=FLAGS.data_dir, vocab_model_file=FLAGS.vocab_model_file, masked_lm_prob=FLAGS.masked_lm_prob, max_encoder_length=FLAGS.max_encoder_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, preprocessed_data=FLAGS.preprocessed_data, substitute_newline=FLAGS.substitute_newline, tmp_dir=tmp_data_dir, is_training=False) # Run continuous evaluation for latest checkpoint as training progresses. last_evaluated = None while True: latest = tf.train.latest_checkpoint(FLAGS.output_dir) if latest == last_evaluated: if not latest: logging.info("No checkpoints found yet.") else: logging.info("Latest checkpoint %s already evaluated.", latest) time.sleep(300) continue else: logging.info("Evaluating check point %s", latest) last_evaluated = latest current_step = int(os.path.basename(latest).split("-")[1]) output_eval_file = os.path.join( FLAGS.output_dir, "eval_results_{}.txt".format(current_step)) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps, checkpoint_path=latest) with tf.io.gfile.GFile(output_eval_file, "w") as writer: logging.info("***** Eval results *****") for key in sorted(result.keys()): logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) if FLAGS.do_export: logging.info("***** Running export *****") serving_input_fn = serving_input_fn_builder( batch_size=FLAGS.eval_batch_size, vocab_model_file=FLAGS.vocab_model_file, max_encoder_length=FLAGS.max_encoder_length, substitute_newline=FLAGS.substitute_newline) estimator.export_saved_model(os.path.join(FLAGS.output_dir, "export"), serving_input_fn)
def main(_): if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_export: raise ValueError( "At least one of `do_train`, `do_eval` must be True.") bert_config = flags.as_dictionary() if FLAGS.max_encoder_length > bert_config["max_position_embeddings"]: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_encoder_length, bert_config["max_position_embeddings"])) tf.io.gfile.makedirs(FLAGS.output_dir) if FLAGS.do_train: flags.save(os.path.join(FLAGS.output_dir, "classifier.config")) model_fn = model_fn_builder(bert_config) estimator = utils.get_estimator(bert_config, model_fn) if FLAGS.do_train: logging.info("***** Running training *****") logging.info(" Batch size = %d", estimator.train_batch_size) logging.info(" Num steps = %d", FLAGS.num_train_steps) train_input_fn = input_fn_builder( data_dir=FLAGS.data_dir, vocab_model_file=FLAGS.vocab_model_file, max_encoder_length=FLAGS.max_encoder_length, substitute_newline=FLAGS.substitute_newline, tmp_dir=os.path.join(FLAGS.output_dir, "tfds"), is_training=True) estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) if FLAGS.do_eval: logging.info("***** Running evaluation *****") logging.info(" Batch size = %d", estimator.eval_batch_size) eval_input_fn = input_fn_builder( data_dir=FLAGS.data_dir, vocab_model_file=FLAGS.vocab_model_file, max_encoder_length=FLAGS.max_encoder_length, substitute_newline=FLAGS.substitute_newline, tmp_dir=os.path.join(FLAGS.output_dir, "tfds"), is_training=False) if FLAGS.use_tpu: with tf.compat.v1.Session() as sess: eval_steps = eval_input_fn({ "batch_size": estimator.eval_batch_size }).cardinality().eval(session=sess) else: eval_steps = None # Run evaluation for each new checkpoint. all_ckpts = [ v.split(".meta")[0] for v in tf.io.gfile.glob( os.path.join(FLAGS.output_dir, "model.ckpt*.meta")) ] all_ckpts = natsorted(all_ckpts) for ckpt in all_ckpts: current_step = int(os.path.basename(ckpt).split("-")[1]) output_eval_file = os.path.join( FLAGS.output_dir, "eval_results_{}.txt".format(current_step)) result = estimator.evaluate(input_fn=eval_input_fn, checkpoint_path=ckpt, steps=eval_steps) with tf.io.gfile.GFile(output_eval_file, "w") as writer: logging.info("***** Eval results *****") for key in sorted(result.keys()): logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) if FLAGS.do_export: logging.info("***** Running export *****") serving_input_fn = serving_input_fn_builder( batch_size=FLAGS.eval_batch_size, vocab_model_file=FLAGS.vocab_model_file, max_encoder_length=FLAGS.max_encoder_length, substitute_newline=FLAGS.substitute_newline) estimator.export_saved_model( os.path.join(FLAGS.output_dir, "export"), serving_input_fn)