예제 #1
0
        def body(*args) -> LoopState:

            loop_state = LoopState(*args)
            feedables = loop_state.feedables
            histories = loop_state.histories

            with tf.variable_scope(self._variable_scope, reuse=tf.AUTO_REUSE):
                output_state, dec_other, hist_other = self.next_state(
                    loop_state)

                logits = state_to_logits(output_state)
                logits /= temperature

                next_symbols = logits_to_symbols(logits, loop_state)
                finished = is_finished(feedables.finished, next_symbols)

            next_feedables = DecoderFeedables(
                step=feedables.step + 1,
                finished=finished,
                embedded_input=self.embed_input_symbols(next_symbols),
                other=dec_other)

            next_histories = DecoderHistories(
                logits=append_tensor(histories.logits, logits),
                output_states=append_tensor(histories.output_states,
                                            output_state),
                output_symbols=append_tensor(histories.output_symbols,
                                             next_symbols),
                output_mask=append_tensor(histories.output_mask,
                                          tf.logical_not(finished)),
                other=hist_other)

            return LoopState(feedables=next_feedables,
                             histories=next_histories,
                             constants=loop_state.constants)
예제 #2
0
    def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
        feedables = loop_state.feedables
        tr_feedables = feedables.other
        tr_histories = loop_state.histories.other

        with tf.variable_scope(self._variable_scope, reuse=tf.AUTO_REUSE):
            # shape (time, batch)
            input_sequence = append_tensor(
                tr_feedables.input_sequence, feedables.embedded_input, 1)

            unfinished_mask = tf.to_float(tf.logical_not(feedables.finished))
            input_mask = append_tensor(
                tr_feedables.input_mask,
                tf.expand_dims(unfinished_mask, -1),
                axis=1)

            last_layer = self.layer(
                self.depth, input_sequence, tf.squeeze(input_mask, -1))

            # (batch, state_size)
            output_state = last_layer.temporal_states[:, -1, :]

        new_feedables = TransformerFeedables(
            input_sequence=input_sequence,
            input_mask=input_mask)

        # TODO: do something more interesting here
        new_histories = tr_histories

        return (output_state, new_feedables, new_histories)
예제 #3
0
        def body(*args) -> LoopState:

            loop_state = LoopState(*args)
            feedables = loop_state.feedables
            histories = loop_state.histories

            with tf.variable_scope(self._variable_scope, reuse=tf.AUTO_REUSE):
                output_state, dec_other, hist_other = self.next_state(
                    loop_state)

                logits = state_to_logits(output_state)
                logits /= temperature

                next_symbols = logits_to_symbols(logits, loop_state)
                finished = is_finished(feedables.finished, next_symbols)

            next_feedables = DecoderFeedables(
                step=feedables.step + 1,
                finished=finished,
                embedded_input=self.embed_input_symbols(next_symbols),
                other=dec_other)

            next_histories = DecoderHistories(
                logits=append_tensor(histories.logits, logits),
                output_states=append_tensor(
                    histories.output_states, output_state),
                output_symbols=append_tensor(
                    histories.output_symbols, next_symbols),
                output_mask=append_tensor(
                    histories.output_mask, tf.logical_not(finished)),
                other=hist_other)

            return LoopState(
                feedables=next_feedables,
                histories=next_histories,
                constants=loop_state.constants)
