def _append_word(self, tokens, target_word):
        """Appends a word to each of the given tokens, and updates their scores.

        :type tokens: list of LatticeDecoder.Tokens
        :param tokens: input tokens

        :type target_word: int or str
        :param target_word: word ID or word to be appended to the existing
                            history of each input token; if not an integer, the
                            word will be considered ``<unk>`` and this variable
                            will be taken literally as the word that will be
                            used in the resulting transcript
        """
        def str_to_unk(self, word):
            """Returns the <unk> word ID if the argument is not a word ID.
            """
            if isinstance(word, int):
                return word
            else:
                return self._unk_id

        input_word_ids = [[
            str_to_unk(self, token.history[-1]) for token in tokens
        ]]
        input_word_ids = numpy.asarray(input_word_ids).astype('int64')
        input_class_ids, membership_probs = \
            self._vocabulary.get_class_memberships(input_word_ids)
        recurrent_state = [token.state for token in tokens]
        recurrent_state = RecurrentState.combine_sequences(recurrent_state)
        target_word_id = str_to_unk(self, target_word)
        target_class_ids = numpy.ones(shape=(1, len(tokens))).astype('int64')
        target_class_ids *= self._vocabulary.word_id_to_class_id[
            target_word_id]
        step_result = self._step_function(input_word_ids, input_class_ids,
                                          target_class_ids,
                                          *recurrent_state.get())
        logprobs = step_result[0]
        # Add logprobs from the class membership of the predicted words.
        logprobs += numpy.log(membership_probs)
        output_state = step_result[1:]

        for index, token in enumerate(tokens):
            token.history.append(target_word)
            token.state = RecurrentState(self._network.recurrent_state_size)
            # Slice the sequence that corresponds to this token.
            token.state.set([
                layer_state[:, index:index + 1] for layer_state in output_state
            ])

            if target_word_id == self._unk_id:
                if self._ignore_unk:
                    continue
                if self._unk_penalty is not None:
                    token.nn_lm_logprob += self._unk_penalty
                    continue
            # logprobs matrix contains only one time step.
            token.nn_lm_logprob += logprobs[0, index]
Exemple #2
0
    def generate(self, max_length=30, num_sequences=1):
        """Generates a text sequence.

        Calls self.step_function() repeatedly, reading the word output and
        the state output of the hidden layer and passing the hidden layer state
        output to the next time step.

        Generates at most ``max_length`` words, stopping if a sentence break is
        generated.

        :type max_length: int
        :param max_length: maximum number of words in a sequence

        :type num_sequences: int
        :param num_sequences: number of sequences to generate in parallel

        :rtype: list of list of strs
        :returns: list of word sequences
        """

        sos_id = self.vocabulary.word_to_id['<s>']
        sos_class_id = self.vocabulary.word_id_to_class_id[sos_id]
        eos_id = self.vocabulary.word_to_id['</s>']

        word_input = sos_id * \
                     numpy.ones(shape=(1, num_sequences)).astype('int64')
        class_input = sos_class_id * \
                      numpy.ones(shape=(1, num_sequences)).astype('int64')
        result = sos_id * \
                 numpy.ones(shape=(max_length, num_sequences)).astype('int64')
        state = RecurrentState(self.network, num_sequences)

        for time_step in range(1, max_length):
            # The input is the output from the previous step.
            step_result = self.step_function(word_input,
                                             class_input,
                                             *state.get())
            class_ids = step_result[0]
            # The class IDs from the single time step.
            step_class_ids = class_ids[0]
            step_word_ids = numpy.array(
                self.vocabulary.class_ids_to_word_ids(step_class_ids))
            result[time_step] = step_word_ids
            word_input = step_word_ids[numpy.newaxis]
            class_input = class_ids
            state.set(step_result[1:])

        return self.vocabulary.id_to_word[result.transpose()].tolist()
