def test_bigram_statistics(self): self.sentences_file.seek(0) word_counts = compute_word_counts([self.sentences_file]) self.vocabulary = Vocabulary.from_word_counts(word_counts) self.sentences_file.seek(0) statistics = BigramStatistics([self.sentences_file], self.vocabulary) unigram_counts = statistics.unigram_counts vocabulary = self.vocabulary self.assertEqual(unigram_counts[vocabulary.word_to_id['a']], 13) self.assertEqual(unigram_counts[vocabulary.word_to_id['b']], 8) self.assertEqual(unigram_counts[vocabulary.word_to_id['c']], 8) self.assertEqual(unigram_counts[vocabulary.word_to_id['d']], 11) self.assertEqual(unigram_counts[vocabulary.word_to_id['e']], 15) self.assertEqual(unigram_counts[vocabulary.word_to_id['<unk>']], 0) self.assertEqual(unigram_counts[vocabulary.word_to_id['<s>']], 11) self.assertEqual(unigram_counts[vocabulary.word_to_id['</s>']], 11) bigram_counts = statistics.bigram_counts vocabulary = self.vocabulary a_id = vocabulary.word_to_id['a'] b_id = vocabulary.word_to_id['b'] self.assertEqual(bigram_counts[a_id, a_id], 3) self.assertEqual(bigram_counts[a_id, b_id], 2) self.assertEqual(bigram_counts[b_id, a_id], 1) self.assertEqual(bigram_counts[b_id, b_id], 0)
def _read_vocabulary(args, state): """If ``state`` contains data, reads the vocabulary from the HDF5 state. Otherwise reads a vocabulary file or constructs the vocabulary from the training set and writes it to the HDF5 state. If the state does not contain data and --vocabulary argument is given, reads the vocabulary from the file given after the argument. The rest of the words in the training set will be added as out-of-shortlist words. If the state does not contain data and no vocabulary is given, constructs a vocabulary that contains all the training set words. In that case, --num-classes argument can be used to control the number of classes. :type args: argparse.Namespace :param args: a collection of command line arguments :type state: hdf5.File :param state: HDF5 file where the vocabulary should be saved :rtype: Vocabulary :returns: the created vocabulary """ if state.keys(): print("Reading vocabulary from existing network state.") sys.stdout.flush() result = Vocabulary.from_state(state) if not result.has_unigram_probs(): # This is for backward compatibility. Remove at some point. print("Computing unigram word probabilities from training set.") sys.stdout.flush() word_counts = compute_word_counts(args.training_set) shortlist_words = list(result.id_to_word) shortlist_set = set(shortlist_words) oos_words = [ x for x in word_counts.keys() if x not in shortlist_set ] result.id_to_word = numpy.asarray(shortlist_words + oos_words, dtype=object) result.word_to_id = { word: word_id for word_id, word in enumerate(result.id_to_word) } result.compute_probs(word_counts, update_class_probs=False) result.get_state(state) elif args.vocabulary is None: print("Constructing vocabulary from training set.") sys.stdout.flush() word_counts = compute_word_counts(args.training_set) result = Vocabulary.from_word_counts(word_counts, args.num_classes) result.get_state(state) else: print("Reading vocabulary from {}.".format(args.vocabulary)) sys.stdout.flush() word_counts = compute_word_counts(args.training_set) oos_words = word_counts.keys() with open(args.vocabulary, 'rt', encoding='utf-8') as vocab_file: result = Vocabulary.from_file(vocab_file, args.vocabulary_format, oos_words=oos_words) if args.vocabulary_format == 'classes': print("Computing class membership probabilities and unigram " "probabilities for out-of-shortlist words.") sys.stdout.flush() update_class_probs = True else: print( "Computing unigram probabilities for out-of-shortlist words.") sys.stdout.flush() update_class_probs = False result.compute_probs(word_counts, update_class_probs=update_class_probs) result.get_state(state) print("Number of words in vocabulary:", result.num_words()) print("Number of words in shortlist:", result.num_shortlist_words()) print("Number of word classes:", result.num_classes()) return result
def test_decode(self): vocabulary = Vocabulary.from_word_counts({ 'TO': 1, 'AND': 1, 'IT': 1, 'BUT': 1, 'A.': 1, 'IN': 1, 'A': 1, 'AT': 1, 'THE': 1, 'E.': 1, "DIDN'T": 1, 'ELABORATE': 1 }) projection_vector = tensor.ones(shape=(vocabulary.num_words(), ), dtype=theano.config.floatX) projection_vector *= 0.05 network = DummyNetwork(vocabulary, projection_vector) decoding_options = { 'nnlm_weight': 0.0, 'lm_scale': None, 'wi_penalty': None, 'ignore_unk': False, 'unk_penalty': None, 'linear_interpolation': True, 'max_tokens_per_node': None, 'beam': None, 'recombination_order': None } decoder = LatticeDecoder(network, decoding_options) tokens = decoder.decode(self.lattice) # Compare tokens to n-best list given by SRILM lattice-tool. log_scale = math.log(10) print() for token in tokens: print(token.ac_logprob / log_scale, token.lat_lm_logprob / log_scale, token.total_logprob / log_scale, ' '.join(vocabulary.id_to_word[token.history])) all_paths = [ "<s> IT DIDN'T ELABORATE </s>", "<s> BUT IT DIDN'T ELABORATE </s>", "<s> THE DIDN'T ELABORATE </s>", "<s> AND IT DIDN'T ELABORATE </s>", "<s> E. DIDN'T ELABORATE </s>", "<s> IN IT DIDN'T ELABORATE </s>", "<s> A DIDN'T ELABORATE </s>", "<s> AT IT DIDN'T ELABORATE </s>", "<s> IT IT DIDN'T ELABORATE </s>", "<s> TO IT DIDN'T ELABORATE </s>", "<s> A. IT DIDN'T ELABORATE </s>", "<s> A IT DIDN'T ELABORATE </s>" ] paths = [ ' '.join(vocabulary.id_to_word[token.history]) for token in tokens ] self.assertListEqual(paths, all_paths) token = tokens[0] history = ' '.join(vocabulary.id_to_word[token.history]) self.assertAlmostEqual(token.ac_logprob / log_scale, -8686.28, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -94.3896, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4) token = tokens[1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8743.96, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -111.488, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5) token = tokens[-1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8696.26, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -178.00, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)
def test_decode(self): vocabulary = Vocabulary.from_word_counts({ 'TO': 1, 'AND': 1, 'IT': 1, 'BUT': 1, 'A.': 1, 'IN': 1, 'A': 1, 'AT': 1, 'THE': 1, 'E.': 1, "DIDN'T": 1, 'ELABORATE': 1}) projection_vector = tensor.ones(shape=(vocabulary.num_words(),), dtype=theano.config.floatX) projection_vector *= 0.05 network = DummyNetwork(vocabulary, projection_vector) decoding_options = { 'nnlm_weight': 0.0, 'lm_scale': None, 'wi_penalty': None, 'ignore_unk': False, 'unk_penalty': None, 'linear_interpolation': True, 'max_tokens_per_node': None, 'beam': None, 'recombination_order': None } decoder = LatticeDecoder(network, decoding_options) tokens = decoder.decode(self.lattice) # Compare tokens to n-best list given by SRILM lattice-tool. log_scale = math.log(10) print() for token in tokens: print(token.ac_logprob / log_scale, token.lat_lm_logprob / log_scale, token.total_logprob / log_scale, ' '.join(vocabulary.id_to_word[token.history])) all_paths = ["<s> IT DIDN'T ELABORATE </s>", "<s> BUT IT DIDN'T ELABORATE </s>", "<s> THE DIDN'T ELABORATE </s>", "<s> AND IT DIDN'T ELABORATE </s>", "<s> E. DIDN'T ELABORATE </s>", "<s> IN IT DIDN'T ELABORATE </s>", "<s> A DIDN'T ELABORATE </s>", "<s> AT IT DIDN'T ELABORATE </s>", "<s> IT IT DIDN'T ELABORATE </s>", "<s> TO IT DIDN'T ELABORATE </s>", "<s> A. IT DIDN'T ELABORATE </s>", "<s> A IT DIDN'T ELABORATE </s>"] paths = [' '.join(vocabulary.id_to_word[token.history]) for token in tokens] self.assertListEqual(paths, all_paths) token = tokens[0] history = ' '.join(vocabulary.id_to_word[token.history]) self.assertAlmostEqual(token.ac_logprob / log_scale, -8686.28, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -94.3896, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4) token = tokens[1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8743.96, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -111.488, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5) token = tokens[-1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8696.26, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -178.00, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)
def test_decode(self): vocabulary = Vocabulary.from_word_counts({ 'to': 1, 'and': 1, 'it': 1, 'but': 1, 'a.': 1, 'in': 1, 'a': 1, 'at': 1, 'the': 1, "didn't": 1, 'elaborate': 1 }) projection_vector = tensor.ones( shape=(vocabulary.num_shortlist_words(), ), dtype=theano.config.floatX) projection_vector *= 0.05 network = DummyNetwork(vocabulary, projection_vector) decoding_options = { 'nnlm_weight': 0.0, 'lm_scale': None, 'wi_penalty': None, 'unk_penalty': None, 'use_shortlist': False, 'unk_from_lattice': False, 'linear_interpolation': True, 'max_tokens_per_node': None, 'beam': None, 'recombination_order': 20 } decoder = LatticeDecoder(network, decoding_options) tokens = decoder.decode(self.lattice)[0] # Compare tokens to n-best list given by SRILM lattice-tool. log_scale = math.log(10) print() for token in tokens: print(token.ac_logprob / log_scale, token.lat_lm_logprob / log_scale, token.total_logprob / log_scale, ' '.join(token.history_words(vocabulary))) all_paths = [ "<s> it didn't elaborate </s>", "<s> but it didn't elaborate </s>", "<s> the didn't elaborate </s>", "<s> and it didn't elaborate </s>", "<s> e. didn't elaborate </s>", "<s> in it didn't elaborate </s>", "<s> a didn't elaborate </s>", "<s> at it didn't elaborate </s>", "<s> it it didn't elaborate </s>", "<s> to it didn't elaborate </s>", "<s> a. it didn't elaborate </s>", "<s> a it didn't elaborate </s>" ] paths = [' '.join(token.history_words(vocabulary)) for token in tokens] self.assertListEqual(paths, all_paths) token = tokens[0] history = ' '.join(token.history_words(vocabulary)) self.assertAlmostEqual(token.ac_logprob / log_scale, -8686.28, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -94.3896, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4) token = tokens[1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8743.96, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -111.488, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5) token = tokens[-1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8696.26, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -178.00, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)