def transformer_concated_decoder_internal(inputs, memory, bias, mem_bias, params, state=None, scope=None, reuse=False): return transformer_decoder(inputs, memory, bias, mem_bias, params, state, scope, reuse)
def decode_infer(self, inputs, state): # during infer, following graph are constructed using beam search target_sequence = inputs['target'] target_length = inputs['target_length'] tgt_mask = tf.sequence_mask(target_length, maxlen=tf.shape(target_sequence)[1], dtype=tf.float32) # [b, q_2] # [b, q_2, e] self.tgt_embed = tf.nn.embedding_lookup( self.word_embeddings, target_sequence) * (self.embed_dim**0.5) self.masked_tgt_embed = self.tgt_embed * tf.expand_dims(tgt_mask, -1) self.dec_attn_bias = attention_bias( tf.shape(self.masked_tgt_embed)[1], "causal") self.decoder_input = tf.pad( self.masked_tgt_embed, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] # Shift left self.decoder_input = add_timing_signal(self.decoder_input) if self.params.residual_dropout > 0: self.decoder_input = tf.nn.dropout( self.decoder_input, 1.0 - self.params.residual_dropout) self.infer_decoder_input = self.decoder_input[:, -1:, :] self.infer_dec_attn_bias = self.dec_attn_bias[:, :, -1:, :] self.all_att_weights, self.decoder_output, self.decoder_state = transformer_decoder( self.infer_decoder_input, state['encoder'], self.infer_dec_attn_bias, self.enc_attn_bias, self.params, state=state['decoder']) self.decoder_output = self.decoder_output[:, -1, :] # [batch_size, hidden] self.logits = tf.matmul(self.decoder_output, self.decoder_weights, False, True) # [batch_size, vocab_size] self.log_prob = tf.nn.log_softmax(self.logits) return self.log_prob, { 'encoder': state['encoder'], 'decoder': self.decoder_state }
def _decode_train(self): # [b, q_2, e] self.tgt_embed = tf.nn.embedding_lookup( self.word_embeddings, self.tgt_seq) * (self.embed_dim**0.5) self.masked_tgt_embed = self.tgt_embed * tf.expand_dims( self.tgt_mask, -1) self.dec_attn_bias = attention_bias( tf.shape(self.masked_tgt_embed)[1], "causal") self.decoder_input = tf.pad( self.masked_tgt_embed, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] # Shift left self.decoder_input = add_timing_signal(self.decoder_input) if self.params.residual_dropout > 0: self.decoder_input = tf.nn.dropout( self.decoder_input, 1.0 - self.params.residual_dropout) self.all_att_weights, self.decoder_output = transformer_decoder( self.decoder_input, self.encoder_output, self.dec_attn_bias, self.enc_attn_bias, self.params) # [b, q_2, e] => [b*q_2, v] self.decoder_output = tf.reshape(self.decoder_output, [-1, self.hidden_size]) self.logits = tf.matmul(self.decoder_output, self.decoder_weights, False, True) self._compute_loss()