class LearnBinarySearchTest(parameterized.TestCase):
    @parameterized.named_parameters(
        {
            'testcase_name':
            'ReachesVocabSize',
            'word_counts': [('apple', 2), ('peach', 1), ('pear', 1)],
            'lower':
            1,
            'upper':
            10,
            'delta':
            0,
            'expected_vocab':
            ['a', 'c', 'e', 'h', 'l', 'p', 'r', 'apple', 'peach', 'pear'],
            'params':
            learner.Params(upper_thresh=4,
                           lower_thresh=1,
                           num_iterations=4,
                           max_input_tokens=1000,
                           max_token_length=50,
                           max_unique_chars=50,
                           vocab_size=10,
                           slack_ratio=0,
                           include_joiner_token=False,
                           joiner='##',
                           reserved_tokens=[])
        }, {
            'testcase_name':
            'VocabSizeWithinSlack',
            'word_counts': [('apple', 2), ('peach', 1), ('pear', 1),
                            ('app', 2)],
            'lower':
            1,
            'upper':
            10,
            'delta':
            6,
            'expected_vocab': ['a', 'c', 'e', 'h', 'l', 'p', 'r'],
            'params':
            learner.Params(upper_thresh=4,
                           lower_thresh=1,
                           num_iterations=4,
                           max_input_tokens=1000,
                           max_token_length=50,
                           max_unique_chars=50,
                           vocab_size=12,
                           slack_ratio=0.5,
                           include_joiner_token=False,
                           joiner='##',
                           reserved_tokens=[])
        })
    def testBinarySearch(self, word_counts, lower, upper, delta,
                         expected_vocab, params):
        vocab = learner.learn_binary_search(word_counts, lower, upper, params)
        self.assertAlmostEqual(len(vocab), params.vocab_size, delta=delta)
        self.assertLessEqual(len(vocab), params.vocab_size)
        self.assertEqual(vocab, expected_vocab)
def main(_):
  # Read in wordcount file.
  with open(FLAGS.input_path) as wordcount_file:
    word_counts = [(line.split()[0], int(line.split()[1]))
                   for line in wordcount_file]

  # Add in padding tokens.
  reserved_tokens = FLAGS.reserved_tokens
  if FLAGS.num_pad_tokens:
    padded_tokens = ['<pad>']
    padded_tokens += ['<pad%d>' % i for i in range(1, FLAGS.num_pad_tokens)]
    reserved_tokens = padded_tokens + reserved_tokens

  params = learner.Params(FLAGS.upper_thresh, FLAGS.lower_thresh,
                          FLAGS.num_iterations, FLAGS.max_input_tokens,
                          FLAGS.max_token_length, FLAGS.max_unique_chars,
                          FLAGS.vocab_size, FLAGS.slack_ratio,
                          FLAGS.include_joiner_token, FLAGS.joiner,
                          reserved_tokens)

  vocab = learner.learn(word_counts, params)
  vocab = ''.join([line + '\n' for line in vocab])

  # Write vocab to file.
  with open(FLAGS.output_path, 'w') as vocab_file:
    vocab_file.write(vocab)
示例#3
0
def main(_):
    # Define schema.
    raw_metadata = dataset_metadata.DatasetMetadata(
        dataset_schema.from_feature_spec({
            'text':
            tf.FixedLenFeature([], tf.string),
            'language_code':
            tf.FixedLenFeature([], tf.string),
        }))

    # Add in padding tokens.
    reserved_tokens = FLAGS.reserved_tokens
    if FLAGS.num_pad_tokens:
        padded_tokens = ['<pad>']
        padded_tokens += [
            '<pad%d>' % i for i in range(1, FLAGS.num_pad_tokens)
        ]
        reserved_tokens = padded_tokens + reserved_tokens

    params = learner.Params(FLAGS.upper_thresh, FLAGS.lower_thresh,
                            FLAGS.num_iterations, FLAGS.max_input_tokens,
                            FLAGS.max_token_length, FLAGS.max_unique_chars,
                            FLAGS.vocab_size, FLAGS.slack_ratio,
                            FLAGS.include_joiner_token, FLAGS.joiner,
                            reserved_tokens)

    generate_vocab(FLAGS.data_file, FLAGS.vocab_file, FLAGS.metrics_file,
                   raw_metadata, params)
class LearnWithThreshTest(parameterized.TestCase):
    @parameterized.named_parameters(
        {
            'testcase_name':
            'LearnWithOneIteration',
            'word_counts': [('apple', 1), ('app', 1)],
            'thresh':
            1,
            'expected_vocab':
            ['a', 'e', 'l', 'p', 'app', 'apple', 'le', 'ple', 'pp', 'pple'],
            'params':
            learner.Params(upper_thresh=4,
                           lower_thresh=1,
                           num_iterations=1,
                           max_input_tokens=1000,
                           max_token_length=50,
                           max_unique_chars=5,
                           vocab_size=10,
                           slack_ratio=0,
                           include_joiner_token=False,
                           joiner='##',
                           reserved_tokens=[])
        }, {
            'testcase_name':
            'LearnWithTwoIterations',
            'word_counts': [('apple', 1), ('app', 1)],
            'thresh':
            1,
            'expected_vocab': ['a', 'e', 'l', 'p', 'app', 'apple'],
            'params':
            learner.Params(upper_thresh=4,
                           lower_thresh=1,
                           num_iterations=2,
                           max_input_tokens=1000,
                           max_token_length=50,
                           max_unique_chars=5,
                           vocab_size=10,
                           slack_ratio=0,
                           include_joiner_token=False,
                           joiner='##',
                           reserved_tokens=[])
        }, {
            'testcase_name':
            'LearnWithHigherThresh',
            'word_counts': [('apple', 1), ('app', 2)],
            'thresh':
            2,
            'expected_vocab': ['a', 'e', 'l', 'p', 'app', 'pp'],
            'params':
            learner.Params(upper_thresh=4,
                           lower_thresh=1,
                           num_iterations=1,
                           max_input_tokens=1000,
                           max_token_length=50,
                           max_unique_chars=5,
                           vocab_size=10,
                           slack_ratio=0,
                           include_joiner_token=False,
                           joiner='##',
                           reserved_tokens=[])
        })
    def testLearnWithThresh(self, word_counts, thresh, expected_vocab, params):
        vocab = learner.learn_with_thresh(word_counts, thresh, params)
        self.assertEqual(vocab, expected_vocab)