Ejemplo n.º 1
0
    def test_bigram_statistics(self):
        self.sentences_file.seek(0)
        word_counts = compute_word_counts([self.sentences_file])
        self.vocabulary = Vocabulary.from_word_counts(word_counts)
        self.sentences_file.seek(0)
        statistics = BigramStatistics([self.sentences_file], self.vocabulary)

        unigram_counts = statistics.unigram_counts
        vocabulary = self.vocabulary
        self.assertEqual(unigram_counts[vocabulary.word_to_id['a']], 13)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['b']], 8)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['c']], 8)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['d']], 11)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['e']], 15)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['<unk>']], 0)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['<s>']], 11)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['</s>']], 11)

        bigram_counts = statistics.bigram_counts
        vocabulary = self.vocabulary
        a_id = vocabulary.word_to_id['a']
        b_id = vocabulary.word_to_id['b']
        self.assertEqual(bigram_counts[a_id, a_id], 3)
        self.assertEqual(bigram_counts[a_id, b_id], 2)
        self.assertEqual(bigram_counts[b_id, a_id], 1)
        self.assertEqual(bigram_counts[b_id, b_id], 0)
Ejemplo n.º 2
0
def _read_vocabulary(args, state):
    """If ``state`` contains data, reads the vocabulary from the HDF5 state.
    Otherwise reads a vocabulary file or constructs the vocabulary from the
    training set and writes it to the HDF5 state.

    If the state does not contain data and --vocabulary argument is given, reads
    the vocabulary from the file given after the argument. The rest of the words
    in the training set will be added as out-of-shortlist words.

    If the state does not contain data and no vocabulary is given, constructs a
    vocabulary that contains all the training set words. In that case,
    --num-classes argument can be used to control the number of classes.

    :type args: argparse.Namespace
    :param args: a collection of command line arguments

    :type state: hdf5.File
    :param state: HDF5 file where the vocabulary should be saved

    :rtype: Vocabulary
    :returns: the created vocabulary
    """

    if state.keys():
        print("Reading vocabulary from existing network state.")
        sys.stdout.flush()
        result = Vocabulary.from_state(state)
        if not result.has_unigram_probs():
            # This is for backward compatibility. Remove at some point.
            print("Computing unigram word probabilities from training set.")
            sys.stdout.flush()
            word_counts = compute_word_counts(args.training_set)
            shortlist_words = list(result.id_to_word)
            shortlist_set = set(shortlist_words)
            oos_words = [
                x for x in word_counts.keys() if x not in shortlist_set
            ]
            result.id_to_word = numpy.asarray(shortlist_words + oos_words,
                                              dtype=object)
            result.word_to_id = {
                word: word_id
                for word_id, word in enumerate(result.id_to_word)
            }
            result.compute_probs(word_counts, update_class_probs=False)
            result.get_state(state)

    elif args.vocabulary is None:
        print("Constructing vocabulary from training set.")
        sys.stdout.flush()
        word_counts = compute_word_counts(args.training_set)
        result = Vocabulary.from_word_counts(word_counts, args.num_classes)
        result.get_state(state)

    else:
        print("Reading vocabulary from {}.".format(args.vocabulary))
        sys.stdout.flush()
        word_counts = compute_word_counts(args.training_set)
        oos_words = word_counts.keys()
        with open(args.vocabulary, 'rt', encoding='utf-8') as vocab_file:
            result = Vocabulary.from_file(vocab_file,
                                          args.vocabulary_format,
                                          oos_words=oos_words)

        if args.vocabulary_format == 'classes':
            print("Computing class membership probabilities and unigram "
                  "probabilities for out-of-shortlist words.")
            sys.stdout.flush()
            update_class_probs = True
        else:
            print(
                "Computing unigram probabilities for out-of-shortlist words.")
            sys.stdout.flush()
            update_class_probs = False
        result.compute_probs(word_counts,
                             update_class_probs=update_class_probs)
        result.get_state(state)

    print("Number of words in vocabulary:", result.num_words())
    print("Number of words in shortlist:", result.num_shortlist_words())
    print("Number of word classes:", result.num_classes())
    return result