예제 #4
0
        def body(*args) -> LoopState:
            loop_state = LoopState(*args)
            step = loop_state.feedables.step

            with tf.variable_scope(self.step_scope):
                # Compute the input to the RNN
                rnn_input = self.input_projection(*loop_state)

                # Run the RNN.
                cell = self._get_rnn_cell()
                if self._rnn_cell_str in ["GRU", "NematusGRU"]:
                    cell_output, next_state = cell(
                        rnn_input, loop_state.feedables.prev_rnn_output)

                    attns = [
                        a.attention(cell_output,
                                    loop_state.feedables.prev_rnn_output,
                                    rnn_input, att_loop_state)
                        for a, att_loop_state in zip(
                            self.attentions,
                            loop_state.histories.attention_histories)
                    ]
                    if self.attentions:
                        contexts, att_loop_states = zip(*attns)
                    else:
                        contexts, att_loop_states = [], []

                    if self._conditional_gru:
                        cell_cond = self._get_conditional_gru_cell()
                        cond_input = tf.concat(contexts, -1)
                        cell_output, next_state = cell_cond(
                            cond_input, next_state, scope="cond_gru_2_cell")

                elif self._rnn_cell_str == "LSTM":
                    prev_state = tf.contrib.rnn.LSTMStateTuple(
                        loop_state.feedables.prev_rnn_state,
                        loop_state.feedables.prev_rnn_output)
                    cell_output, state = cell(rnn_input, prev_state)
                    next_state = state.c
                    attns = [
                        a.attention(cell_output,
                                    loop_state.feedables.prev_rnn_output,
                                    rnn_input, att_loop_state)
                        for a, att_loop_state in zip(
                            self.attentions,
                            loop_state.histories.attention_histories)
                    ]
                    if self.attentions:
                        contexts, att_loop_states = zip(*attns)
                    else:
                        contexts, att_loop_states = [], []
                else:
                    raise ValueError("Unknown RNN cell.")

                with tf.name_scope("rnn_output_projection"):
                    embedded_input = tf.nn.embedding_lookup(
                        self.embedding_matrix,
                        loop_state.feedables.input_symbol)

                    output = self.output_projection(cell_output,
                                                    embedded_input,
                                                    list(contexts),
                                                    self.train_mode)

                logits = self.get_logits(output) / temperature

            self.step_scope.reuse_variables()

            if sample:
                next_symbols = tf.to_int32(
                    tf.squeeze(tf.multinomial(logits, num_samples=1), axis=1))
            elif train_mode:
                next_symbols = loop_state.constants.train_inputs[step]
            else:
                next_symbols = tf.to_int32(tf.argmax(logits, axis=1))
                int_unfinished_mask = tf.to_int32(
                    tf.logical_not(loop_state.feedables.finished))

                # Note this works only when PAD_TOKEN_INDEX is 0. Otherwise
                # this have to be rewritten
                assert PAD_TOKEN_INDEX == 0
                next_symbols = next_symbols * int_unfinished_mask

            has_just_finished = tf.equal(next_symbols, END_TOKEN_INDEX)
            has_finished = tf.logical_or(loop_state.feedables.finished,
                                         has_just_finished)
            not_finished = tf.logical_not(has_finished)

            # pylint: disable=not-callable
            new_feedables = RNNFeedables(step=step + 1,
                                         finished=has_finished,
                                         input_symbol=next_symbols,
                                         prev_logits=logits,
                                         prev_rnn_state=next_state,
                                         prev_rnn_output=cell_output,
                                         prev_contexts=list(contexts))

            new_histories = RNNHistories(
                attention_histories=list(att_loop_states),
                logits=append_tensor(loop_state.histories.logits, logits),
                decoder_outputs=append_tensor(
                    loop_state.histories.decoder_outputs, cell_output),
                outputs=append_tensor(loop_state.histories.outputs,
                                      next_symbols),
                mask=append_tensor(loop_state.histories.mask, not_finished))
            # pylint: enable=not-callable

            new_loop_state = LoopState(histories=new_histories,
                                       constants=loop_state.constants,
                                       feedables=new_feedables)

            return new_loop_state
