Ejemplo n.º 1
0
    def test_load_from_file(self):
        # Test a vocab file with words not wrapped with single quotes
        encoder = text_encoder.SubwordTextEncoder()
        correct_vocab = ["the", "and", "of"]
        vocab = io.StringIO("the\n" "and\n" "of\n")
        encoder._load_from_file_object(vocab)
        self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab)

        # Test a vocab file with words wrapped in single quotes
        encoder = text_encoder.SubwordTextEncoder()
        vocab = io.StringIO("\"the\"\n" "\"and\"\n" "\"of\"\n")
        encoder._load_from_file_object(vocab)
        self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab)
Ejemplo n.º 2
0
def _get_vocab(vocab_type='subword', vocab_file=None, vocab_dir=None):
    """Gets the vocabulary object for tokenization; see tokenize for details."""
    if vocab_type not in [
            'char', 'subword', 'sentencepiece', 'bert', 'bert-lowercase'
    ]:
        raise ValueError(
            'vocab_type must be "subword", "char", "sentencepiece", "bert" or "bert-lowercase" '
            f'but got {vocab_type}')

    if vocab_type == 'char':
        # Note that we set num_reserved_ids=0 below. We could instead pass
        # the value n_reserved_ids from tokenize here -- ByteTextEncoder does
        # exactly the same thing as tokenize above, ie., adds num_reserved_ids.
        return text_encoder.ByteTextEncoder(num_reserved_ids=0)

    vocab_dir = vocab_dir or 'gs://trax-ml/vocabs/'
    path = os.path.join(vocab_dir, vocab_file)

    if vocab_type == 'subword':
        return text_encoder.SubwordTextEncoder(path)

    if vocab_type == 'bert':
        return text_encoder.BertEncoder(path, do_lower_case=False)

    if vocab_type == 'bert-lowercase':
        return text_encoder.BertEncoder(path, do_lower_case=True)

    assert vocab_type == 'sentencepiece'
    return t5.data.SentencePieceVocabulary(sentencepiece_model_file=path,
                                           extra_ids=0)
Ejemplo n.º 3
0
    def test_reserved_token_chars_not_in_alphabet(self):
        corpus = "dog"
        token_counts = collections.Counter(corpus.split(" "))
        encoder1 = text_encoder.SubwordTextEncoder.build_to_target_size(
            100, token_counts, 2, 100)
        filename = os.path.join(self.test_temp_dir, "out.voc")
        encoder1.store_to_file(filename)
        encoder2 = text_encoder.SubwordTextEncoder(filename=filename)

        self.assertEqual(encoder1._alphabet, encoder2._alphabet)

        for t in text_encoder.RESERVED_TOKENS:
            for c in t:
                # Verify that encoders can encode all reserved token chars.
                encoder1.encode(c)
                encoder2.encode(c)
Ejemplo n.º 4
0
    def test_save_and_reload_no_single_quotes(self):
        corpus = "the quick brown fox jumps over the lazy dog"
        token_counts = collections.Counter(corpus.split(" "))

        # Deliberately exclude some required encoding chars from the alphabet
        # and token list, making some strings unencodable.
        encoder = text_encoder.SubwordTextEncoder.build_to_target_size(
            100, token_counts, 2, 10)

        filename = os.path.join(self.test_temp_dir, "out.voc")
        encoder.store_to_file(filename, add_single_quotes=False)
        new_encoder = text_encoder.SubwordTextEncoder(filename)

        self.assertEqual(encoder._alphabet, new_encoder._alphabet)
        self.assertEqual(encoder.all_subtoken_strings,
                         new_encoder.all_subtoken_strings)
        self.assertEqual(encoder._subtoken_string_to_id,
                         new_encoder._subtoken_string_to_id)
        self.assertEqual(encoder._max_subtoken_len,
                         new_encoder._max_subtoken_len)
Ejemplo n.º 5
0
def main(unused_argv):
    if FLAGS.corpus_filepattern and FLAGS.vocab_filepattern:
        raise ValueError(
            'Must only provide one of --corpus_filepattern or --vocab_filepattern'
        )

    elif FLAGS.corpus_filepattern:
        token_counts = tokenizer.corpus_token_counts(
            FLAGS.corpus_filepattern,
            FLAGS.corpus_max_lines,
            split_on_newlines=FLAGS.split_on_newlines)

    elif FLAGS.vocab_filepattern:
        token_counts = tokenizer.vocab_token_counts(FLAGS.vocab_filepattern,
                                                    FLAGS.corpus_max_lines)

    else:
        raise ValueError(
            'Must provide one of --corpus_filepattern or --vocab_filepattern')

    encoder = text_encoder.SubwordTextEncoder()
    encoder.build_from_token_counts(token_counts, FLAGS.min_count,
                                    FLAGS.num_iterations)
    encoder.store_to_file(FLAGS.output_filename)