def gather_fn(x): return partial_transpose( gather_flat( partial_transpose(x, [1, 0]), batch_beam_ids, self.batch_size, self.beam_size), [1, 0])
def gather_fn(x): if len(x.shape.dims) < 2: return x return partial_transpose( gather_flat( partial_transpose(x, [1, 0]), batch_beam_ids, self.batch_size, self.beam_size), [1, 0])
def body(*args: Any) -> BeamSearchLoopState: """Execute a single beam search step. An implementation of the beam search algorithm, which works as follows: 1. Create a valid ``logprobs`` tensor which contains distributions over the output tokens for each hypothesis in the beam. For finished hypotheses, the log probabilities of all tokens except the padding token are set to negative infinity. 2. Expand the beam by appending every possible token to every existing hypothesis. Update the log probabilitiy sum of each hypothesis and its length (add one for unfinished hypotheses). For each hypothesis, compute the score using the length penalty term. 3. Select the ``beam_size`` best hypotheses from the score pool. This is implemented by flattening the scores tensor and using the ``tf.nn.top_k`` function. 4. Reconstruct the beam by gathering elements from the original data structures using the data indices computed in the previous step. 5. Call the ``body`` function of the underlying decoder. 6. Populate a new ``BeamSearchLoopState`` object with the selected values and with the newly obtained decoder loop state. Note that this function expects the decoder to be called at least once prior the first execution. Arguments: args: An instance of the ``BeamSearchLoopState`` structure. (see the docs for this module) Returns: A ``BeamSearchLoopState`` after one step of the decoding. """ loop_state = BeamSearchLoopState(*args) dec_loop_state = loop_state.decoder_loop_state search_state = loop_state.search_state search_results = loop_state.search_results # mask the probabilities # >> shape(logprobs) = [batch, beam, vocabulary] logprobs = search_state.prev_logprobs # >> shape(finished_mask) = [batch, beam, 1] # float, 0 alebo 1, 1 pre dokoncene hypotezy, inak 0 finished_mask = tf.expand_dims(tf.to_float(search_state.finished), 2) # >> shape(unfinished_logprobs) = [batch, beam, vocabulary] # dokoncene hypotezy maju logprob 0 pre vsetky tokeny zo slovniku unfinished_logprobs = (1. - finished_mask) * logprobs # >> shape(finished_row) = [vocabulary] # vsade -INF, okrem indexu <PAD>, tam 0 finished_row = tf.one_hot(PAD_TOKEN_INDEX, len(self.vocabulary), dtype=tf.float32, on_value=0., off_value=-INF) # >> shape(finished_logprobs) = [batch, beam, vocabulary] # nedokoncene hypotezy maju 0 pre cely vocabulary # dokoncene hypotezy maju u vsetkych tokenov zo slovniku -INF logprob, # okrem <PAD>, tam 0. docielene ze sa bude paddovat, ak nic lepsie. finished_logprobs = finished_mask * finished_row # >> shape(logprobs) = [batch, beam, vocabulary] # pravdepodobnosti tokenov pre hypotezy. dokoncene hypotezy # maju vsade -INF okrem <PAD>, tam 0 logprobs = unfinished_logprobs + finished_logprobs # update hypothesis scores # >> shape(hyp_probs) = [batch, beam, vocabulary] # # hyp_probs obsahuju skore celej hypotezy pri pridani daneho tokenu hyp_probs = tf.expand_dims(search_state.logprob_sum, 2) + logprobs # update hypothesis lengths # # >> shape(hyp_lengths) = [batch, beam] # zvys dlzku nedokoncenych hypotez o 1 hyp_lengths = search_state.lengths + 1 - tf.to_int32( search_state.finished) # >> shape(scores) = [batch, beam, vocabulary] # # aplikacia length penalty na skore hypotez # scores teraz drzi nove skore hypotez scores = hyp_probs / tf.expand_dims( self._length_penalty(hyp_lengths), 2) # reshape to [batch, beam * vocabulary] for topk # >> shape(scores_flat) = [batch, beam * vocabulary] scores_flat = tf.reshape( scores, [-1, self.beam_size * len(self.vocabulary)]) # >> shape(both) = [batch, beam] # # topk_scores obsahuje vrchnych `beam_size` skor hypotez # topk_indices obsahuje ich indexy v scores_flat topk_scores, topk_indices = tf.nn.top_k(scores_flat, k=self.beam_size) topk_indices.set_shape([None, self.beam_size]) topk_scores.set_shape([None, self.beam_size]) # >> shape(next_word_ids) = [batch, beam] # next_word_ids obsajuje indexy do slovniku pre nove tokeny # # >> shape(next_beam_ids) = [batch, beam] # next_beam_ids obsahuje indexy beamov z ktorych vzniknu nove hypotezy # odtialto parent_ids next_word_ids = tf.mod(topk_indices, len(self.vocabulary)) next_beam_ids = tf.div(topk_indices, len(self.vocabulary)) parent_ids = next_beam_ids # batch offset for tf.gather_nd batch_offset = tf.tile( tf.expand_dims(tf.range(self.batch_size), 1), [1, self.beam_size]) batch_beam_ids = tf.stack([batch_offset, next_beam_ids], axis=2) # gather the topk logprob_sums # # >> shape(next_beam_lengths) = [batch, beam] # v next_beam_lengths dlzky novych beam_size-hypotez next_beam_lengths = tf.gather_nd(hyp_lengths, batch_beam_ids) # >> shape(next_beam_logprob_sum) = [batch, beam] # v next_beam_logprob_sum skore novych hypotez next_beam_logprob_sum = tf.gather_nd( tf.reshape(hyp_probs, [-1, self.beam_size * len(self.vocabulary)]), tf.stack([batch_offset, topk_indices], axis=2)) # mark finished beams next_finished = tf.gather_nd(search_state.finished, batch_beam_ids) next_just_finished = tf.equal(next_word_ids, END_TOKEN_INDEX) next_finished = tf.logical_or(next_finished, next_just_finished) # we need to flatten the feedables for the parent_decoder next_feedables = tf.contrib.framework.nest.map_structure( lambda x: gather_flat(x, batch_beam_ids, self.batch_size, self. beam_size), dec_loop_state.feedables) next_feedables = next_feedables._replace( input_symbol=tf.reshape(next_word_ids, [-1]), finished=tf.reshape(next_finished, [-1])) # histories have shape [len, batch, ...] def gather_fn(x): return partial_transpose( gather_flat(partial_transpose(x, [1, 0]), batch_beam_ids, self.batch_size, self.beam_size), [1, 0]) next_histories = tf.contrib.framework.nest.map_structure( gather_fn, dec_loop_state.histories) dec_loop_state = dec_loop_state._replace(feedables=next_feedables, histories=next_histories) # CALL THE DECODER BODY FUNCTION next_loop_state = decoder_body(*dec_loop_state) next_search_state = SearchState( logprob_sum=next_beam_logprob_sum, prev_logprobs=tf.reshape( tf.nn.log_softmax(next_loop_state.feedables.prev_logits), [self.batch_size, self.beam_size, len(self.vocabulary)]), lengths=next_beam_lengths, finished=next_finished) # next_token_ids = tf.transpose(search_results.token_ids, [1, 2, 0]) # next_token_ids = tf.gather_nd(next_token_ids, batch_beam_ids) # next_token_ids = tf.transpose(next_token_ids, [2, 0, 1]) # zakomentovane, lebo chcem povodne tokeny next_output = SearchResults( scores=append_tensor(search_results.scores, topk_scores), #token_ids=append_tensor(next_token_ids, next_word_ids), token_ids=append_tensor(search_results.token_ids, next_word_ids), parent_ids=append_tensor(search_results.parent_ids, parent_ids)) return BeamSearchLoopState(search_state=next_search_state, search_results=next_output, decoder_loop_state=next_loop_state)
def body(*args: Any) -> BeamSearchLoopState: """Execute a single beam search step. An implementation of the beam search algorithm, which works as follows: 1. Create a valid ``logprobs`` tensor which contains distributions over the output tokens for each hypothesis in the beam. For finished hypotheses, the log probabilities of all tokens except the padding token are set to negative infinity. 2. Expand the beam by appending every possible token to every existing hypothesis. Update the log probabilitiy sum of each hypothesis and its length (add one for unfinished hypotheses). For each hypothesis, compute the score using the length penalty term. 3. Select the ``beam_size`` best hypotheses from the score pool. This is implemented by flattening the scores tensor and using the ``tf.nn.top_k`` function. 4. Reconstruct the beam by gathering elements from the original data structures using the data indices computed in the previous step. 5. Call the ``body`` function of the underlying decoder. 6. Populate a new ``BeamSearchLoopState`` object with the selected values and with the newly obtained decoder loop state. Note that this function expects the decoder to be called at least once prior the first execution. Arguments: args: An instance of the ``BeamSearchLoopState`` structure. (see the docs for this module) Returns: A ``BeamSearchLoopState`` after one step of the decoding. """ loop_state = BeamSearchLoopState(*args) dec_loop_state = loop_state.decoder_loop_state search_state = loop_state.search_state search_results = loop_state.search_results # mask the probabilities # shape(logprobs) = [batch, beam, vocabulary] logprobs = search_state.prev_logprobs finished_mask = tf.expand_dims( tf.to_float(search_state.finished), 2) unfinished_logprobs = (1. - finished_mask) * logprobs finished_row = tf.one_hot( PAD_TOKEN_INDEX, len(self.vocabulary), dtype=tf.float32, on_value=0., off_value=-INF) finished_logprobs = finished_mask * finished_row logprobs = unfinished_logprobs + finished_logprobs # update hypothesis scores # shape(hyp_probs) = [batch, beam, vocabulary] hyp_probs = tf.expand_dims(search_state.logprob_sum, 2) + logprobs # update hypothesis lengths hyp_lengths = search_state.lengths + 1 - tf.to_int32( search_state.finished) # shape(scores) = [batch, beam, vocabulary] scores = hyp_probs / tf.expand_dims( self._length_penalty(hyp_lengths), 2) # reshape to [batch, beam * vocabulary] for topk scores_flat = tf.reshape( scores, [-1, self.beam_size * len(self.vocabulary)]) # shape(both) = [batch, beam] topk_scores, topk_indices = tf.nn.top_k( scores_flat, k=self.beam_size) topk_indices.set_shape([None, self.beam_size]) topk_scores.set_shape([None, self.beam_size]) next_word_ids = tf.mod(topk_indices, len(self.vocabulary)) next_beam_ids = tf.div(topk_indices, len(self.vocabulary)) # batch offset for tf.gather_nd batch_offset = tf.tile( tf.expand_dims(tf.range(self.batch_size), 1), [1, self.beam_size]) batch_beam_ids = tf.stack([batch_offset, next_beam_ids], axis=2) # gather the topk logprob_sums next_beam_lengths = tf.gather_nd(hyp_lengths, batch_beam_ids) next_beam_logprob_sum = tf.gather_nd( tf.reshape( hyp_probs, [-1, self.beam_size * len(self.vocabulary)]), tf.stack([batch_offset, topk_indices], axis=2)) # mark finished beams next_finished = tf.gather_nd(search_state.finished, batch_beam_ids) next_just_finished = tf.equal(next_word_ids, END_TOKEN_INDEX) next_finished = tf.logical_or(next_finished, next_just_finished) # we need to flatten the feedables for the parent_decoder next_feedables = tf.contrib.framework.nest.map_structure( lambda x: gather_flat(x, batch_beam_ids, self.batch_size, self.beam_size), dec_loop_state.feedables) next_feedables = next_feedables._replace( input_symbol=tf.reshape(next_word_ids, [-1]), finished=tf.reshape(next_finished, [-1])) # histories have shape [len, batch, ...] def gather_fn(x): return partial_transpose( gather_flat( partial_transpose(x, [1, 0]), batch_beam_ids, self.batch_size, self.beam_size), [1, 0]) next_histories = tf.contrib.framework.nest.map_structure( gather_fn, dec_loop_state.histories) dec_loop_state = dec_loop_state._replace( feedables=next_feedables, histories=next_histories) # CALL THE DECODER BODY FUNCTION next_loop_state = decoder_body(*dec_loop_state) next_search_state = SearchState( logprob_sum=next_beam_logprob_sum, prev_logprobs=tf.reshape( tf.nn.log_softmax(next_loop_state.feedables.prev_logits), [self.batch_size, self.beam_size, len(self.vocabulary)]), lengths=next_beam_lengths, finished=next_finished) next_token_ids = tf.transpose(search_results.token_ids, [1, 2, 0]) next_token_ids = tf.gather_nd(next_token_ids, batch_beam_ids) next_token_ids = tf.transpose(next_token_ids, [2, 0, 1]) next_output = SearchResults( scores=topk_scores, token_ids=append_tensor(next_token_ids, next_word_ids)) return BeamSearchLoopState( search_state=next_search_state, search_results=next_output, decoder_loop_state=next_loop_state)
def body(*args: Any) -> BeamSearchLoopState: """Execute a single beam search step. An implementation of the beam search algorithm, which works as follows: 1. Create a valid ``logprobs`` tensor which contains distributions over the output tokens for each hypothesis in the beam. For finished hypotheses, the log probabilities of all tokens except the padding token are set to negative infinity. 2. Expand the beam by appending every possible token to every existing hypothesis. Update the log probabilitiy sum of each hypothesis and its length (add one for unfinished hypotheses). For each hypothesis, compute the score using the length penalty term. 3. Select the ``beam_size`` best hypotheses from the score pool. This is implemented by flattening the scores tensor and using the ``tf.nn.top_k`` function. 4. Reconstruct the beam by gathering elements from the original data structures using the data indices computed in the previous step. 5. Call the ``body`` function of the underlying decoder. 6. Populate a new ``BeamSearchLoopState`` object with the selected values and with the newly obtained decoder loop state. Note that this function expects the decoder to be called at least once prior the first execution. Arguments: args: An instance of the ``BeamSearchLoopState`` structure. (see the docs for this module) Returns: A ``BeamSearchLoopState`` after one step of the decoding. """ loop_state = BeamSearchLoopState(*args) dec_loop_state = loop_state.decoder_loop_state search_state = loop_state.search_state search_results = loop_state.search_results # mask the probabilities # shape(logprobs) = [batch, beam, vocabulary] logprobs = search_state.prev_logprobs finished_mask = tf.expand_dims( tf.to_float(search_state.finished), 2) unfinished_logprobs = (1. - finished_mask) * logprobs finished_row = tf.one_hot( PAD_TOKEN_INDEX, len(self.vocabulary), dtype=tf.float32, on_value=0., off_value=-INF) finished_logprobs = finished_mask * finished_row logprobs = unfinished_logprobs + finished_logprobs # update hypothesis scores # shape(hyp_probs) = [batch, beam, vocabulary] hyp_probs = tf.expand_dims(search_state.logprob_sum, 2) + logprobs # update hypothesis lengths hyp_lengths = search_state.lengths + 1 - tf.to_int32( search_state.finished) # shape(scores) = [batch, beam, vocabulary] scores = hyp_probs / tf.expand_dims( self._length_penalty(hyp_lengths), 2) # reshape to [batch, beam * vocabulary] for topk scores_flat = tf.reshape( scores, [-1, self.beam_size * len(self.vocabulary)]) # shape(both) = [batch, beam] topk_scores, topk_indices = tf.nn.top_k( scores_flat, k=self.beam_size) topk_indices.set_shape([None, self.beam_size]) topk_scores.set_shape([None, self.beam_size]) next_word_ids = tf.to_int64( tf.mod(topk_indices, len(self.vocabulary))) next_beam_ids = tf.div(topk_indices, len(self.vocabulary)) # batch offset for tf.gather_nd batch_offset = tf.tile( tf.expand_dims(tf.range(self.batch_size), 1), [1, self.beam_size]) batch_beam_ids = tf.stack([batch_offset, next_beam_ids], axis=2) # gather the topk logprob_sums next_beam_lengths = tf.gather_nd(hyp_lengths, batch_beam_ids) next_beam_logprob_sum = tf.gather_nd( tf.reshape( hyp_probs, [-1, self.beam_size * len(self.vocabulary)]), tf.stack([batch_offset, topk_indices], axis=2)) # mark finished beams next_finished = tf.gather_nd(search_state.finished, batch_beam_ids) next_just_finished = tf.equal(next_word_ids, END_TOKEN_INDEX) next_finished = tf.logical_or(next_finished, next_just_finished) # we need to flatten the feedables for the parent_decoder next_feedables = tf.contrib.framework.nest.map_structure( lambda x: gather_flat(x, batch_beam_ids, self.batch_size, self.beam_size), dec_loop_state.feedables) next_feedables = next_feedables._replace( embedded_input=self.parent_decoder.embed_input_symbols( tf.reshape(next_word_ids, [-1])), finished=tf.reshape(next_finished, [-1])) # histories have shape [len, batch, ...] def gather_fn(x): if len(x.shape.dims) < 2: return x return partial_transpose( gather_flat( partial_transpose(x, [1, 0]), batch_beam_ids, self.batch_size, self.beam_size), [1, 0]) next_histories = tf.contrib.framework.nest.map_structure( gather_fn, dec_loop_state.histories) dec_loop_state = dec_loop_state._replace( feedables=next_feedables, histories=next_histories) # CALL THE DECODER BODY FUNCTION next_loop_state = decoder_body(*dec_loop_state) logits = next_loop_state.histories.logits[-1, :, :] next_search_state = SearchState( logprob_sum=next_beam_logprob_sum, prev_logprobs=tf.reshape( tf.nn.log_softmax(logits), [self.batch_size, self.beam_size, len(self.vocabulary)]), lengths=next_beam_lengths, finished=next_finished) next_token_ids = tf.transpose(search_results.token_ids, [1, 2, 0]) next_token_ids = tf.gather_nd(next_token_ids, batch_beam_ids) next_token_ids = tf.transpose(next_token_ids, [2, 0, 1]) next_output = SearchResults( scores=topk_scores, token_ids=append_tensor(next_token_ids, next_word_ids)) return BeamSearchLoopState( search_state=next_search_state, search_results=next_output, decoder_loop_state=next_loop_state)