Exemple #3
0
    def generate(self, length, num_sequences=1):
        """Generates a text sequence.

        Calls self.step_function() repeatedly, reading the word output and
        the state output of the hidden layer and passing the hidden layer state
        output to the next time step.

        :type length: int
        :param length: number of words (tokens) in each sequence

        :type num_sequences: int
        :param num_sequences: number of sequences to generate in parallel

        :rtype: list of list of strs
        :returns: list of word sequences
        """

        sos_id = self._vocabulary.word_to_id['<s>']
        sos_class_id = self._vocabulary.word_id_to_class_id[sos_id]
        eos_id = self._vocabulary.word_to_id['</s>']

        input_word_ids = sos_id * \
                         numpy.ones(shape=(1, num_sequences)).astype('int64')
        input_class_ids = sos_class_id * \
                          numpy.ones(shape=(1, num_sequences)).astype('int64')
        result = sos_id * \
                 numpy.ones(shape=(length, num_sequences)).astype('int64')
        state = RecurrentState(self._network.recurrent_state_size,
                               num_sequences)

        for time_step in range(1, length):
            # The input is the output from the previous step.
            step_result = self.step_function(input_word_ids,
                                             input_class_ids,
                                             *state.get())
            class_ids = step_result[0]
            # The class IDs from the single time step.
            step_class_ids = class_ids[0]
            step_word_ids = numpy.array(
                self._vocabulary.class_ids_to_word_ids(step_class_ids))
            result[time_step] = step_word_ids
            input_word_ids = step_word_ids[numpy.newaxis]
            input_class_ids = class_ids
            state.set(step_result[1:])

        return self._vocabulary.id_to_word[result.transpose()].tolist()
Exemple #4
0
    def generate(self, length, num_sequences=1):
        """Generates a text sequence.

        Calls self.step_function() repeatedly, reading the word output and
        the state output of the hidden layer and passing the hidden layer state
        output to the next time step.

        :type length: int
        :param length: number of words (tokens) in each sequence

        :type num_sequences: int
        :param num_sequences: number of sequences to generate in parallel

        :rtype: list of list of strs
        :returns: list of word sequences
        """

        sos_id = self._vocabulary.word_to_id['<s>']
        sos_class_id = self._vocabulary.word_id_to_class_id[sos_id]
        eos_id = self._vocabulary.word_to_id['</s>']

        input_word_ids = sos_id * \
                         numpy.ones(shape=(1, num_sequences)).astype('int64')
        input_class_ids = sos_class_id * \
                          numpy.ones(shape=(1, num_sequences)).astype('int64')
        result = sos_id * \
                 numpy.ones(shape=(length, num_sequences)).astype('int64')
        state = RecurrentState(self._network.recurrent_state_size,
                               num_sequences)

        for time_step in range(1, length):
            # The input is the output from the previous step.
            step_result = self.step_function(input_word_ids, input_class_ids,
                                             *state.get())
            class_ids = step_result[0]
            # The class IDs from the single time step.
            step_class_ids = class_ids[0]
            step_word_ids = numpy.array(
                self._vocabulary.class_ids_to_word_ids(step_class_ids))
            result[time_step] = step_word_ids
            input_word_ids = step_word_ids[numpy.newaxis]
            input_class_ids = class_ids
            state.set(step_result[1:])

        return self._vocabulary.id_to_word[result.transpose()].tolist()
