def test_load_dataset(self): batch_tokens_size = 100 train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record') _create_fake_dataset(train_data_path) data_config = wmt_dataloader.WMTDataConfig( input_path=train_data_path, max_seq_length=35, global_batch_size=batch_tokens_size, is_training=True, static_batch=False) dataset = wmt_dataloader.WMTDataLoader(data_config).load() examples = next(iter(dataset)) inputs, targets = examples['inputs'], examples['targets'] logging.info('dynamic inputs=%s targets=%s', inputs, targets) data_config = wmt_dataloader.WMTDataConfig( input_path=train_data_path, max_seq_length=35, global_batch_size=batch_tokens_size, is_training=True, static_batch=True) dataset = wmt_dataloader.WMTDataLoader(data_config).load() examples = next(iter(dataset)) inputs, targets = examples['inputs'], examples['targets'] logging.info('static inputs=%s targets=%s', inputs, targets) self.assertEqual(inputs.shape, (2, 35)) self.assertEqual(targets.shape, (2, 35))
def wmt_transformer_large_progressive() -> cfg.ExperimentConfig: """WMT Transformer Larger with progressive training. Please refer to tensorflow_models/official/nlp/data/train_sentencepiece.py to generate sentencepiece_model and pass --params_override=task.sentencepiece_model_path='YOUR_PATH' to the train script. """ hidden_size = 1024 train_steps = 300000 token_batch_size = 24576 encdecoder = translation.EncDecoder( num_attention_heads=16, intermediate_size=hidden_size * 4) config = cfg.ExperimentConfig( task=progressive_translation.ProgTranslationConfig( model=translation.ModelConfig( encoder=encdecoder, decoder=encdecoder, embedding_width=hidden_size, padded_decode=True, decode_max_length=100), train_data=wmt_dataloader.WMTDataConfig( tfds_name='wmt14_translate/de-en', tfds_split='train', src_lang='en', tgt_lang='de', is_training=True, global_batch_size=token_batch_size, static_batch=True, max_seq_length=64 ), validation_data=wmt_dataloader.WMTDataConfig( tfds_name='wmt14_translate/de-en', tfds_split='test', src_lang='en', tgt_lang='de', is_training=False, global_batch_size=32, static_batch=True, max_seq_length=100, ), sentencepiece_model_path=None, ), trainer=prog_trainer_lib.ProgressiveTrainerConfig( train_steps=train_steps, validation_steps=-1, steps_per_loop=1000, summary_interval=1000, checkpoint_interval=5000, validation_interval=5000, optimizer_config=None, ), restrictions=[ 'task.train_data.is_training != None', 'task.sentencepiece_model_path != None', ]) return config
def setUp(self): super(ProgressiveTranslationTest, self).setUp() self._temp_dir = self.get_temp_dir() src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"] tgt_lines = ["dd cc a ef g", "bcd ef a g", "gef cd ba"] self._record_input_path = os.path.join(self._temp_dir, "train.record") _generate_record_file(self._record_input_path, src_lines, tgt_lines) self._sentencepeice_input_path = os.path.join(self._temp_dir, "inputs.txt") _generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines) sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp") _train_sentencepiece(self._sentencepeice_input_path, 11, sentencepeice_model_prefix) self._sentencepeice_model_path = "{}.model".format( sentencepeice_model_prefix) encdecoder = translation.EncDecoder(num_attention_heads=2, intermediate_size=8) self.task_config = progressive_translation.ProgTranslationConfig( model=translation.ModelConfig(encoder=encdecoder, decoder=encdecoder, embedding_width=8, padded_decode=True, decode_max_length=100), train_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, is_training=True, global_batch_size=24, static_batch=True, src_lang="en", tgt_lang="reverse_en", max_seq_length=12), validation_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, is_training=False, global_batch_size=2, static_batch=True, src_lang="en", tgt_lang="reverse_en", max_seq_length=12), sentencepiece_model_path=self._sentencepeice_model_path, stage_list=[ progressive_translation.StackingStageConfig( num_encoder_layers=1, num_decoder_layers=1, num_steps=4), progressive_translation.StackingStageConfig( num_encoder_layers=2, num_decoder_layers=1, num_steps=8), ], ) self.exp_config = cfg.ExperimentConfig( task=self.task_config, trainer=prog_trainer_lib.ProgressiveTrainerConfig())
def test_evaluation(self): config = translation.TranslationConfig( model=translation.ModelConfig(encoder=translation.EncDecoder(), decoder=translation.EncDecoder(), padded_decode=False, decode_max_length=64), validation_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, src_lang="en", tgt_lang="reverse_en", static_batch=True, global_batch_size=4), sentencepiece_model_path=self._sentencepeice_model_path) logging_dir = self.get_temp_dir() task = translation.TranslationTask(config, logging_dir=logging_dir) dataset = orbit.utils.make_distributed_dataset( tf.distribute.get_strategy(), task.build_inputs, config.validation_data) model = task.build_model() strategy = tf.distribute.get_strategy() aggregated = None for data in dataset: distributed_outputs = strategy.run(functools.partial( task.validation_step, model=model), args=(data, )) outputs = tf.nest.map_structure( strategy.experimental_local_results, distributed_outputs) aggregated = task.aggregate_logs(state=aggregated, step_outputs=outputs) metrics = task.reduce_aggregated_logs(aggregated) self.assertIn("sacrebleu_score", metrics) self.assertIn("bleu_score", metrics)
def test_load_dataset_raise_invalid_window(self): batch_tokens_size = 10 # this is too small to form buckets. train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record') _create_fake_dataset(train_data_path) data_config = wmt_dataloader.WMTDataConfig( input_path=train_data_path, max_seq_length=100, global_batch_size=batch_tokens_size) with self.assertRaisesRegex( ValueError, 'The token budget, global batch size, is too small.*'): _ = wmt_dataloader.WMTDataLoader(data_config).load()
def test_load_dataset_raise_invalid_window(self): batch_tokens_size = 10 # this is too small to form buckets. data_config = wmt_dataloader.WMTDataConfig( input_path=self._record_train_input_path, max_seq_length=100, global_batch_size=batch_tokens_size, is_training=True, static_batch=False, src_lang='en', tgt_lang='reverse_en', sentencepiece_model_path=self._sentencepeice_model_path) with self.assertRaisesRegex( ValueError, 'The token budget, global batch size, is too small.*'): _ = wmt_dataloader.WMTDataLoader(data_config).load()
def test_no_sentencepiece_path(self): config = translation.TranslationConfig( model=translation.ModelConfig(encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), train_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, src_lang="en", tgt_lang="reverse_en", is_training=True, static_batch=True, global_batch_size=4, max_seq_length=4), sentencepiece_model_path=None) with self.assertRaisesRegex(ValueError, "Setencepiece model path not provided."): translation.TranslationTask(config)
def test_load_dataset( self, is_training, static_batch, batch_size, expected_shape): data_config = wmt_dataloader.WMTDataConfig( input_path=self._record_train_input_path if is_training else self._record_test_input_path, max_seq_length=35, global_batch_size=batch_size, is_training=is_training, static_batch=static_batch, src_lang='en', tgt_lang='reverse_en', sentencepiece_model_path=self._sentencepeice_model_path) dataset = wmt_dataloader.WMTDataLoader(data_config).load() examples = next(iter(dataset)) inputs, targets = examples['inputs'], examples['targets'] self.assertEqual(inputs.shape, expected_shape) self.assertEqual(targets.shape, expected_shape)
def test_task(self): config = translation.TranslationConfig( model=translation.ModelConfig(encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), train_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, src_lang="en", tgt_lang="reverse_en", is_training=True, static_batch=True, global_batch_size=24, max_seq_length=12), sentencepiece_model_path=self._sentencepeice_model_path) task = translation.TranslationTask(config) model = task.build_model() dataset = task.build_inputs(config.train_data) iterator = iter(dataset) optimizer = tf.keras.optimizers.SGD(lr=0.1) task.train_step(next(iterator), model, optimizer)
def test_sentencepiece_no_eos(self): sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp_no_eos") _train_sentencepiece(self._sentencepeice_input_path, 20, sentencepeice_model_prefix, eos_id=-1) sentencepeice_model_path = "{}.model".format( sentencepeice_model_prefix) config = translation.TranslationConfig( model=translation.ModelConfig(encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), train_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, src_lang="en", tgt_lang="reverse_en", is_training=True, static_batch=True, global_batch_size=4, max_seq_length=4), sentencepiece_model_path=sentencepeice_model_path) with self.assertRaisesRegex(ValueError, "EOS token not in tokenizer vocab.*"): translation.TranslationTask(config)
def wmt_transformer_large() -> cfg.ExperimentConfig: """WMT Transformer Large. Please refer to tensorflow_models/official/nlp/data/train_sentencepiece.py to generate sentencepiece_model and pass --params_override=task.sentencepiece_model_path='YOUR_PATH' to the train script. """ learning_rate = 2.0 hidden_size = 1024 learning_rate *= (hidden_size**-0.5) warmup_steps = 16000 train_steps = 300000 token_batch_size = 24576 encdecoder = translation.EncDecoder(num_attention_heads=16, intermediate_size=hidden_size * 4) config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(enable_xla=True), task=translation.TranslationConfig( model=translation.ModelConfig(encoder=encdecoder, decoder=encdecoder, embedding_width=hidden_size, padded_decode=True, decode_max_length=100), train_data=wmt_dataloader.WMTDataConfig( tfds_name='wmt14_translate/de-en', tfds_split='train', src_lang='en', tgt_lang='de', is_training=True, global_batch_size=token_batch_size, static_batch=True, max_seq_length=64), validation_data=wmt_dataloader.WMTDataConfig( tfds_name='wmt14_translate/de-en', tfds_split='test', src_lang='en', tgt_lang='de', is_training=False, global_batch_size=32, static_batch=True, max_seq_length=100, ), sentencepiece_model_path=None, ), trainer=cfg.TrainerConfig( train_steps=train_steps, validation_steps=-1, steps_per_loop=1000, summary_interval=1000, checkpoint_interval=5000, validation_interval=5000, max_to_keep=1, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'adam', 'adam': { 'beta_2': 0.997, 'epsilon': 1e-9, }, }, 'learning_rate': { 'type': 'power', 'power': { 'initial_learning_rate': learning_rate, 'power': -0.5, } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': warmup_steps, 'warmup_learning_rate': 0.0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.sentencepiece_model_path != None', ]) return config