def encode(self, src_tensor, src_postion, turns_tensor): encode = self.word_embedding(src_tensor) + self.pos_embedding( src_postion) + self.turn_embedding(turns_tensor) slf_attn_mask = common.get_attn_key_pad_mask(src_tensor, src_tensor) non_pad_mask = common.get_non_pad_mask(src_tensor) enc_output = self.enc(encode, slf_attn_mask, non_pad_mask) return enc_output
def decode(self, tgt_tensor, src_tensor, enc_output): dec_output = self.word_embedding(tgt_tensor) dec_output = self.droupout(dec_output) non_pad_mask = common.get_non_pad_mask(tgt_tensor) slf_attn_mask_subseq = common.get_subsequent_mask(tgt_tensor) slf_attn_mask_keypad = common.get_attn_key_pad_mask( tgt_tensor, tgt_tensor, True) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) dec_enc_attn_mask = common.get_attn_key_pad_mask( src_tensor, tgt_tensor) dec_output, m_dec_output = self.dec(dec_output, enc_output, non_pad_mask, slf_attn_mask, dec_enc_attn_mask) distributes = self.attention(m_dec_output, enc_output) return distributes
def forward(self, src_tensor, src_postion, turns_tensor, tgt_tensor): # encode embedding encode = self.word_embedding(src_tensor) + self.pos_embedding( src_postion) + self.turn_embedding(turns_tensor) encode = self.droupout(encode) # encode mask slf_attn_mask = common.get_attn_key_pad_mask(src_tensor, src_tensor) non_pad_mask = common.get_non_pad_mask(src_tensor) # encode enc_output = self.enc(encode, slf_attn_mask, non_pad_mask) # decode embedding dec_output = self.word_embedding(tgt_tensor) dec_output = self.droupout(dec_output) # decode mask non_pad_mask = common.get_non_pad_mask(tgt_tensor) slf_attn_mask_subseq = common.get_subsequent_mask(tgt_tensor) slf_attn_mask_keypad = common.get_attn_key_pad_mask( tgt_tensor, tgt_tensor, True) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) dec_enc_attn_mask = common.get_attn_key_pad_mask( src_tensor, tgt_tensor) # decode dec_output, m_dec_output = self.dec(dec_output, enc_output, non_pad_mask, slf_attn_mask, dec_enc_attn_mask) # pointer network distributes = self.attention(m_dec_output, enc_output) return distributes
def encode(self, src_tensor, src_postion, turns_tensor, src_max_len): args = self.args with tf.compat.v1.variable_scope("encode", reuse=tf.compat.v1.AUTO_REUSE): # embedding enc_output = tf.nn.embedding_lookup(self.word_embedding, src_tensor) enc_output *= args.d_model**0.5 enc_output += tf.nn.embedding_lookup(self.pos_embedding, src_postion) turn_enc_output = tf.nn.embedding_lookup(self.turn_embedding, turns_tensor) enc_output += turn_enc_output * (args.d_model**0.5) enc_output = tf.nn.dropout(enc_output, keep_prob=self.dropout_rate) # encode mask slf_attn_mask = common.get_attn_key_pad_mask( src_tensor, src_tensor, src_max_len) non_pad_mask = common.get_non_pad_mask(src_tensor) # encode for i in range(args.enc_stack_layers): with tf.compat.v1.variable_scope( "num_blocks_{}".format(i), reuse=tf.compat.v1.AUTO_REUSE): enc_output, enc_slf_attn = multi_head_attention( enc_output, enc_output, enc_output, slf_attn_mask, args.n_head, args.d_model, args.d_k, args.d_v, self.dropout_rate, self.initializer) enc_output *= non_pad_mask enc_output = position_wise(enc_output, args.d_model, args.d_ff, self.dropout_rate, self.initializer) enc_output *= non_pad_mask return enc_output, non_pad_mask
def decode(self, tgt_tensor, tgt_postion, tgt_max_len, src_tensor, enc_output): args = self.args with tf.compat.v1.variable_scope("decode", reuse=tf.compat.v1.AUTO_REUSE): dec_output = tf.nn.embedding_lookup(self.word_embedding, tgt_tensor) dec_output *= args.d_model**0.5 dec_output += tf.nn.embedding_lookup(self.pos_embedding, tgt_postion) dec_output = tf.nn.dropout(dec_output, keep_prob=self.dropout_rate) # decode mask non_pad_mask = common.get_non_pad_mask(tgt_tensor) slf_attn_mask_subseq = common.get_subsequent_mask( tgt_tensor, self.batch_size, tgt_max_len) slf_attn_mask_keypad = common.get_attn_key_pad_mask( tgt_tensor, tgt_tensor, tgt_max_len) slf_attn_mask = tf.math.greater( (slf_attn_mask_keypad + slf_attn_mask_subseq), 0) dec_enc_attn_mask = common.get_attn_key_pad_mask( src_tensor, tgt_tensor, tgt_max_len) for i in range(args.dec_stack_layers): with tf.compat.v1.variable_scope( f"num_blocks_{i}", reuse=tf.compat.v1.AUTO_REUSE): dec_output, dec_slf_attn = multi_head_attention( dec_output, dec_output, dec_output, slf_attn_mask, args.n_head, args.d_model, args.d_k, args.d_v, self.dropout_rate, self.initializer, scope="self_attention") dec_output *= non_pad_mask m_dec_output = dec_output dec_output, dec_enc_attn = multi_head_attention( dec_output, enc_output, enc_output, dec_enc_attn_mask, args.n_head, args.d_model, args.d_k, args.d_v, self.dropout_rate, self.initializer, scope="vanilla_attention") dec_output *= non_pad_mask dec_output = position_wise(dec_output, args.d_model, args.d_ff, self.dropout_rate, self.initializer) dec_output *= non_pad_mask dec_output = m_dec_output return dec_output, non_pad_mask
def decode(self, tgt_tensor, tgt_max_len, src_tensor, enc_output): args = self.args with tf.compat.v1.variable_scope("decode", reuse=tf.compat.v1.AUTO_REUSE): # decode embedding dec_output = tf.nn.embedding_lookup(self.word_embedding, tgt_tensor) dec_output *= args.d_model**0.5 dec_output = tf.nn.dropout(dec_output, keep_prob=self.dropout_rate) # decode mask non_pad_mask = common.get_non_pad_mask(tgt_tensor) slf_attn_mask_subseq = common.get_subsequent_mask( tgt_tensor, self.batch_size, tgt_max_len) slf_attn_mask_keypad = common.get_attn_key_pad_mask( tgt_tensor, tgt_tensor, tgt_max_len) slf_attn_mask = tf.math.greater( (slf_attn_mask_keypad + slf_attn_mask_subseq), 0) dec_enc_attn_mask = common.get_attn_key_pad_mask( src_tensor, tgt_tensor, tgt_max_len) # decode for i in range(args.dec_stack_layers): with tf.compat.v1.variable_scope( "num_blocks_{}".format(i), reuse=tf.compat.v1.AUTO_REUSE): dec_output, dec_slf_attn = multi_head_attention( dec_output, dec_output, dec_output, slf_attn_mask, args.n_head, args.d_model, args.d_k, args.d_v, self.dropout_rate, self.initializer, scope="self_attention") dec_output *= non_pad_mask m_dec_output = dec_output dec_output, dec_enc_attn = multi_head_attention( dec_output, enc_output, enc_output, dec_enc_attn_mask, args.n_head, args.d_model, args.d_k, args.d_v, self.dropout_rate, self.initializer, scope="vanilla_attention") dec_output *= non_pad_mask dec_output = position_wise(dec_output, args.d_model, args.d_ff, self.dropout_rate, self.initializer) dec_output *= non_pad_mask dec_output = m_dec_output with tf.compat.v1.variable_scope("pointer", reuse=tf.compat.v1.AUTO_REUSE): last_enc_output = tf.layers.dense( enc_output, args.d_model, use_bias=False, kernel_initializer=self.initializer) # bsz slen dim last_enc_output = tf.expand_dims(last_enc_output, 0) # 1 bsz slen dim dec_output_trans = tf.transpose(dec_output, [1, 0, 2]) # tlen bsz dim dec_output_trans = tf.layers.dense( dec_output_trans, args.d_model, kernel_initializer=self.initializer, use_bias=False, name="pointer_decode", reuse=tf.compat.v1.AUTO_REUSE) # tlen bsz dim dec_output_trans = tf.expand_dims(dec_output_trans, 2) # tlen bsz 1 dim attn_encode = tf.nn.tanh(dec_output_trans + last_enc_output) # tlen bsz slen dim attn_encode = tf.layers.dense( attn_encode, 1, kernel_initializer=self.initializer, use_bias=False, name="pointer_v", reuse=tf.compat.v1.AUTO_REUSE) # tlen bsz slen 1 attn_encode = tf.transpose(tf.squeeze(attn_encode, 3), [1, 0, 2]) # bsz tlen slen distributes = tf.nn.log_softmax(attn_encode, axis=-1) + 1e-9 return distributes, dec_output