def __init__(self, checkpoint=None, predict_fn=None): """Initializes the BLEURT model. Args: checkpoint: BLEURT checkpoint. Will default to BLEURT-tiny if None. predict_fn: (optional) prediction function, overrides chkpt_dir. Mostly used for testing. Returns: A BLEURT scorer export. """ if not checkpoint: logging.info("No checkpoint specified, defaulting to BLEURT-tiny.") checkpoint = _get_default_checkpoint() logging.info("Reading checkpoint {}.".format(checkpoint)) config = checkpoint_lib.read_bleurt_config(checkpoint) max_seq_length = config["max_seq_length"] vocab_file = config["vocab_file"] do_lower_case = config["do_lower_case"] logging.info("Creating BLEURT scorer.") self.tokenizer = tokenization.FullTokenizer( vocab_file=vocab_file, do_lower_case=do_lower_case) self.max_seq_length = max_seq_length if predict_fn: self.predict_fn = predict_fn logging.info("BLEURT initialized.") return logging.info("Loading model...") self.chkpt_dir = checkpoint self.predict_fn = _make_predict_fn_from_checkpoint(checkpoint) logging.info("BLEURT initialized.")
def encode_and_serialize(input_file, output_file, vocab_file, do_lower_case, max_seq_length): """Encodes and serializes a set of ratings in JSON format.""" assert tf.io.gfile.exists(input_file), "Could not find file." logging.info("Reading data...") with tf.io.gfile.GFile(input_file, "r") as f: examples_df = pd.read_json(f, lines=True) for col in ["reference", "candidate", "score"]: assert col in examples_df.columns, \ "field {} not found in input file!".format(col) n_records = len(examples_df) logging.info("Read {} examples.".format(n_records)) logging.info("Encoding and writing TFRecord file...") tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) with tf.python_io.TFRecordWriter(output_file) as writer: iterator_id, iterator_cycle = 0, max(int(n_records / 10), 1) for record in examples_df.itertuples(index=False): iterator_id += 1 if iterator_id % iterator_cycle == 0: logging.info("Writing example %d of %d", iterator_id, n_records) tf_example = serialize_example(record.reference, record.candidate, tokenizer, max_seq_length, score=record.score) writer.write(tf_example) logging.info("Done writing {} tf examples.".format(n_records))
def bleurt_ops(references, candidates): """Builds computation graph for BLEURT. Args: references: <tf.string>[...] Tensor that contains reference sentences. candidates: <tf.string>[...] Tensor that contains candidate sentences. Returns: A <tf.float>[...] Tensor that contains BLEURT scores. """ logging.info("Creating BLEURT TF Ops...") logging.info("Reading info from checkpoint {}".format(checkpoint)) config = checkpoint_lib.read_bleurt_config(checkpoint) max_seq_length = config["max_seq_length"] vocab_file = config["vocab_file"] do_lower_case = config["do_lower_case"] logging.info("Creating tokenizer...") tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) logging.info("Tokenizer created") logging.info("Creating BLEURT Preprocessing Ops...") bleurt_preprocessing_ops = create_bleurt_preprocessing_ops( tokenizer, max_seq_length) logging.info("Preprocessing Ops created.") logging.info("Loading checkpoint...") if not bleurt_model_fn: imported = tf.saved_model.load_v2(checkpoint) bleurt_model_ops = imported.signatures["serving_default"] else: bleurt_model_ops = bleurt_model_fn logging.info("BLEURT Checkpoint loaded") input_ids, input_mask, segment_ids = bleurt_preprocessing_ops( references, candidates) out = bleurt_model_ops(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids) logging.info("BLEURT TF Ops created.") return out