Exemple #5
0
    def _append_word(self, tokens, target_word_id):
        """Appends a word to each of the given tokens, and updates their scores.

        :type tokens: list of LatticeDecoder.Tokens
        :param tokens: input tokens

        :type target_word_id: int
        :param target_word_id: word ID to be appended to the existing history of
                               each input token
        """

        input_word_ids = [[token.history[-1] for token in tokens]]
        input_word_ids = numpy.asarray(input_word_ids).astype('int64')
        input_class_ids, membership_probs = \
            self._vocabulary.get_class_memberships(input_word_ids)
        recurrent_state = [token.state for token in tokens]
        recurrent_state = RecurrentState.combine_sequences(recurrent_state)
        target_class_ids = numpy.ones(shape=(1, len(tokens))).astype('int64')
        target_class_ids *= self._vocabulary.word_id_to_class_id[
            target_word_id]
        step_result = self.step_function(input_word_ids, input_class_ids,
                                         target_class_ids,
                                         *recurrent_state.get())
        logprobs = step_result[0]
        # Add logprobs from the class membership of the predicted words.
        logprobs += numpy.log(membership_probs)
        output_state = step_result[1:]

        for index, token in enumerate(tokens):
            token.history.append(target_word_id)
            token.state = RecurrentState(self._network.recurrent_state_size)
            # Slice the sequence that corresponds to this token.
            token.state.set([
                layer_state[:, index:index + 1] for layer_state in output_state
            ])

            if target_word_id == self._unk_id:
                if self._ignore_unk:
                    continue
                if not self._unk_penalty is None:
                    token.nn_lm_logprob += self._unk_penalty
                    continue
            # logprobs matrix contains only one time step.
            token.nn_lm_logprob += logprobs[0, index]
    def test_set(self):
        state = RecurrentState([5, 10], 3)
        layer1_state = numpy.arange(15, dtype='int64').reshape((1, 3, 5))
        layer2_state = numpy.arange(30, dtype='int64').reshape((1, 3, 10))
        state.set([layer1_state, layer2_state])
        assert_equal(state.get(0), layer1_state)
        assert_equal(state.get(1), layer2_state)

        with self.assertRaises(ValueError):
            state.set([layer2_state, layer1_state])
    def test_init(self):
        state = RecurrentState([200, 100, 300], 3)
        self.assertEqual(len(state.get()), 3)
        self.assertEqual(state.get(0).shape, (1,3,200))
        self.assertEqual(state.get(1).shape, (1,3,100))
        self.assertEqual(state.get(2).shape, (1,3,300))
        assert_equal(state.get(0), numpy.zeros(shape=(1,3,200), dtype='int64'))
        assert_equal(state.get(1), numpy.zeros(shape=(1,3,100), dtype='int64'))
        assert_equal(state.get(2), numpy.zeros(shape=(1,3,300), dtype='int64'))

        layer1_state = numpy.arange(15, dtype='int64').reshape((1, 3, 5))
        layer2_state = numpy.arange(30, dtype='int64').reshape((1, 3, 10))
        state = RecurrentState([5, 10], 3, [layer1_state, layer2_state])
        assert_equal(state.get(0), layer1_state)
        assert_equal(state.get(1), layer2_state)
    def test_set(self):
        state = RecurrentState([5, 10], 3)
        layer1_state = numpy.arange(15, dtype='int64').reshape((1, 3, 5))
        layer2_state = numpy.arange(30, dtype='int64').reshape((1, 3, 10))
        state.set([layer1_state, layer2_state])
        assert_equal(state.get(0), layer1_state)
        assert_equal(state.get(1), layer2_state)

        with self.assertRaises(ValueError):
            state.set([layer2_state, layer1_state])
