def test_load_dataset(self):
     batch_tokens_size = 100
     train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
     _create_fake_dataset(train_data_path)
     data_config = wmt_dataloader.WMTDataConfig(
         input_path=train_data_path,
         max_seq_length=35,
         global_batch_size=batch_tokens_size,
         is_training=True,
         static_batch=False)
     dataset = wmt_dataloader.WMTDataLoader(data_config).load()
     examples = next(iter(dataset))
     inputs, targets = examples['inputs'], examples['targets']
     logging.info('dynamic inputs=%s targets=%s', inputs, targets)
     data_config = wmt_dataloader.WMTDataConfig(
         input_path=train_data_path,
         max_seq_length=35,
         global_batch_size=batch_tokens_size,
         is_training=True,
         static_batch=True)
     dataset = wmt_dataloader.WMTDataLoader(data_config).load()
     examples = next(iter(dataset))
     inputs, targets = examples['inputs'], examples['targets']
     logging.info('static inputs=%s targets=%s', inputs, targets)
     self.assertEqual(inputs.shape, (2, 35))
     self.assertEqual(targets.shape, (2, 35))
Ejemplo n.º 2
0
def wmt_transformer_large_progressive() -> cfg.ExperimentConfig:
  """WMT Transformer Larger with progressive training.

  Please refer to
  tensorflow_models/official/nlp/data/train_sentencepiece.py
  to generate sentencepiece_model
  and pass
  --params_override=task.sentencepiece_model_path='YOUR_PATH'
  to the train script.
  """
  hidden_size = 1024
  train_steps = 300000
  token_batch_size = 24576
  encdecoder = translation.EncDecoder(
      num_attention_heads=16, intermediate_size=hidden_size * 4)
  config = cfg.ExperimentConfig(
      task=progressive_translation.ProgTranslationConfig(
          model=translation.ModelConfig(
              encoder=encdecoder,
              decoder=encdecoder,
              embedding_width=hidden_size,
              padded_decode=True,
              decode_max_length=100),
          train_data=wmt_dataloader.WMTDataConfig(
              tfds_name='wmt14_translate/de-en',
              tfds_split='train',
              src_lang='en',
              tgt_lang='de',
              is_training=True,
              global_batch_size=token_batch_size,
              static_batch=True,
              max_seq_length=64
          ),
          validation_data=wmt_dataloader.WMTDataConfig(
              tfds_name='wmt14_translate/de-en',
              tfds_split='test',
              src_lang='en',
              tgt_lang='de',
              is_training=False,
              global_batch_size=32,
              static_batch=True,
              max_seq_length=100,
          ),
          sentencepiece_model_path=None,
      ),
      trainer=prog_trainer_lib.ProgressiveTrainerConfig(
          train_steps=train_steps,
          validation_steps=-1,
          steps_per_loop=1000,
          summary_interval=1000,
          checkpoint_interval=5000,
          validation_interval=5000,
          optimizer_config=None,
      ),
      restrictions=[
          'task.train_data.is_training != None',
          'task.sentencepiece_model_path != None',
      ])
  return config
