def test_predict(self): task_config = tagging.TaggingConfig( model=tagging.ModelConfig(encoder=self._encoder_config), train_data=self._train_data_config, class_names=["O", "B-PER", "I-PER"]) task = tagging.TaggingTask(task_config) model = task.build_model() test_data_path = os.path.join(self.get_temp_dir(), "test.tf_record") seq_length = 16 num_examples = 100 _create_fake_dataset( test_data_path, seq_length=seq_length, num_labels=len(task_config.class_names), num_examples=num_examples) test_data_config = tagging_data_loader.TaggingDataConfig( input_path=test_data_path, seq_length=seq_length, is_training=False, global_batch_size=16, drop_remainder=False, include_sentence_id=True) predict_ids, sentence_ids = tagging.predict(task, test_data_config, model) self.assertLen(predict_ids, num_examples) self.assertLen(sentence_ids, num_examples)
def test_load_dataset(self): seq_length = 16 batch_size = 10 train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record') _create_fake_dataset(train_data_path, seq_length) data_config = tagging_data_loader.TaggingDataConfig( input_path=train_data_path, seq_length=seq_length, global_batch_size=batch_size) dataset = tagging_data_loader.TaggingDataLoader(data_config).load() features, labels = next(iter(dataset)) self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'], features.keys()) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(labels.shape, (batch_size, seq_length))
def setUp(self): super(TaggingTest, self).setUp() self._encoder_config = encoders.EncoderConfig( bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)) self._train_data_config = tagging_data_loader.TaggingDataConfig( input_path="dummy", seq_length=128, global_batch_size=1)