Exemple #9
0
    def _append_word(self, tokens, target_word_id):
        """Appends a word to each of the given tokens, and updates their scores.

        :type tokens: list of LatticeDecoder.Tokens
        :param tokens: input tokens

        :type target_word_id: int
        :param target_word_id: word ID to be appended to the existing history of
                               each input token
        """

        input_word_ids = [[token.history[-1] for token in tokens]]
        input_word_ids = numpy.asarray(input_word_ids).astype('int64')
        input_class_ids, membership_probs = \
            self._vocabulary.get_class_memberships(input_word_ids)
        recurrent_state = [token.state for token in tokens]
        recurrent_state = RecurrentState.combine_sequences(recurrent_state)
        target_class_ids = numpy.ones(shape=(1, len(tokens))).astype('int64')
        target_class_ids *= self._vocabulary.word_id_to_class_id[target_word_id]
        step_result = self.step_function(input_word_ids,
                                         input_class_ids,
                                         target_class_ids,
                                         *recurrent_state.get())
        logprobs = step_result[0]
        # Add logprobs from the class membership of the predicted words.
        logprobs += numpy.log(membership_probs)
        output_state = step_result[1:]

        for index, token in enumerate(tokens):
            token.history.append(target_word_id)
            token.state = RecurrentState(self._network.recurrent_state_size)
            # Slice the sequence that corresponds to this token.
            token.state.set([layer_state[:,index:index+1]
                             for layer_state in output_state])

            if target_word_id == self._unk_id:
                if self._ignore_unk:
                    continue
                if not self._unk_penalty is None:
                    token.nn_lm_logprob += self._unk_penalty
                    continue
            # logprobs matrix contains only one time step.
            token.nn_lm_logprob += logprobs[0,index]
    def test_init(self):
        state = RecurrentState([200, 100, 300], 3)
        self.assertEqual(len(state.get()), 3)
        self.assertEqual(state.get(0).shape, (1, 3, 200))
        self.assertEqual(state.get(1).shape, (1, 3, 100))
        self.assertEqual(state.get(2).shape, (1, 3, 300))
        assert_equal(state.get(0), numpy.zeros(shape=(1, 3, 200),
                                               dtype='int64'))
        assert_equal(state.get(1), numpy.zeros(shape=(1, 3, 100),
                                               dtype='int64'))
        assert_equal(state.get(2), numpy.zeros(shape=(1, 3, 300),
                                               dtype='int64'))

        layer1_state = numpy.arange(15, dtype='int64').reshape((1, 3, 5))
        layer2_state = numpy.arange(30, dtype='int64').reshape((1, 3, 10))
        state = RecurrentState([5, 10], 3, [layer1_state, layer2_state])
        assert_equal(state.get(0), layer1_state)
        assert_equal(state.get(1), layer2_state)
    def test_combine_sequences(self):
        state1 = RecurrentState([5, 10], 1)
        layer1_state = numpy.arange(5, dtype='int64').reshape(1, 1, 5)
        layer2_state = numpy.arange(10, 20, dtype='int64').reshape(1, 1, 10)
        state1.set([layer1_state, layer2_state])

        state2 = RecurrentState([5, 10], 1)
        layer1_state = numpy.arange(100, 105, dtype='int64').reshape(1, 1, 5)
        layer2_state = numpy.arange(110, 120, dtype='int64').reshape(1, 1, 10)
        state2.set([layer1_state, layer2_state])

        state3 = RecurrentState([5, 10], 2)
        layer1_state = numpy.arange(200, 210, dtype='int64').reshape(1, 2, 5)
        layer2_state = numpy.arange(210, 230, dtype='int64').reshape(1, 2, 10)
        state3.set([layer1_state, layer2_state])

        combined_state = RecurrentState.combine_sequences([state1, state2, state3])
        self.assertEqual(combined_state.num_sequences, 4)
        self.assertEqual(len(combined_state.get()), 2)
        self.assertEqual(combined_state.get(0).shape, (1,4,5))
        self.assertEqual(combined_state.get(1).shape, (1,4,10))
        assert_equal(combined_state.get(0), numpy.asarray(
            [[list(range(5)),
              list(range(100, 105)),
              list(range(200, 205)),
              list(range(205, 210))]],
            dtype='int64'))
        assert_equal(combined_state.get(1), numpy.asarray(
            [[list(range(10, 20)),
              list(range(110, 120)),
              list(range(210, 220)),
              list(range(220, 230))]],
            dtype='int64'))

        state4 = RecurrentState([5, 11], 2)
        with self.assertRaises(ValueError):
            combined_state = RecurrentState.combine_sequences([state1, state2, state3, state4])
