Exemple #1
0
    def __init__(self, checkpoint=None, predict_fn=None):
        """Initializes the BLEURT model.

    Args:
      checkpoint: BLEURT checkpoint. Will default to BLEURT-tiny if None.
      predict_fn: (optional) prediction function, overrides chkpt_dir. Mostly
        used for testing.

    Returns:
      A BLEURT scorer export.
    """
        if not checkpoint:
            logging.info("No checkpoint specified, defaulting to BLEURT-tiny.")
            checkpoint = _get_default_checkpoint()

        logging.info("Reading checkpoint {}.".format(checkpoint))
        config = checkpoint_lib.read_bleurt_config(checkpoint)
        max_seq_length = config["max_seq_length"]
        vocab_file = config["vocab_file"]
        do_lower_case = config["do_lower_case"]

        logging.info("Creating BLEURT scorer.")
        self.tokenizer = tokenization.FullTokenizer(
            vocab_file=vocab_file, do_lower_case=do_lower_case)
        self.max_seq_length = max_seq_length

        if predict_fn:
            self.predict_fn = predict_fn
            logging.info("BLEURT initialized.")
            return

        logging.info("Loading model...")
        self.chkpt_dir = checkpoint
        self.predict_fn = _make_predict_fn_from_checkpoint(checkpoint)
        logging.info("BLEURT initialized.")
Exemple #2
0
    def test_finetune_and_predict(self):
        checkpoint = get_test_checkpoint()
        train_file, dev_file = get_test_data()

        with tempfile.TemporaryDirectory() as model_dir:
            # Sets new flags.
            FLAGS.init_checkpoint = os.path.join(checkpoint, "variables",
                                                 "variables")
            FLAGS.bert_config_file = os.path.join(checkpoint,
                                                  "bert_config.json")
            FLAGS.vocab_file = os.path.join(checkpoint, "vocab.txt")
            FLAGS.do_lower_case = True
            FLAGS.dynamic_seq_length = True
            FLAGS.max_seq_length = 512
            FLAGS.model_dir = model_dir
            FLAGS.num_train_steps = 1
            FLAGS.learning_rate = 0.00000000001
            FLAGS.serialized_train_set = os.path.join(model_dir,
                                                      "train.tfrecord")
            FLAGS.serialized_dev_set = os.path.join(model_dir, "dev.tfrecord")

            # Runs 1 training step.
            export = finetune.run_finetuning_pipeline(train_file, dev_file)

            # Checks if the pipeline produced a valid BLEURT checkpoint.
            self.assertTrue(tf.io.gfile.exists(export))
            config = checkpoint_lib.read_bleurt_config(export)
            self.assertTrue(type(config), dict)

            # Runs a prediction.
            scorer = score.LengthBatchingBleurtScorer(export)
            scores = scorer.score(references=references, candidates=candidates)
            self.assertLen(scores, 2)
            self.assertAllClose(scores, ref_scores)
    def test_finetune_from_bert(self):
        checkpoint = get_test_checkpoint()
        train_file, dev_file = get_test_data()

        with tempfile.TemporaryDirectory() as model_dir:
            # Sets new flags.
            FLAGS.model_dir = model_dir
            FLAGS.init_checkpoint = os.path.join(checkpoint, "variables",
                                                 "variables")
            FLAGS.bert_config_file = os.path.join(checkpoint,
                                                  "bert_config.json")
            FLAGS.vocab_file = os.path.join(checkpoint, "vocab.txt")
            FLAGS.do_lower_case = True
            FLAGS.max_seq_length = 512
            FLAGS.num_train_steps = 1
            FLAGS.serialized_train_set = os.path.join(model_dir,
                                                      "train.tfrecord")
            FLAGS.serialized_dev_set = os.path.join(model_dir, "dev.tfrecord")

            # Runs 1 training step.
            export = finetune.run_finetuning_pipeline(train_file, dev_file)

            # Checks if the pipeline produced a valid BLEURT checkpoint.
            self.assertTrue(tf.io.gfile.exists(export))
            config = checkpoint_lib.read_bleurt_config(export)
            self.assertTrue(type(config), dict)