Ejemplo n.º 3
0
 def setUp(self):
     super(ProgressiveTranslationTest, self).setUp()
     self._temp_dir = self.get_temp_dir()
     src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"]
     tgt_lines = ["dd cc a ef  g", "bcd ef a g", "gef cd ba"]
     self._record_input_path = os.path.join(self._temp_dir, "train.record")
     _generate_record_file(self._record_input_path, src_lines, tgt_lines)
     self._sentencepeice_input_path = os.path.join(self._temp_dir,
                                                   "inputs.txt")
     _generate_line_file(self._sentencepeice_input_path,
                         src_lines + tgt_lines)
     sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp")
     _train_sentencepiece(self._sentencepeice_input_path, 11,
                          sentencepeice_model_prefix)
     self._sentencepeice_model_path = "{}.model".format(
         sentencepeice_model_prefix)
     encdecoder = translation.EncDecoder(num_attention_heads=2,
                                         intermediate_size=8)
     self.task_config = progressive_translation.ProgTranslationConfig(
         model=translation.ModelConfig(encoder=encdecoder,
                                       decoder=encdecoder,
                                       embedding_width=8,
                                       padded_decode=True,
                                       decode_max_length=100),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             is_training=True,
             global_batch_size=24,
             static_batch=True,
             src_lang="en",
             tgt_lang="reverse_en",
             max_seq_length=12),
         validation_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             is_training=False,
             global_batch_size=2,
             static_batch=True,
             src_lang="en",
             tgt_lang="reverse_en",
             max_seq_length=12),
         sentencepiece_model_path=self._sentencepeice_model_path,
         stage_list=[
             progressive_translation.StackingStageConfig(
                 num_encoder_layers=1, num_decoder_layers=1, num_steps=4),
             progressive_translation.StackingStageConfig(
                 num_encoder_layers=2, num_decoder_layers=1, num_steps=8),
         ],
     )
     self.exp_config = cfg.ExperimentConfig(
         task=self.task_config,
         trainer=prog_trainer_lib.ProgressiveTrainerConfig())
 def test_evaluation(self):
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder(),
                                       padded_decode=False,
                                       decode_max_length=64),
         validation_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             static_batch=True,
             global_batch_size=4),
         sentencepiece_model_path=self._sentencepeice_model_path)
     logging_dir = self.get_temp_dir()
     task = translation.TranslationTask(config, logging_dir=logging_dir)
     dataset = orbit.utils.make_distributed_dataset(
         tf.distribute.get_strategy(), task.build_inputs,
         config.validation_data)
     model = task.build_model()
     strategy = tf.distribute.get_strategy()
     aggregated = None
     for data in dataset:
         distributed_outputs = strategy.run(functools.partial(
             task.validation_step, model=model),
                                            args=(data, ))
         outputs = tf.nest.map_structure(
             strategy.experimental_local_results, distributed_outputs)
         aggregated = task.aggregate_logs(state=aggregated,
                                          step_outputs=outputs)
     metrics = task.reduce_aggregated_logs(aggregated)
     self.assertIn("sacrebleu_score", metrics)
     self.assertIn("bleu_score", metrics)
Ejemplo n.º 5
0
 def test_load_dataset_raise_invalid_window(self):
     batch_tokens_size = 10  # this is too small to form buckets.
     train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
     _create_fake_dataset(train_data_path)
     data_config = wmt_dataloader.WMTDataConfig(
         input_path=train_data_path,
         max_seq_length=100,
         global_batch_size=batch_tokens_size)
     with self.assertRaisesRegex(
             ValueError,
             'The token budget, global batch size, is too small.*'):
         _ = wmt_dataloader.WMTDataLoader(data_config).load()
 def test_load_dataset_raise_invalid_window(self):
   batch_tokens_size = 10  # this is too small to form buckets.
   data_config = wmt_dataloader.WMTDataConfig(
       input_path=self._record_train_input_path,
       max_seq_length=100,
       global_batch_size=batch_tokens_size,
       is_training=True,
       static_batch=False,
       src_lang='en',
       tgt_lang='reverse_en',
       sentencepiece_model_path=self._sentencepeice_model_path)
   with self.assertRaisesRegex(
       ValueError, 'The token budget, global batch size, is too small.*'):
     _ = wmt_dataloader.WMTDataLoader(data_config).load()
 def test_no_sentencepiece_path(self):
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder()),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             is_training=True,
             static_batch=True,
             global_batch_size=4,
             max_seq_length=4),
         sentencepiece_model_path=None)
     with self.assertRaisesRegex(ValueError,
                                 "Setencepiece model path not provided."):
         translation.TranslationTask(config)
 def test_load_dataset(
     self, is_training, static_batch, batch_size, expected_shape):
   data_config = wmt_dataloader.WMTDataConfig(
       input_path=self._record_train_input_path
       if is_training else self._record_test_input_path,
       max_seq_length=35,
       global_batch_size=batch_size,
       is_training=is_training,
       static_batch=static_batch,
       src_lang='en',
       tgt_lang='reverse_en',
       sentencepiece_model_path=self._sentencepeice_model_path)
   dataset = wmt_dataloader.WMTDataLoader(data_config).load()
   examples = next(iter(dataset))
   inputs, targets = examples['inputs'], examples['targets']
   self.assertEqual(inputs.shape, expected_shape)
   self.assertEqual(targets.shape, expected_shape)
 def test_task(self):
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder()),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             is_training=True,
             static_batch=True,
             global_batch_size=24,
             max_seq_length=12),
         sentencepiece_model_path=self._sentencepeice_model_path)
     task = translation.TranslationTask(config)
     model = task.build_model()
     dataset = task.build_inputs(config.train_data)
     iterator = iter(dataset)
     optimizer = tf.keras.optimizers.SGD(lr=0.1)
     task.train_step(next(iterator), model, optimizer)
 def test_sentencepiece_no_eos(self):
     sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp_no_eos")
     _train_sentencepiece(self._sentencepeice_input_path,
                          20,
                          sentencepeice_model_prefix,
                          eos_id=-1)
     sentencepeice_model_path = "{}.model".format(
         sentencepeice_model_prefix)
     config = translation.TranslationConfig(
         model=translation.ModelConfig(encoder=translation.EncDecoder(),
                                       decoder=translation.EncDecoder()),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             src_lang="en",
             tgt_lang="reverse_en",
             is_training=True,
             static_batch=True,
             global_batch_size=4,
             max_seq_length=4),
         sentencepiece_model_path=sentencepeice_model_path)
     with self.assertRaisesRegex(ValueError,
                                 "EOS token not in tokenizer vocab.*"):
         translation.TranslationTask(config)
