def test(): from brokenegg_transformer.utils import tokenizer import os import tensorflow.compat.v1 as tf data_dir = '/tmp/brokenegg_transformer' dataset = tf.data.TFRecordDataset( os.path.join(data_dir, 'brokenegg-train-00030-of-00030')) feature_description = { 'inputs': tf.VarLenFeature(dtype=tf.int64), 'targets': tf.VarLenFeature(dtype=tf.int64), } #feature_description = { # 'inputs': tf.FixedLenFeature(shape=[1, None], dtype=tf.int64), # 'targets': tf.FixedLenFeature(shape=[1, None], dtype=tf.int64), #} subtokenizer = tokenizer.Subtokenizer( os.path.join(data_dir, 'brokenegg.en-es-ja.spm64k.model')) for count, raw_record in enumerate(dataset): #print(raw_record) example = tf.train.Example() example.ParseFromString(raw_record.numpy()) #print(example) example = tf.io.parse_single_example(raw_record, feature_description) encoded_inputs = tf.sparse.to_dense(example['inputs']).numpy().tolist() encoded_targets = tf.sparse.to_dense( example['targets']).numpy().tolist() print('LANG: %d' % encoded_targets[0]) print('SRC: %s' % subtokenizer.decode(encoded_inputs)) print('TGT: %s' % subtokenizer.decode(encoded_targets[1:])) if count > 10: break
def evaluate_and_log_bleu(model, params, bleu_source, bleu_ref, vocab_file, distribution_strategy=None): """Calculate and record the BLEU score. Args: model: A Keras model, used to generate the translations. params: A dictionary, containing the translation related parameters. bleu_source: A file containing source sentences for translation. bleu_ref: A file containing the reference for the translated sentences. vocab_file: A file containing the vocabulary for translation. distribution_strategy: A platform distribution strategy, used for TPU based translation. Returns: uncased_score: A float, the case insensitive BLEU score. cased_score: A float, the case sensitive BLEU score. """ subtokenizer = tokenizer.Subtokenizer(vocab_file) uncased_score, cased_score = translate_and_compute_bleu( model, params, subtokenizer, bleu_source, bleu_ref, distribution_strategy) logging.info("Bleu score (uncased): %s", uncased_score) logging.info("Bleu score (cased): %s", cased_score) return uncased_score, cased_score
def predict(self): """Predicts result from the model.""" params = self.params flags_obj = self.flags_obj with tf.name_scope("model"): model = transformer.create_model(params, is_train=False) self._load_weights_if_possible( model, tf.train.latest_checkpoint(self.flags_obj.model_dir)) model.summary() subtokenizer = tokenizer.Subtokenizer(flags_obj.vocab_file) ds = data_pipeline.eval_input_fn(params) if params['targets_with_lang_id']: ds = ds.map(lambda x, y: { 'inputs': x, 'initial_ids': tf.cast(y[:, 0], tf.int32) }).take(_SINGLE_SAMPLE) else: ds = ds.map(lambda x, y: x).take(_SINGLE_SAMPLE) ret = model.predict(ds) val_outputs, _ = ret length = len(val_outputs) for i in range(length): translate.translate_from_input(val_outputs[i], subtokenizer)
def test_simple(self): path = 'gs://brokenegg/data/brokenegg/brokenegg.en-es-ja.spm64k.model' path = '/tmp/brokenegg_transformer/brokenegg.en-es-ja.spm64k.model' subtokenizer = tokenizer.Subtokenizer(path) text = "Hello, world! こんにちはです。" encoded = subtokenizer.encode(text) print(encoded) decoded = subtokenizer.decode(encoded) self.assertEqual(decoded, text)
from flask import request from brokenegg_transformer import model_params from brokenegg_transformer import transformer from brokenegg_transformer.utils import tokenizer import tensorflow.compat.v1 as tf import numpy as np import os MODEL_DIR = os.getenv("MODEL_DIR", "/tmp") VOCAB_FILE = os.path.join(MODEL_DIR, 'wikimatrix_lang10.spm64k.model') langs = sorted({'en', 'es', 'fr', 'ru', 'de', 'ja', 'ar', 'zh', 'el', 'ko'}) lang_map = {k: v + 64000 for v, k in enumerate(langs)} subtokenizer = tokenizer.Subtokenizer(VOCAB_FILE) params = model_params.BASE_PARAMS.copy() params["dtype"] = tf.float32 with tf.name_scope("model"): model = transformer.create_model(params, is_train=False) init_weight_path = tf.train.latest_checkpoint(MODEL_DIR) print('Restoring from %s' % init_weight_path) checkpoint = tf.train.Checkpoint(model=model) checkpoint.restore(init_weight_path) def translate(text, lang_id): encoded = subtokenizer.encode(text, add_eos=True) output, score = model.predict([ np.array([encoded], dtype=np.int64),