def test_predict(self): task_config = tagging.TaggingConfig( model=tagging.ModelConfig(encoder=self._encoder_config), train_data=self._train_data_config, class_names=["O", "B-PER", "I-PER"]) task = tagging.TaggingTask(task_config) model = task.build_model() test_data_path = os.path.join(self.get_temp_dir(), "test.tf_record") seq_length = 16 num_examples = 100 _create_fake_dataset(test_data_path, seq_length=seq_length, num_labels=len(task_config.class_names), num_examples=num_examples) test_data_config = tagging_dataloader.TaggingDataConfig( input_path=test_data_path, seq_length=seq_length, is_training=False, global_batch_size=16, drop_remainder=False, include_sentence_id=True) results = tagging.predict(task, test_data_config, model) self.assertLen(results, num_examples) self.assertLen(results[0], 3)
def write_tagging(task, model, input_file, output_file, predict_batch_size, seq_length): """Makes tagging predictions and writes to output file.""" data_config = tagging_dataloader.TaggingDataConfig( input_path=input_file, is_training=False, seq_length=seq_length, global_batch_size=predict_batch_size, drop_remainder=False, include_sentence_id=True) results = tagging.predict(task, data_config, model) class_names = task.task_config.class_names last_sentence_id = -1 with tf.io.gfile.GFile(output_file, 'w') as writer: for sentence_id, _, predict_ids in results: token_labels = [class_names[x] for x in predict_ids] assert sentence_id == last_sentence_id or ( sentence_id == last_sentence_id + 1) if sentence_id != last_sentence_id and last_sentence_id != -1: writer.write('\n') writer.write('\n'.join(token_labels)) writer.write('\n') last_sentence_id = sentence_id