Esempio n. 1
0
def generate_tagging_dataset():
    """Generates tagging dataset."""
    processors = {
        "panx": tagging_data_lib.PanxProcessor,
        "udpos": tagging_data_lib.UdposProcessor,
    }
    task_name = FLAGS.tagging_task_name.lower()
    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    if FLAGS.tokenizer_impl == "word_piece":
        tokenizer = tokenization.FullTokenizer(
            vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
        processor_text_fn = tokenization.convert_to_unicode
    elif FLAGS.tokenizer_impl == "sentence_piece":
        tokenizer = tokenization.FullSentencePieceTokenizer(
            FLAGS.sp_model_file)
        processor_text_fn = functools.partial(tokenization.preprocess_text,
                                              lower=FLAGS.do_lower_case)
    else:
        raise ValueError("Unsupported tokenizer_impl: %s" %
                         FLAGS.tokenizer_impl)

    processor = processors[task_name]()
    return tagging_data_lib.generate_tf_record_from_data_file(
        processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
        FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
        FLAGS.test_data_output_path, processor_text_fn)
  def test_generate_tf_record(self, task_type):
    processor = self.processors[task_type]()
    input_data_dir = os.path.join(self.get_temp_dir(), task_type)
    tf.io.gfile.mkdir(input_data_dir)
    # Write fake train file.
    _create_fake_file(
        os.path.join(input_data_dir, "train-en.tsv"),
        processor.get_labels(),
        is_test=False)

    # Write fake dev file.
    _create_fake_file(
        os.path.join(input_data_dir, "dev-en.tsv"),
        processor.get_labels(),
        is_test=False)

    # Write fake test files.
    for lang in processor.supported_languages:
      _create_fake_file(
          os.path.join(input_data_dir, "test-%s.tsv" % lang),
          processor.get_labels(),
          is_test=True)

    output_path = os.path.join(self.get_temp_dir(), task_type, "output")
    tokenizer = tokenization.FullTokenizer(
        vocab_file=self.vocab_file, do_lower_case=True)
    metadata = tagging_data_lib.generate_tf_record_from_data_file(
        processor,
        input_data_dir,
        tokenizer,
        max_seq_length=8,
        train_data_output_path=os.path.join(output_path, "train.tfrecord"),
        eval_data_output_path=os.path.join(output_path, "eval.tfrecord"),
        test_data_output_path=os.path.join(output_path, "test_{}.tfrecord"),
        text_preprocessing=tokenization.convert_to_unicode)

    self.assertEqual(metadata["train_data_size"], 5)
    files = tf.io.gfile.glob(output_path + "/*")
    expected_files = []
    expected_files.append(os.path.join(output_path, "train.tfrecord"))
    expected_files.append(os.path.join(output_path, "eval.tfrecord"))
    for lang in processor.supported_languages:
      expected_files.append(
          os.path.join(output_path, "test_%s.tfrecord" % lang))

    self.assertCountEqual(files, expected_files)