Ejemplo n.º 3
0
    def test_decode(self):
        vocabulary = Vocabulary.from_word_counts({
            'TO': 1,
            'AND': 1,
            'IT': 1,
            'BUT': 1,
            'A.': 1,
            'IN': 1,
            'A': 1,
            'AT': 1,
            'THE': 1,
            'E.': 1,
            "DIDN'T": 1,
            'ELABORATE': 1
        })
        projection_vector = tensor.ones(shape=(vocabulary.num_words(), ),
                                        dtype=theano.config.floatX)
        projection_vector *= 0.05
        network = DummyNetwork(vocabulary, projection_vector)

        decoding_options = {
            'nnlm_weight': 0.0,
            'lm_scale': None,
            'wi_penalty': None,
            'ignore_unk': False,
            'unk_penalty': None,
            'linear_interpolation': True,
            'max_tokens_per_node': None,
            'beam': None,
            'recombination_order': None
        }
        decoder = LatticeDecoder(network, decoding_options)
        tokens = decoder.decode(self.lattice)

        # Compare tokens to n-best list given by SRILM lattice-tool.
        log_scale = math.log(10)

        print()
        for token in tokens:
            print(token.ac_logprob / log_scale,
                  token.lat_lm_logprob / log_scale,
                  token.total_logprob / log_scale,
                  ' '.join(vocabulary.id_to_word[token.history]))

        all_paths = [
            "<s> IT DIDN'T ELABORATE </s>", "<s> BUT IT DIDN'T ELABORATE </s>",
            "<s> THE DIDN'T ELABORATE </s>",
            "<s> AND IT DIDN'T ELABORATE </s>", "<s> E. DIDN'T ELABORATE </s>",
            "<s> IN IT DIDN'T ELABORATE </s>", "<s> A DIDN'T ELABORATE </s>",
            "<s> AT IT DIDN'T ELABORATE </s>",
            "<s> IT IT DIDN'T ELABORATE </s>",
            "<s> TO IT DIDN'T ELABORATE </s>",
            "<s> A. IT DIDN'T ELABORATE </s>", "<s> A IT DIDN'T ELABORATE </s>"
        ]
        paths = [
            ' '.join(vocabulary.id_to_word[token.history]) for token in tokens
        ]
        self.assertListEqual(paths, all_paths)

        token = tokens[0]
        history = ' '.join(vocabulary.id_to_word[token.history])
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8686.28,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -94.3896,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4)

        token = tokens[1]
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8743.96,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -111.488,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)

        token = tokens[-1]
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8696.26,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -178.00,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)
Ejemplo n.º 4
0
    def test_decode(self):
        vocabulary = Vocabulary.from_word_counts({
            'TO': 1,
            'AND': 1,
            'IT': 1,
            'BUT': 1,
            'A.': 1,
            'IN': 1,
            'A': 1,
            'AT': 1,
            'THE': 1,
            'E.': 1,
            "DIDN'T": 1,
            'ELABORATE': 1})
        projection_vector = tensor.ones(shape=(vocabulary.num_words(),),
                                        dtype=theano.config.floatX)
        projection_vector *= 0.05
        network = DummyNetwork(vocabulary, projection_vector)

        decoding_options = {
            'nnlm_weight': 0.0,
            'lm_scale': None,
            'wi_penalty': None,
            'ignore_unk': False,
            'unk_penalty': None,
            'linear_interpolation': True,
            'max_tokens_per_node': None,
            'beam': None,
            'recombination_order': None
        }
        decoder = LatticeDecoder(network, decoding_options)
        tokens = decoder.decode(self.lattice)

        # Compare tokens to n-best list given by SRILM lattice-tool.
        log_scale = math.log(10)

        print()
        for token in tokens:
            print(token.ac_logprob / log_scale,
                  token.lat_lm_logprob / log_scale,
                  token.total_logprob / log_scale,
                  ' '.join(vocabulary.id_to_word[token.history]))

        all_paths = ["<s> IT DIDN'T ELABORATE </s>",
                     "<s> BUT IT DIDN'T ELABORATE </s>",
                     "<s> THE DIDN'T ELABORATE </s>",
                     "<s> AND IT DIDN'T ELABORATE </s>",
                     "<s> E. DIDN'T ELABORATE </s>",
                     "<s> IN IT DIDN'T ELABORATE </s>",
                     "<s> A DIDN'T ELABORATE </s>",
                     "<s> AT IT DIDN'T ELABORATE </s>",
                     "<s> IT IT DIDN'T ELABORATE </s>",
                     "<s> TO IT DIDN'T ELABORATE </s>",
                     "<s> A. IT DIDN'T ELABORATE </s>",
                     "<s> A IT DIDN'T ELABORATE </s>"]
        paths = [' '.join(vocabulary.id_to_word[token.history])
                 for token in tokens]
        self.assertListEqual(paths, all_paths)

        token = tokens[0]
        history = ' '.join(vocabulary.id_to_word[token.history])
        self.assertAlmostEqual(token.ac_logprob / log_scale, -8686.28, places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -94.3896, places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4)

        token = tokens[1]
        self.assertAlmostEqual(token.ac_logprob / log_scale, -8743.96, places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -111.488, places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)

        token = tokens[-1]
        self.assertAlmostEqual(token.ac_logprob / log_scale, -8696.26, places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -178.00, places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)