예제 #5
0
        def body(*args: Any) -> BeamSearchLoopState:
            """Execute a single beam search step.

            An implementation of the beam search algorithm, which works as
            follows:

            1. Create a valid ``logprobs`` tensor which contains distributions
               over the output tokens for each hypothesis in the beam. For
               finished hypotheses, the log probabilities of all tokens except
               the padding token are set to negative infinity.

            2. Expand the beam by appending every possible token to every
               existing hypothesis. Update the log probabilitiy sum of each
               hypothesis and its length (add one for unfinished hypotheses).
               For each hypothesis, compute the score using the length penalty
               term.

            3. Select the ``beam_size`` best hypotheses from the score pool.
               This is implemented by flattening the scores tensor and using
               the ``tf.nn.top_k`` function.

            4. Reconstruct the beam by gathering elements from the original
               data structures using the data indices computed in the previous
               step.

            5. Call the ``body`` function of the underlying decoder.

            6. Populate a new ``BeamSearchLoopState`` object with the selected
               values and with the newly obtained decoder loop state.

            Note that this function expects the decoder to be called at least
            once prior the first execution.

            Arguments:
                args: An instance of the ``BeamSearchLoopState`` structure.
                    (see the docs for this module)

            Returns:
                A ``BeamSearchLoopState`` after one step of the decoding.

            """
            loop_state = BeamSearchLoopState(*args)
            dec_loop_state = loop_state.decoder_loop_state
            search_state = loop_state.search_state
            search_results = loop_state.search_results

            # mask the probabilities
            # >> shape(logprobs) = [batch, beam, vocabulary]
            logprobs = search_state.prev_logprobs

            # >> shape(finished_mask) = [batch, beam, 1]
            # float, 0 alebo 1, 1 pre dokoncene hypotezy, inak 0
            finished_mask = tf.expand_dims(tf.to_float(search_state.finished),
                                           2)
            # >> shape(unfinished_logprobs) = [batch, beam, vocabulary]
            # dokoncene hypotezy maju logprob 0 pre vsetky tokeny zo slovniku
            unfinished_logprobs = (1. - finished_mask) * logprobs

            # >> shape(finished_row) = [vocabulary]
            # vsade -INF, okrem indexu <PAD>, tam 0
            finished_row = tf.one_hot(PAD_TOKEN_INDEX,
                                      len(self.vocabulary),
                                      dtype=tf.float32,
                                      on_value=0.,
                                      off_value=-INF)

            # >> shape(finished_logprobs) = [batch, beam, vocabulary]
            # nedokoncene hypotezy maju 0 pre cely vocabulary
            # dokoncene hypotezy maju u vsetkych tokenov zo slovniku -INF logprob,
            # okrem <PAD>, tam 0. docielene ze sa bude paddovat, ak nic lepsie.
            finished_logprobs = finished_mask * finished_row
            # >> shape(logprobs) = [batch, beam, vocabulary]
            # pravdepodobnosti tokenov pre hypotezy. dokoncene hypotezy
            # maju vsade -INF okrem <PAD>, tam 0
            logprobs = unfinished_logprobs + finished_logprobs

            # update hypothesis scores
            # >> shape(hyp_probs) = [batch, beam, vocabulary]
            #
            # hyp_probs obsahuju skore celej hypotezy pri pridani daneho tokenu
            hyp_probs = tf.expand_dims(search_state.logprob_sum, 2) + logprobs

            # update hypothesis lengths
            #
            # >> shape(hyp_lengths) = [batch, beam]
            # zvys dlzku nedokoncenych hypotez o 1
            hyp_lengths = search_state.lengths + 1 - tf.to_int32(
                search_state.finished)

            # >> shape(scores) = [batch, beam, vocabulary]
            #
            # aplikacia length penalty na skore hypotez
            # scores teraz drzi nove skore hypotez
            scores = hyp_probs / tf.expand_dims(
                self._length_penalty(hyp_lengths), 2)

            # reshape to [batch, beam * vocabulary] for topk
            # >> shape(scores_flat) = [batch, beam * vocabulary]
            scores_flat = tf.reshape(
                scores, [-1, self.beam_size * len(self.vocabulary)])

            # >> shape(both) = [batch, beam]
            #
            # topk_scores obsahuje vrchnych `beam_size` skor hypotez
            # topk_indices obsahuje ich indexy v scores_flat
            topk_scores, topk_indices = tf.nn.top_k(scores_flat,
                                                    k=self.beam_size)

            topk_indices.set_shape([None, self.beam_size])
            topk_scores.set_shape([None, self.beam_size])

            # >> shape(next_word_ids) = [batch, beam]
            # next_word_ids obsajuje indexy do slovniku pre nove tokeny
            #
            # >> shape(next_beam_ids) = [batch, beam]
            # next_beam_ids obsahuje indexy beamov z ktorych vzniknu nove hypotezy
            # odtialto parent_ids
            next_word_ids = tf.mod(topk_indices, len(self.vocabulary))
            next_beam_ids = tf.div(topk_indices, len(self.vocabulary))

            parent_ids = next_beam_ids

            # batch offset for tf.gather_nd
            batch_offset = tf.tile(
                tf.expand_dims(tf.range(self.batch_size), 1),
                [1, self.beam_size])
            batch_beam_ids = tf.stack([batch_offset, next_beam_ids], axis=2)

            # gather the topk logprob_sums
            #
            # >> shape(next_beam_lengths) = [batch, beam]
            # v next_beam_lengths dlzky novych beam_size-hypotez
            next_beam_lengths = tf.gather_nd(hyp_lengths, batch_beam_ids)

            # >> shape(next_beam_logprob_sum) = [batch, beam]
            # v next_beam_logprob_sum skore novych hypotez
            next_beam_logprob_sum = tf.gather_nd(
                tf.reshape(hyp_probs,
                           [-1, self.beam_size * len(self.vocabulary)]),
                tf.stack([batch_offset, topk_indices], axis=2))

            # mark finished beams
            next_finished = tf.gather_nd(search_state.finished, batch_beam_ids)
            next_just_finished = tf.equal(next_word_ids, END_TOKEN_INDEX)
            next_finished = tf.logical_or(next_finished, next_just_finished)

            # we need to flatten the feedables for the parent_decoder
            next_feedables = tf.contrib.framework.nest.map_structure(
                lambda x: gather_flat(x, batch_beam_ids, self.batch_size, self.
                                      beam_size), dec_loop_state.feedables)

            next_feedables = next_feedables._replace(
                input_symbol=tf.reshape(next_word_ids, [-1]),
                finished=tf.reshape(next_finished, [-1]))

            # histories have shape [len, batch, ...]
            def gather_fn(x):
                return partial_transpose(
                    gather_flat(partial_transpose(x, [1, 0]), batch_beam_ids,
                                self.batch_size, self.beam_size), [1, 0])

            next_histories = tf.contrib.framework.nest.map_structure(
                gather_fn, dec_loop_state.histories)

            dec_loop_state = dec_loop_state._replace(feedables=next_feedables,
                                                     histories=next_histories)

            # CALL THE DECODER BODY FUNCTION
            next_loop_state = decoder_body(*dec_loop_state)

            next_search_state = SearchState(
                logprob_sum=next_beam_logprob_sum,
                prev_logprobs=tf.reshape(
                    tf.nn.log_softmax(next_loop_state.feedables.prev_logits),
                    [self.batch_size, self.beam_size,
                     len(self.vocabulary)]),
                lengths=next_beam_lengths,
                finished=next_finished)

            # next_token_ids = tf.transpose(search_results.token_ids, [1, 2, 0])
            # next_token_ids = tf.gather_nd(next_token_ids, batch_beam_ids)
            # next_token_ids = tf.transpose(next_token_ids, [2, 0, 1])
            # zakomentovane, lebo chcem povodne tokeny
            next_output = SearchResults(
                scores=append_tensor(search_results.scores, topk_scores),
                #token_ids=append_tensor(next_token_ids, next_word_ids),
                token_ids=append_tensor(search_results.token_ids,
                                        next_word_ids),
                parent_ids=append_tensor(search_results.parent_ids,
                                         parent_ids))

            return BeamSearchLoopState(search_state=next_search_state,
                                       search_results=next_output,
                                       decoder_loop_state=next_loop_state)
