def __init__(self, checkpoint=None, predict_fn=None): """Initializes the BLEURT model. Args: checkpoint: BLEURT checkpoint. Will default to BLEURT-tiny if None. predict_fn: (optional) prediction function, overrides chkpt_dir. Mostly used for testing. Returns: A BLEURT scorer export. """ if not checkpoint: logging.info("No checkpoint specified, defaulting to BLEURT-tiny.") checkpoint = _get_default_checkpoint() logging.info("Reading checkpoint {}.".format(checkpoint)) config = checkpoint_lib.read_bleurt_config(checkpoint) max_seq_length = config["max_seq_length"] vocab_file = config["vocab_file"] do_lower_case = config["do_lower_case"] logging.info("Creating BLEURT scorer.") self.tokenizer = tokenization.FullTokenizer( vocab_file=vocab_file, do_lower_case=do_lower_case) self.max_seq_length = max_seq_length if predict_fn: self.predict_fn = predict_fn logging.info("BLEURT initialized.") return logging.info("Loading model...") self.chkpt_dir = checkpoint self.predict_fn = _make_predict_fn_from_checkpoint(checkpoint) logging.info("BLEURT initialized.")
def test_finetune_and_predict(self): checkpoint = get_test_checkpoint() train_file, dev_file = get_test_data() with tempfile.TemporaryDirectory() as model_dir: # Sets new flags. FLAGS.init_checkpoint = os.path.join(checkpoint, "variables", "variables") FLAGS.bert_config_file = os.path.join(checkpoint, "bert_config.json") FLAGS.vocab_file = os.path.join(checkpoint, "vocab.txt") FLAGS.do_lower_case = True FLAGS.dynamic_seq_length = True FLAGS.max_seq_length = 512 FLAGS.model_dir = model_dir FLAGS.num_train_steps = 1 FLAGS.learning_rate = 0.00000000001 FLAGS.serialized_train_set = os.path.join(model_dir, "train.tfrecord") FLAGS.serialized_dev_set = os.path.join(model_dir, "dev.tfrecord") # Runs 1 training step. export = finetune.run_finetuning_pipeline(train_file, dev_file) # Checks if the pipeline produced a valid BLEURT checkpoint. self.assertTrue(tf.io.gfile.exists(export)) config = checkpoint_lib.read_bleurt_config(export) self.assertTrue(type(config), dict) # Runs a prediction. scorer = score.LengthBatchingBleurtScorer(export) scores = scorer.score(references=references, candidates=candidates) self.assertLen(scores, 2) self.assertAllClose(scores, ref_scores)
def test_finetune_from_bert(self): checkpoint = get_test_checkpoint() train_file, dev_file = get_test_data() with tempfile.TemporaryDirectory() as model_dir: # Sets new flags. FLAGS.model_dir = model_dir FLAGS.init_checkpoint = os.path.join(checkpoint, "variables", "variables") FLAGS.bert_config_file = os.path.join(checkpoint, "bert_config.json") FLAGS.vocab_file = os.path.join(checkpoint, "vocab.txt") FLAGS.do_lower_case = True FLAGS.max_seq_length = 512 FLAGS.num_train_steps = 1 FLAGS.serialized_train_set = os.path.join(model_dir, "train.tfrecord") FLAGS.serialized_dev_set = os.path.join(model_dir, "dev.tfrecord") # Runs 1 training step. export = finetune.run_finetuning_pipeline(train_file, dev_file) # Checks if the pipeline produced a valid BLEURT checkpoint. self.assertTrue(tf.io.gfile.exists(export)) config = checkpoint_lib.read_bleurt_config(export) self.assertTrue(type(config), dict)
def __init__(self, checkpoint=None, predict_fn=None): """Initializes the BLEURT model. Args: checkpoint: BLEURT checkpoint. Will default to BLEURT-tiny if None. predict_fn: (optional) prediction function, overrides chkpt_dir. Mostly used for testing. Returns: A BLEURT scorer export. """ if not checkpoint: logging.info("No checkpoint specified, defaulting to BLEURT-tiny.") checkpoint = _get_default_checkpoint() logging.info("Reading checkpoint {}.".format(checkpoint)) self.config = checkpoint_lib.read_bleurt_config(checkpoint) max_seq_length = self.config["max_seq_length"] vocab_file = self.config["vocab_file"] do_lower_case = self.config["do_lower_case"] sp_model = self.config["sp_model"] logging.info("Creating BLEURT scorer.") self.tokenizer = tokenizers.create_tokenizer( vocab_file=vocab_file, do_lower_case=do_lower_case, sp_model=sp_model) self.max_seq_length = max_seq_length self._predictor = _create_predictor(checkpoint, predict_fn) self._predictor.initialize() logging.info("BLEURT initialized.")
def bleurt_ops(*args, references=None, candidates=None): """Builds computation graph for BLEURT. Args: *args: dummy positional arguments. references: <tf.string>[...] Tensor that contains reference sentences. candidates: <tf.string>[...] Tensor that contains candidate sentences. Returns: A <tf.float>[...] Tensor that contains BLEURT scores. """ logging.info("Creating BLEURT TF Ops...") assert not args, ( "This function does not accept positional arguments. Please " "specify the name of the arguments explicitly, i.e., " "`bleurt_ops(references=..., candidates=...`).") assert references is not None and candidates is not None logging.info("Reading info from checkpoint {}".format(checkpoint)) config = checkpoint_lib.read_bleurt_config(checkpoint) max_seq_length = config["max_seq_length"] vocab_file = config["vocab_file"] do_lower_case = config["do_lower_case"] sp_model = config["sp_model"] logging.info("Creating tokenizer...") tokenizer = tokenizers.create_tokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case, sp_model=sp_model) logging.info("Tokenizer created") logging.info("Creating BLEURT Preprocessing Ops...") bleurt_preprocessing_ops = create_bleurt_preprocessing_ops( tokenizer, max_seq_length) logging.info("Preprocessing Ops created.") logging.info("Loading checkpoint...") if not bleurt_model_fn: imported = tf.saved_model.load(checkpoint) bleurt_model_ops = imported.signatures["serving_default"] else: bleurt_model_ops = bleurt_model_fn logging.info("BLEURT Checkpoint loaded") input_ids, input_mask, segment_ids = bleurt_preprocessing_ops( references, candidates) out = bleurt_model_ops(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids) logging.info("BLEURT TF Ops created.") return out
def bleurt_ops(references, candidates): """Builds computation graph for BLEURT. Args: references: <tf.string>[...] Tensor that contains reference sentences. candidates: <tf.string>[...] Tensor that contains candidate sentences. Returns: A <tf.float>[...] Tensor that contains BLEURT scores. """ logging.info("Creating BLEURT TF Ops...") logging.info("Reading info from checkpoint {}".format(checkpoint)) config = checkpoint_lib.read_bleurt_config(checkpoint) max_seq_length = config["max_seq_length"] vocab_file = config["vocab_file"] do_lower_case = config["do_lower_case"] logging.info("Creating tokenizer...") tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) logging.info("Tokenizer created") logging.info("Creating BLEURT Preprocessing Ops...") bleurt_preprocessing_ops = create_bleurt_preprocessing_ops( tokenizer, max_seq_length) logging.info("Preprocessing Ops created.") logging.info("Loading checkpoint...") if not bleurt_model_fn: imported = tf.saved_model.load_v2(checkpoint) bleurt_model_ops = imported.signatures["serving_default"] else: bleurt_model_ops = bleurt_model_fn logging.info("BLEURT Checkpoint loaded") input_ids, input_mask, segment_ids = bleurt_preprocessing_ops( references, candidates) out = bleurt_model_ops(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids) logging.info("BLEURT TF Ops created.") return out