def test_load_dataset(self, include_sentence_id): seq_length = 16 batch_size = 10 train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record') _create_fake_dataset(train_data_path, seq_length, include_sentence_id) data_config = tagging_dataloader.TaggingDataConfig( input_path=train_data_path, seq_length=seq_length, global_batch_size=batch_size, include_sentence_id=include_sentence_id) dataset = tagging_dataloader.TaggingDataLoader(data_config).load() features, labels = next(iter(dataset)) expected_keys = ['input_word_ids', 'input_mask', 'input_type_ids'] if include_sentence_id: expected_keys.extend(['sentence_id', 'sub_sentence_id']) self.assertCountEqual(expected_keys, features.keys()) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(labels.shape, (batch_size, seq_length)) if include_sentence_id: self.assertEqual(features['sentence_id'].shape, (batch_size, )) self.assertEqual(features['sub_sentence_id'].shape, (batch_size, ))
def test_load_dataset(self): seq_length = 16 batch_size = 10 train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record') _create_fake_dataset(train_data_path, seq_length) data_config = tagging_dataloader.TaggingDataConfig( input_path=train_data_path, seq_length=seq_length, global_batch_size=batch_size) dataset = tagging_dataloader.TaggingDataLoader(data_config).load() features, labels = next(iter(dataset)) self.assertCountEqual( ['input_word_ids', 'input_mask', 'input_type_ids'], features.keys()) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(labels.shape, (batch_size, seq_length))