def _append_word(self, tokens, target_word): """Appends a word to each of the given tokens, and updates their scores. :type tokens: list of LatticeDecoder.Tokens :param tokens: input tokens :type target_word: int or str :param target_word: word ID or word to be appended to the existing history of each input token; if not an integer, the word will be considered ``<unk>`` and this variable will be taken literally as the word that will be used in the resulting transcript """ def str_to_unk(self, word): """Returns the <unk> word ID if the argument is not a word ID. """ if isinstance(word, int): return word else: return self._unk_id input_word_ids = [[ str_to_unk(self, token.history[-1]) for token in tokens ]] input_word_ids = numpy.asarray(input_word_ids).astype('int64') input_class_ids, membership_probs = \ self._vocabulary.get_class_memberships(input_word_ids) recurrent_state = [token.state for token in tokens] recurrent_state = RecurrentState.combine_sequences(recurrent_state) target_word_id = str_to_unk(self, target_word) target_class_ids = numpy.ones(shape=(1, len(tokens))).astype('int64') target_class_ids *= self._vocabulary.word_id_to_class_id[ target_word_id] step_result = self._step_function(input_word_ids, input_class_ids, target_class_ids, *recurrent_state.get()) logprobs = step_result[0] # Add logprobs from the class membership of the predicted words. logprobs += numpy.log(membership_probs) output_state = step_result[1:] for index, token in enumerate(tokens): token.history.append(target_word) token.state = RecurrentState(self._network.recurrent_state_size) # Slice the sequence that corresponds to this token. token.state.set([ layer_state[:, index:index + 1] for layer_state in output_state ]) if target_word_id == self._unk_id: if self._ignore_unk: continue if self._unk_penalty is not None: token.nn_lm_logprob += self._unk_penalty continue # logprobs matrix contains only one time step. token.nn_lm_logprob += logprobs[0, index]
def generate(self, max_length=30, num_sequences=1): """Generates a text sequence. Calls self.step_function() repeatedly, reading the word output and the state output of the hidden layer and passing the hidden layer state output to the next time step. Generates at most ``max_length`` words, stopping if a sentence break is generated. :type max_length: int :param max_length: maximum number of words in a sequence :type num_sequences: int :param num_sequences: number of sequences to generate in parallel :rtype: list of list of strs :returns: list of word sequences """ sos_id = self.vocabulary.word_to_id['<s>'] sos_class_id = self.vocabulary.word_id_to_class_id[sos_id] eos_id = self.vocabulary.word_to_id['</s>'] word_input = sos_id * \ numpy.ones(shape=(1, num_sequences)).astype('int64') class_input = sos_class_id * \ numpy.ones(shape=(1, num_sequences)).astype('int64') result = sos_id * \ numpy.ones(shape=(max_length, num_sequences)).astype('int64') state = RecurrentState(self.network, num_sequences) for time_step in range(1, max_length): # The input is the output from the previous step. step_result = self.step_function(word_input, class_input, *state.get()) class_ids = step_result[0] # The class IDs from the single time step. step_class_ids = class_ids[0] step_word_ids = numpy.array( self.vocabulary.class_ids_to_word_ids(step_class_ids)) result[time_step] = step_word_ids word_input = step_word_ids[numpy.newaxis] class_input = class_ids state.set(step_result[1:]) return self.vocabulary.id_to_word[result.transpose()].tolist()
def generate(self, length, num_sequences=1): """Generates a text sequence. Calls self.step_function() repeatedly, reading the word output and the state output of the hidden layer and passing the hidden layer state output to the next time step. :type length: int :param length: number of words (tokens) in each sequence :type num_sequences: int :param num_sequences: number of sequences to generate in parallel :rtype: list of list of strs :returns: list of word sequences """ sos_id = self._vocabulary.word_to_id['<s>'] sos_class_id = self._vocabulary.word_id_to_class_id[sos_id] eos_id = self._vocabulary.word_to_id['</s>'] input_word_ids = sos_id * \ numpy.ones(shape=(1, num_sequences)).astype('int64') input_class_ids = sos_class_id * \ numpy.ones(shape=(1, num_sequences)).astype('int64') result = sos_id * \ numpy.ones(shape=(length, num_sequences)).astype('int64') state = RecurrentState(self._network.recurrent_state_size, num_sequences) for time_step in range(1, length): # The input is the output from the previous step. step_result = self.step_function(input_word_ids, input_class_ids, *state.get()) class_ids = step_result[0] # The class IDs from the single time step. step_class_ids = class_ids[0] step_word_ids = numpy.array( self._vocabulary.class_ids_to_word_ids(step_class_ids)) result[time_step] = step_word_ids input_word_ids = step_word_ids[numpy.newaxis] input_class_ids = class_ids state.set(step_result[1:]) return self._vocabulary.id_to_word[result.transpose()].tolist()
def _append_word(self, tokens, target_word_id): """Appends a word to each of the given tokens, and updates their scores. :type tokens: list of LatticeDecoder.Tokens :param tokens: input tokens :type target_word_id: int :param target_word_id: word ID to be appended to the existing history of each input token """ input_word_ids = [[token.history[-1] for token in tokens]] input_word_ids = numpy.asarray(input_word_ids).astype('int64') input_class_ids, membership_probs = \ self._vocabulary.get_class_memberships(input_word_ids) recurrent_state = [token.state for token in tokens] recurrent_state = RecurrentState.combine_sequences(recurrent_state) target_class_ids = numpy.ones(shape=(1, len(tokens))).astype('int64') target_class_ids *= self._vocabulary.word_id_to_class_id[ target_word_id] step_result = self.step_function(input_word_ids, input_class_ids, target_class_ids, *recurrent_state.get()) logprobs = step_result[0] # Add logprobs from the class membership of the predicted words. logprobs += numpy.log(membership_probs) output_state = step_result[1:] for index, token in enumerate(tokens): token.history.append(target_word_id) token.state = RecurrentState(self._network.recurrent_state_size) # Slice the sequence that corresponds to this token. token.state.set([ layer_state[:, index:index + 1] for layer_state in output_state ]) if target_word_id == self._unk_id: if self._ignore_unk: continue if not self._unk_penalty is None: token.nn_lm_logprob += self._unk_penalty continue # logprobs matrix contains only one time step. token.nn_lm_logprob += logprobs[0, index]
def test_set(self): state = RecurrentState([5, 10], 3) layer1_state = numpy.arange(15, dtype='int64').reshape((1, 3, 5)) layer2_state = numpy.arange(30, dtype='int64').reshape((1, 3, 10)) state.set([layer1_state, layer2_state]) assert_equal(state.get(0), layer1_state) assert_equal(state.get(1), layer2_state) with self.assertRaises(ValueError): state.set([layer2_state, layer1_state])
def test_init(self): state = RecurrentState([200, 100, 300], 3) self.assertEqual(len(state.get()), 3) self.assertEqual(state.get(0).shape, (1,3,200)) self.assertEqual(state.get(1).shape, (1,3,100)) self.assertEqual(state.get(2).shape, (1,3,300)) assert_equal(state.get(0), numpy.zeros(shape=(1,3,200), dtype='int64')) assert_equal(state.get(1), numpy.zeros(shape=(1,3,100), dtype='int64')) assert_equal(state.get(2), numpy.zeros(shape=(1,3,300), dtype='int64')) layer1_state = numpy.arange(15, dtype='int64').reshape((1, 3, 5)) layer2_state = numpy.arange(30, dtype='int64').reshape((1, 3, 10)) state = RecurrentState([5, 10], 3, [layer1_state, layer2_state]) assert_equal(state.get(0), layer1_state) assert_equal(state.get(1), layer2_state)
def _append_word(self, tokens, target_word_id): """Appends a word to each of the given tokens, and updates their scores. :type tokens: list of LatticeDecoder.Tokens :param tokens: input tokens :type target_word_id: int :param target_word_id: word ID to be appended to the existing history of each input token """ input_word_ids = [[token.history[-1] for token in tokens]] input_word_ids = numpy.asarray(input_word_ids).astype('int64') input_class_ids, membership_probs = \ self._vocabulary.get_class_memberships(input_word_ids) recurrent_state = [token.state for token in tokens] recurrent_state = RecurrentState.combine_sequences(recurrent_state) target_class_ids = numpy.ones(shape=(1, len(tokens))).astype('int64') target_class_ids *= self._vocabulary.word_id_to_class_id[target_word_id] step_result = self.step_function(input_word_ids, input_class_ids, target_class_ids, *recurrent_state.get()) logprobs = step_result[0] # Add logprobs from the class membership of the predicted words. logprobs += numpy.log(membership_probs) output_state = step_result[1:] for index, token in enumerate(tokens): token.history.append(target_word_id) token.state = RecurrentState(self._network.recurrent_state_size) # Slice the sequence that corresponds to this token. token.state.set([layer_state[:,index:index+1] for layer_state in output_state]) if target_word_id == self._unk_id: if self._ignore_unk: continue if not self._unk_penalty is None: token.nn_lm_logprob += self._unk_penalty continue # logprobs matrix contains only one time step. token.nn_lm_logprob += logprobs[0,index]
def test_init(self): state = RecurrentState([200, 100, 300], 3) self.assertEqual(len(state.get()), 3) self.assertEqual(state.get(0).shape, (1, 3, 200)) self.assertEqual(state.get(1).shape, (1, 3, 100)) self.assertEqual(state.get(2).shape, (1, 3, 300)) assert_equal(state.get(0), numpy.zeros(shape=(1, 3, 200), dtype='int64')) assert_equal(state.get(1), numpy.zeros(shape=(1, 3, 100), dtype='int64')) assert_equal(state.get(2), numpy.zeros(shape=(1, 3, 300), dtype='int64')) layer1_state = numpy.arange(15, dtype='int64').reshape((1, 3, 5)) layer2_state = numpy.arange(30, dtype='int64').reshape((1, 3, 10)) state = RecurrentState([5, 10], 3, [layer1_state, layer2_state]) assert_equal(state.get(0), layer1_state) assert_equal(state.get(1), layer2_state)
def test_combine_sequences(self): state1 = RecurrentState([5, 10], 1) layer1_state = numpy.arange(5, dtype='int64').reshape(1, 1, 5) layer2_state = numpy.arange(10, 20, dtype='int64').reshape(1, 1, 10) state1.set([layer1_state, layer2_state]) state2 = RecurrentState([5, 10], 1) layer1_state = numpy.arange(100, 105, dtype='int64').reshape(1, 1, 5) layer2_state = numpy.arange(110, 120, dtype='int64').reshape(1, 1, 10) state2.set([layer1_state, layer2_state]) state3 = RecurrentState([5, 10], 2) layer1_state = numpy.arange(200, 210, dtype='int64').reshape(1, 2, 5) layer2_state = numpy.arange(210, 230, dtype='int64').reshape(1, 2, 10) state3.set([layer1_state, layer2_state]) combined_state = RecurrentState.combine_sequences([state1, state2, state3]) self.assertEqual(combined_state.num_sequences, 4) self.assertEqual(len(combined_state.get()), 2) self.assertEqual(combined_state.get(0).shape, (1,4,5)) self.assertEqual(combined_state.get(1).shape, (1,4,10)) assert_equal(combined_state.get(0), numpy.asarray( [[list(range(5)), list(range(100, 105)), list(range(200, 205)), list(range(205, 210))]], dtype='int64')) assert_equal(combined_state.get(1), numpy.asarray( [[list(range(10, 20)), list(range(110, 120)), list(range(210, 220)), list(range(220, 230))]], dtype='int64')) state4 = RecurrentState([5, 11], 2) with self.assertRaises(ValueError): combined_state = RecurrentState.combine_sequences([state1, state2, state3, state4])
def decode(self, lattice): """Propagates tokens through given lattice and returns a list of tokens in the final node. Propagates tokens at a node to every outgoing link by creating a copy of each token and updating the language model scores according to the link. :type lattice: Lattice :param lattice: a word lattice to be decoded :rtype: list of LatticeDecoder.Tokens :returns: the final tokens sorted by total log probability in descending order """ if not self._lm_scale is None: lm_scale = logprob_type(self._lm_scale) elif not lattice.lm_scale is None: lm_scale = logprob_type(lattice.lm_scale) else: lm_scale = logprob_type(1.0) if not self._wi_penalty is None: wi_penalty = logprob_type(self._wi_penalty) if not lattice.wi_penalty is None: wi_penalty = logprob_type(lattice.wi_penalty) else: wi_penalty = logprob_type(0.0) self._tokens = [list() for _ in lattice.nodes] initial_state = RecurrentState(self._network.recurrent_state_size) initial_token = self.Token(history=[self._sos_id], state=initial_state) initial_token.recompute_hash(self._recombination_order) initial_token.recompute_total(self._nnlm_weight, lm_scale, wi_penalty, self._linear_interpolation) self._tokens[lattice.initial_node.id].append(initial_token) lattice.initial_node.best_logprob = initial_token.total_logprob self._sorted_nodes = lattice.sorted_nodes() nodes_processed = 0 for node in self._sorted_nodes: node_tokens = self._tokens[node.id] assert node_tokens num_pruned_tokens = len(node_tokens) self._prune(node) node_tokens = self._tokens[node.id] assert node_tokens num_pruned_tokens -= len(node_tokens) if node.id == lattice.final_node.id: new_tokens = self._propagate( node_tokens, None, lm_scale, wi_penalty) return sorted(new_tokens, key=lambda token: token.total_logprob, reverse=True) num_new_tokens = 0 for link in node.out_links: new_tokens = self._propagate( node_tokens, link, lm_scale, wi_penalty) self._tokens[link.end_node.id].extend(new_tokens) num_new_tokens += len(new_tokens) nodes_processed += 1 if nodes_processed % math.ceil(len(self._sorted_nodes) / 20) == 0: logging.debug("[%d] (%.2f %%) -- tokens = %d +%d -%d", nodes_processed, nodes_processed / len(self._sorted_nodes) * 100, len(node_tokens), num_new_tokens, num_pruned_tokens) raise InputError("Could not reach the final node of word lattice.")
def generate(self, length, num_sequences=1, seed_sequence=''): """Generates a text sequence. Calls self.step_function() repeatedly, reading the word output and the state output of the hidden layer and passing the hidden layer state output to the next time step. :type length: int :param length: number of words (tokens) in each sequence :type num_sequences: int :param num_sequences: number of sequences to generate in parallel :rtype: list of list of strs :returns: list of word sequences """ seed_tokens = seed_sequence.strip().split() sos_id = self._vocabulary.word_to_id['<s>'] sos_class_id = self._vocabulary.word_id_to_class_id[sos_id] input_word_ids = sos_id * \ numpy.ones(shape=(1, num_sequences)).astype('int64') input_class_ids = sos_class_id * \ numpy.ones(shape=(1, num_sequences)).astype('int64') result = sos_id * \ numpy.ones(shape=(len(seed_tokens)+length, num_sequences)).astype('int64') state = RecurrentState(self._network.recurrent_state_size, num_sequences) #First, possibly compute forward passes with the seed sequence for time_step, token in enumerate(seed_tokens, start=1): step_result = self.step_function(input_word_ids, input_class_ids, *state.get()) token_id = self._vocabulary.word_to_id[token] token_class_id = \ self._vocabulary.word_id_to_class_id[token_id] input_word_ids = token_id * \ numpy.ones(shape=(1, num_sequences)).astype('int64') input_class_ids = token_class_id * \ numpy.ones(shape=(1, num_sequences)).astype('int64') step_word_ids = input_word_ids result[time_step] = step_word_ids state.set(step_result[1:]) #Then sample: for time_step in range( len(seed_tokens) + 1, length + len(seed_tokens)): # the input is the output from the previous step. step_result = self.step_function(input_word_ids, input_class_ids, *state.get()) class_ids = step_result[0] # The class IDs from the single time step. step_class_ids = class_ids[0] step_word_ids = numpy.array( self._vocabulary.class_ids_to_word_ids(step_class_ids)) result[time_step] = step_word_ids input_word_ids = step_word_ids[numpy.newaxis] input_class_ids = class_ids state.set(step_result[1:]) return self._vocabulary.id_to_word[result.transpose()].tolist()
def decode(self, lattice): """Propagates tokens through given lattice and returns a list of tokens in the final nodes. Propagates tokens at a node to every outgoing link by creating a copy of each token and updating the language model scores according to the link. The function returns two lists. The first list contains the final tokens, sorted in the descending order of total log probability. I.e. the first token in the list represents the best path through the lattice. The second list contains the tokens that were dropped during recombination. This is needed for constructing a new rescored lattice. :type lattice: Lattice :param lattice: a word lattice to be decoded :rtype: a tuple of two lists of LatticeDecoder.Tokens :returns: a list of the final tokens sorted by probability (most likely token first), and a list of the tokens that were dropped during recombination """ if self._lm_scale is not None: lm_scale = logprob_type(self._lm_scale) elif lattice.lm_scale is not None: lm_scale = logprob_type(lattice.lm_scale) else: lm_scale = logprob_type(1.0) if self._wi_penalty is not None: wi_penalty = logprob_type(self._wi_penalty) elif lattice.wi_penalty is not None: wi_penalty = logprob_type(lattice.wi_penalty) else: wi_penalty = logprob_type(0.0) tokens = [list() for _ in lattice.nodes] recomb_tokens = [] initial_state = RecurrentState(self._network.recurrent_state_size) initial_token = self.Token(history=(self._sos_id, ), state=initial_state) initial_token.recompute_hash(self._recombination_order) initial_token.recompute_total(self._nnlm_weight, lm_scale, wi_penalty, self._linear_interpolation) tokens[lattice.initial_node.id].append(initial_token) lattice.initial_node.best_logprob = initial_token.total_logprob sorted_nodes = lattice.sorted_nodes() self._nodes_processed = 0 final_tokens = [] for node in sorted_nodes: stats = self._prune(node, sorted_nodes, tokens, recomb_tokens) num_new_tokens = 0 node_tokens = tokens[node.id] assert node_tokens if node.final: new_tokens = self._propagate(node_tokens, None, lm_scale, wi_penalty) final_tokens.extend(new_tokens) num_new_tokens += len(new_tokens) for link in node.out_links: new_tokens = self._propagate(node_tokens, link, lm_scale, wi_penalty) tokens[link.end_node.id].extend(new_tokens) # If there are lots of tokens in the end node, prune already to # conserve memory. if self._max_tokens_per_node is not None and \ len(tokens[link.end_node.id]) > self._max_tokens_per_node * 2: self._prune(link.end_node, sorted_nodes, tokens, recomb_tokens) num_new_tokens += len(new_tokens) stats['new'] = num_new_tokens self._nodes_processed += 1 self._log_stats(stats, node.id, len(sorted_nodes)) if len(final_tokens) == 0: raise InputError("Could not reach a final node of word lattice.") final_tokens = self._sorted_recombined_tokens(final_tokens, recomb_tokens) return final_tokens, recomb_tokens
def test_append_word(self): decoding_options = { 'nnlm_weight': 1.0, 'lm_scale': 1.0, 'wi_penalty': 0.0, 'ignore_unk': False, 'unk_penalty': 0.0, 'linear_interpolation': False, 'max_tokens_per_node': 10, 'beam': None, 'recombination_order': None } initial_state = RecurrentState(self.network.recurrent_state_size) token1 = LatticeDecoder.Token(history=[self.sos_id], state=initial_state) token2 = LatticeDecoder.Token(history=[self.sos_id, self.yksi_id], state=initial_state) decoder = LatticeDecoder(self.network, decoding_options) self.assertSequenceEqual(token1.history, [self.sos_id]) self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id]) assert_equal(token1.state.get(0), numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX)) assert_equal(token2.state.get(0), numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX)) self.assertEqual(token1.nn_lm_logprob, 0.0) self.assertEqual(token2.nn_lm_logprob, 0.0) decoder._append_word([token1, token2], self.kaksi_id) self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id]) self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id, self.kaksi_id]) assert_equal(token1.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX)) assert_equal(token2.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX)) token1_nn_lm_logprob = math.log(self.sos_prob + self.kaksi_prob) token2_nn_lm_logprob = math.log(self.yksi_prob + self.kaksi_prob) self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob) self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob) decoder._append_word([token1, token2], self.eos_id) self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id, self.eos_id]) self.assertSequenceEqual( token2.history, [self.sos_id, self.yksi_id, self.kaksi_id, self.eos_id]) assert_equal( token1.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2) assert_equal( token2.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2) token1_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob) token2_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob) self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob) self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob) lm_scale = 2.0 token1.recompute_total(1.0, lm_scale, -0.01) token2.recompute_total(1.0, lm_scale, -0.01) self.assertAlmostEqual(token1.total_logprob, token1_nn_lm_logprob * lm_scale - 0.03) self.assertAlmostEqual(token2.total_logprob, token2_nn_lm_logprob * lm_scale - 0.04)
def test_combine_sequences(self): state1 = RecurrentState([5, 10], 1) layer1_state = numpy.arange(5, dtype='int64').reshape(1, 1, 5) layer2_state = numpy.arange(10, 20, dtype='int64').reshape(1, 1, 10) state1.set([layer1_state, layer2_state]) state2 = RecurrentState([5, 10], 1) layer1_state = numpy.arange(100, 105, dtype='int64').reshape(1, 1, 5) layer2_state = numpy.arange(110, 120, dtype='int64').reshape(1, 1, 10) state2.set([layer1_state, layer2_state]) state3 = RecurrentState([5, 10], 2) layer1_state = numpy.arange(200, 210, dtype='int64').reshape(1, 2, 5) layer2_state = numpy.arange(210, 230, dtype='int64').reshape(1, 2, 10) state3.set([layer1_state, layer2_state]) combined_state = RecurrentState.combine_sequences( [state1, state2, state3]) self.assertEqual(combined_state.num_sequences, 4) self.assertEqual(len(combined_state.get()), 2) self.assertEqual(combined_state.get(0).shape, (1, 4, 5)) self.assertEqual(combined_state.get(1).shape, (1, 4, 10)) assert_equal( combined_state.get(0), numpy.asarray([[ list(range(5)), list(range(100, 105)), list(range(200, 205)), list(range(205, 210)) ]], dtype='int64')) assert_equal( combined_state.get(1), numpy.asarray([[ list(range(10, 20)), list(range(110, 120)), list(range(210, 220)), list(range(220, 230)) ]], dtype='int64')) state4 = RecurrentState([5, 11], 2) with self.assertRaises(ValueError): combined_state = RecurrentState.combine_sequences( [state1, state2, state3, state4])
def _append_word(self, tokens, target_word, oov_logprob=None): """Appends a word to each of the given tokens, and updates their scores. :type tokens: list of LatticeDecoder.Tokens :param tokens: input tokens :type target_word: int or str :param target_word: word ID or word to be appended to the existing history of each input token; if not an integer, the word will be considered ``<unk>`` and this variable will be taken literally as the word that will be used in the resulting transcript :type oov_logprob: float :param oov_logprob: log probability to be assigned to OOV words """ def limit_to_shortlist(self, word): """Returns the ``<unk>`` word ID if the argument is not a shortlist word ID. """ if isinstance(word, int) and self._vocabulary.in_shortlist(word): return word else: return self._unk_id # Pass all/limited previous input if the net has attention if self._network.has_attention: # handle sequeces separately as the history could be varying length for index, token in enumerate(tokens): input_word_ids = [[ limit_to_shortlist(self, t) for t in token.history ]] input_word_ids = numpy.asarray(input_word_ids).astype( 'int64').transpose() #logging.debug("input shape: %s", input_word_ids.shape) input_class_ids, membership_prob = \ self._vocabulary.get_class_memberships(input_word_ids) recurrent_state = [token.state] recurrent_state = RecurrentState.combine_sequences( recurrent_state) target_word_id = limit_to_shortlist(self, target_word) target_class_ids = numpy.ones(shape=(1, 1)).astype('int64') target_class_ids *= self._vocabulary.word_id_to_class_id[ target_word_id] sub_step_result = self._step_function(input_word_ids, input_class_ids, target_class_ids, *recurrent_state.get()) #update token with the values logprobs = sub_step_result[0] #logging.debug("Logprobs shape: %s, %s", logprobs.shape, self._network.recurrent_state_size) # Add logprobs from the class membership of the predicted words. logprobs += numpy.log(membership_prob[-1, :]) output_state = sub_step_result[1:] token.history = token.history + (target_word, ) token.state = RecurrentState( self._network.recurrent_state_size) # Slice the sequence that corresponds to this token. token.state.set([layer_state for layer_state in output_state]) # logprobs matrix contains only one time step. token.nn_lm_logprob += self._handle_unk_logprob( target_word, logprobs[0, 0], oov_logprob) #logging.debug("Next time step") #for i, data in enumerate(input_word_ids): # logging.debug("Length of data: %d, %s", len(data), data) # if len(data)<max_len: #pad with zeros # input_word_ids[i] = [0] * (max_len-len(data)) + data # logging.debug("New length of data: %d, %s", len(input_word_ids[i]), input_word_ids[i]) # No attention in the network else: input_word_ids = [[ limit_to_shortlist(self, token.history[-1]) for token in tokens ]] input_word_ids = numpy.asarray(input_word_ids).astype('int64') input_class_ids, membership_probs = \ self._vocabulary.get_class_memberships(input_word_ids) recurrent_state = [token.state for token in tokens] recurrent_state = RecurrentState.combine_sequences(recurrent_state) target_word_id = limit_to_shortlist(self, target_word) target_class_ids = numpy.ones(shape=(1, len(tokens))).astype('int64') target_class_ids *= self._vocabulary.word_id_to_class_id[ target_word_id] step_result = self._step_function(input_word_ids, input_class_ids, target_class_ids, *recurrent_state.get()) logprobs = step_result[0] # Add logprobs from the class membership of the predicted words. logprobs += numpy.log(membership_probs) output_state = step_result[1:] for index, token in enumerate(tokens): token.history = token.history + (target_word, ) token.state = RecurrentState( self._network.recurrent_state_size) # Slice the sequence that corresponds to this token. token.state.set([ layer_state[:, index:index + 1] for layer_state in output_state ]) # logprobs matrix contains only one time step. token.nn_lm_logprob += self._handle_unk_logprob( target_word, logprobs[0, index], oov_logprob)