예제 #6
0
        def body(*args) -> LoopState:

            loop_state = LoopState(*args)
            histories = loop_state.histories
            feedables = loop_state.feedables

            # shape (time, batch)
            decoded_symbols = append_tensor(histories.decoded_symbols,
                                            feedables.input_symbol)

            unfinished_mask = tf.to_float(tf.logical_not(feedables.finished))
            input_mask = append_tensor(histories.input_mask, unfinished_mask)

            # shape (batch, time)
            decoded_symbols_in_batch = tf.transpose(decoded_symbols)

            # mask (time, batch)
            mask = input_mask

            with tf.variable_scope(self._variable_scope, reuse=tf.AUTO_REUSE):
                # shape (batch, time, dimension)
                embedded_inputs = self.embed_inputs(decoded_symbols_in_batch)

                last_layer = self.layer(self.depth, embedded_inputs,
                                        tf.transpose(mask))

                # (batch, state_size)
                output_state = last_layer.temporal_states[:, -1, :]

                # See train_logits definition
                logits = tf.matmul(output_state, self.decoding_w)
                logits += self.decoding_b

                # apply temperature
                logits /= temperature

                if sample:
                    next_symbols = tf.squeeze(tf.multinomial(logits,
                                                             num_samples=1),
                                              axis=1)
                    next_symbols = tf.to_int32(next_symbols)
                else:
                    next_symbols = tf.to_int32(tf.argmax(logits, axis=1))
                    int_unfinished_mask = tf.to_int32(
                        tf.logical_not(loop_state.feedables.finished))

                    # Note this works only when PAD_TOKEN_INDEX is 0. Otherwise
                    # this have to be rewritten
                    assert PAD_TOKEN_INDEX == 0
                    next_symbols = next_symbols * int_unfinished_mask

                has_just_finished = tf.equal(next_symbols, END_TOKEN_INDEX)
                has_finished = tf.logical_or(feedables.finished,
                                             has_just_finished)
                not_finished = tf.logical_not(has_finished)

            new_feedables = DecoderFeedables(step=feedables.step + 1,
                                             finished=has_finished,
                                             input_symbol=next_symbols,
                                             prev_logits=logits)

            # TransformerHistories is a type and should be callable
            # pylint: disable=not-callable
            new_histories = TransformerHistories(
                logits=append_tensor(histories.logits, logits),
                decoder_outputs=append_tensor(histories.decoder_outputs,
                                              output_state),
                mask=append_tensor(histories.mask, not_finished),
                outputs=append_tensor(histories.outputs, next_symbols),
                # transformer-specific:
                decoded_symbols=decoded_symbols,
                # TODO(all) handle these!
                # self_attention_histories=histories.self_attention_histories,
                # inter_attention_histories analogicky
                input_mask=input_mask)
            # pylint: enable=not-callable

            new_loop_state = LoopState(histories=new_histories,
                                       constants=loop_state.constants,
                                       feedables=new_feedables)

            return new_loop_state
