def test_evaluation(self):
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder(),
                                       padded_decode=False,
                                       decode_max_length=64),
         validation_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             static_batch=True,
             global_batch_size=4),
         sentencepiece_model_path=self._sentencepeice_model_path)
     logging_dir = self.get_temp_dir()
     task = translation.TranslationTask(config, logging_dir=logging_dir)
     dataset = orbit.utils.make_distributed_dataset(
         tf.distribute.get_strategy(), task.build_inputs,
         config.validation_data)
     model = task.build_model()
     strategy = tf.distribute.get_strategy()
     aggregated = None
     for data in dataset:
         distributed_outputs = strategy.run(functools.partial(
             task.validation_step, model=model),
                                            args=(data, ))
         outputs = tf.nest.map_structure(
             strategy.experimental_local_results, distributed_outputs)
         aggregated = task.aggregate_logs(state=aggregated,
                                          step_outputs=outputs)
     metrics = task.reduce_aggregated_logs(aggregated)
     self.assertIn("sacrebleu_score", metrics)
     self.assertIn("bleu_score", metrics)
Exemple #2
0
class ProgTranslationConfig(translation.TranslationConfig):
    """The progressive model config."""
    model: translation.ModelConfig = translation.ModelConfig(
        encoder=translation.EncDecoder(num_attention_heads=16,
                                       intermediate_size=4096),
        decoder=translation.EncDecoder(num_attention_heads=16,
                                       intermediate_size=4096),
        embedding_width=1024,
        padded_decode=True,
        decode_max_length=100)
    optimizer_config: optimization.OptimizationConfig = (
        optimization.OptimizationConfig({
            'optimizer': {
                'type': 'adam',
                'adam': {
                    'beta_2': 0.997,
                    'epsilon': 1e-9,
                },
            },
            'learning_rate': {
                'type': 'power',
                'power': {
                    'initial_learning_rate': 0.0625,
                    'power': -0.5,
                }
            },
            'warmup': {
                'type': 'linear',
                'linear': {
                    'warmup_steps': 16000,
                    'warmup_learning_rate': 0.0
                }
            }
        }))

    stage_list: List[StackingStageConfig] = dataclasses.field(
        default_factory=lambda: [  # pylint: disable=g-long-lambda
            StackingStageConfig(num_encoder_layers=3,
                                num_decoder_layers=3,
                                num_steps=20000,
                                warmup_steps=5000,
                                initial_learning_rate=0.0625),
            StackingStageConfig(num_encoder_layers=6,
                                num_decoder_layers=6,
                                num_steps=20000,
                                warmup_steps=5000,
                                initial_learning_rate=0.0625),
            StackingStageConfig(num_encoder_layers=12,
                                num_decoder_layers=12,
                                num_steps=100000,
                                warmup_steps=5000,
                                initial_learning_rate=0.0625)
        ])
 def test_no_sentencepiece_path(self):
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder()),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             is_training=True,
             static_batch=True,
             global_batch_size=4,
             max_seq_length=4),
         sentencepiece_model_path=None)
     with self.assertRaisesRegex(ValueError,
                                 "Setencepiece model path not provided."):
         translation.TranslationTask(config)