Ejemplo n.º 5
0
    def test_decode(self):
        vocabulary = Vocabulary.from_word_counts({
            'to': 1,
            'and': 1,
            'it': 1,
            'but': 1,
            'a.': 1,
            'in': 1,
            'a': 1,
            'at': 1,
            'the': 1,
            "didn't": 1,
            'elaborate': 1
        })
        projection_vector = tensor.ones(
            shape=(vocabulary.num_shortlist_words(), ),
            dtype=theano.config.floatX)
        projection_vector *= 0.05
        network = DummyNetwork(vocabulary, projection_vector)

        decoding_options = {
            'nnlm_weight': 0.0,
            'lm_scale': None,
            'wi_penalty': None,
            'unk_penalty': None,
            'use_shortlist': False,
            'unk_from_lattice': False,
            'linear_interpolation': True,
            'max_tokens_per_node': None,
            'beam': None,
            'recombination_order': 20
        }
        decoder = LatticeDecoder(network, decoding_options)
        tokens = decoder.decode(self.lattice)[0]

        # Compare tokens to n-best list given by SRILM lattice-tool.
        log_scale = math.log(10)

        print()
        for token in tokens:
            print(token.ac_logprob / log_scale,
                  token.lat_lm_logprob / log_scale,
                  token.total_logprob / log_scale,
                  ' '.join(token.history_words(vocabulary)))

        all_paths = [
            "<s> it didn't elaborate </s>", "<s> but it didn't elaborate </s>",
            "<s> the didn't elaborate </s>",
            "<s> and it didn't elaborate </s>", "<s> e. didn't elaborate </s>",
            "<s> in it didn't elaborate </s>", "<s> a didn't elaborate </s>",
            "<s> at it didn't elaborate </s>",
            "<s> it it didn't elaborate </s>",
            "<s> to it didn't elaborate </s>",
            "<s> a. it didn't elaborate </s>", "<s> a it didn't elaborate </s>"
        ]
        paths = [' '.join(token.history_words(vocabulary)) for token in tokens]
        self.assertListEqual(paths, all_paths)

        token = tokens[0]
        history = ' '.join(token.history_words(vocabulary))
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8686.28,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -94.3896,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4)

        token = tokens[1]
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8743.96,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -111.488,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)

        token = tokens[-1]
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8696.26,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -178.00,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)