예제 #7
0
        def body(*args: Any) -> BeamSearchLoopState:
            """Execute a single beam search step.

            An implementation of the beam search algorithm, which works as
            follows:

            1. Create a valid ``logprobs`` tensor which contains distributions
               over the output tokens for each hypothesis in the beam. For
               finished hypotheses, the log probabilities of all tokens except
               the padding token are set to negative infinity.

            2. Expand the beam by appending every possible token to every
               existing hypothesis. Update the log probabilitiy sum of each
               hypothesis and its length (add one for unfinished hypotheses).
               For each hypothesis, compute the score using the length penalty
               term.

            3. Select the ``beam_size`` best hypotheses from the score pool.
               This is implemented by flattening the scores tensor and using
               the ``tf.nn.top_k`` function.

            4. Reconstruct the beam by gathering elements from the original
               data structures using the data indices computed in the previous
               step.

            5. Call the ``body`` function of the underlying decoder.

            6. Populate a new ``BeamSearchLoopState`` object with the selected
               values and with the newly obtained decoder loop state.

            Note that this function expects the decoder to be called at least
            once prior the first execution.

            Arguments:
                args: An instance of the ``BeamSearchLoopState`` structure.
                    (see the docs for this module)

            Returns:
                A ``BeamSearchLoopState`` after one step of the decoding.

            """
            loop_state = BeamSearchLoopState(*args)
            dec_loop_state = loop_state.decoder_loop_state
            search_state = loop_state.search_state
            search_results = loop_state.search_results

            # mask the probabilities
            # shape(logprobs) = [batch, beam, vocabulary]
            logprobs = search_state.prev_logprobs

            finished_mask = tf.expand_dims(
                tf.to_float(search_state.finished), 2)
            unfinished_logprobs = (1. - finished_mask) * logprobs

            finished_row = tf.one_hot(
                PAD_TOKEN_INDEX,
                len(self.vocabulary),
                dtype=tf.float32,
                on_value=0.,
                off_value=-INF)

            finished_logprobs = finished_mask * finished_row
            logprobs = unfinished_logprobs + finished_logprobs

            # update hypothesis scores
            # shape(hyp_probs) = [batch, beam, vocabulary]
            hyp_probs = tf.expand_dims(search_state.logprob_sum, 2) + logprobs

            # update hypothesis lengths
            hyp_lengths = search_state.lengths + 1 - tf.to_int32(
                search_state.finished)

            # shape(scores) = [batch, beam, vocabulary]
            scores = hyp_probs / tf.expand_dims(
                self._length_penalty(hyp_lengths), 2)

            # reshape to [batch, beam * vocabulary] for topk
            scores_flat = tf.reshape(
                scores, [-1, self.beam_size * len(self.vocabulary)])

            # shape(both) = [batch, beam]
            topk_scores, topk_indices = tf.nn.top_k(
                scores_flat, k=self.beam_size)

            topk_indices.set_shape([None, self.beam_size])
            topk_scores.set_shape([None, self.beam_size])

            next_word_ids = tf.mod(topk_indices, len(self.vocabulary))
            next_beam_ids = tf.div(topk_indices, len(self.vocabulary))

            # batch offset for tf.gather_nd
            batch_offset = tf.tile(
                tf.expand_dims(tf.range(self.batch_size), 1),
                [1, self.beam_size])
            batch_beam_ids = tf.stack([batch_offset, next_beam_ids], axis=2)

            # gather the topk logprob_sums
            next_beam_lengths = tf.gather_nd(hyp_lengths, batch_beam_ids)
            next_beam_logprob_sum = tf.gather_nd(
                tf.reshape(
                    hyp_probs, [-1, self.beam_size * len(self.vocabulary)]),
                tf.stack([batch_offset, topk_indices], axis=2))

            # mark finished beams
            next_finished = tf.gather_nd(search_state.finished, batch_beam_ids)
            next_just_finished = tf.equal(next_word_ids, END_TOKEN_INDEX)
            next_finished = tf.logical_or(next_finished, next_just_finished)

            # we need to flatten the feedables for the parent_decoder
            next_feedables = tf.contrib.framework.nest.map_structure(
                lambda x: gather_flat(x, batch_beam_ids,
                                      self.batch_size, self.beam_size),
                dec_loop_state.feedables)

            next_feedables = next_feedables._replace(
                input_symbol=tf.reshape(next_word_ids, [-1]),
                finished=tf.reshape(next_finished, [-1]))

            # histories have shape [len, batch, ...]
            def gather_fn(x):
                return partial_transpose(
                    gather_flat(
                        partial_transpose(x, [1, 0]),
                        batch_beam_ids,
                        self.batch_size,
                        self.beam_size),
                    [1, 0])

            next_histories = tf.contrib.framework.nest.map_structure(
                gather_fn, dec_loop_state.histories)

            dec_loop_state = dec_loop_state._replace(
                feedables=next_feedables,
                histories=next_histories)

            # CALL THE DECODER BODY FUNCTION
            next_loop_state = decoder_body(*dec_loop_state)

            next_search_state = SearchState(
                logprob_sum=next_beam_logprob_sum,
                prev_logprobs=tf.reshape(
                    tf.nn.log_softmax(next_loop_state.feedables.prev_logits),
                    [self.batch_size, self.beam_size, len(self.vocabulary)]),
                lengths=next_beam_lengths,
                finished=next_finished)

            next_token_ids = tf.transpose(search_results.token_ids, [1, 2, 0])
            next_token_ids = tf.gather_nd(next_token_ids, batch_beam_ids)
            next_token_ids = tf.transpose(next_token_ids, [2, 0, 1])
            next_output = SearchResults(
                scores=topk_scores,
                token_ids=append_tensor(next_token_ids, next_word_ids))

            return BeamSearchLoopState(
                search_state=next_search_state,
                search_results=next_output,
                decoder_loop_state=next_loop_state)