Exemple #4
0
def wmt_transformer_large_progressive() -> cfg.ExperimentConfig:
  """WMT Transformer Larger with progressive training.

  Please refer to
  tensorflow_models/official/nlp/data/train_sentencepiece.py
  to generate sentencepiece_model
  and pass
  --params_override=task.sentencepiece_model_path='YOUR_PATH'
  to the train script.
  """
  hidden_size = 1024
  train_steps = 300000
  token_batch_size = 24576
  encdecoder = translation.EncDecoder(
      num_attention_heads=16, intermediate_size=hidden_size * 4)
  config = cfg.ExperimentConfig(
      task=progressive_translation.ProgTranslationConfig(
          model=translation.ModelConfig(
              encoder=encdecoder,
              decoder=encdecoder,
              embedding_width=hidden_size,
              padded_decode=True,
              decode_max_length=100),
          train_data=wmt_dataloader.WMTDataConfig(
              tfds_name='wmt14_translate/de-en',
              tfds_split='train',
              src_lang='en',
              tgt_lang='de',
              is_training=True,
              global_batch_size=token_batch_size,
              static_batch=True,
              max_seq_length=64
          ),
          validation_data=wmt_dataloader.WMTDataConfig(
              tfds_name='wmt14_translate/de-en',
              tfds_split='test',
              src_lang='en',
              tgt_lang='de',
              is_training=False,
              global_batch_size=32,
              static_batch=True,
              max_seq_length=100,
          ),
          sentencepiece_model_path=None,
      ),
      trainer=prog_trainer_lib.ProgressiveTrainerConfig(
          train_steps=train_steps,
          validation_steps=-1,
          steps_per_loop=1000,
          summary_interval=1000,
          checkpoint_interval=5000,
          validation_interval=5000,
          optimizer_config=None,
      ),
      restrictions=[
          'task.train_data.is_training != None',
          'task.sentencepiece_model_path != None',
      ])
  return config
 def test_task(self):
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder()),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             is_training=True,
             static_batch=True,
             global_batch_size=24,
             max_seq_length=12),
         sentencepiece_model_path=self._sentencepeice_model_path)
     task = translation.TranslationTask(config)
     model = task.build_model()
     dataset = task.build_inputs(config.train_data)
     iterator = iter(dataset)
     optimizer = tf.keras.optimizers.SGD(lr=0.1)
     task.train_step(next(iterator), model, optimizer)
 def setUp(self):
     super(ProgressiveTranslationTest, self).setUp()
     self._temp_dir = self.get_temp_dir()
     src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"]
     tgt_lines = ["dd cc a ef  g", "bcd ef a g", "gef cd ba"]
     self._record_input_path = os.path.join(self._temp_dir, "train.record")
     _generate_record_file(self._record_input_path, src_lines, tgt_lines)
     self._sentencepeice_input_path = os.path.join(self._temp_dir,
                                                   "inputs.txt")
     _generate_line_file(self._sentencepeice_input_path,
                         src_lines + tgt_lines)
     sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp")
     _train_sentencepiece(self._sentencepeice_input_path, 11,
                          sentencepeice_model_prefix)
     self._sentencepeice_model_path = "{}.model".format(
         sentencepeice_model_prefix)
     encdecoder = translation.EncDecoder(num_attention_heads=2,
                                         intermediate_size=8)
     self.task_config = progressive_translation.ProgTranslationConfig(
         model=translation.ModelConfig(encoder=encdecoder,
                                       decoder=encdecoder,
                                       embedding_width=8,
                                       padded_decode=True,
                                       decode_max_length=100),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             is_training=True,
             global_batch_size=24,
             static_batch=True,
             src_lang="en",
             tgt_lang="reverse_en",
             max_seq_length=12),
         validation_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             is_training=False,
             global_batch_size=2,
             static_batch=True,
             src_lang="en",
             tgt_lang="reverse_en",
             max_seq_length=12),
         sentencepiece_model_path=self._sentencepeice_model_path,
         stage_list=[
             progressive_translation.StackingStageConfig(
                 num_encoder_layers=1, num_decoder_layers=1, num_steps=4),
             progressive_translation.StackingStageConfig(
                 num_encoder_layers=2, num_decoder_layers=1, num_steps=8),
         ],
     )
     self.exp_config = cfg.ExperimentConfig(
         task=self.task_config,
         trainer=prog_trainer_lib.ProgressiveTrainerConfig())
 def test_sentencepiece_no_eos(self):
     sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp_no_eos")
     _train_sentencepiece(self._sentencepeice_input_path,
                          20,
                          sentencepeice_model_prefix,
                          eos_id=-1)
     sentencepeice_model_path = "{}.model".format(
         sentencepeice_model_prefix)
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder()),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             is_training=True,
             static_batch=True,
             global_batch_size=4,
             max_seq_length=4),
         sentencepiece_model_path=sentencepeice_model_path)
     with self.assertRaisesRegex(ValueError,
                                 "EOS token not in tokenizer vocab.*"):
         translation.TranslationTask(config)
    def test_translation(self, padded_decode, batch_size):
        sp_path = _make_sentencepeice(self.get_temp_dir())
        encdecoder = translation.EncDecoder(num_attention_heads=4,
                                            intermediate_size=256)
        config = translation.TranslationConfig(
            model=translation.ModelConfig(encoder=encdecoder,
                                          decoder=encdecoder,
                                          embedding_width=256,
                                          padded_decode=padded_decode,
                                          decode_max_length=100),
            sentencepiece_model_path=sp_path,
        )
        task = translation.TranslationTask(config)
        model = task.build_model()

        params = serving_modules.Translation.Params(
            sentencepiece_model_path=sp_path, batch_size=batch_size)
        export_module = serving_modules.Translation(params=params, model=model)
        functions = export_module.get_inference_signatures(
            {"serve_text": "serving_default"})
        outputs = functions["serving_default"](tf.constant(["abcd", "ef gh"]))
        self.assertEqual(outputs.shape, (2, ))
        self.assertEqual(outputs.dtype, tf.string)

        tmp_dir = self.get_temp_dir()
        tmp_dir = os.path.join(tmp_dir, "padded_decode", str(padded_decode))
        export_base_dir = os.path.join(tmp_dir, "export")
        ckpt_dir = os.path.join(tmp_dir, "ckpt")
        ckpt_path = tf.train.Checkpoint(model=model).save(ckpt_dir)
        export_dir = export_base.export(export_module,
                                        {"serve_text": "serving_default"},
                                        export_base_dir, ckpt_path)
        loaded = tf.saved_model.load(export_dir)
        infer = loaded.signatures["serving_default"]
        out = infer(text=tf.constant(["abcd", "ef gh"]))
        self.assertLen(out["output_0"], 2)
