def test_finetune_and_predict(self): checkpoint = get_test_checkpoint() train_file, dev_file = get_test_data() with tempfile.TemporaryDirectory() as model_dir: # Sets new flags. FLAGS.init_checkpoint = os.path.join(checkpoint, "variables", "variables") FLAGS.bert_config_file = os.path.join(checkpoint, "bert_config.json") FLAGS.vocab_file = os.path.join(checkpoint, "vocab.txt") FLAGS.do_lower_case = True FLAGS.dynamic_seq_length = True FLAGS.max_seq_length = 512 FLAGS.model_dir = model_dir FLAGS.num_train_steps = 1 FLAGS.learning_rate = 0.00000000001 FLAGS.serialized_train_set = os.path.join(model_dir, "train.tfrecord") FLAGS.serialized_dev_set = os.path.join(model_dir, "dev.tfrecord") # Runs 1 training step. export = finetune.run_finetuning_pipeline(train_file, dev_file) # Checks if the pipeline produced a valid BLEURT checkpoint. self.assertTrue(tf.io.gfile.exists(export)) config = checkpoint_lib.read_bleurt_config(export) self.assertTrue(type(config), dict) # Runs a prediction. scorer = score.LengthBatchingBleurtScorer(export) scores = scorer.score(references=references, candidates=candidates) self.assertLen(scores, 2) self.assertAllClose(scores, ref_scores)
def test_finetune_from_bert(self): checkpoint = get_test_checkpoint() train_file, dev_file = get_test_data() with tempfile.TemporaryDirectory() as model_dir: # Sets new flags. FLAGS.model_dir = model_dir FLAGS.init_checkpoint = os.path.join(checkpoint, "variables", "variables") FLAGS.bert_config_file = os.path.join(checkpoint, "bert_config.json") FLAGS.vocab_file = os.path.join(checkpoint, "vocab.txt") FLAGS.do_lower_case = True FLAGS.max_seq_length = 512 FLAGS.num_train_steps = 1 FLAGS.serialized_train_set = os.path.join(model_dir, "train.tfrecord") FLAGS.serialized_dev_set = os.path.join(model_dir, "dev.tfrecord") # Runs 1 training step. export = finetune.run_finetuning_pipeline(train_file, dev_file) # Checks if the pipeline produced a valid BLEURT checkpoint. self.assertTrue(tf.io.gfile.exists(export)) config = checkpoint_lib.read_bleurt_config(export) self.assertTrue(type(config), dict)
def run_benchmark(): """Runs the WMT Metrics Benchmark end-to-end.""" logging.info("Running WMT Metrics Shared Task Benchmark") # Prepares the datasets. if not tf.io.gfile.exists(FLAGS.data_dir): logging.info("Creating directory {}".format(FLAGS.data_dir)) tf.io.gfile.mkdir(FLAGS.data_dir) train_ratings_file = os.path.join(FLAGS.data_dir, "train_ratings.json") dev_ratings_file = os.path.join(FLAGS.data_dir, "dev_ratings.json") test_ratings_file = os.path.join(FLAGS.data_dir, "test_ratings.json") for f in [train_ratings_file, dev_ratings_file, test_ratings_file]: if tf.io.gfile.exists(f): logging.info("Deleting existing file: {}".format(f)) tf.io.gfile.remove(f) print("Done.") logging.info("\n*** Creating training data. ***") db_builder.create_wmt_dataset(train_ratings_file, FLAGS.train_years, FLAGS.target_language) db_builder.postprocess(train_ratings_file) db_builder.shuffle_split(train_ratings_file, train_ratings_file, dev_ratings_file, dev_ratio=FLAGS.dev_ratio, prevent_leaks=FLAGS.prevent_leaks) logging.info("\n*** Creating test data. ***") db_builder.create_wmt_dataset(test_ratings_file, FLAGS.test_years, FLAGS.target_language) db_builder.postprocess(test_ratings_file, average_duplicates=FLAGS.average_duplicates_on_test) # Trains BLEURT. logging.info("\n*** Training BLEURT. ***") export_dir = finetune.run_finetuning_pipeline(train_ratings_file, dev_ratings_file) # Runs the eval. logging.info("\n*** Testing BLEURT. ***") if not FLAGS.results_json: results_json = os.path.join(FLAGS.data_dir, "results.json") else: results_json = FLAGS.results_json results = evaluator.eval_checkpoint(export_dir, test_ratings_file, results_json) logging.info(results) logging.info("\n*** Done. ***")