def test_load_dataset(self):
     batch_tokens_size = 100
     train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
     _create_fake_dataset(train_data_path)
     data_config = wmt_dataloader.WMTDataConfig(
         input_path=train_data_path,
         max_seq_length=35,
         global_batch_size=batch_tokens_size,
         is_training=True,
         static_batch=False)
     dataset = wmt_dataloader.WMTDataLoader(data_config).load()
     examples = next(iter(dataset))
     inputs, targets = examples['inputs'], examples['targets']
     logging.info('dynamic inputs=%s targets=%s', inputs, targets)
     data_config = wmt_dataloader.WMTDataConfig(
         input_path=train_data_path,
         max_seq_length=35,
         global_batch_size=batch_tokens_size,
         is_training=True,
         static_batch=True)
     dataset = wmt_dataloader.WMTDataLoader(data_config).load()
     examples = next(iter(dataset))
     inputs, targets = examples['inputs'], examples['targets']
     logging.info('static inputs=%s targets=%s', inputs, targets)
     self.assertEqual(inputs.shape, (2, 35))
     self.assertEqual(targets.shape, (2, 35))
Example #2
0
 def test_load_dataset_raise_invalid_window(self):
     batch_tokens_size = 10  # this is too small to form buckets.
     train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
     _create_fake_dataset(train_data_path)
     data_config = wmt_dataloader.WMTDataConfig(
         input_path=train_data_path,
         max_seq_length=100,
         global_batch_size=batch_tokens_size)
     with self.assertRaisesRegex(
             ValueError,
             'The token budget, global batch size, is too small.*'):
         _ = wmt_dataloader.WMTDataLoader(data_config).load()
 def test_load_dataset_raise_invalid_window(self):
   batch_tokens_size = 10  # this is too small to form buckets.
   data_config = wmt_dataloader.WMTDataConfig(
       input_path=self._record_train_input_path,
       max_seq_length=100,
       global_batch_size=batch_tokens_size,
       is_training=True,
       static_batch=False,
       src_lang='en',
       tgt_lang='reverse_en',
       sentencepiece_model_path=self._sentencepeice_model_path)
   with self.assertRaisesRegex(
       ValueError, 'The token budget, global batch size, is too small.*'):
     _ = wmt_dataloader.WMTDataLoader(data_config).load()
 def test_load_dataset(
     self, is_training, static_batch, batch_size, expected_shape):
   data_config = wmt_dataloader.WMTDataConfig(
       input_path=self._record_train_input_path
       if is_training else self._record_test_input_path,
       max_seq_length=35,
       global_batch_size=batch_size,
       is_training=is_training,
       static_batch=static_batch,
       src_lang='en',
       tgt_lang='reverse_en',
       sentencepiece_model_path=self._sentencepeice_model_path)
   dataset = wmt_dataloader.WMTDataLoader(data_config).load()
   examples = next(iter(dataset))
   inputs, targets = examples['inputs'], examples['targets']
   self.assertEqual(inputs.shape, expected_shape)
   self.assertEqual(targets.shape, expected_shape)