Exemple #12
0
    def decode(self, lattice):
        """Propagates tokens through given lattice and returns a list of tokens
        in the final node.

        Propagates tokens at a node to every outgoing link by creating a copy of
        each token and updating the language model scores according to the link.

        :type lattice: Lattice
        :param lattice: a word lattice to be decoded

        :rtype: list of LatticeDecoder.Tokens
        :returns: the final tokens sorted by total log probability in descending
                  order
        """

        if not self._lm_scale is None:
            lm_scale = logprob_type(self._lm_scale)
        elif not lattice.lm_scale is None:
            lm_scale = logprob_type(lattice.lm_scale)
        else:
            lm_scale = logprob_type(1.0)

        if not self._wi_penalty is None:
            wi_penalty = logprob_type(self._wi_penalty)
        if not lattice.wi_penalty is None:
            wi_penalty = logprob_type(lattice.wi_penalty)
        else:
            wi_penalty = logprob_type(0.0)

        self._tokens = [list() for _ in lattice.nodes]
        initial_state = RecurrentState(self._network.recurrent_state_size)
        initial_token = self.Token(history=[self._sos_id], state=initial_state)
        initial_token.recompute_hash(self._recombination_order)
        initial_token.recompute_total(self._nnlm_weight, lm_scale, wi_penalty,
                                      self._linear_interpolation)
        self._tokens[lattice.initial_node.id].append(initial_token)
        lattice.initial_node.best_logprob = initial_token.total_logprob

        self._sorted_nodes = lattice.sorted_nodes()
        nodes_processed = 0
        for node in self._sorted_nodes:
            node_tokens = self._tokens[node.id]
            assert node_tokens
            num_pruned_tokens = len(node_tokens)
            self._prune(node)
            node_tokens = self._tokens[node.id]
            assert node_tokens
            num_pruned_tokens -= len(node_tokens)

            if node.id == lattice.final_node.id:
                new_tokens = self._propagate(
                    node_tokens, None, lm_scale, wi_penalty)
                return sorted(new_tokens,
                              key=lambda token: token.total_logprob,
                              reverse=True)

            num_new_tokens = 0
            for link in node.out_links:
                new_tokens = self._propagate(
                    node_tokens, link, lm_scale, wi_penalty)
                self._tokens[link.end_node.id].extend(new_tokens)
                num_new_tokens += len(new_tokens)

            nodes_processed += 1
            if nodes_processed % math.ceil(len(self._sorted_nodes) / 20) == 0:
                logging.debug("[%d] (%.2f %%) -- tokens = %d +%d -%d",
                              nodes_processed,
                              nodes_processed / len(self._sorted_nodes) * 100,
                              len(node_tokens),
                              num_new_tokens,
                              num_pruned_tokens)

        raise InputError("Could not reach the final node of word lattice.")