예제 #8
0
    def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
        rnn_feedables = loop_state.feedables.other
        rnn_histories = loop_state.histories.other

        with tf.variable_scope(self.step_scope):
            rnn_input = self.input_projection(*loop_state)

            cell = self._get_rnn_cell()
            if self._rnn_cell_str in ["GRU", "NematusGRU"]:
                cell_output, next_state = cell(rnn_input,
                                               rnn_feedables.prev_rnn_output)

                attns = [
                    a.attention(cell_output, rnn_feedables.prev_rnn_output,
                                rnn_input, att_loop_state)
                    for a, att_loop_state in zip(
                        self.attentions, rnn_histories.attention_histories)
                ]
                if self.attentions:
                    contexts, att_loop_states = zip(*attns)
                else:
                    contexts, att_loop_states = [], []

                if self._conditional_gru:
                    cell_cond = self._get_conditional_gru_cell()
                    cond_input = tf.concat(contexts, -1)
                    cell_output, next_state = cell_cond(
                        cond_input, next_state, scope="cond_gru_2_cell")

            elif self._rnn_cell_str == "LSTM":
                prev_state = tf.contrib.rnn.LSTMStateTuple(
                    rnn_feedables.prev_rnn_state,
                    rnn_feedables.prev_rnn_output)
                cell_output, state = cell(rnn_input, prev_state)
                next_state = state.c
                attns = [
                    a.attention(cell_output, rnn_feedables.prev_rnn_output,
                                rnn_input, att_loop_state)
                    for a, att_loop_state in zip(
                        self.attentions, rnn_histories.attention_histories)
                ]
                if self.attentions:
                    contexts, att_loop_states = zip(*attns)
                else:
                    contexts, att_loop_states = [], []
            else:
                raise ValueError("Unknown RNN cell.")

            # TODO: attention functions should apply dropout on output
            #       themselves before returning the tensors
            contexts = [
                dropout(ctx, self.dropout_keep_prob, self.train_mode)
                for ctx in list(contexts)
            ]
            cell_output = dropout(cell_output, self.dropout_keep_prob,
                                  self.train_mode)

            with tf.name_scope("rnn_output_projection"):
                if self.embedding_size != self.output_dimension:
                    raise ValueError(
                        "The dimension ({}) of the output projection must be "
                        "same as the dimension of the input embedding "
                        "({})".format(self.output_dimension,
                                      self.embedding_size))
                # pylint: disable=not-callable
                output = self.output_projection(
                    cell_output, loop_state.feedables.embedded_input,
                    list(contexts), self.train_mode)
                # pylint: enable=not-callable

        new_feedables = RNNFeedables(prev_rnn_state=next_state,
                                     prev_rnn_output=cell_output,
                                     prev_contexts=list(contexts))

        new_histories = RNNHistories(rnn_outputs=append_tensor(
            rnn_histories.rnn_outputs, cell_output),
                                     attention_histories=list(att_loop_states))

        return (output, new_feedables, new_histories)
