示例#1
0
    def test_finetune_and_predict(self):
        checkpoint = get_test_checkpoint()
        train_file, dev_file = get_test_data()

        with tempfile.TemporaryDirectory() as model_dir:
            # Sets new flags.
            FLAGS.init_checkpoint = os.path.join(checkpoint, "variables",
                                                 "variables")
            FLAGS.bert_config_file = os.path.join(checkpoint,
                                                  "bert_config.json")
            FLAGS.vocab_file = os.path.join(checkpoint, "vocab.txt")
            FLAGS.do_lower_case = True
            FLAGS.dynamic_seq_length = True
            FLAGS.max_seq_length = 512
            FLAGS.model_dir = model_dir
            FLAGS.num_train_steps = 1
            FLAGS.learning_rate = 0.00000000001
            FLAGS.serialized_train_set = os.path.join(model_dir,
                                                      "train.tfrecord")
            FLAGS.serialized_dev_set = os.path.join(model_dir, "dev.tfrecord")

            # Runs 1 training step.
            export = finetune.run_finetuning_pipeline(train_file, dev_file)

            # Checks if the pipeline produced a valid BLEURT checkpoint.
            self.assertTrue(tf.io.gfile.exists(export))
            config = checkpoint_lib.read_bleurt_config(export)
            self.assertTrue(type(config), dict)

            # Runs a prediction.
            scorer = score.LengthBatchingBleurtScorer(export)
            scores = scorer.score(references=references, candidates=candidates)
            self.assertLen(scores, 2)
            self.assertAllClose(scores, ref_scores)
示例#2
0
    def test_finetune_from_bert(self):
        checkpoint = get_test_checkpoint()
        train_file, dev_file = get_test_data()

        with tempfile.TemporaryDirectory() as model_dir:
            # Sets new flags.
            FLAGS.model_dir = model_dir
            FLAGS.init_checkpoint = os.path.join(checkpoint, "variables",
                                                 "variables")
            FLAGS.bert_config_file = os.path.join(checkpoint,
                                                  "bert_config.json")
            FLAGS.vocab_file = os.path.join(checkpoint, "vocab.txt")
            FLAGS.do_lower_case = True
            FLAGS.max_seq_length = 512
            FLAGS.num_train_steps = 1
            FLAGS.serialized_train_set = os.path.join(model_dir,
                                                      "train.tfrecord")
            FLAGS.serialized_dev_set = os.path.join(model_dir, "dev.tfrecord")

            # Runs 1 training step.
            export = finetune.run_finetuning_pipeline(train_file, dev_file)

            # Checks if the pipeline produced a valid BLEURT checkpoint.
            self.assertTrue(tf.io.gfile.exists(export))
            config = checkpoint_lib.read_bleurt_config(export)
            self.assertTrue(type(config), dict)
示例#3
0
def run_benchmark():
    """Runs the WMT Metrics Benchmark end-to-end."""
    logging.info("Running WMT Metrics Shared Task Benchmark")

    # Prepares the datasets.
    if not tf.io.gfile.exists(FLAGS.data_dir):
        logging.info("Creating directory {}".format(FLAGS.data_dir))
        tf.io.gfile.mkdir(FLAGS.data_dir)

    train_ratings_file = os.path.join(FLAGS.data_dir, "train_ratings.json")
    dev_ratings_file = os.path.join(FLAGS.data_dir, "dev_ratings.json")
    test_ratings_file = os.path.join(FLAGS.data_dir, "test_ratings.json")

    for f in [train_ratings_file, dev_ratings_file, test_ratings_file]:
        if tf.io.gfile.exists(f):
            logging.info("Deleting existing file: {}".format(f))
            tf.io.gfile.remove(f)
            print("Done.")

    logging.info("\n*** Creating training data. ***")
    db_builder.create_wmt_dataset(train_ratings_file, FLAGS.train_years,
                                  FLAGS.target_language)
    db_builder.postprocess(train_ratings_file)
    db_builder.shuffle_split(train_ratings_file,
                             train_ratings_file,
                             dev_ratings_file,
                             dev_ratio=FLAGS.dev_ratio,
                             prevent_leaks=FLAGS.prevent_leaks)

    logging.info("\n*** Creating test data. ***")
    db_builder.create_wmt_dataset(test_ratings_file, FLAGS.test_years,
                                  FLAGS.target_language)
    db_builder.postprocess(test_ratings_file,
                           average_duplicates=FLAGS.average_duplicates_on_test)

    # Trains BLEURT.
    logging.info("\n*** Training BLEURT. ***")
    export_dir = finetune.run_finetuning_pipeline(train_ratings_file,
                                                  dev_ratings_file)

    # Runs the eval.
    logging.info("\n*** Testing BLEURT. ***")
    if not FLAGS.results_json:
        results_json = os.path.join(FLAGS.data_dir, "results.json")
    else:
        results_json = FLAGS.results_json
    results = evaluator.eval_checkpoint(export_dir, test_ratings_file,
                                        results_json)
    logging.info(results)
    logging.info("\n*** Done. ***")