Exemple #13
0
    def generate(self, length, num_sequences=1, seed_sequence=''):
        """Generates a text sequence.

        Calls self.step_function() repeatedly, reading the word output and
        the state output of the hidden layer and passing the hidden layer state
        output to the next time step.

        :type length: int
        :param length: number of words (tokens) in each sequence

        :type num_sequences: int
        :param num_sequences: number of sequences to generate in parallel

        :rtype: list of list of strs
        :returns: list of word sequences
        """
        seed_tokens = seed_sequence.strip().split()
        sos_id = self._vocabulary.word_to_id['<s>']
        sos_class_id = self._vocabulary.word_id_to_class_id[sos_id]

        input_word_ids = sos_id * \
                         numpy.ones(shape=(1, num_sequences)).astype('int64')
        input_class_ids = sos_class_id * \
                          numpy.ones(shape=(1, num_sequences)).astype('int64')
        result = sos_id * \
                 numpy.ones(shape=(len(seed_tokens)+length,
                   num_sequences)).astype('int64')
        state = RecurrentState(self._network.recurrent_state_size,
                               num_sequences)

        #First, possibly compute forward passes with the seed sequence
        for time_step, token in enumerate(seed_tokens, start=1):
            step_result = self.step_function(input_word_ids, input_class_ids,
                                             *state.get())
            token_id = self._vocabulary.word_to_id[token]
            token_class_id = \
                self._vocabulary.word_id_to_class_id[token_id]
            input_word_ids = token_id * \
                             numpy.ones(shape=(1,
                               num_sequences)).astype('int64')
            input_class_ids = token_class_id * \
                              numpy.ones(shape=(1,
                                num_sequences)).astype('int64')
            step_word_ids = input_word_ids
            result[time_step] = step_word_ids
            state.set(step_result[1:])

        #Then sample:
        for time_step in range(
                len(seed_tokens) + 1, length + len(seed_tokens)):
            # the input is the output from the previous step.
            step_result = self.step_function(input_word_ids, input_class_ids,
                                             *state.get())
            class_ids = step_result[0]
            # The class IDs from the single time step.
            step_class_ids = class_ids[0]
            step_word_ids = numpy.array(
                self._vocabulary.class_ids_to_word_ids(step_class_ids))
            result[time_step] = step_word_ids
            input_word_ids = step_word_ids[numpy.newaxis]
            input_class_ids = class_ids
            state.set(step_result[1:])

        return self._vocabulary.id_to_word[result.transpose()].tolist()
    def decode(self, lattice):
        """Propagates tokens through given lattice and returns a list of tokens
        in the final nodes.

        Propagates tokens at a node to every outgoing link by creating a copy of
        each token and updating the language model scores according to the link.

        The function returns two lists. The first list contains the final
        tokens, sorted in the descending order of total log probability. I.e.
        the first token in the list represents the best path through the
        lattice. The second list contains the tokens that were dropped during
        recombination. This is needed for constructing a new rescored lattice.

        :type lattice: Lattice
        :param lattice: a word lattice to be decoded

        :rtype: a tuple of two lists of LatticeDecoder.Tokens
        :returns: a list of the final tokens sorted by probability (most likely
                  token first), and a list of the tokens that were dropped
                  during recombination
        """

        if self._lm_scale is not None:
            lm_scale = logprob_type(self._lm_scale)
        elif lattice.lm_scale is not None:
            lm_scale = logprob_type(lattice.lm_scale)
        else:
            lm_scale = logprob_type(1.0)

        if self._wi_penalty is not None:
            wi_penalty = logprob_type(self._wi_penalty)
        elif lattice.wi_penalty is not None:
            wi_penalty = logprob_type(lattice.wi_penalty)
        else:
            wi_penalty = logprob_type(0.0)

        tokens = [list() for _ in lattice.nodes]
        recomb_tokens = []
        initial_state = RecurrentState(self._network.recurrent_state_size)
        initial_token = self.Token(history=(self._sos_id, ),
                                   state=initial_state)
        initial_token.recompute_hash(self._recombination_order)
        initial_token.recompute_total(self._nnlm_weight, lm_scale, wi_penalty,
                                      self._linear_interpolation)
        tokens[lattice.initial_node.id].append(initial_token)
        lattice.initial_node.best_logprob = initial_token.total_logprob

        sorted_nodes = lattice.sorted_nodes()
        self._nodes_processed = 0
        final_tokens = []
        for node in sorted_nodes:
            stats = self._prune(node, sorted_nodes, tokens, recomb_tokens)

            num_new_tokens = 0
            node_tokens = tokens[node.id]
            assert node_tokens
            if node.final:
                new_tokens = self._propagate(node_tokens, None, lm_scale,
                                             wi_penalty)
                final_tokens.extend(new_tokens)
                num_new_tokens += len(new_tokens)
            for link in node.out_links:
                new_tokens = self._propagate(node_tokens, link, lm_scale,
                                             wi_penalty)
                tokens[link.end_node.id].extend(new_tokens)
                # If there are lots of tokens in the end node, prune already to
                # conserve memory.
                if self._max_tokens_per_node is not None and \
                   len(tokens[link.end_node.id]) > self._max_tokens_per_node * 2:
                    self._prune(link.end_node, sorted_nodes, tokens,
                                recomb_tokens)
                num_new_tokens += len(new_tokens)
            stats['new'] = num_new_tokens

            self._nodes_processed += 1
            self._log_stats(stats, node.id, len(sorted_nodes))

        if len(final_tokens) == 0:
            raise InputError("Could not reach a final node of word lattice.")

        final_tokens = self._sorted_recombined_tokens(final_tokens,
                                                      recomb_tokens)
        return final_tokens, recomb_tokens
