예제 #1
0
    def _get_symbols_to_logits_fn(self, max_decode_length):
        """Returns a decoding function that calculates logits of the next tokens."""

        timing_signal = positional_encoding(max_decode_length + 1,
                                            self.params.hidden_size)
        decoder_self_attention_bias = get_target_mask(max_decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Generate logits for next potential IDs.

      Args:
        ids: Current decoded sequences. int tensor with shape [batch_size *
          beam_size, i + 1]
        i: Loop index
        cache: dictionary of values storing the encoder output, encoder-decoder
          attention bias, and previous decoder attention values.

      Returns:
        Tuple of
          (logits with shape [batch_size * beam_size, vocab_size],
           updated cache values)
      """
            # Set decoder input to the last generated IDs
            decoder_input = ids[:, -1:]

            # Preprocess decoder input by getting embeddings and adding timing signal.
            decoder_input = self.embedding_softmax_layer(decoder_input)
            decoder_input += timing_signal[i:i + 1]

            self_attention_bias = decoder_self_attention_bias[:, :,
                                                              i:i + 1, :i + 1]
            decoder_outputs = self.decoder_stack(
                decoder_input,
                features=cache.get("encoder_outputs"),
                target_mask=self_attention_bias,
                input_mask=cache.get("encoder_decoder_attention_bias"),
                cache=cache)
            logits = self.embedding_softmax_layer(decoder_outputs,
                                                  mode="linear")
            logits = tf.squeeze(logits, axis=[1])
            return logits, cache

        return symbols_to_logits_fn
예제 #2
0
    def decode(self, targets, encoder_outputs, attention_bias):
        """Generate logits for each value in the target sequence.

    Args:
      targets: target values for the output sequence. int tensor with shape
        [batch_size, target_length]
      encoder_outputs: continuous representation of input sequence. float tensor
        with shape [batch_size, input_length, hidden_size]
      attention_bias: float tensor with shape [batch_size, 1, 1, input_length]
      training: boolean, whether in training mode or not.

    Returns:
      float32 tensor with shape [batch_size, target_length, vocab_size]
    """
        with tf.name_scope("decode"):
            # Prepare inputs to decoder layers by shifting targets, adding positional
            # encoding and applying dropout.
            decoder_inputs = self.embedding_softmax_layer(targets)
            with tf.name_scope("shift_targets"):
                # Shift targets to the right, and remove the last element
                decoder_inputs = tf.pad(decoder_inputs,
                                        [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
            with tf.name_scope("add_pos_encoding"):
                length = tf.shape(decoder_inputs)[1]
                decoder_inputs += positional_encoding(length,
                                                      self.params.hidden_size)
            if self.is_train:
                decoder_inputs = tf.nn.dropout(decoder_inputs,
                                               rate=1 - self.params.keep_prob)

            # Run values
            decoder_self_attention_bias = get_target_mask(length)
            outputs = self.decoder_stack(
                decoder_inputs,
                features=encoder_outputs,
                input_mask=attention_bias,
                target_mask=decoder_self_attention_bias,
            )
            logits = self.embedding_softmax_layer(outputs, mode="linear")
            return logits