Ejemplo n.º 1
0
def main(_):
    examples_processed = 0
    num_examples_with_correct_context = 0
    num_errors = 0
    tf_examples = []
    sample_ratio = {}

    # Print the first 25 examples to let user know what's going on.
    num_examples_to_print = 25

    if FLAGS.oversample and FLAGS.is_training:
        lang_count = get_lang_counts(FLAGS.input_jsonl)
        max_count = max(count for lang, count in lang_count.items())
        for lang, curr_count in lang_count.items():
            sample_ratio[lang] = int(
                min(FLAGS.max_oversample_ratio, max_count / curr_count))

    splitter = char_splitter.CharacterSplitter()
    creator_fn = tf_io.CreateTFExampleFn(
        is_training=FLAGS.is_training,
        max_question_length=FLAGS.max_question_length,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        include_unknowns=FLAGS.include_unknowns,
        tokenizer=splitter)
    tf.logging.info("Reading examples from glob: %s", FLAGS.input_jsonl)
    for filename, line_no, entry, debug_info in tf_io.read_entries(
            FLAGS.input_jsonl,
            tokenizer=splitter,
            max_passages=FLAGS.max_passages,
            max_position=FLAGS.max_position,
            fail_on_invalid=FLAGS.fail_on_invalid):
        errors = []
        for tf_example in creator_fn.process(entry, errors, debug_info):
            if FLAGS.oversample:
                tf_examples.extend([tf_example] *
                                   sample_ratio[entry["language"]])
            else:
                tf_examples.append(tf_example)

        if errors or examples_processed < num_examples_to_print:
            debug.log_debug_info(filename, line_no, entry, debug_info,
                                 splitter.id_to_string)

        if examples_processed % 10 == 0:
            tf.logging.info("Examples processed: %d", examples_processed)
        examples_processed += 1

        if errors:
            tf.logging.info(
                "Encountered errors while creating {} example ({}:{}): {}".
                format(entry["language"], filename, line_no,
                       "; ".join(errors)))
            if FLAGS.fail_on_invalid:
                raise ValueError(
                    "Encountered errors while creating example ({}:{}): {}".
                    format(filename, line_no, "; ".join(errors)))
            num_errors += 1
            if num_errors % 10 == 0:
                tf.logging.info("Errors so far: %d", num_errors)

        if entry["has_correct_context"]:
            num_examples_with_correct_context += 1
        if FLAGS.max_examples > 0 and examples_processed >= FLAGS.max_examples:
            break
    tf.logging.info("Examples with correct context retained: %d of %d",
                    num_examples_with_correct_context, examples_processed)

    # Even though the input is shuffled, we need to do this in case we're
    # oversampling.
    random.shuffle(tf_examples)
    num_features = len(tf_examples)
    tf.logging.info("Number of total features %d", num_features)
    tf.logging.info("'Features' are windowed slices of a document paired with "
                    "a supervision label.")

    with tf.python_io.TFRecordWriter(FLAGS.output_tfrecord) as writer:
        for tf_example in tf_examples:
            writer.write(tf_example.SerializeToString())
    if FLAGS.record_count_file:
        with tf.gfile.Open(FLAGS.record_count_file, "w") as writer:
            writer.write(str(num_features))
Ejemplo n.º 2
0
def make_tokenizer() -> tydi_tokenization_interface.TokenizerWithOffsets:
    return char_splitter.CharacterSplitter()
Ejemplo n.º 3
0
def make_tokenizer():
    return char_splitter.CharacterSplitter()
Ejemplo n.º 4
0
 def get_tokenizer(self):
     return char_splitter.CharacterSplitter()