예제 #9
0
    def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
        rnn_feedables = loop_state.feedables.other
        rnn_histories = loop_state.histories.other

        with tf.variable_scope(self.step_scope):
            rnn_input = self.input_projection(*loop_state)

            cell = self._get_rnn_cell()
            if self._rnn_cell_str in ["GRU", "NematusGRU"]:
                cell_output, next_state = cell(
                    rnn_input, rnn_feedables.prev_rnn_output)

                attns = [
                    a.attention(
                        cell_output, rnn_feedables.prev_rnn_output,
                        rnn_input, att_loop_state)
                    for a, att_loop_state in zip(
                        self.attentions,
                        rnn_histories.attention_histories)]
                if self.attentions:
                    contexts, att_loop_states = zip(*attns)
                else:
                    contexts, att_loop_states = [], []

                if self._conditional_gru:
                    cell_cond = self._get_conditional_gru_cell()
                    cond_input = tf.concat(contexts, -1)
                    cell_output, next_state = cell_cond(
                        cond_input, next_state, scope="cond_gru_2_cell")

            elif self._rnn_cell_str == "LSTM":
                prev_state = tf.contrib.rnn.LSTMStateTuple(
                    rnn_feedables.prev_rnn_state,
                    rnn_feedables.prev_rnn_output)
                cell_output, state = cell(rnn_input, prev_state)
                next_state = state.c
                attns = [
                    a.attention(
                        cell_output, rnn_feedables.prev_rnn_output,
                        rnn_input, att_loop_state)
                    for a, att_loop_state in zip(
                        self.attentions,
                        rnn_histories.attention_histories)]
                if self.attentions:
                    contexts, att_loop_states = zip(*attns)
                else:
                    contexts, att_loop_states = [], []
            else:
                raise ValueError("Unknown RNN cell.")

            # TODO: attention functions should apply dropout on output
            #       themselves before returning the tensors
            contexts = [dropout(ctx, self.dropout_keep_prob, self.train_mode)
                        for ctx in list(contexts)]
            cell_output = dropout(
                cell_output, self.dropout_keep_prob, self.train_mode)

            with tf.name_scope("rnn_output_projection"):
                if self.embedding_size != self.output_dimension:
                    raise ValueError(
                        "The dimension ({}) of the output projection must be "
                        "same as the dimension of the input embedding "
                        "({})".format(self.output_dimension,
                                      self.embedding_size))
                # pylint: disable=not-callable
                output = self.output_projection(
                    cell_output, loop_state.feedables.embedded_input,
                    list(contexts), self.train_mode)
                # pylint: enable=not-callable

        new_feedables = RNNFeedables(
            prev_rnn_state=next_state,
            prev_rnn_output=cell_output,
            prev_contexts=list(contexts))

        new_histories = RNNHistories(
            rnn_outputs=append_tensor(rnn_histories.rnn_outputs, cell_output),
            attention_histories=list(att_loop_states))

        return (output, new_feedables, new_histories)
