コード例 #1
0
def test():
    from brokenegg_transformer.utils import tokenizer

    import os
    import tensorflow.compat.v1 as tf

    data_dir = '/tmp/brokenegg_transformer'
    dataset = tf.data.TFRecordDataset(
        os.path.join(data_dir, 'brokenegg-train-00030-of-00030'))
    feature_description = {
        'inputs': tf.VarLenFeature(dtype=tf.int64),
        'targets': tf.VarLenFeature(dtype=tf.int64),
    }
    #feature_description = {
    #  'inputs': tf.FixedLenFeature(shape=[1, None], dtype=tf.int64),
    #  'targets': tf.FixedLenFeature(shape=[1, None], dtype=tf.int64),
    #}
    subtokenizer = tokenizer.Subtokenizer(
        os.path.join(data_dir, 'brokenegg.en-es-ja.spm64k.model'))
    for count, raw_record in enumerate(dataset):
        #print(raw_record)
        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
        #print(example)

        example = tf.io.parse_single_example(raw_record, feature_description)
        encoded_inputs = tf.sparse.to_dense(example['inputs']).numpy().tolist()
        encoded_targets = tf.sparse.to_dense(
            example['targets']).numpy().tolist()
        print('LANG: %d' % encoded_targets[0])
        print('SRC: %s' % subtokenizer.decode(encoded_inputs))
        print('TGT: %s' % subtokenizer.decode(encoded_targets[1:]))
        if count > 10:
            break
コード例 #2
0
def evaluate_and_log_bleu(model,
                          params,
                          bleu_source,
                          bleu_ref,
                          vocab_file,
                          distribution_strategy=None):
    """Calculate and record the BLEU score.

  Args:
    model: A Keras model, used to generate the translations.
    params: A dictionary, containing the translation related parameters.
    bleu_source: A file containing source sentences for translation.
    bleu_ref: A file containing the reference for the translated sentences.
    vocab_file: A file containing the vocabulary for translation.
    distribution_strategy: A platform distribution strategy, used for TPU based
      translation.

  Returns:
    uncased_score: A float, the case insensitive BLEU score.
    cased_score: A float, the case sensitive BLEU score.
  """
    subtokenizer = tokenizer.Subtokenizer(vocab_file)

    uncased_score, cased_score = translate_and_compute_bleu(
        model, params, subtokenizer, bleu_source, bleu_ref,
        distribution_strategy)

    logging.info("Bleu score (uncased): %s", uncased_score)
    logging.info("Bleu score (cased): %s", cased_score)
    return uncased_score, cased_score
コード例 #3
0
    def predict(self):
        """Predicts result from the model."""
        params = self.params
        flags_obj = self.flags_obj

        with tf.name_scope("model"):
            model = transformer.create_model(params, is_train=False)
            self._load_weights_if_possible(
                model, tf.train.latest_checkpoint(self.flags_obj.model_dir))
            model.summary()
        subtokenizer = tokenizer.Subtokenizer(flags_obj.vocab_file)

        ds = data_pipeline.eval_input_fn(params)
        if params['targets_with_lang_id']:
            ds = ds.map(lambda x, y: {
                'inputs': x,
                'initial_ids': tf.cast(y[:, 0], tf.int32)
            }).take(_SINGLE_SAMPLE)
        else:
            ds = ds.map(lambda x, y: x).take(_SINGLE_SAMPLE)
        ret = model.predict(ds)
        val_outputs, _ = ret
        length = len(val_outputs)
        for i in range(length):
            translate.translate_from_input(val_outputs[i], subtokenizer)
コード例 #4
0
 def test_simple(self):
     path = 'gs://brokenegg/data/brokenegg/brokenegg.en-es-ja.spm64k.model'
     path = '/tmp/brokenegg_transformer/brokenegg.en-es-ja.spm64k.model'
     subtokenizer = tokenizer.Subtokenizer(path)
     text = "Hello, world! こんにちはです。"
     encoded = subtokenizer.encode(text)
     print(encoded)
     decoded = subtokenizer.decode(encoded)
     self.assertEqual(decoded, text)
コード例 #5
0
from flask import request

from brokenegg_transformer import model_params
from brokenegg_transformer import transformer
from brokenegg_transformer.utils import tokenizer
import tensorflow.compat.v1 as tf
import numpy as np
import os

MODEL_DIR = os.getenv("MODEL_DIR", "/tmp")
VOCAB_FILE = os.path.join(MODEL_DIR, 'wikimatrix_lang10.spm64k.model')

langs = sorted({'en', 'es', 'fr', 'ru', 'de', 'ja', 'ar', 'zh', 'el', 'ko'})
lang_map = {k: v + 64000 for v, k in enumerate(langs)}

subtokenizer = tokenizer.Subtokenizer(VOCAB_FILE)

params = model_params.BASE_PARAMS.copy()
params["dtype"] = tf.float32
with tf.name_scope("model"):
    model = transformer.create_model(params, is_train=False)
init_weight_path = tf.train.latest_checkpoint(MODEL_DIR)
print('Restoring from %s' % init_weight_path)
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_weight_path)


def translate(text, lang_id):
    encoded = subtokenizer.encode(text, add_eos=True)
    output, score = model.predict([
        np.array([encoded], dtype=np.int64),