def gather_fn(x):
     return partial_transpose(
         gather_flat(
             partial_transpose(x, [1, 0]),
             batch_beam_ids,
             self.batch_size,
             self.beam_size),
         [1, 0])
            def gather_fn(x):
                if len(x.shape.dims) < 2:
                    return x

                return partial_transpose(
                    gather_flat(
                        partial_transpose(x, [1, 0]),
                        batch_beam_ids,
                        self.batch_size,
                        self.beam_size),
                    [1, 0])
示例#3
0
        def body(*args: Any) -> BeamSearchLoopState:
            """Execute a single beam search step.

            An implementation of the beam search algorithm, which works as
            follows:

            1. Create a valid ``logprobs`` tensor which contains distributions
               over the output tokens for each hypothesis in the beam. For
               finished hypotheses, the log probabilities of all tokens except
               the padding token are set to negative infinity.

            2. Expand the beam by appending every possible token to every
               existing hypothesis. Update the log probabilitiy sum of each
               hypothesis and its length (add one for unfinished hypotheses).
               For each hypothesis, compute the score using the length penalty
               term.

            3. Select the ``beam_size`` best hypotheses from the score pool.
               This is implemented by flattening the scores tensor and using
               the ``tf.nn.top_k`` function.

            4. Reconstruct the beam by gathering elements from the original
               data structures using the data indices computed in the previous
               step.

            5. Call the ``body`` function of the underlying decoder.

            6. Populate a new ``BeamSearchLoopState`` object with the selected
               values and with the newly obtained decoder loop state.

            Note that this function expects the decoder to be called at least
            once prior the first execution.

            Arguments:
                args: An instance of the ``BeamSearchLoopState`` structure.
                    (see the docs for this module)

            Returns:
                A ``BeamSearchLoopState`` after one step of the decoding.

            """
            loop_state = BeamSearchLoopState(*args)
            dec_loop_state = loop_state.decoder_loop_state
            search_state = loop_state.search_state
            search_results = loop_state.search_results

            # mask the probabilities
            # >> shape(logprobs) = [batch, beam, vocabulary]
            logprobs = search_state.prev_logprobs

            # >> shape(finished_mask) = [batch, beam, 1]
            # float, 0 alebo 1, 1 pre dokoncene hypotezy, inak 0
            finished_mask = tf.expand_dims(tf.to_float(search_state.finished),
                                           2)
            # >> shape(unfinished_logprobs) = [batch, beam, vocabulary]
            # dokoncene hypotezy maju logprob 0 pre vsetky tokeny zo slovniku
            unfinished_logprobs = (1. - finished_mask) * logprobs

            # >> shape(finished_row) = [vocabulary]
            # vsade -INF, okrem indexu <PAD>, tam 0
            finished_row = tf.one_hot(PAD_TOKEN_INDEX,
                                      len(self.vocabulary),
                                      dtype=tf.float32,
                                      on_value=0.,
                                      off_value=-INF)

            # >> shape(finished_logprobs) = [batch, beam, vocabulary]
            # nedokoncene hypotezy maju 0 pre cely vocabulary
            # dokoncene hypotezy maju u vsetkych tokenov zo slovniku -INF logprob,
            # okrem <PAD>, tam 0. docielene ze sa bude paddovat, ak nic lepsie.
            finished_logprobs = finished_mask * finished_row
            # >> shape(logprobs) = [batch, beam, vocabulary]
            # pravdepodobnosti tokenov pre hypotezy. dokoncene hypotezy
            # maju vsade -INF okrem <PAD>, tam 0
            logprobs = unfinished_logprobs + finished_logprobs

            # update hypothesis scores
            # >> shape(hyp_probs) = [batch, beam, vocabulary]
            #
            # hyp_probs obsahuju skore celej hypotezy pri pridani daneho tokenu
            hyp_probs = tf.expand_dims(search_state.logprob_sum, 2) + logprobs

            # update hypothesis lengths
            #
            # >> shape(hyp_lengths) = [batch, beam]
            # zvys dlzku nedokoncenych hypotez o 1
            hyp_lengths = search_state.lengths + 1 - tf.to_int32(
                search_state.finished)

            # >> shape(scores) = [batch, beam, vocabulary]
            #
            # aplikacia length penalty na skore hypotez
            # scores teraz drzi nove skore hypotez
            scores = hyp_probs / tf.expand_dims(
                self._length_penalty(hyp_lengths), 2)

            # reshape to [batch, beam * vocabulary] for topk
            # >> shape(scores_flat) = [batch, beam * vocabulary]
            scores_flat = tf.reshape(
                scores, [-1, self.beam_size * len(self.vocabulary)])

            # >> shape(both) = [batch, beam]
            #
            # topk_scores obsahuje vrchnych `beam_size` skor hypotez
            # topk_indices obsahuje ich indexy v scores_flat
            topk_scores, topk_indices = tf.nn.top_k(scores_flat,
                                                    k=self.beam_size)

            topk_indices.set_shape([None, self.beam_size])
            topk_scores.set_shape([None, self.beam_size])

            # >> shape(next_word_ids) = [batch, beam]
            # next_word_ids obsajuje indexy do slovniku pre nove tokeny
            #
            # >> shape(next_beam_ids) = [batch, beam]
            # next_beam_ids obsahuje indexy beamov z ktorych vzniknu nove hypotezy
            # odtialto parent_ids
            next_word_ids = tf.mod(topk_indices, len(self.vocabulary))
            next_beam_ids = tf.div(topk_indices, len(self.vocabulary))

            parent_ids = next_beam_ids

            # batch offset for tf.gather_nd
            batch_offset = tf.tile(
                tf.expand_dims(tf.range(self.batch_size), 1),
                [1, self.beam_size])
            batch_beam_ids = tf.stack([batch_offset, next_beam_ids], axis=2)

            # gather the topk logprob_sums
            #
            # >> shape(next_beam_lengths) = [batch, beam]
            # v next_beam_lengths dlzky novych beam_size-hypotez
            next_beam_lengths = tf.gather_nd(hyp_lengths, batch_beam_ids)

            # >> shape(next_beam_logprob_sum) = [batch, beam]
            # v next_beam_logprob_sum skore novych hypotez
            next_beam_logprob_sum = tf.gather_nd(
                tf.reshape(hyp_probs,
                           [-1, self.beam_size * len(self.vocabulary)]),
                tf.stack([batch_offset, topk_indices], axis=2))

            # mark finished beams
            next_finished = tf.gather_nd(search_state.finished, batch_beam_ids)
            next_just_finished = tf.equal(next_word_ids, END_TOKEN_INDEX)
            next_finished = tf.logical_or(next_finished, next_just_finished)

            # we need to flatten the feedables for the parent_decoder
            next_feedables = tf.contrib.framework.nest.map_structure(
                lambda x: gather_flat(x, batch_beam_ids, self.batch_size, self.
                                      beam_size), dec_loop_state.feedables)

            next_feedables = next_feedables._replace(
                input_symbol=tf.reshape(next_word_ids, [-1]),
                finished=tf.reshape(next_finished, [-1]))

            # histories have shape [len, batch, ...]
            def gather_fn(x):
                return partial_transpose(
                    gather_flat(partial_transpose(x, [1, 0]), batch_beam_ids,
                                self.batch_size, self.beam_size), [1, 0])

            next_histories = tf.contrib.framework.nest.map_structure(
                gather_fn, dec_loop_state.histories)

            dec_loop_state = dec_loop_state._replace(feedables=next_feedables,
                                                     histories=next_histories)

            # CALL THE DECODER BODY FUNCTION
            next_loop_state = decoder_body(*dec_loop_state)

            next_search_state = SearchState(
                logprob_sum=next_beam_logprob_sum,
                prev_logprobs=tf.reshape(
                    tf.nn.log_softmax(next_loop_state.feedables.prev_logits),
                    [self.batch_size, self.beam_size,
                     len(self.vocabulary)]),
                lengths=next_beam_lengths,
                finished=next_finished)

            # next_token_ids = tf.transpose(search_results.token_ids, [1, 2, 0])
            # next_token_ids = tf.gather_nd(next_token_ids, batch_beam_ids)
            # next_token_ids = tf.transpose(next_token_ids, [2, 0, 1])
            # zakomentovane, lebo chcem povodne tokeny
            next_output = SearchResults(
                scores=append_tensor(search_results.scores, topk_scores),
                #token_ids=append_tensor(next_token_ids, next_word_ids),
                token_ids=append_tensor(search_results.token_ids,
                                        next_word_ids),
                parent_ids=append_tensor(search_results.parent_ids,
                                         parent_ids))

            return BeamSearchLoopState(search_state=next_search_state,
                                       search_results=next_output,
                                       decoder_loop_state=next_loop_state)
        def body(*args: Any) -> BeamSearchLoopState:
            """Execute a single beam search step.

            An implementation of the beam search algorithm, which works as
            follows:

            1. Create a valid ``logprobs`` tensor which contains distributions
               over the output tokens for each hypothesis in the beam. For
               finished hypotheses, the log probabilities of all tokens except
               the padding token are set to negative infinity.

            2. Expand the beam by appending every possible token to every
               existing hypothesis. Update the log probabilitiy sum of each
               hypothesis and its length (add one for unfinished hypotheses).
               For each hypothesis, compute the score using the length penalty
               term.

            3. Select the ``beam_size`` best hypotheses from the score pool.
               This is implemented by flattening the scores tensor and using
               the ``tf.nn.top_k`` function.

            4. Reconstruct the beam by gathering elements from the original
               data structures using the data indices computed in the previous
               step.

            5. Call the ``body`` function of the underlying decoder.

            6. Populate a new ``BeamSearchLoopState`` object with the selected
               values and with the newly obtained decoder loop state.

            Note that this function expects the decoder to be called at least
            once prior the first execution.

            Arguments:
                args: An instance of the ``BeamSearchLoopState`` structure.
                    (see the docs for this module)

            Returns:
                A ``BeamSearchLoopState`` after one step of the decoding.

            """
            loop_state = BeamSearchLoopState(*args)
            dec_loop_state = loop_state.decoder_loop_state
            search_state = loop_state.search_state
            search_results = loop_state.search_results

            # mask the probabilities
            # shape(logprobs) = [batch, beam, vocabulary]
            logprobs = search_state.prev_logprobs

            finished_mask = tf.expand_dims(
                tf.to_float(search_state.finished), 2)
            unfinished_logprobs = (1. - finished_mask) * logprobs

            finished_row = tf.one_hot(
                PAD_TOKEN_INDEX,
                len(self.vocabulary),
                dtype=tf.float32,
                on_value=0.,
                off_value=-INF)

            finished_logprobs = finished_mask * finished_row
            logprobs = unfinished_logprobs + finished_logprobs

            # update hypothesis scores
            # shape(hyp_probs) = [batch, beam, vocabulary]
            hyp_probs = tf.expand_dims(search_state.logprob_sum, 2) + logprobs

            # update hypothesis lengths
            hyp_lengths = search_state.lengths + 1 - tf.to_int32(
                search_state.finished)

            # shape(scores) = [batch, beam, vocabulary]
            scores = hyp_probs / tf.expand_dims(
                self._length_penalty(hyp_lengths), 2)

            # reshape to [batch, beam * vocabulary] for topk
            scores_flat = tf.reshape(
                scores, [-1, self.beam_size * len(self.vocabulary)])

            # shape(both) = [batch, beam]
            topk_scores, topk_indices = tf.nn.top_k(
                scores_flat, k=self.beam_size)

            topk_indices.set_shape([None, self.beam_size])
            topk_scores.set_shape([None, self.beam_size])

            next_word_ids = tf.mod(topk_indices, len(self.vocabulary))
            next_beam_ids = tf.div(topk_indices, len(self.vocabulary))

            # batch offset for tf.gather_nd
            batch_offset = tf.tile(
                tf.expand_dims(tf.range(self.batch_size), 1),
                [1, self.beam_size])
            batch_beam_ids = tf.stack([batch_offset, next_beam_ids], axis=2)

            # gather the topk logprob_sums
            next_beam_lengths = tf.gather_nd(hyp_lengths, batch_beam_ids)
            next_beam_logprob_sum = tf.gather_nd(
                tf.reshape(
                    hyp_probs, [-1, self.beam_size * len(self.vocabulary)]),
                tf.stack([batch_offset, topk_indices], axis=2))

            # mark finished beams
            next_finished = tf.gather_nd(search_state.finished, batch_beam_ids)
            next_just_finished = tf.equal(next_word_ids, END_TOKEN_INDEX)
            next_finished = tf.logical_or(next_finished, next_just_finished)

            # we need to flatten the feedables for the parent_decoder
            next_feedables = tf.contrib.framework.nest.map_structure(
                lambda x: gather_flat(x, batch_beam_ids,
                                      self.batch_size, self.beam_size),
                dec_loop_state.feedables)

            next_feedables = next_feedables._replace(
                input_symbol=tf.reshape(next_word_ids, [-1]),
                finished=tf.reshape(next_finished, [-1]))

            # histories have shape [len, batch, ...]
            def gather_fn(x):
                return partial_transpose(
                    gather_flat(
                        partial_transpose(x, [1, 0]),
                        batch_beam_ids,
                        self.batch_size,
                        self.beam_size),
                    [1, 0])

            next_histories = tf.contrib.framework.nest.map_structure(
                gather_fn, dec_loop_state.histories)

            dec_loop_state = dec_loop_state._replace(
                feedables=next_feedables,
                histories=next_histories)

            # CALL THE DECODER BODY FUNCTION
            next_loop_state = decoder_body(*dec_loop_state)

            next_search_state = SearchState(
                logprob_sum=next_beam_logprob_sum,
                prev_logprobs=tf.reshape(
                    tf.nn.log_softmax(next_loop_state.feedables.prev_logits),
                    [self.batch_size, self.beam_size, len(self.vocabulary)]),
                lengths=next_beam_lengths,
                finished=next_finished)

            next_token_ids = tf.transpose(search_results.token_ids, [1, 2, 0])
            next_token_ids = tf.gather_nd(next_token_ids, batch_beam_ids)
            next_token_ids = tf.transpose(next_token_ids, [2, 0, 1])
            next_output = SearchResults(
                scores=topk_scores,
                token_ids=append_tensor(next_token_ids, next_word_ids))

            return BeamSearchLoopState(
                search_state=next_search_state,
                search_results=next_output,
                decoder_loop_state=next_loop_state)
