def test_from_file(self): self.vocabulary_file.seek(0) vocabulary = Vocabulary.from_file(self.vocabulary_file, 'words') self.assertEqual(vocabulary.num_words(), 10 + 3) self.assertEqual(vocabulary.num_shortlist_words(), 10 + 3) self.assertEqual(vocabulary.num_classes(), 10 + 3) oos_words = ['yksi', 'kaksi', 'yksitoista', 'kaksitoista'] self.vocabulary_file.seek(0) vocabulary = Vocabulary.from_file(self.vocabulary_file, 'words', oos_words=oos_words) self.assertEqual(vocabulary.num_words(), 12 + 3) self.assertEqual(vocabulary.num_shortlist_words(), 10 + 3) self.assertEqual(vocabulary.num_classes(), 10 + 3)
def test_get_oos_probs(self): oos_words = ['yksitoista', 'kaksitoista'] self.vocabulary_file.seek(0) vocabulary = Vocabulary.from_file(self.vocabulary_file, 'words', oos_words=oos_words) word_counts = { 'yksi': 1, 'kaksi': 2, 'kolme': 3, 'neljä': 4, 'viisi': 5, 'kuusi': 6, 'seitsemän': 7, 'kahdeksan': 8, 'yhdeksän': 9, 'kymmenen': 10, '<s>': 11, '</s>': 12, '<unk>': 13, 'yksitoista': 3, 'kaksitoista': 7 } vocabulary.compute_probs(word_counts) oos_logprobs = vocabulary.get_oos_probs() assert_almost_equal(oos_logprobs, [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.3, 0.7 ])
def test_get_class_memberships(self): vocabulary = Vocabulary.from_file(self.classes_file, 'srilm-classes', oos_words=['yksitoista']) word_ids = numpy.array([ vocabulary.word_to_id['yksi'], vocabulary.word_to_id['kaksi'], vocabulary.word_to_id['kolme'], vocabulary.word_to_id['neljä'], vocabulary.word_to_id['viisi'], vocabulary.word_to_id['kuusi'], vocabulary.word_to_id['seitsemän'], vocabulary.word_to_id['kahdeksan'], vocabulary.word_to_id['yhdeksän'], vocabulary.word_to_id['kymmenen'], vocabulary.word_to_id['<s>'], vocabulary.word_to_id['</s>'], vocabulary.word_to_id['<unk>'] ]) class_ids, probs = vocabulary.get_class_memberships(word_ids) assert_equal(class_ids, vocabulary.word_id_to_class_id[word_ids]) assert_almost_equal(probs, [ 1.0, 0.999, 0.599 / (0.599 + 0.400), 0.400 / (0.599 + 0.400), 1.0, 0.281 / (0.281 + 0.226 + 0.262 + 0.228), 0.226 / (0.281 + 0.226 + 0.262 + 0.228), 0.262 / (0.281 + 0.226 + 0.262 + 0.228), 0.228 / (0.281 + 0.226 + 0.262 + 0.228), 1.0, 1.0, 0.001, 1.0 ]) word_counts = compute_word_counts([self.sentences3_file]) vocabulary.compute_probs(word_counts) class_ids, probs = vocabulary.get_class_memberships(word_ids) assert_almost_equal(probs, [ 1.0, 1.0 / 6.0, 0.5, 0.5, 1.0, 0.25, 0.25, 0.25, 0.25, 1.0, 1.0, 5.0 / 6.0, 1.0 ])
def test_compute_probs(self): self.classes_file.seek(0) vocabulary = Vocabulary.from_file(self.classes_file, 'srilm-classes') word_counts = compute_word_counts( [self.sentences1_file, self.sentences2_file]) vocabulary.compute_probs(word_counts) # 10 * <s> + 10 * </s> + 20 words. total_count = 40.0 word_id = vocabulary.word_to_id['yksi'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0) word_id = vocabulary.word_to_id['kaksi'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 2.0 / 12.0) word_id = vocabulary.word_to_id['kolme'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.5) word_id = vocabulary.word_to_id['neljä'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.5) word_id = vocabulary.word_to_id['viisi'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0) word_id = vocabulary.word_to_id['kuusi'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25) word_id = vocabulary.word_to_id['seitsemän'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25) word_id = vocabulary.word_to_id['kahdeksan'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25) word_id = vocabulary.word_to_id['yhdeksän'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25) word_id = vocabulary.word_to_id['kymmenen'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0) word_id = vocabulary.word_to_id['<s>'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 10.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0) word_id = vocabulary.word_to_id['</s>'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 10.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 10.0 / 12.0) word_id = vocabulary.word_to_id['<unk>'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 0.0) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
def test_class_ids(self): self.classes_file.seek(0) vocabulary = Vocabulary.from_file(self.classes_file, 'srilm-classes') word_id = vocabulary.word_to_id['yksi'] yksi_class_id = vocabulary.word_id_to_class_id[word_id] word_id = vocabulary.word_to_id['kaksi'] kaksi_class_id = vocabulary.word_id_to_class_id[word_id] word_id = vocabulary.word_to_id['kolme'] kolme_class_id = vocabulary.word_id_to_class_id[word_id] word_id = vocabulary.word_to_id['neljä'] nelja_class_id = vocabulary.word_id_to_class_id[word_id] word_id = vocabulary.word_to_id['</s>'] eos_class_id = vocabulary.word_id_to_class_id[word_id] self.assertNotEqual(yksi_class_id, kaksi_class_id) self.assertEqual(kolme_class_id, nelja_class_id) self.assertNotEqual(kolme_class_id, eos_class_id) self.assertEqual(kaksi_class_id, eos_class_id)
def test_from_state(self): self.classes_file.seek(0) vocabulary1 = Vocabulary.from_file(self.classes_file, 'srilm-classes') word_counts = compute_word_counts( [self.sentences1_file, self.sentences2_file]) vocabulary1.compute_probs(word_counts) f = h5py.File('in-memory.h5', driver='core', backing_store=False) vocabulary1.get_state(f) vocabulary2 = Vocabulary.from_state(f) self.assertTrue( numpy.array_equal(vocabulary1.id_to_word, vocabulary2.id_to_word)) self.assertDictEqual(vocabulary1.word_to_id, vocabulary2.word_to_id) self.assertTrue( numpy.array_equal(vocabulary1.word_id_to_class_id, vocabulary2.word_id_to_class_id)) self.assertListEqual(list(vocabulary1._word_classes), list(vocabulary2._word_classes)) self.assertTrue( numpy.array_equal(vocabulary1._unigram_probs, vocabulary2._unigram_probs))
def main(): parser = argparse.ArgumentParser(prog='wctool') argument_group = parser.add_argument_group("files") argument_group.add_argument( '--training-set', metavar='FILE', type=TextFileType('r'), nargs='+', required=True, help='text or .gz files containing training data (one sentence per ' 'line)') argument_group.add_argument( '--vocabulary', metavar='FILE', type=TextFileType('r'), default=None, help='text or .gz file containing a list of words to include in class ' 'forming, and possibly their initial classes') argument_group.add_argument( '--vocabulary-format', metavar='FORMAT', type=str, default='words', help='vocabulary format, one of "words" (one word per line, default), ' '"classes" (word and class ID per line), "srilm-classes" (class ' 'name, membership probability, and word per line)') argument_group.add_argument( '--output-file', metavar='FILE', type=TextFileType('w'), default='-', help='where to write the word classes (default stdout)') argument_group.add_argument( '--output-format', metavar='FORMAT', type=str, default='srilm-classes', help='format of the output file, one of "classes" (word and class ID ' 'per line), "srilm-classes" (default; class name, membership ' 'probability, and word per line)') argument_group.add_argument( '--output-frequency', metavar='N', type=int, default='1', help='save classes N times per optimization iteration (default 1)') argument_group = parser.add_argument_group("optimization") argument_group.add_argument( '--num-classes', metavar='N', type=int, default=2000, help='number of classes to form, if vocabulary is not specified ' '(default 2000)') argument_group.add_argument( '--method', metavar='NAME', type=str, default='bigram-theano', help='method for creating word classes, one of "bigram-theano", ' '"bigram-numpy" (default "bigram-theano")') argument_group = parser.add_argument_group("logging and debugging") argument_group.add_argument( '--log-file', metavar='FILE', type=str, default='-', help='path where to write log file (default is standard output)') argument_group.add_argument( '--log-level', metavar='LEVEL', type=str, default='info', help='minimum level of events to log, one of "debug", "info", "warn" ' '(default "info")') argument_group.add_argument( '--log-interval', metavar='N', type=int, default=1000, help='print statistics after every Nth word; quiet if less than one ' '(default 1000)') args = parser.parse_args() log_file = args.log_file log_level = getattr(logging, args.log_level.upper(), None) if not isinstance(log_level, int): raise ValueError("Invalid logging level requested: " + args.log_level) log_format = '%(asctime)s %(funcName)s: %(message)s' if args.log_file == '-': logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level) else: logging.basicConfig(filename=log_file, format=log_format, level=log_level) if args.vocabulary is None: word_counts = compute_word_counts(args.training_set) vocabulary = Vocabulary.from_word_counts(word_counts, args.num_classes) for subset_file in args.training_set: subset_file.seek(0) else: vocabulary = Vocabulary.from_file(args.vocabulary, args.vocabulary_format) print("Number of words in vocabulary:", vocabulary.num_shortlist_words()) print("Number of word classes:", vocabulary.num_classes()) print("Number of normal word classes:", vocabulary.num_normal_classes) logging.info("Reading word unigram and bigram statistics.") statistics = BigramStatistics(args.training_set, vocabulary) if args.method == 'bigram-theano': optimizer = TheanoBigramOptimizer(statistics, vocabulary) elif args.method == 'bigram-numpy': optimizer = NumpyBigramOptimizer(statistics, vocabulary) else: raise ValueError("Invalid method requested: " + args.method) iteration = 1 while True: logging.info("Starting iteration %d.", iteration) num_words = 0 num_moves = 0 for word in vocabulary.words(): start_time = time() num_words += 1 if optimizer.move_to_best_class(word): num_moves += 1 duration = time() - start_time if (args.log_interval >= 1) and \ (num_words % args.log_interval == 0): logging.info( "[%d] (%.1f %%) of iteration %d -- moves = %d, cost = %.2f, duration = %.1f ms", num_words, num_words / vocabulary.num_shortlist_words() * 100, iteration, num_moves, optimizer.log_likelihood(), duration * 100) if is_scheduled(num_words, args.output_frequency, vocabulary.num_shortlist_words()): save(optimizer, args.output_file, args.output_format) if num_moves == 0: break iteration += 1 logging.info("Optimization finished.") save(optimizer, args.output_file, args.output_format)