def test_evaluation(self): config = translation.TranslationConfig( model=translation.ModelConfig(encoder=translation.EncDecoder(), decoder=translation.EncDecoder(), padded_decode=False, decode_max_length=64), validation_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, src_lang="en", tgt_lang="reverse_en", static_batch=True, global_batch_size=4), sentencepiece_model_path=self._sentencepeice_model_path) logging_dir = self.get_temp_dir() task = translation.TranslationTask(config, logging_dir=logging_dir) dataset = orbit.utils.make_distributed_dataset( tf.distribute.get_strategy(), task.build_inputs, config.validation_data) model = task.build_model() strategy = tf.distribute.get_strategy() aggregated = None for data in dataset: distributed_outputs = strategy.run(functools.partial( task.validation_step, model=model), args=(data, )) outputs = tf.nest.map_structure( strategy.experimental_local_results, distributed_outputs) aggregated = task.aggregate_logs(state=aggregated, step_outputs=outputs) metrics = task.reduce_aggregated_logs(aggregated) self.assertIn("sacrebleu_score", metrics) self.assertIn("bleu_score", metrics)
def test_no_sentencepiece_path(self): config = translation.TranslationConfig( model=translation.ModelConfig(encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), train_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, src_lang="en", tgt_lang="reverse_en", is_training=True, static_batch=True, global_batch_size=4, max_seq_length=4), sentencepiece_model_path=None) with self.assertRaisesRegex(ValueError, "Setencepiece model path not provided."): translation.TranslationTask(config)
def test_task(self): config = translation.TranslationConfig( model=translation.ModelConfig(encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), train_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, src_lang="en", tgt_lang="reverse_en", is_training=True, static_batch=True, global_batch_size=24, max_seq_length=12), sentencepiece_model_path=self._sentencepeice_model_path) task = translation.TranslationTask(config) model = task.build_model() dataset = task.build_inputs(config.train_data) iterator = iter(dataset) optimizer = tf.keras.optimizers.SGD(lr=0.1) task.train_step(next(iterator), model, optimizer)
def test_sentencepiece_no_eos(self): sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp_no_eos") _train_sentencepiece(self._sentencepeice_input_path, 20, sentencepeice_model_prefix, eos_id=-1) sentencepeice_model_path = "{}.model".format( sentencepeice_model_prefix) config = translation.TranslationConfig( model=translation.ModelConfig(encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), train_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, src_lang="en", tgt_lang="reverse_en", is_training=True, static_batch=True, global_batch_size=4, max_seq_length=4), sentencepiece_model_path=sentencepeice_model_path) with self.assertRaisesRegex(ValueError, "EOS token not in tokenizer vocab.*"): translation.TranslationTask(config)
def test_translation(self, padded_decode, batch_size): sp_path = _make_sentencepeice(self.get_temp_dir()) encdecoder = translation.EncDecoder(num_attention_heads=4, intermediate_size=256) config = translation.TranslationConfig( model=translation.ModelConfig(encoder=encdecoder, decoder=encdecoder, embedding_width=256, padded_decode=padded_decode, decode_max_length=100), sentencepiece_model_path=sp_path, ) task = translation.TranslationTask(config) model = task.build_model() params = serving_modules.Translation.Params( sentencepiece_model_path=sp_path, batch_size=batch_size) export_module = serving_modules.Translation(params=params, model=model) functions = export_module.get_inference_signatures( {"serve_text": "serving_default"}) outputs = functions["serving_default"](tf.constant(["abcd", "ef gh"])) self.assertEqual(outputs.shape, (2, )) self.assertEqual(outputs.dtype, tf.string) tmp_dir = self.get_temp_dir() tmp_dir = os.path.join(tmp_dir, "padded_decode", str(padded_decode)) export_base_dir = os.path.join(tmp_dir, "export") ckpt_dir = os.path.join(tmp_dir, "ckpt") ckpt_path = tf.train.Checkpoint(model=model).save(ckpt_dir) export_dir = export_base.export(export_module, {"serve_text": "serving_default"}, export_base_dir, ckpt_path) loaded = tf.saved_model.load(export_dir) infer = loaded.signatures["serving_default"] out = infer(text=tf.constant(["abcd", "ef gh"])) self.assertLen(out["output_0"], 2)
def wmt_transformer_large() -> cfg.ExperimentConfig: """WMT Transformer Large. Please refer to tensorflow_models/official/nlp/data/train_sentencepiece.py to generate sentencepiece_model and pass --params_override=task.sentencepiece_model_path='YOUR_PATH' to the train script. """ learning_rate = 2.0 hidden_size = 1024 learning_rate *= (hidden_size**-0.5) warmup_steps = 16000 train_steps = 300000 token_batch_size = 24576 encdecoder = translation.EncDecoder(num_attention_heads=16, intermediate_size=hidden_size * 4) config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(enable_xla=True), task=translation.TranslationConfig( model=translation.ModelConfig(encoder=encdecoder, decoder=encdecoder, embedding_width=hidden_size, padded_decode=True, decode_max_length=100), train_data=wmt_dataloader.WMTDataConfig( tfds_name='wmt14_translate/de-en', tfds_split='train', src_lang='en', tgt_lang='de', is_training=True, global_batch_size=token_batch_size, static_batch=True, max_seq_length=64), validation_data=wmt_dataloader.WMTDataConfig( tfds_name='wmt14_translate/de-en', tfds_split='test', src_lang='en', tgt_lang='de', is_training=False, global_batch_size=32, static_batch=True, max_seq_length=100, ), sentencepiece_model_path=None, ), trainer=cfg.TrainerConfig( train_steps=train_steps, validation_steps=-1, steps_per_loop=1000, summary_interval=1000, checkpoint_interval=5000, validation_interval=5000, max_to_keep=1, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'adam', 'adam': { 'beta_2': 0.997, 'epsilon': 1e-9, }, }, 'learning_rate': { 'type': 'power', 'power': { 'initial_learning_rate': learning_rate, 'power': -0.5, } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': warmup_steps, 'warmup_learning_rate': 0.0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.sentencepiece_model_path != None', ]) return config