def test_load_from_file(self): # Test a vocab file with words not wrapped with single quotes encoder = text_encoder.SubwordTextEncoder() correct_vocab = ["the", "and", "of"] vocab = io.StringIO("the\n" "and\n" "of\n") encoder._load_from_file_object(vocab) self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab) # Test a vocab file with words wrapped in single quotes encoder = text_encoder.SubwordTextEncoder() vocab = io.StringIO("\"the\"\n" "\"and\"\n" "\"of\"\n") encoder._load_from_file_object(vocab) self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab)
def _get_vocab(vocab_type='subword', vocab_file=None, vocab_dir=None): """Gets the vocabulary object for tokenization; see tokenize for details.""" if vocab_type not in [ 'char', 'subword', 'sentencepiece', 'bert', 'bert-lowercase' ]: raise ValueError( 'vocab_type must be "subword", "char", "sentencepiece", "bert" or "bert-lowercase" ' f'but got {vocab_type}') if vocab_type == 'char': # Note that we set num_reserved_ids=0 below. We could instead pass # the value n_reserved_ids from tokenize here -- ByteTextEncoder does # exactly the same thing as tokenize above, ie., adds num_reserved_ids. return text_encoder.ByteTextEncoder(num_reserved_ids=0) vocab_dir = vocab_dir or 'gs://trax-ml/vocabs/' path = os.path.join(vocab_dir, vocab_file) if vocab_type == 'subword': return text_encoder.SubwordTextEncoder(path) if vocab_type == 'bert': return text_encoder.BertEncoder(path, do_lower_case=False) if vocab_type == 'bert-lowercase': return text_encoder.BertEncoder(path, do_lower_case=True) assert vocab_type == 'sentencepiece' return t5.data.SentencePieceVocabulary(sentencepiece_model_file=path, extra_ids=0)
def test_reserved_token_chars_not_in_alphabet(self): corpus = "dog" token_counts = collections.Counter(corpus.split(" ")) encoder1 = text_encoder.SubwordTextEncoder.build_to_target_size( 100, token_counts, 2, 100) filename = os.path.join(self.test_temp_dir, "out.voc") encoder1.store_to_file(filename) encoder2 = text_encoder.SubwordTextEncoder(filename=filename) self.assertEqual(encoder1._alphabet, encoder2._alphabet) for t in text_encoder.RESERVED_TOKENS: for c in t: # Verify that encoders can encode all reserved token chars. encoder1.encode(c) encoder2.encode(c)
def test_save_and_reload_no_single_quotes(self): corpus = "the quick brown fox jumps over the lazy dog" token_counts = collections.Counter(corpus.split(" ")) # Deliberately exclude some required encoding chars from the alphabet # and token list, making some strings unencodable. encoder = text_encoder.SubwordTextEncoder.build_to_target_size( 100, token_counts, 2, 10) filename = os.path.join(self.test_temp_dir, "out.voc") encoder.store_to_file(filename, add_single_quotes=False) new_encoder = text_encoder.SubwordTextEncoder(filename) self.assertEqual(encoder._alphabet, new_encoder._alphabet) self.assertEqual(encoder.all_subtoken_strings, new_encoder.all_subtoken_strings) self.assertEqual(encoder._subtoken_string_to_id, new_encoder._subtoken_string_to_id) self.assertEqual(encoder._max_subtoken_len, new_encoder._max_subtoken_len)
def main(unused_argv): if FLAGS.corpus_filepattern and FLAGS.vocab_filepattern: raise ValueError( 'Must only provide one of --corpus_filepattern or --vocab_filepattern' ) elif FLAGS.corpus_filepattern: token_counts = tokenizer.corpus_token_counts( FLAGS.corpus_filepattern, FLAGS.corpus_max_lines, split_on_newlines=FLAGS.split_on_newlines) elif FLAGS.vocab_filepattern: token_counts = tokenizer.vocab_token_counts(FLAGS.vocab_filepattern, FLAGS.corpus_max_lines) else: raise ValueError( 'Must provide one of --corpus_filepattern or --vocab_filepattern') encoder = text_encoder.SubwordTextEncoder() encoder.build_from_token_counts(token_counts, FLAGS.min_count, FLAGS.num_iterations) encoder.store_to_file(FLAGS.output_filename)