def test_get_class_memberships(self): vocabulary = Vocabulary.from_file(self.classes_file, 'srilm-classes', oos_words=['yksitoista']) word_ids = numpy.array([ vocabulary.word_to_id['yksi'], vocabulary.word_to_id['kaksi'], vocabulary.word_to_id['kolme'], vocabulary.word_to_id['neljä'], vocabulary.word_to_id['viisi'], vocabulary.word_to_id['kuusi'], vocabulary.word_to_id['seitsemän'], vocabulary.word_to_id['kahdeksan'], vocabulary.word_to_id['yhdeksän'], vocabulary.word_to_id['kymmenen'], vocabulary.word_to_id['<s>'], vocabulary.word_to_id['</s>'], vocabulary.word_to_id['<unk>'] ]) class_ids, probs = vocabulary.get_class_memberships(word_ids) assert_equal(class_ids, vocabulary.word_id_to_class_id[word_ids]) assert_almost_equal(probs, [ 1.0, 0.999, 0.599 / (0.599 + 0.400), 0.400 / (0.599 + 0.400), 1.0, 0.281 / (0.281 + 0.226 + 0.262 + 0.228), 0.226 / (0.281 + 0.226 + 0.262 + 0.228), 0.262 / (0.281 + 0.226 + 0.262 + 0.228), 0.228 / (0.281 + 0.226 + 0.262 + 0.228), 1.0, 1.0, 0.001, 1.0 ]) word_counts = compute_word_counts([self.sentences3_file]) vocabulary.compute_probs(word_counts) class_ids, probs = vocabulary.get_class_memberships(word_ids) assert_almost_equal(probs, [ 1.0, 1.0 / 6.0, 0.5, 0.5, 1.0, 0.25, 0.25, 0.25, 0.25, 1.0, 1.0, 5.0 / 6.0, 1.0 ])
def test_compute_probs(self): self.classes_file.seek(0) vocabulary = Vocabulary.from_file(self.classes_file, 'srilm-classes') word_counts = compute_word_counts( [self.sentences1_file, self.sentences2_file]) vocabulary.compute_probs(word_counts) # 10 * <s> + 10 * </s> + 20 words. total_count = 40.0 word_id = vocabulary.word_to_id['yksi'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0) word_id = vocabulary.word_to_id['kaksi'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 2.0 / 12.0) word_id = vocabulary.word_to_id['kolme'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.5) word_id = vocabulary.word_to_id['neljä'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.5) word_id = vocabulary.word_to_id['viisi'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0) word_id = vocabulary.word_to_id['kuusi'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25) word_id = vocabulary.word_to_id['seitsemän'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25) word_id = vocabulary.word_to_id['kahdeksan'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25) word_id = vocabulary.word_to_id['yhdeksän'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25) word_id = vocabulary.word_to_id['kymmenen'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 2.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0) word_id = vocabulary.word_to_id['<s>'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 10.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0) word_id = vocabulary.word_to_id['</s>'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 10.0 / total_count) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 10.0 / 12.0) word_id = vocabulary.word_to_id['<unk>'] self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 0.0) self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
def test_bigram_statistics(self): self.sentences_file.seek(0) word_counts = compute_word_counts([self.sentences_file]) self.vocabulary = Vocabulary.from_word_counts(word_counts) self.sentences_file.seek(0) statistics = BigramStatistics([self.sentences_file], self.vocabulary) unigram_counts = statistics.unigram_counts vocabulary = self.vocabulary self.assertEqual(unigram_counts[vocabulary.word_to_id['a']], 13) self.assertEqual(unigram_counts[vocabulary.word_to_id['b']], 8) self.assertEqual(unigram_counts[vocabulary.word_to_id['c']], 8) self.assertEqual(unigram_counts[vocabulary.word_to_id['d']], 11) self.assertEqual(unigram_counts[vocabulary.word_to_id['e']], 15) self.assertEqual(unigram_counts[vocabulary.word_to_id['<unk>']], 0) self.assertEqual(unigram_counts[vocabulary.word_to_id['<s>']], 11) self.assertEqual(unigram_counts[vocabulary.word_to_id['</s>']], 11) bigram_counts = statistics.bigram_counts vocabulary = self.vocabulary a_id = vocabulary.word_to_id['a'] b_id = vocabulary.word_to_id['b'] self.assertEqual(bigram_counts[a_id, a_id], 3) self.assertEqual(bigram_counts[a_id, b_id], 2) self.assertEqual(bigram_counts[b_id, a_id], 1) self.assertEqual(bigram_counts[b_id, b_id], 0)
def test_compute_word_counts(self): self.sentences_file.seek(0) word_counts = compute_word_counts([self.sentences_file]) self.assertEqual(word_counts['a'], 13) self.assertEqual(word_counts['b'], 8) self.assertEqual(word_counts['c'], 8) self.assertEqual(word_counts['d'], 11) self.assertEqual(word_counts['e'], 15) self.assertEqual(word_counts['<s>'], 11) self.assertEqual(word_counts['</s>'], 11)
def setUp(self): script_path = os.path.dirname(os.path.realpath(__file__)) sentences_path = os.path.join(script_path, 'sentences.txt') self.sentences_file = open(sentences_path) self.num_classes = 2 word_counts = compute_word_counts([self.sentences_file]) self.vocabulary = Vocabulary.from_word_counts(word_counts, self.num_classes) self.sentences_file.seek(0) self.statistics = BigramStatistics([self.sentences_file], self.vocabulary)
def test_from_word_counts(self): self.sentences1_file.seek(0) word_counts = compute_word_counts([self.sentences1_file]) vocabulary = Vocabulary.from_word_counts(word_counts) self.assertEqual(vocabulary.num_words(), 10 + 3) self.assertEqual(vocabulary.num_shortlist_words(), 10 + 3) self.assertEqual(vocabulary.num_normal_classes, 10) self.assertEqual(vocabulary.num_classes(), 10 + 3) self.sentences1_file.seek(0) self.sentences2_file.seek(0) word_counts = compute_word_counts( [self.sentences1_file, self.sentences2_file]) vocabulary = Vocabulary.from_word_counts(word_counts, 3) self.assertEqual(vocabulary.num_words(), 10 + 3) self.assertEqual(vocabulary.num_shortlist_words(), 10 + 3) self.assertEqual(vocabulary.num_normal_classes, 3) self.assertEqual(vocabulary.num_classes(), 3 + 3) sos_id = vocabulary.word_to_id['<s>'] eos_id = vocabulary.word_to_id['</s>'] unk_id = vocabulary.word_to_id['<unk>'] self.assertEqual(sos_id, 10) self.assertEqual(eos_id, 11) self.assertEqual(unk_id, 12) self.assertEqual(vocabulary.word_id_to_class_id[sos_id], 3) self.assertEqual(vocabulary.word_id_to_class_id[eos_id], 4) self.assertEqual(vocabulary.word_id_to_class_id[unk_id], 5) word_ids = set() class_ids = set() for word in vocabulary.words(): if not word.startswith('<'): word_id = vocabulary.word_to_id[word] word_ids.add(word_id) class_ids.add(vocabulary.word_id_to_class_id[word_id]) self.assertEqual(word_ids, set(range(10))) self.assertEqual(class_ids, set(range(3)))
def test_from_state(self): self.classes_file.seek(0) vocabulary1 = Vocabulary.from_file(self.classes_file, 'srilm-classes') word_counts = compute_word_counts( [self.sentences1_file, self.sentences2_file]) vocabulary1.compute_probs(word_counts) f = h5py.File('in-memory.h5', driver='core', backing_store=False) vocabulary1.get_state(f) vocabulary2 = Vocabulary.from_state(f) self.assertTrue( numpy.array_equal(vocabulary1.id_to_word, vocabulary2.id_to_word)) self.assertDictEqual(vocabulary1.word_to_id, vocabulary2.word_to_id) self.assertTrue( numpy.array_equal(vocabulary1.word_id_to_class_id, vocabulary2.word_id_to_class_id)) self.assertListEqual(list(vocabulary1._word_classes), list(vocabulary2._word_classes)) self.assertTrue( numpy.array_equal(vocabulary1._unigram_probs, vocabulary2._unigram_probs))
def _read_vocabulary(args, state): """If ``state`` contains data, reads the vocabulary from the HDF5 state. Otherwise reads a vocabulary file or constructs the vocabulary from the training set and writes it to the HDF5 state. If the state does not contain data and --vocabulary argument is given, reads the vocabulary from the file given after the argument. The rest of the words in the training set will be added as out-of-shortlist words. If the state does not contain data and no vocabulary is given, constructs a vocabulary that contains all the training set words. In that case, --num-classes argument can be used to control the number of classes. :type args: argparse.Namespace :param args: a collection of command line arguments :type state: hdf5.File :param state: HDF5 file where the vocabulary should be saved :rtype: Vocabulary :returns: the created vocabulary """ if state.keys(): print("Reading vocabulary from existing network state.") sys.stdout.flush() result = Vocabulary.from_state(state) if not result.has_unigram_probs(): # This is for backward compatibility. Remove at some point. print("Computing unigram word probabilities from training set.") sys.stdout.flush() word_counts = compute_word_counts(args.training_set) shortlist_words = list(result.id_to_word) shortlist_set = set(shortlist_words) oos_words = [ x for x in word_counts.keys() if x not in shortlist_set ] result.id_to_word = numpy.asarray(shortlist_words + oos_words, dtype=object) result.word_to_id = { word: word_id for word_id, word in enumerate(result.id_to_word) } result.compute_probs(word_counts, update_class_probs=False) result.get_state(state) elif args.vocabulary is None: print("Constructing vocabulary from training set.") sys.stdout.flush() word_counts = compute_word_counts(args.training_set) result = Vocabulary.from_word_counts(word_counts, args.num_classes) result.get_state(state) else: print("Reading vocabulary from {}.".format(args.vocabulary)) sys.stdout.flush() word_counts = compute_word_counts(args.training_set) oos_words = word_counts.keys() with open(args.vocabulary, 'rt', encoding='utf-8') as vocab_file: result = Vocabulary.from_file(vocab_file, args.vocabulary_format, oos_words=oos_words) if args.vocabulary_format == 'classes': print("Computing class membership probabilities and unigram " "probabilities for out-of-shortlist words.") sys.stdout.flush() update_class_probs = True else: print( "Computing unigram probabilities for out-of-shortlist words.") sys.stdout.flush() update_class_probs = False result.compute_probs(word_counts, update_class_probs=update_class_probs) result.get_state(state) print("Number of words in vocabulary:", result.num_words()) print("Number of words in shortlist:", result.num_shortlist_words()) print("Number of word classes:", result.num_classes()) return result
def main(): parser = argparse.ArgumentParser(prog='wctool') argument_group = parser.add_argument_group("files") argument_group.add_argument( '--training-set', metavar='FILE', type=TextFileType('r'), nargs='+', required=True, help='text or .gz files containing training data (one sentence per ' 'line)') argument_group.add_argument( '--vocabulary', metavar='FILE', type=TextFileType('r'), default=None, help='text or .gz file containing a list of words to include in class ' 'forming, and possibly their initial classes') argument_group.add_argument( '--vocabulary-format', metavar='FORMAT', type=str, default='words', help='vocabulary format, one of "words" (one word per line, default), ' '"classes" (word and class ID per line), "srilm-classes" (class ' 'name, membership probability, and word per line)') argument_group.add_argument( '--output-file', metavar='FILE', type=TextFileType('w'), default='-', help='where to write the word classes (default stdout)') argument_group.add_argument( '--output-format', metavar='FORMAT', type=str, default='srilm-classes', help='format of the output file, one of "classes" (word and class ID ' 'per line), "srilm-classes" (default; class name, membership ' 'probability, and word per line)') argument_group.add_argument( '--output-frequency', metavar='N', type=int, default='1', help='save classes N times per optimization iteration (default 1)') argument_group = parser.add_argument_group("optimization") argument_group.add_argument( '--num-classes', metavar='N', type=int, default=2000, help='number of classes to form, if vocabulary is not specified ' '(default 2000)') argument_group.add_argument( '--method', metavar='NAME', type=str, default='bigram-theano', help='method for creating word classes, one of "bigram-theano", ' '"bigram-numpy" (default "bigram-theano")') argument_group = parser.add_argument_group("logging and debugging") argument_group.add_argument( '--log-file', metavar='FILE', type=str, default='-', help='path where to write log file (default is standard output)') argument_group.add_argument( '--log-level', metavar='LEVEL', type=str, default='info', help='minimum level of events to log, one of "debug", "info", "warn" ' '(default "info")') argument_group.add_argument( '--log-interval', metavar='N', type=int, default=1000, help='print statistics after every Nth word; quiet if less than one ' '(default 1000)') args = parser.parse_args() log_file = args.log_file log_level = getattr(logging, args.log_level.upper(), None) if not isinstance(log_level, int): raise ValueError("Invalid logging level requested: " + args.log_level) log_format = '%(asctime)s %(funcName)s: %(message)s' if args.log_file == '-': logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level) else: logging.basicConfig(filename=log_file, format=log_format, level=log_level) if args.vocabulary is None: word_counts = compute_word_counts(args.training_set) vocabulary = Vocabulary.from_word_counts(word_counts, args.num_classes) for subset_file in args.training_set: subset_file.seek(0) else: vocabulary = Vocabulary.from_file(args.vocabulary, args.vocabulary_format) print("Number of words in vocabulary:", vocabulary.num_shortlist_words()) print("Number of word classes:", vocabulary.num_classes()) print("Number of normal word classes:", vocabulary.num_normal_classes) logging.info("Reading word unigram and bigram statistics.") statistics = BigramStatistics(args.training_set, vocabulary) if args.method == 'bigram-theano': optimizer = TheanoBigramOptimizer(statistics, vocabulary) elif args.method == 'bigram-numpy': optimizer = NumpyBigramOptimizer(statistics, vocabulary) else: raise ValueError("Invalid method requested: " + args.method) iteration = 1 while True: logging.info("Starting iteration %d.", iteration) num_words = 0 num_moves = 0 for word in vocabulary.words(): start_time = time() num_words += 1 if optimizer.move_to_best_class(word): num_moves += 1 duration = time() - start_time if (args.log_interval >= 1) and \ (num_words % args.log_interval == 0): logging.info( "[%d] (%.1f %%) of iteration %d -- moves = %d, cost = %.2f, duration = %.1f ms", num_words, num_words / vocabulary.num_shortlist_words() * 100, iteration, num_moves, optimizer.log_likelihood(), duration * 100) if is_scheduled(num_words, args.output_frequency, vocabulary.num_shortlist_words()): save(optimizer, args.output_file, args.output_format) if num_moves == 0: break iteration += 1 logging.info("Optimization finished.") save(optimizer, args.output_file, args.output_format)