Exemple #9
0
def wmt_transformer_large() -> cfg.ExperimentConfig:
    """WMT Transformer Large.

  Please refer to
  tensorflow_models/official/nlp/data/train_sentencepiece.py
  to generate sentencepiece_model
  and pass
  --params_override=task.sentencepiece_model_path='YOUR_PATH'
  to the train script.
  """
    learning_rate = 2.0
    hidden_size = 1024
    learning_rate *= (hidden_size**-0.5)
    warmup_steps = 16000
    train_steps = 300000
    token_batch_size = 24576
    encdecoder = translation.EncDecoder(num_attention_heads=16,
                                        intermediate_size=hidden_size * 4)
    config = cfg.ExperimentConfig(
        runtime=cfg.RuntimeConfig(enable_xla=True),
        task=translation.TranslationConfig(
            model=translation.ModelConfig(encoder=encdecoder,
                                          decoder=encdecoder,
                                          embedding_width=hidden_size,
                                          padded_decode=True,
                                          decode_max_length=100),
            train_data=wmt_dataloader.WMTDataConfig(
                tfds_name='wmt14_translate/de-en',
                tfds_split='train',
                src_lang='en',
                tgt_lang='de',
                is_training=True,
                global_batch_size=token_batch_size,
                static_batch=True,
                max_seq_length=64),
            validation_data=wmt_dataloader.WMTDataConfig(
                tfds_name='wmt14_translate/de-en',
                tfds_split='test',
                src_lang='en',
                tgt_lang='de',
                is_training=False,
                global_batch_size=32,
                static_batch=True,
                max_seq_length=100,
            ),
            sentencepiece_model_path=None,
        ),
        trainer=cfg.TrainerConfig(
            train_steps=train_steps,
            validation_steps=-1,
            steps_per_loop=1000,
            summary_interval=1000,
            checkpoint_interval=5000,
            validation_interval=5000,
            max_to_keep=1,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'adam',
                    'adam': {
                        'beta_2': 0.997,
                        'epsilon': 1e-9,
                    },
                },
                'learning_rate': {
                    'type': 'power',
                    'power': {
                        'initial_learning_rate': learning_rate,
                        'power': -0.5,
                    }
                },
                'warmup': {
                    'type': 'linear',
                    'linear': {
                        'warmup_steps': warmup_steps,
                        'warmup_learning_rate': 0.0
                    }
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.sentencepiece_model_path != None',
        ])
    return config