Exemple #15
0
    def test_append_word(self):
        decoding_options = {
            'nnlm_weight': 1.0,
            'lm_scale': 1.0,
            'wi_penalty': 0.0,
            'ignore_unk': False,
            'unk_penalty': 0.0,
            'linear_interpolation': False,
            'max_tokens_per_node': 10,
            'beam': None,
            'recombination_order': None
        }

        initial_state = RecurrentState(self.network.recurrent_state_size)
        token1 = LatticeDecoder.Token(history=[self.sos_id],
                                      state=initial_state)
        token2 = LatticeDecoder.Token(history=[self.sos_id, self.yksi_id],
                                      state=initial_state)
        decoder = LatticeDecoder(self.network, decoding_options)

        self.assertSequenceEqual(token1.history, [self.sos_id])
        self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id])
        assert_equal(token1.state.get(0),
                     numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX))
        assert_equal(token2.state.get(0),
                     numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX))
        self.assertEqual(token1.nn_lm_logprob, 0.0)
        self.assertEqual(token2.nn_lm_logprob, 0.0)

        decoder._append_word([token1, token2], self.kaksi_id)
        self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id])
        self.assertSequenceEqual(token2.history,
                                 [self.sos_id, self.yksi_id, self.kaksi_id])
        assert_equal(token1.state.get(0),
                     numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX))
        assert_equal(token2.state.get(0),
                     numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX))
        token1_nn_lm_logprob = math.log(self.sos_prob + self.kaksi_prob)
        token2_nn_lm_logprob = math.log(self.yksi_prob + self.kaksi_prob)
        self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob)
        self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob)

        decoder._append_word([token1, token2], self.eos_id)
        self.assertSequenceEqual(token1.history,
                                 [self.sos_id, self.kaksi_id, self.eos_id])
        self.assertSequenceEqual(
            token2.history,
            [self.sos_id, self.yksi_id, self.kaksi_id, self.eos_id])
        assert_equal(
            token1.state.get(0),
            numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2)
        assert_equal(
            token2.state.get(0),
            numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2)
        token1_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob)
        token2_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob)
        self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob)
        self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob)

        lm_scale = 2.0
        token1.recompute_total(1.0, lm_scale, -0.01)
        token2.recompute_total(1.0, lm_scale, -0.01)
        self.assertAlmostEqual(token1.total_logprob,
                               token1_nn_lm_logprob * lm_scale - 0.03)
        self.assertAlmostEqual(token2.total_logprob,
                               token2_nn_lm_logprob * lm_scale - 0.04)
    def test_combine_sequences(self):
        state1 = RecurrentState([5, 10], 1)
        layer1_state = numpy.arange(5, dtype='int64').reshape(1, 1, 5)
        layer2_state = numpy.arange(10, 20, dtype='int64').reshape(1, 1, 10)
        state1.set([layer1_state, layer2_state])

        state2 = RecurrentState([5, 10], 1)
        layer1_state = numpy.arange(100, 105, dtype='int64').reshape(1, 1, 5)
        layer2_state = numpy.arange(110, 120, dtype='int64').reshape(1, 1, 10)
        state2.set([layer1_state, layer2_state])

        state3 = RecurrentState([5, 10], 2)
        layer1_state = numpy.arange(200, 210, dtype='int64').reshape(1, 2, 5)
        layer2_state = numpy.arange(210, 230, dtype='int64').reshape(1, 2, 10)
        state3.set([layer1_state, layer2_state])

        combined_state = RecurrentState.combine_sequences(
            [state1, state2, state3])
        self.assertEqual(combined_state.num_sequences, 4)
        self.assertEqual(len(combined_state.get()), 2)
        self.assertEqual(combined_state.get(0).shape, (1, 4, 5))
        self.assertEqual(combined_state.get(1).shape, (1, 4, 10))
        assert_equal(
            combined_state.get(0),
            numpy.asarray([[
                list(range(5)),
                list(range(100, 105)),
                list(range(200, 205)),
                list(range(205, 210))
            ]],
                          dtype='int64'))
        assert_equal(
            combined_state.get(1),
            numpy.asarray([[
                list(range(10, 20)),
                list(range(110, 120)),
                list(range(210, 220)),
                list(range(220, 230))
            ]],
                          dtype='int64'))

        state4 = RecurrentState([5, 11], 2)
        with self.assertRaises(ValueError):
            combined_state = RecurrentState.combine_sequences(
                [state1, state2, state3, state4])