Ejemplo n.º 11
0
def wmt_transformer_large() -> cfg.ExperimentConfig:
    """WMT Transformer Large.

  Please refer to
  tensorflow_models/official/nlp/data/train_sentencepiece.py
  to generate sentencepiece_model
  and pass
  --params_override=task.sentencepiece_model_path='YOUR_PATH'
  to the train script.
  """
    learning_rate = 2.0
    hidden_size = 1024
    learning_rate *= (hidden_size**-0.5)
    warmup_steps = 16000
    train_steps = 300000
    token_batch_size = 24576
    encdecoder = translation.EncDecoder(num_attention_heads=16,
                                        intermediate_size=hidden_size * 4)
    config = cfg.ExperimentConfig(
        runtime=cfg.RuntimeConfig(enable_xla=True),
        task=translation.TranslationConfig(
            model=translation.ModelConfig(encoder=encdecoder,
                                          decoder=encdecoder,
                                          embedding_width=hidden_size,
                                          padded_decode=True,
                                          decode_max_length=100),
            train_data=wmt_dataloader.WMTDataConfig(
                tfds_name='wmt14_translate/de-en',
                tfds_split='train',
                src_lang='en',
                tgt_lang='de',
                is_training=True,
                global_batch_size=token_batch_size,
                static_batch=True,
                max_seq_length=64),
            validation_data=wmt_dataloader.WMTDataConfig(
                tfds_name='wmt14_translate/de-en',
                tfds_split='test',
                src_lang='en',
                tgt_lang='de',
                is_training=False,
                global_batch_size=32,
                static_batch=True,
                max_seq_length=100,
            ),
            sentencepiece_model_path=None,
        ),
        trainer=cfg.TrainerConfig(
            train_steps=train_steps,
            validation_steps=-1,
            steps_per_loop=1000,
            summary_interval=1000,
            checkpoint_interval=5000,
            validation_interval=5000,
            max_to_keep=1,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'adam',
                    'adam': {
                        'beta_2': 0.997,
                        'epsilon': 1e-9,
                    },
                },
                'learning_rate': {
                    'type': 'power',
                    'power': {
                        'initial_learning_rate': learning_rate,
                        'power': -0.5,
                    }
                },
                'warmup': {
                    'type': 'linear',
                    'linear': {
                        'warmup_steps': warmup_steps,
                        'warmup_learning_rate': 0.0
                    }
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.sentencepiece_model_path != None',
        ])
    return config