예제 #10
0
        def body(*args: Any) -> BeamSearchLoopState:
            """Execute a single beam search step.

            An implementation of the beam search algorithm, which works as
            follows:

            1. Create a valid ``logprobs`` tensor which contains distributions
               over the output tokens for each hypothesis in the beam. For
               finished hypotheses, the log probabilities of all tokens except
               the padding token are set to negative infinity.

            2. Expand the beam by appending every possible token to every
               existing hypothesis. Update the log probabilitiy sum of each
               hypothesis and its length (add one for unfinished hypotheses).
               For each hypothesis, compute the score using the length penalty
               term.

            3. Select the ``beam_size`` best hypotheses from the score pool.
               This is implemented by flattening the scores tensor and using
               the ``tf.nn.top_k`` function.

            4. Reconstruct the beam by gathering elements from the original
               data structures using the data indices computed in the previous
               step.

            5. Call the ``body`` function of the underlying decoder.

            6. Populate a new ``BeamSearchLoopState`` object with the selected
               values and with the newly obtained decoder loop state.

            Note that this function expects the decoder to be called at least
            once prior the first execution.

            Arguments:
                args: An instance of the ``BeamSearchLoopState`` structure.
                    (see the docs for this module)

            Returns:
                A ``BeamSearchLoopState`` after one step of the decoding.

            """
            loop_state = BeamSearchLoopState(*args)
            dec_loop_state = loop_state.decoder_loop_state
            search_state = loop_state.search_state
            search_results = loop_state.search_results

            # mask the probabilities
            # shape(logprobs) = [batch, beam, vocabulary]
            logprobs = search_state.prev_logprobs

            finished_mask = tf.expand_dims(
                tf.to_float(search_state.finished), 2)
            unfinished_logprobs = (1. - finished_mask) * logprobs

            finished_row = tf.one_hot(
                PAD_TOKEN_INDEX,
                len(self.vocabulary),
                dtype=tf.float32,
                on_value=0.,
                off_value=-INF)

            finished_logprobs = finished_mask * finished_row
            logprobs = unfinished_logprobs + finished_logprobs

            # update hypothesis scores
            # shape(hyp_probs) = [batch, beam, vocabulary]
            hyp_probs = tf.expand_dims(search_state.logprob_sum, 2) + logprobs

            # update hypothesis lengths
            hyp_lengths = search_state.lengths + 1 - tf.to_int32(
                search_state.finished)

            # shape(scores) = [batch, beam, vocabulary]
            scores = hyp_probs / tf.expand_dims(
                self._length_penalty(hyp_lengths), 2)

            # reshape to [batch, beam * vocabulary] for topk
            scores_flat = tf.reshape(
                scores, [-1, self.beam_size * len(self.vocabulary)])

            # shape(both) = [batch, beam]
            topk_scores, topk_indices = tf.nn.top_k(
                scores_flat, k=self.beam_size)

            topk_indices.set_shape([None, self.beam_size])
            topk_scores.set_shape([None, self.beam_size])

            next_word_ids = tf.to_int64(
                tf.mod(topk_indices, len(self.vocabulary)))
            next_beam_ids = tf.div(topk_indices, len(self.vocabulary))

            # batch offset for tf.gather_nd
            batch_offset = tf.tile(
                tf.expand_dims(tf.range(self.batch_size), 1),
                [1, self.beam_size])
            batch_beam_ids = tf.stack([batch_offset, next_beam_ids], axis=2)

            # gather the topk logprob_sums
            next_beam_lengths = tf.gather_nd(hyp_lengths, batch_beam_ids)
            next_beam_logprob_sum = tf.gather_nd(
                tf.reshape(
                    hyp_probs, [-1, self.beam_size * len(self.vocabulary)]),
                tf.stack([batch_offset, topk_indices], axis=2))

            # mark finished beams
            next_finished = tf.gather_nd(search_state.finished, batch_beam_ids)
            next_just_finished = tf.equal(next_word_ids, END_TOKEN_INDEX)
            next_finished = tf.logical_or(next_finished, next_just_finished)

            # we need to flatten the feedables for the parent_decoder
            next_feedables = tf.contrib.framework.nest.map_structure(
                lambda x: gather_flat(x, batch_beam_ids,
                                      self.batch_size, self.beam_size),
                dec_loop_state.feedables)

            next_feedables = next_feedables._replace(
                embedded_input=self.parent_decoder.embed_input_symbols(
                    tf.reshape(next_word_ids, [-1])),
                finished=tf.reshape(next_finished, [-1]))

            # histories have shape [len, batch, ...]
            def gather_fn(x):
                if len(x.shape.dims) < 2:
                    return x

                return partial_transpose(
                    gather_flat(
                        partial_transpose(x, [1, 0]),
                        batch_beam_ids,
                        self.batch_size,
                        self.beam_size),
                    [1, 0])

            next_histories = tf.contrib.framework.nest.map_structure(
                gather_fn, dec_loop_state.histories)

            dec_loop_state = dec_loop_state._replace(
                feedables=next_feedables,
                histories=next_histories)

            # CALL THE DECODER BODY FUNCTION
            next_loop_state = decoder_body(*dec_loop_state)

            logits = next_loop_state.histories.logits[-1, :, :]
            next_search_state = SearchState(
                logprob_sum=next_beam_logprob_sum,
                prev_logprobs=tf.reshape(
                    tf.nn.log_softmax(logits),
                    [self.batch_size, self.beam_size, len(self.vocabulary)]),
                lengths=next_beam_lengths,
                finished=next_finished)

            next_token_ids = tf.transpose(search_results.token_ids, [1, 2, 0])
            next_token_ids = tf.gather_nd(next_token_ids, batch_beam_ids)
            next_token_ids = tf.transpose(next_token_ids, [2, 0, 1])
            next_output = SearchResults(
                scores=topk_scores,
                token_ids=append_tensor(next_token_ids, next_word_ids))

            return BeamSearchLoopState(
                search_state=next_search_state,
                search_results=next_output,
                decoder_loop_state=next_loop_state)