示例#5
0
        def body(*args: Any) -> BeamSearchLoopState:
            """Execute a single beam search step.

            An implementation of the beam search algorithm, which works as
            follows:

            1. Create a valid ``logprobs`` tensor which contains distributions
               over the output tokens for each hypothesis in the beam. For
               finished hypotheses, the log probabilities of all tokens except
               the padding token are set to negative infinity.

            2. Expand the beam by appending every possible token to every
               existing hypothesis. Update the log probabilitiy sum of each
               hypothesis and its length (add one for unfinished hypotheses).
               For each hypothesis, compute the score using the length penalty
               term.

            3. Select the ``beam_size`` best hypotheses from the score pool.
               This is implemented by flattening the scores tensor and using
               the ``tf.nn.top_k`` function.

            4. Reconstruct the beam by gathering elements from the original
               data structures using the data indices computed in the previous
               step.

            5. Call the ``body`` function of the underlying decoder.

            6. Populate a new ``BeamSearchLoopState`` object with the selected
               values and with the newly obtained decoder loop state.

            Note that this function expects the decoder to be called at least
            once prior the first execution.

            Arguments:
                args: An instance of the ``BeamSearchLoopState`` structure.
                    (see the docs for this module)

            Returns:
                A ``BeamSearchLoopState`` after one step of the decoding.

            """
            loop_state = BeamSearchLoopState(*args)
            dec_loop_state = loop_state.decoder_loop_state
            search_state = loop_state.search_state
            search_results = loop_state.search_results

            # mask the probabilities
            # shape(logprobs) = [batch, beam, vocabulary]
            logprobs = search_state.prev_logprobs

            finished_mask = tf.expand_dims(
                tf.to_float(search_state.finished), 2)
            unfinished_logprobs = (1. - finished_mask) * logprobs

            finished_row = tf.one_hot(
                PAD_TOKEN_INDEX,
                len(self.vocabulary),
                dtype=tf.float32,
                on_value=0.,
                off_value=-INF)

            finished_logprobs = finished_mask * finished_row
            logprobs = unfinished_logprobs + finished_logprobs

            # update hypothesis scores
            # shape(hyp_probs) = [batch, beam, vocabulary]
            hyp_probs = tf.expand_dims(search_state.logprob_sum, 2) + logprobs

            # update hypothesis lengths
            hyp_lengths = search_state.lengths + 1 - tf.to_int32(
                search_state.finished)

            # shape(scores) = [batch, beam, vocabulary]
            scores = hyp_probs / tf.expand_dims(
                self._length_penalty(hyp_lengths), 2)

            # reshape to [batch, beam * vocabulary] for topk
            scores_flat = tf.reshape(
                scores, [-1, self.beam_size * len(self.vocabulary)])

            # shape(both) = [batch, beam]
            topk_scores, topk_indices = tf.nn.top_k(
                scores_flat, k=self.beam_size)

            topk_indices.set_shape([None, self.beam_size])
            topk_scores.set_shape([None, self.beam_size])

            next_word_ids = tf.to_int64(
                tf.mod(topk_indices, len(self.vocabulary)))
            next_beam_ids = tf.div(topk_indices, len(self.vocabulary))

            # batch offset for tf.gather_nd
            batch_offset = tf.tile(
                tf.expand_dims(tf.range(self.batch_size), 1),
                [1, self.beam_size])
            batch_beam_ids = tf.stack([batch_offset, next_beam_ids], axis=2)

            # gather the topk logprob_sums
            next_beam_lengths = tf.gather_nd(hyp_lengths, batch_beam_ids)
            next_beam_logprob_sum = tf.gather_nd(
                tf.reshape(
                    hyp_probs, [-1, self.beam_size * len(self.vocabulary)]),
                tf.stack([batch_offset, topk_indices], axis=2))

            # mark finished beams
            next_finished = tf.gather_nd(search_state.finished, batch_beam_ids)
            next_just_finished = tf.equal(next_word_ids, END_TOKEN_INDEX)
            next_finished = tf.logical_or(next_finished, next_just_finished)

            # we need to flatten the feedables for the parent_decoder
            next_feedables = tf.contrib.framework.nest.map_structure(
                lambda x: gather_flat(x, batch_beam_ids,
                                      self.batch_size, self.beam_size),
                dec_loop_state.feedables)

            next_feedables = next_feedables._replace(
                embedded_input=self.parent_decoder.embed_input_symbols(
                    tf.reshape(next_word_ids, [-1])),
                finished=tf.reshape(next_finished, [-1]))

            # histories have shape [len, batch, ...]
            def gather_fn(x):
                if len(x.shape.dims) < 2:
                    return x

                return partial_transpose(
                    gather_flat(
                        partial_transpose(x, [1, 0]),
                        batch_beam_ids,
                        self.batch_size,
                        self.beam_size),
                    [1, 0])

            next_histories = tf.contrib.framework.nest.map_structure(
                gather_fn, dec_loop_state.histories)

            dec_loop_state = dec_loop_state._replace(
                feedables=next_feedables,
                histories=next_histories)

            # CALL THE DECODER BODY FUNCTION
            next_loop_state = decoder_body(*dec_loop_state)

            logits = next_loop_state.histories.logits[-1, :, :]
            next_search_state = SearchState(
                logprob_sum=next_beam_logprob_sum,
                prev_logprobs=tf.reshape(
                    tf.nn.log_softmax(logits),
                    [self.batch_size, self.beam_size, len(self.vocabulary)]),
                lengths=next_beam_lengths,
                finished=next_finished)

            next_token_ids = tf.transpose(search_results.token_ids, [1, 2, 0])
            next_token_ids = tf.gather_nd(next_token_ids, batch_beam_ids)
            next_token_ids = tf.transpose(next_token_ids, [2, 0, 1])
            next_output = SearchResults(
                scores=topk_scores,
                token_ids=append_tensor(next_token_ids, next_word_ids))

            return BeamSearchLoopState(
                search_state=next_search_state,
                search_results=next_output,
                decoder_loop_state=next_loop_state)