コード例 #1
0
    def test_load_from_file(self):
        # Test a vocab file with words not wrapped with single quotes
        encoder = text_encoder.SubwordTextEncoder()
        correct_vocab = ["the", "and", "of"]
        vocab = io.StringIO("the\n" "and\n" "of\n")
        encoder._load_from_file_object(vocab)
        self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab)

        # Test a vocab file with words wrapped in single quotes
        encoder = text_encoder.SubwordTextEncoder()
        vocab = io.StringIO("\"the\"\n" "\"and\"\n" "\"of\"\n")
        encoder._load_from_file_object(vocab)
        self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab)
コード例 #2
0
    def test_reserved_token_chars_not_in_alphabet(self):
        corpus = "dog"
        token_counts = collections.Counter(corpus.split(" "))
        encoder1 = text_encoder.SubwordTextEncoder.build_to_target_size(
            100, token_counts, 2, 100)
        filename = os.path.join(self.test_temp_dir, "out.voc")
        encoder1.store_to_file(filename)
        encoder2 = text_encoder.SubwordTextEncoder(filename=filename)

        self.assertEqual(encoder1._alphabet, encoder2._alphabet)

        for t in text_encoder.RESERVED_TOKENS:
            for c in t:
                # Verify that encoders can encode all reserved token chars.
                encoder1.encode(c)
                encoder2.encode(c)
コード例 #3
0
def get_or_generate_vocab_inner(data_dir,
                                vocab_filename,
                                vocab_size,
                                generator,
                                max_subtoken_length=None,
                                reserved_tokens=None):
    """Inner implementation for vocab generators.

  Args:
    data_dir: The base directory where data and vocab files are stored. If None,
      then do not save the vocab even if it doesn't exist.
    vocab_filename: relative filename where vocab file is stored
    vocab_size: target size of the vocabulary constructed by SubwordTextEncoder
    generator: a generator that produces tokens from the vocabulary
    max_subtoken_length: an optional integer.  Set this to a finite value to
      avoid quadratic costs during vocab building.
    reserved_tokens: List of reserved tokens. `text_encoder.RESERVED_TOKENS`
      should be a prefix of `reserved_tokens`. If `None`, defaults to
      `RESERVED_TOKENS`.

  Returns:
    A SubwordTextEncoder vocabulary object.
  """
    if data_dir and vocab_filename:
        vocab_filepath = os.path.join(data_dir, vocab_filename)
        if tf.gfile.Exists(vocab_filepath):
            tf.logging.info("Found vocab file: %s", vocab_filepath)
            return text_encoder.SubwordTextEncoder(vocab_filepath)
    else:
        vocab_filepath = None

    tf.logging.info("Generating vocab file: %s", vocab_filepath)
    vocab = text_encoder.SubwordTextEncoder.build_from_generator(
        generator,
        vocab_size,
        max_subtoken_length=max_subtoken_length,
        reserved_tokens=reserved_tokens)

    if vocab_filepath:
        tf.gfile.MakeDirs(data_dir)
        vocab.store_to_file(vocab_filepath)

    return vocab
コード例 #4
0
    def test_save_and_reload_no_single_quotes(self):
        corpus = "the quick brown fox jumps over the lazy dog"
        token_counts = collections.Counter(corpus.split(" "))

        # Deliberately exclude some required encoding chars from the alphabet
        # and token list, making some strings unencodable.
        encoder = text_encoder.SubwordTextEncoder.build_to_target_size(
            100, token_counts, 2, 10)

        filename = os.path.join(self.test_temp_dir, "out.voc")
        encoder.store_to_file(filename, add_single_quotes=False)
        new_encoder = text_encoder.SubwordTextEncoder(filename)

        self.assertEqual(encoder._alphabet, new_encoder._alphabet)
        self.assertEqual(encoder.all_subtoken_strings,
                         new_encoder.all_subtoken_strings)
        self.assertEqual(encoder._subtoken_string_to_id,
                         new_encoder._subtoken_string_to_id)
        self.assertEqual(encoder._max_subtoken_len,
                         new_encoder._max_subtoken_len)
コード例 #5
0
  def testText2TextTmpDir(self):
    problem = Test1()
    problem.generate_data(self.tmp_dir, self.tmp_dir)
    vocab_file = os.path.join(self.tmp_dir, "vocab.test1.3.subwords")
    train_file = os.path.join(self.tmp_dir, "test1-train-00000-of-00001")
    eval_file = os.path.join(self.tmp_dir, "test1-dev-00000-of-00001")
    self.assertTrue(tf.gfile.Exists(vocab_file))
    self.assertTrue(tf.gfile.Exists(train_file))
    self.assertTrue(tf.gfile.Exists(eval_file))

    dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, self.tmp_dir)
    features = dataset.make_one_shot_iterator().get_next()

    examples = []
    exhausted = False
    with self.test_session() as sess:
      examples.append(sess.run(features))
      examples.append(sess.run(features))
      try:
        sess.run(features)
      except tf.errors.OutOfRangeError:
        exhausted = True

    self.assertTrue(exhausted)
    self.assertEqual(2, len(examples))

    self.assertNotEqual(
        list(examples[0]["inputs"]), list(examples[1]["inputs"]))

    example = examples[0]
    encoder = text_encoder.SubwordTextEncoder(vocab_file)
    inputs_encoded = list(example["inputs"])
    inputs_encoded.pop()  # rm EOS
    self.assertTrue(encoder.decode(inputs_encoded) in self.inputs)
    targets_encoded = list(example["targets"])
    targets_encoded.pop()  # rm EOS
    self.assertTrue(encoder.decode(targets_encoded) in self.targets)
コード例 #6
0
def main(args):
    encoder = text_encoder.SubwordTextEncoder()

    fns = read_fns_codesearchnet(args.data)
    encoder.build_from_generator(fns, args.min_count, args.num_iterations)
    encoder.store_to_file(args.output_filename)