def generate_tagging_dataset(): """Generates tagging dataset.""" processors = { "panx": tagging_data_lib.PanxProcessor, "udpos": tagging_data_lib.UdposProcessor, } task_name = FLAGS.tagging_task_name.lower() if task_name not in processors: raise ValueError("Task not found: %s" % task_name) if FLAGS.tokenizer_impl == "word_piece": tokenizer = tokenization.FullTokenizer( vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) processor_text_fn = tokenization.convert_to_unicode elif FLAGS.tokenizer_impl == "sentence_piece": tokenizer = tokenization.FullSentencePieceTokenizer( FLAGS.sp_model_file) processor_text_fn = functools.partial(tokenization.preprocess_text, lower=FLAGS.do_lower_case) else: raise ValueError("Unsupported tokenizer_impl: %s" % FLAGS.tokenizer_impl) processor = processors[task_name]() return tagging_data_lib.generate_tf_record_from_data_file( processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length, FLAGS.train_data_output_path, FLAGS.eval_data_output_path, FLAGS.test_data_output_path, processor_text_fn)
def test_generate_tf_record(self, task_type): processor = self.processors[task_type]() input_data_dir = os.path.join(self.get_temp_dir(), task_type) tf.io.gfile.mkdir(input_data_dir) # Write fake train file. _create_fake_file( os.path.join(input_data_dir, "train-en.tsv"), processor.get_labels(), is_test=False) # Write fake dev file. _create_fake_file( os.path.join(input_data_dir, "dev-en.tsv"), processor.get_labels(), is_test=False) # Write fake test files. for lang in processor.supported_languages: _create_fake_file( os.path.join(input_data_dir, "test-%s.tsv" % lang), processor.get_labels(), is_test=True) output_path = os.path.join(self.get_temp_dir(), task_type, "output") tokenizer = tokenization.FullTokenizer( vocab_file=self.vocab_file, do_lower_case=True) metadata = tagging_data_lib.generate_tf_record_from_data_file( processor, input_data_dir, tokenizer, max_seq_length=8, train_data_output_path=os.path.join(output_path, "train.tfrecord"), eval_data_output_path=os.path.join(output_path, "eval.tfrecord"), test_data_output_path=os.path.join(output_path, "test_{}.tfrecord"), text_preprocessing=tokenization.convert_to_unicode) self.assertEqual(metadata["train_data_size"], 5) files = tf.io.gfile.glob(output_path + "/*") expected_files = [] expected_files.append(os.path.join(output_path, "train.tfrecord")) expected_files.append(os.path.join(output_path, "eval.tfrecord")) for lang in processor.supported_languages: expected_files.append( os.path.join(output_path, "test_%s.tfrecord" % lang)) self.assertCountEqual(files, expected_files)