def generate_squad_dataset():
    """Generates squad training dataset and returns input meta data."""
    assert FLAGS.squad_data_file
    if FLAGS.tokenization == "WordPiece":
        return squad_lib_wp.generate_tf_record_from_json_file(
            input_file_path=FLAGS.squad_data_file,
            vocab_file_path=FLAGS.vocab_file,
            output_path=FLAGS.train_data_output_path,
            translated_input_folder=FLAGS.translated_squad_data_folder,
            max_seq_length=FLAGS.max_seq_length,
            do_lower_case=FLAGS.do_lower_case,
            max_query_length=FLAGS.max_query_length,
            doc_stride=FLAGS.doc_stride,
            version_2_with_negative=FLAGS.version_2_with_negative,
            xlnet_format=FLAGS.xlnet_format)
    else:
        assert FLAGS.tokenization == "SentencePiece"
        return squad_lib_sp.generate_tf_record_from_json_file(
            input_file_path=FLAGS.squad_data_file,
            sp_model_file=FLAGS.sp_model_file,
            output_path=FLAGS.train_data_output_path,
            translated_input_folder=FLAGS.translated_squad_data_folder,
            max_seq_length=FLAGS.max_seq_length,
            do_lower_case=FLAGS.do_lower_case,
            max_query_length=FLAGS.max_query_length,
            doc_stride=FLAGS.doc_stride,
            xlnet_format=FLAGS.xlnet_format,
            version_2_with_negative=FLAGS.version_2_with_negative)
示例#2
0
def generate_squad_dataset():
  """Generates squad training dataset and returns input meta data."""
  assert FLAGS.squad_data_file
  if FLAGS.tokenizer_impl == "word_piece":
    return squad_lib_wp.generate_tf_record_from_json_file(
        FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
        FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
        FLAGS.doc_stride, FLAGS.version_2_with_negative)
  else:
    assert FLAGS.tokenizer_impl == "sentence_piece"
    return squad_lib_sp.generate_tf_record_from_json_file(
        FLAGS.squad_data_file, FLAGS.sp_model_file,
        FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case,
        FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)