Exemple #1
0
  def test_predict(self):
    task_config = tagging.TaggingConfig(
        model=tagging.ModelConfig(encoder=self._encoder_config),
        train_data=self._train_data_config,
        class_names=["O", "B-PER", "I-PER"])
    task = tagging.TaggingTask(task_config)
    model = task.build_model()

    test_data_path = os.path.join(self.get_temp_dir(), "test.tf_record")
    seq_length = 16
    num_examples = 100
    _create_fake_dataset(
        test_data_path,
        seq_length=seq_length,
        num_labels=len(task_config.class_names),
        num_examples=num_examples)
    test_data_config = tagging_data_loader.TaggingDataConfig(
        input_path=test_data_path,
        seq_length=seq_length,
        is_training=False,
        global_batch_size=16,
        drop_remainder=False,
        include_sentence_id=True)

    predict_ids, sentence_ids = tagging.predict(task, test_data_config, model)
    self.assertLen(predict_ids, num_examples)
    self.assertLen(sentence_ids, num_examples)
Exemple #2
0
  def test_load_dataset(self):
    seq_length = 16
    batch_size = 10
    train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
    _create_fake_dataset(train_data_path, seq_length)
    data_config = tagging_data_loader.TaggingDataConfig(
        input_path=train_data_path,
        seq_length=seq_length,
        global_batch_size=batch_size)

    dataset = tagging_data_loader.TaggingDataLoader(data_config).load()
    features, labels = next(iter(dataset))
    self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'],
                          features.keys())
    self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
    self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
    self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
    self.assertEqual(labels.shape, (batch_size, seq_length))
Exemple #3
0
 def setUp(self):
   super(TaggingTest, self).setUp()
   self._encoder_config = encoders.EncoderConfig(
       bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1))
   self._train_data_config = tagging_data_loader.TaggingDataConfig(
       input_path="dummy", seq_length=128, global_batch_size=1)