示例#1
0
    def __init__(self, checkpoint=None, predict_fn=None):
        """Initializes the BLEURT model.

    Args:
      checkpoint: BLEURT checkpoint. Will default to BLEURT-tiny if None.
      predict_fn: (optional) prediction function, overrides chkpt_dir. Mostly
        used for testing.

    Returns:
      A BLEURT scorer export.
    """
        if not checkpoint:
            logging.info("No checkpoint specified, defaulting to BLEURT-tiny.")
            checkpoint = _get_default_checkpoint()

        logging.info("Reading checkpoint {}.".format(checkpoint))
        config = checkpoint_lib.read_bleurt_config(checkpoint)
        max_seq_length = config["max_seq_length"]
        vocab_file = config["vocab_file"]
        do_lower_case = config["do_lower_case"]

        logging.info("Creating BLEURT scorer.")
        self.tokenizer = tokenization.FullTokenizer(
            vocab_file=vocab_file, do_lower_case=do_lower_case)
        self.max_seq_length = max_seq_length

        if predict_fn:
            self.predict_fn = predict_fn
            logging.info("BLEURT initialized.")
            return

        logging.info("Loading model...")
        self.chkpt_dir = checkpoint
        self.predict_fn = _make_predict_fn_from_checkpoint(checkpoint)
        logging.info("BLEURT initialized.")
示例#2
0
def encode_and_serialize(input_file, output_file, vocab_file, do_lower_case,
                         max_seq_length):
    """Encodes and serializes a set of ratings in JSON format."""
    assert tf.io.gfile.exists(input_file), "Could not find file."
    logging.info("Reading data...")
    with tf.io.gfile.GFile(input_file, "r") as f:
        examples_df = pd.read_json(f, lines=True)
    for col in ["reference", "candidate", "score"]:
        assert col in examples_df.columns, \
            "field {} not found in input file!".format(col)
    n_records = len(examples_df)
    logging.info("Read {} examples.".format(n_records))

    logging.info("Encoding and writing TFRecord file...")
    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                           do_lower_case=do_lower_case)
    with tf.python_io.TFRecordWriter(output_file) as writer:
        iterator_id, iterator_cycle = 0, max(int(n_records / 10), 1)
        for record in examples_df.itertuples(index=False):
            iterator_id += 1
            if iterator_id % iterator_cycle == 0:
                logging.info("Writing example %d of %d", iterator_id,
                             n_records)
            tf_example = serialize_example(record.reference,
                                           record.candidate,
                                           tokenizer,
                                           max_seq_length,
                                           score=record.score)
            writer.write(tf_example)
    logging.info("Done writing {} tf examples.".format(n_records))
示例#3
0
    def bleurt_ops(references, candidates):
        """Builds computation graph for BLEURT.

    Args:
      references: <tf.string>[...] Tensor that contains reference sentences.
      candidates: <tf.string>[...] Tensor that contains candidate sentences.

    Returns:
      A <tf.float>[...] Tensor that contains BLEURT scores.
    """
        logging.info("Creating BLEURT TF Ops...")

        logging.info("Reading info from checkpoint {}".format(checkpoint))
        config = checkpoint_lib.read_bleurt_config(checkpoint)
        max_seq_length = config["max_seq_length"]
        vocab_file = config["vocab_file"]
        do_lower_case = config["do_lower_case"]

        logging.info("Creating tokenizer...")
        tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                               do_lower_case=do_lower_case)
        logging.info("Tokenizer created")
        logging.info("Creating BLEURT Preprocessing Ops...")
        bleurt_preprocessing_ops = create_bleurt_preprocessing_ops(
            tokenizer, max_seq_length)
        logging.info("Preprocessing Ops created.")

        logging.info("Loading checkpoint...")
        if not bleurt_model_fn:
            imported = tf.saved_model.load_v2(checkpoint)
            bleurt_model_ops = imported.signatures["serving_default"]
        else:
            bleurt_model_ops = bleurt_model_fn
        logging.info("BLEURT Checkpoint loaded")

        input_ids, input_mask, segment_ids = bleurt_preprocessing_ops(
            references, candidates)
        out = bleurt_model_ops(input_ids=input_ids,
                               input_mask=input_mask,
                               segment_ids=segment_ids)
        logging.info("BLEURT TF Ops created.")
        return out