Exemple #4
0
    def __init__(self, checkpoint=None, predict_fn=None):
        """Initializes the BLEURT model.

    Args:
      checkpoint: BLEURT checkpoint. Will default to BLEURT-tiny if None.
      predict_fn: (optional) prediction function, overrides chkpt_dir. Mostly
        used for testing.

    Returns:
      A BLEURT scorer export.
    """
        if not checkpoint:
            logging.info("No checkpoint specified, defaulting to BLEURT-tiny.")
            checkpoint = _get_default_checkpoint()

        logging.info("Reading checkpoint {}.".format(checkpoint))
        self.config = checkpoint_lib.read_bleurt_config(checkpoint)
        max_seq_length = self.config["max_seq_length"]
        vocab_file = self.config["vocab_file"]
        do_lower_case = self.config["do_lower_case"]
        sp_model = self.config["sp_model"]

        logging.info("Creating BLEURT scorer.")
        self.tokenizer = tokenizers.create_tokenizer(
            vocab_file=vocab_file,
            do_lower_case=do_lower_case,
            sp_model=sp_model)
        self.max_seq_length = max_seq_length
        self._predictor = _create_predictor(checkpoint, predict_fn)
        self._predictor.initialize()
        logging.info("BLEURT initialized.")
Exemple #5
0
    def bleurt_ops(*args, references=None, candidates=None):
        """Builds computation graph for BLEURT.

    Args:
      *args: dummy positional arguments.
      references: <tf.string>[...] Tensor that contains reference sentences.
      candidates: <tf.string>[...] Tensor that contains candidate sentences.

    Returns:
      A <tf.float>[...] Tensor that contains BLEURT scores.
    """
        logging.info("Creating BLEURT TF Ops...")

        assert not args, (
            "This function does not accept positional arguments. Please "
            "specify the name of the arguments explicitly, i.e., "
            "`bleurt_ops(references=..., candidates=...`).")
        assert references is not None and candidates is not None

        logging.info("Reading info from checkpoint {}".format(checkpoint))
        config = checkpoint_lib.read_bleurt_config(checkpoint)
        max_seq_length = config["max_seq_length"]
        vocab_file = config["vocab_file"]
        do_lower_case = config["do_lower_case"]
        sp_model = config["sp_model"]

        logging.info("Creating tokenizer...")
        tokenizer = tokenizers.create_tokenizer(vocab_file=vocab_file,
                                                do_lower_case=do_lower_case,
                                                sp_model=sp_model)
        logging.info("Tokenizer created")
        logging.info("Creating BLEURT Preprocessing Ops...")
        bleurt_preprocessing_ops = create_bleurt_preprocessing_ops(
            tokenizer, max_seq_length)
        logging.info("Preprocessing Ops created.")

        logging.info("Loading checkpoint...")
        if not bleurt_model_fn:
            imported = tf.saved_model.load(checkpoint)
            bleurt_model_ops = imported.signatures["serving_default"]
        else:
            bleurt_model_ops = bleurt_model_fn
        logging.info("BLEURT Checkpoint loaded")

        input_ids, input_mask, segment_ids = bleurt_preprocessing_ops(
            references, candidates)
        out = bleurt_model_ops(input_ids=input_ids,
                               input_mask=input_mask,
                               segment_ids=segment_ids)
        logging.info("BLEURT TF Ops created.")
        return out
Exemple #6
0
    def bleurt_ops(references, candidates):
        """Builds computation graph for BLEURT.

    Args:
      references: <tf.string>[...] Tensor that contains reference sentences.
      candidates: <tf.string>[...] Tensor that contains candidate sentences.

    Returns:
      A <tf.float>[...] Tensor that contains BLEURT scores.
    """
        logging.info("Creating BLEURT TF Ops...")

        logging.info("Reading info from checkpoint {}".format(checkpoint))
        config = checkpoint_lib.read_bleurt_config(checkpoint)
        max_seq_length = config["max_seq_length"]
        vocab_file = config["vocab_file"]
        do_lower_case = config["do_lower_case"]

        logging.info("Creating tokenizer...")
        tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                               do_lower_case=do_lower_case)
        logging.info("Tokenizer created")
        logging.info("Creating BLEURT Preprocessing Ops...")
        bleurt_preprocessing_ops = create_bleurt_preprocessing_ops(
            tokenizer, max_seq_length)
        logging.info("Preprocessing Ops created.")

        logging.info("Loading checkpoint...")
        if not bleurt_model_fn:
            imported = tf.saved_model.load_v2(checkpoint)
            bleurt_model_ops = imported.signatures["serving_default"]
        else:
            bleurt_model_ops = bleurt_model_fn
        logging.info("BLEURT Checkpoint loaded")

        input_ids, input_mask, segment_ids = bleurt_preprocessing_ops(
            references, candidates)
        out = bleurt_model_ops(input_ids=input_ids,
                               input_mask=input_mask,
                               segment_ids=segment_ids)
        logging.info("BLEURT TF Ops created.")
        return out