Beispiel #1
0
  def make_tf_examples(self, example, is_training):
    passages = []
    spans = []
    token_maps = []
    vocab_file = self._get_vocab_file()
    tf_example_creator = tf_io.CreateTFExampleFn(
        is_training=is_training,
        max_question_length=64,
        max_seq_length=512,
        doc_stride=128,
        include_unknowns=1.0,
        vocab_file=vocab_file)
    for record in list(
        tf_example_creator.process(example, errors=[], debug_info={})):
      tfexample = tf.train.Example()
      tfexample.ParseFromString(record)
      tokens = []
      passages.append(" ".join(tokens).replace(" ##", ""))
      if is_training:
        start = tfexample.features.feature["start_positions"].int64_list.value[
            0]
        end = tfexample.features.feature["end_positions"].int64_list.value[0]
        spans.append(" ".join(tokens[start:end + 1]).replace(" ##", ""))
      else:
        token_maps.append(
            tfexample.features.feature["token_map"].int64_list.value)

    return passages, spans, token_maps
Beispiel #2
0
def main(_):
  examples_processed = 0
  num_examples_with_correct_context = 0
  num_errors = 0
  tf_examples = []
  sample_ratio = {}

  # Print the first 25 examples to let user know what's going on.
  num_examples_to_print = 25

  if FLAGS.oversample and FLAGS.is_training:
    lang_count = get_lang_counts(FLAGS.input_jsonl)
    max_count = max(count for lang, count in lang_count.items())
    for lang, curr_count in lang_count.items():
      sample_ratio[lang] = int(
          min(FLAGS.max_oversample_ratio, max_count / curr_count))
  creator_fn = tf_io.CreateTFExampleFn(
      is_training=FLAGS.is_training,
      max_question_length=FLAGS.max_question_length,
      max_seq_length=FLAGS.max_seq_length,
      doc_stride=FLAGS.doc_stride,
      include_unknowns=FLAGS.include_unknowns,
      vocab_file=FLAGS.vocab_file)
  reverse_vocab_table = {
      word_id: word for word, word_id in creator_fn.vocab.items()
  }
  tf.logging.info("Reading examples from glob: %s", FLAGS.input_jsonl)
  for filename, line_no, entry, debug_info in read_entries(
      FLAGS.input_jsonl, fail_on_invalid=FLAGS.fail_on_invalid):
    errors = []
    for tf_example in creator_fn.process(entry, errors, debug_info):
      if FLAGS.oversample:
        tf_examples.extend([tf_example] * sample_ratio[entry["language"]])
      else:
        tf_examples.append(tf_example)

    if errors or examples_processed < num_examples_to_print:
      debug.log_debug_info(filename, line_no, entry, debug_info,
                           reverse_vocab_table)

    if examples_processed % 10 == 0:
      tf.logging.info("Examples processed: %d", examples_processed)
    examples_processed += 1

    if errors:
      tf.logging.info(
          "Encountered errors while creating {} example ({}:{}): {}".format(
              entry["language"], filename, line_no, "; ".join(errors)))
      if FLAGS.fail_on_invalid:
        raise ValueError(
            "Encountered errors while creating example ({}:{}): {}".format(
                filename, line_no, "; ".join(errors)))
      num_errors += 1
      if num_errors % 10 == 0:
        tf.logging.info("Errors so far: %d", num_errors)

    if entry["has_correct_context"]:
      num_examples_with_correct_context += 1
    if FLAGS.max_examples > 0 and examples_processed >= FLAGS.max_examples:
      break
  tf.logging.info("Examples with correct context retained: %d of %d",
                  num_examples_with_correct_context, examples_processed)

  # Even though the input is shuffled, we need to do this in case we're
  # oversampling.
  random.shuffle(tf_examples)
  num_features = len(tf_examples)
  tf.logging.info("Number of total features %d", num_features)
  tf.logging.info("'Features' are windowed slices of a document paired with "
                  "a supervision label.")

  with tf.python_io.TFRecordWriter(FLAGS.output_tfrecord) as writer:
    for tf_example in tf_examples:
      writer.write(tf_example.SerializeToString())
  if FLAGS.record_count_file:
    with tf.gfile.Open(FLAGS.record_count_file, "w") as writer:
      writer.write(str(num_features))