def test_load_from_file(self): # Test a vocab file with words not wrapped with single quotes encoder = text_encoder.SubwordTextEncoder() correct_vocab = ["the", "and", "of"] vocab = io.StringIO("the\n" "and\n" "of\n") encoder._load_from_file_object(vocab) self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab) # Test a vocab file with words wrapped in single quotes encoder = text_encoder.SubwordTextEncoder() vocab = io.StringIO("\"the\"\n" "\"and\"\n" "\"of\"\n") encoder._load_from_file_object(vocab) self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab)
def test_reserved_token_chars_not_in_alphabet(self): corpus = "dog" token_counts = collections.Counter(corpus.split(" ")) encoder1 = text_encoder.SubwordTextEncoder.build_to_target_size( 100, token_counts, 2, 100) filename = os.path.join(self.test_temp_dir, "out.voc") encoder1.store_to_file(filename) encoder2 = text_encoder.SubwordTextEncoder(filename=filename) self.assertEqual(encoder1._alphabet, encoder2._alphabet) for t in text_encoder.RESERVED_TOKENS: for c in t: # Verify that encoders can encode all reserved token chars. encoder1.encode(c) encoder2.encode(c)
def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, generator, max_subtoken_length=None, reserved_tokens=None): """Inner implementation for vocab generators. Args: data_dir: The base directory where data and vocab files are stored. If None, then do not save the vocab even if it doesn't exist. vocab_filename: relative filename where vocab file is stored vocab_size: target size of the vocabulary constructed by SubwordTextEncoder generator: a generator that produces tokens from the vocabulary max_subtoken_length: an optional integer. Set this to a finite value to avoid quadratic costs during vocab building. reserved_tokens: List of reserved tokens. `text_encoder.RESERVED_TOKENS` should be a prefix of `reserved_tokens`. If `None`, defaults to `RESERVED_TOKENS`. Returns: A SubwordTextEncoder vocabulary object. """ if data_dir and vocab_filename: vocab_filepath = os.path.join(data_dir, vocab_filename) if tf.gfile.Exists(vocab_filepath): tf.logging.info("Found vocab file: %s", vocab_filepath) return text_encoder.SubwordTextEncoder(vocab_filepath) else: vocab_filepath = None tf.logging.info("Generating vocab file: %s", vocab_filepath) vocab = text_encoder.SubwordTextEncoder.build_from_generator( generator, vocab_size, max_subtoken_length=max_subtoken_length, reserved_tokens=reserved_tokens) if vocab_filepath: tf.gfile.MakeDirs(data_dir) vocab.store_to_file(vocab_filepath) return vocab
def test_save_and_reload_no_single_quotes(self): corpus = "the quick brown fox jumps over the lazy dog" token_counts = collections.Counter(corpus.split(" ")) # Deliberately exclude some required encoding chars from the alphabet # and token list, making some strings unencodable. encoder = text_encoder.SubwordTextEncoder.build_to_target_size( 100, token_counts, 2, 10) filename = os.path.join(self.test_temp_dir, "out.voc") encoder.store_to_file(filename, add_single_quotes=False) new_encoder = text_encoder.SubwordTextEncoder(filename) self.assertEqual(encoder._alphabet, new_encoder._alphabet) self.assertEqual(encoder.all_subtoken_strings, new_encoder.all_subtoken_strings) self.assertEqual(encoder._subtoken_string_to_id, new_encoder._subtoken_string_to_id) self.assertEqual(encoder._max_subtoken_len, new_encoder._max_subtoken_len)
def testText2TextTmpDir(self): problem = Test1() problem.generate_data(self.tmp_dir, self.tmp_dir) vocab_file = os.path.join(self.tmp_dir, "vocab.test1.3.subwords") train_file = os.path.join(self.tmp_dir, "test1-train-00000-of-00001") eval_file = os.path.join(self.tmp_dir, "test1-dev-00000-of-00001") self.assertTrue(tf.gfile.Exists(vocab_file)) self.assertTrue(tf.gfile.Exists(train_file)) self.assertTrue(tf.gfile.Exists(eval_file)) dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, self.tmp_dir) features = dataset.make_one_shot_iterator().get_next() examples = [] exhausted = False with self.test_session() as sess: examples.append(sess.run(features)) examples.append(sess.run(features)) try: sess.run(features) except tf.errors.OutOfRangeError: exhausted = True self.assertTrue(exhausted) self.assertEqual(2, len(examples)) self.assertNotEqual( list(examples[0]["inputs"]), list(examples[1]["inputs"])) example = examples[0] encoder = text_encoder.SubwordTextEncoder(vocab_file) inputs_encoded = list(example["inputs"]) inputs_encoded.pop() # rm EOS self.assertTrue(encoder.decode(inputs_encoded) in self.inputs) targets_encoded = list(example["targets"]) targets_encoded.pop() # rm EOS self.assertTrue(encoder.decode(targets_encoded) in self.targets)
def main(args): encoder = text_encoder.SubwordTextEncoder() fns = read_fns_codesearchnet(args.data) encoder.build_from_generator(fns, args.min_count, args.num_iterations) encoder.store_to_file(args.output_filename)