예제 #1
0
 def transformer_concated_decoder_internal(inputs,
                                           memory,
                                           bias,
                                           mem_bias,
                                           params,
                                           state=None,
                                           scope=None,
                                           reuse=False):
     return transformer_decoder(inputs, memory, bias, mem_bias, params,
                                state, scope, reuse)
예제 #2
0
    def decode_infer(self, inputs, state):
        # during infer, following graph are constructed using beam search
        target_sequence = inputs['target']
        target_length = inputs['target_length']
        tgt_mask = tf.sequence_mask(target_length,
                                    maxlen=tf.shape(target_sequence)[1],
                                    dtype=tf.float32)  # [b, q_2]

        # [b, q_2, e]
        self.tgt_embed = tf.nn.embedding_lookup(
            self.word_embeddings, target_sequence) * (self.embed_dim**0.5)
        self.masked_tgt_embed = self.tgt_embed * tf.expand_dims(tgt_mask, -1)
        self.dec_attn_bias = attention_bias(
            tf.shape(self.masked_tgt_embed)[1], "causal")
        self.decoder_input = tf.pad(
            self.masked_tgt_embed,
            [[0, 0], [1, 0], [0, 0]])[:, :-1, :]  # Shift left
        self.decoder_input = add_timing_signal(self.decoder_input)
        if self.params.residual_dropout > 0:
            self.decoder_input = tf.nn.dropout(
                self.decoder_input, 1.0 - self.params.residual_dropout)
        self.infer_decoder_input = self.decoder_input[:, -1:, :]
        self.infer_dec_attn_bias = self.dec_attn_bias[:, :, -1:, :]

        self.all_att_weights, self.decoder_output, self.decoder_state = transformer_decoder(
            self.infer_decoder_input,
            state['encoder'],
            self.infer_dec_attn_bias,
            self.enc_attn_bias,
            self.params,
            state=state['decoder'])
        self.decoder_output = self.decoder_output[:,
                                                  -1, :]  # [batch_size, hidden]
        self.logits = tf.matmul(self.decoder_output, self.decoder_weights,
                                False, True)  # [batch_size, vocab_size]
        self.log_prob = tf.nn.log_softmax(self.logits)
        return self.log_prob, {
            'encoder': state['encoder'],
            'decoder': self.decoder_state
        }
예제 #3
0
 def _decode_train(self):
     # [b, q_2, e]
     self.tgt_embed = tf.nn.embedding_lookup(
         self.word_embeddings, self.tgt_seq) * (self.embed_dim**0.5)
     self.masked_tgt_embed = self.tgt_embed * tf.expand_dims(
         self.tgt_mask, -1)
     self.dec_attn_bias = attention_bias(
         tf.shape(self.masked_tgt_embed)[1], "causal")
     self.decoder_input = tf.pad(
         self.masked_tgt_embed,
         [[0, 0], [1, 0], [0, 0]])[:, :-1, :]  # Shift left
     self.decoder_input = add_timing_signal(self.decoder_input)
     if self.params.residual_dropout > 0:
         self.decoder_input = tf.nn.dropout(
             self.decoder_input, 1.0 - self.params.residual_dropout)
     self.all_att_weights, self.decoder_output = transformer_decoder(
         self.decoder_input, self.encoder_output, self.dec_attn_bias,
         self.enc_attn_bias, self.params)
     # [b, q_2, e] => [b*q_2, v]
     self.decoder_output = tf.reshape(self.decoder_output,
                                      [-1, self.hidden_size])
     self.logits = tf.matmul(self.decoder_output, self.decoder_weights,
                             False, True)
     self._compute_loss()