Exemple #17
0
    def _append_word(self, tokens, target_word, oov_logprob=None):
        """Appends a word to each of the given tokens, and updates their scores.

        :type tokens: list of LatticeDecoder.Tokens
        :param tokens: input tokens

        :type target_word: int or str
        :param target_word: word ID or word to be appended to the existing
                            history of each input token; if not an integer, the
                            word will be considered ``<unk>`` and this variable
                            will be taken literally as the word that will be
                            used in the resulting transcript

        :type oov_logprob: float
        :param oov_logprob: log probability to be assigned to OOV words
        """
        def limit_to_shortlist(self, word):
            """Returns the ``<unk>`` word ID if the argument is not a shortlist
            word ID.
            """
            if isinstance(word, int) and self._vocabulary.in_shortlist(word):
                return word
            else:
                return self._unk_id

        # Pass all/limited previous input if the net has attention
        if self._network.has_attention:
            # handle sequeces separately as the history could be varying length
            for index, token in enumerate(tokens):
                input_word_ids = [[
                    limit_to_shortlist(self, t) for t in token.history
                ]]
                input_word_ids = numpy.asarray(input_word_ids).astype(
                    'int64').transpose()
                #logging.debug("input shape: %s", input_word_ids.shape)
                input_class_ids, membership_prob = \
                   self._vocabulary.get_class_memberships(input_word_ids)
                recurrent_state = [token.state]
                recurrent_state = RecurrentState.combine_sequences(
                    recurrent_state)
                target_word_id = limit_to_shortlist(self, target_word)
                target_class_ids = numpy.ones(shape=(1, 1)).astype('int64')
                target_class_ids *= self._vocabulary.word_id_to_class_id[
                    target_word_id]
                sub_step_result = self._step_function(input_word_ids,
                                                      input_class_ids,
                                                      target_class_ids,
                                                      *recurrent_state.get())
                #update token with the values
                logprobs = sub_step_result[0]
                #logging.debug("Logprobs shape: %s, %s", logprobs.shape, self._network.recurrent_state_size)
                # Add logprobs from the class membership of the predicted words.
                logprobs += numpy.log(membership_prob[-1, :])
                output_state = sub_step_result[1:]
                token.history = token.history + (target_word, )
                token.state = RecurrentState(
                    self._network.recurrent_state_size)
                # Slice the sequence that corresponds to this token.
                token.state.set([layer_state for layer_state in output_state])
                # logprobs matrix contains only one time step.
                token.nn_lm_logprob += self._handle_unk_logprob(
                    target_word, logprobs[0, 0], oov_logprob)

        #logging.debug("Next time step")
        #for i, data in enumerate(input_word_ids):
        #    logging.debug("Length of data: %d, %s", len(data), data)
        #    if len(data)<max_len:
        #pad with zeros
        #        input_word_ids[i] = [0] * (max_len-len(data)) + data
        #        logging.debug("New length of data: %d, %s", len(input_word_ids[i]), input_word_ids[i])
        # No attention in the network
        else:
            input_word_ids = [[
                limit_to_shortlist(self, token.history[-1]) for token in tokens
            ]]
            input_word_ids = numpy.asarray(input_word_ids).astype('int64')
            input_class_ids, membership_probs = \
                self._vocabulary.get_class_memberships(input_word_ids)
            recurrent_state = [token.state for token in tokens]
            recurrent_state = RecurrentState.combine_sequences(recurrent_state)
            target_word_id = limit_to_shortlist(self, target_word)
            target_class_ids = numpy.ones(shape=(1,
                                                 len(tokens))).astype('int64')
            target_class_ids *= self._vocabulary.word_id_to_class_id[
                target_word_id]
            step_result = self._step_function(input_word_ids, input_class_ids,
                                              target_class_ids,
                                              *recurrent_state.get())
            logprobs = step_result[0]
            # Add logprobs from the class membership of the predicted words.
            logprobs += numpy.log(membership_probs)
            output_state = step_result[1:]

            for index, token in enumerate(tokens):
                token.history = token.history + (target_word, )
                token.state = RecurrentState(
                    self._network.recurrent_state_size)
                # Slice the sequence that corresponds to this token.
                token.state.set([
                    layer_state[:, index:index + 1]
                    for layer_state in output_state
                ])
                # logprobs matrix contains only one time step.
                token.nn_lm_logprob += self._handle_unk_logprob(
                    target_word, logprobs[0, index], oov_logprob)