def test_evaluation(self):
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder(),
                                       padded_decode=False,
                                       decode_max_length=64),
         validation_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             static_batch=True,
             global_batch_size=4),
         sentencepiece_model_path=self._sentencepeice_model_path)
     logging_dir = self.get_temp_dir()
     task = translation.TranslationTask(config, logging_dir=logging_dir)
     dataset = orbit.utils.make_distributed_dataset(
         tf.distribute.get_strategy(), task.build_inputs,
         config.validation_data)
     model = task.build_model()
     strategy = tf.distribute.get_strategy()
     aggregated = None
     for data in dataset:
         distributed_outputs = strategy.run(functools.partial(
             task.validation_step, model=model),
                                            args=(data, ))
         outputs = tf.nest.map_structure(
             strategy.experimental_local_results, distributed_outputs)
         aggregated = task.aggregate_logs(state=aggregated,
                                          step_outputs=outputs)
     metrics = task.reduce_aggregated_logs(aggregated)
     self.assertIn("sacrebleu_score", metrics)
     self.assertIn("bleu_score", metrics)
 def test_no_sentencepiece_path(self):
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder()),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             is_training=True,
             static_batch=True,
             global_batch_size=4,
             max_seq_length=4),
         sentencepiece_model_path=None)
     with self.assertRaisesRegex(ValueError,
                                 "Setencepiece model path not provided."):
         translation.TranslationTask(config)
 def test_task(self):
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder()),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             is_training=True,
             static_batch=True,
             global_batch_size=24,
             max_seq_length=12),
         sentencepiece_model_path=self._sentencepeice_model_path)
     task = translation.TranslationTask(config)
     model = task.build_model()
     dataset = task.build_inputs(config.train_data)
     iterator = iter(dataset)
     optimizer = tf.keras.optimizers.SGD(lr=0.1)
     task.train_step(next(iterator), model, optimizer)
 def test_sentencepiece_no_eos(self):
     sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp_no_eos")
     _train_sentencepiece(self._sentencepeice_input_path,
                          20,
                          sentencepeice_model_prefix,
                          eos_id=-1)
     sentencepeice_model_path = "{}.model".format(
         sentencepeice_model_prefix)
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder()),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             is_training=True,
             static_batch=True,
             global_batch_size=4,
             max_seq_length=4),
         sentencepiece_model_path=sentencepeice_model_path)
     with self.assertRaisesRegex(ValueError,
                                 "EOS token not in tokenizer vocab.*"):
         translation.TranslationTask(config)
    def test_translation(self, padded_decode, batch_size):
        sp_path = _make_sentencepeice(self.get_temp_dir())
        encdecoder = translation.EncDecoder(num_attention_heads=4,
                                            intermediate_size=256)
        config = translation.TranslationConfig(
            model=translation.ModelConfig(encoder=encdecoder,
                                          decoder=encdecoder,
                                          embedding_width=256,
                                          padded_decode=padded_decode,
                                          decode_max_length=100),
            sentencepiece_model_path=sp_path,
        )
        task = translation.TranslationTask(config)
        model = task.build_model()

        params = serving_modules.Translation.Params(
            sentencepiece_model_path=sp_path, batch_size=batch_size)
        export_module = serving_modules.Translation(params=params, model=model)
        functions = export_module.get_inference_signatures(
            {"serve_text": "serving_default"})
        outputs = functions["serving_default"](tf.constant(["abcd", "ef gh"]))
        self.assertEqual(outputs.shape, (2, ))
        self.assertEqual(outputs.dtype, tf.string)

        tmp_dir = self.get_temp_dir()
        tmp_dir = os.path.join(tmp_dir, "padded_decode", str(padded_decode))
        export_base_dir = os.path.join(tmp_dir, "export")
        ckpt_dir = os.path.join(tmp_dir, "ckpt")
        ckpt_path = tf.train.Checkpoint(model=model).save(ckpt_dir)
        export_dir = export_base.export(export_module,
                                        {"serve_text": "serving_default"},
                                        export_base_dir, ckpt_path)
        loaded = tf.saved_model.load(export_dir)
        infer = loaded.signatures["serving_default"]
        out = infer(text=tf.constant(["abcd", "ef gh"]))
        self.assertLen(out["output_0"], 2)