예제 #1
0
    def encode(self, src_tensor, src_postion, turns_tensor):
        encode = self.word_embedding(src_tensor) + self.pos_embedding(
            src_postion) + self.turn_embedding(turns_tensor)

        slf_attn_mask = common.get_attn_key_pad_mask(src_tensor, src_tensor)
        non_pad_mask = common.get_non_pad_mask(src_tensor)

        enc_output = self.enc(encode, slf_attn_mask, non_pad_mask)

        return enc_output
예제 #2
0
    def decode(self, tgt_tensor, src_tensor, enc_output):
        dec_output = self.word_embedding(tgt_tensor)
        dec_output = self.droupout(dec_output)

        non_pad_mask = common.get_non_pad_mask(tgt_tensor)
        slf_attn_mask_subseq = common.get_subsequent_mask(tgt_tensor)
        slf_attn_mask_keypad = common.get_attn_key_pad_mask(
            tgt_tensor, tgt_tensor, True)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        dec_enc_attn_mask = common.get_attn_key_pad_mask(
            src_tensor, tgt_tensor)

        dec_output, m_dec_output = self.dec(dec_output, enc_output,
                                            non_pad_mask, slf_attn_mask,
                                            dec_enc_attn_mask)

        distributes = self.attention(m_dec_output, enc_output)

        return distributes
예제 #3
0
    def forward(self, src_tensor, src_postion, turns_tensor, tgt_tensor):
        # encode embedding
        encode = self.word_embedding(src_tensor) + self.pos_embedding(
            src_postion) + self.turn_embedding(turns_tensor)
        encode = self.droupout(encode)

        # encode mask
        slf_attn_mask = common.get_attn_key_pad_mask(src_tensor, src_tensor)
        non_pad_mask = common.get_non_pad_mask(src_tensor)

        # encode
        enc_output = self.enc(encode, slf_attn_mask, non_pad_mask)

        # decode embedding
        dec_output = self.word_embedding(tgt_tensor)
        dec_output = self.droupout(dec_output)

        # decode mask
        non_pad_mask = common.get_non_pad_mask(tgt_tensor)
        slf_attn_mask_subseq = common.get_subsequent_mask(tgt_tensor)
        slf_attn_mask_keypad = common.get_attn_key_pad_mask(
            tgt_tensor, tgt_tensor, True)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        dec_enc_attn_mask = common.get_attn_key_pad_mask(
            src_tensor, tgt_tensor)

        # decode
        dec_output, m_dec_output = self.dec(dec_output, enc_output,
                                            non_pad_mask, slf_attn_mask,
                                            dec_enc_attn_mask)

        # pointer network
        distributes = self.attention(m_dec_output, enc_output)

        return distributes
예제 #4
0
    def encode(self, src_tensor, src_postion, turns_tensor, src_max_len):
        args = self.args

        with tf.compat.v1.variable_scope("encode",
                                         reuse=tf.compat.v1.AUTO_REUSE):
            # embedding
            enc_output = tf.nn.embedding_lookup(self.word_embedding,
                                                src_tensor)
            enc_output *= args.d_model**0.5

            enc_output += tf.nn.embedding_lookup(self.pos_embedding,
                                                 src_postion)

            turn_enc_output = tf.nn.embedding_lookup(self.turn_embedding,
                                                     turns_tensor)
            enc_output += turn_enc_output * (args.d_model**0.5)

            enc_output = tf.nn.dropout(enc_output, keep_prob=self.dropout_rate)

            # encode mask
            slf_attn_mask = common.get_attn_key_pad_mask(
                src_tensor, src_tensor, src_max_len)
            non_pad_mask = common.get_non_pad_mask(src_tensor)

            # encode
            for i in range(args.enc_stack_layers):
                with tf.compat.v1.variable_scope(
                        "num_blocks_{}".format(i),
                        reuse=tf.compat.v1.AUTO_REUSE):
                    enc_output, enc_slf_attn = multi_head_attention(
                        enc_output, enc_output, enc_output, slf_attn_mask,
                        args.n_head, args.d_model, args.d_k, args.d_v,
                        self.dropout_rate, self.initializer)
                    enc_output *= non_pad_mask

                    enc_output = position_wise(enc_output, args.d_model,
                                               args.d_ff, self.dropout_rate,
                                               self.initializer)
                    enc_output *= non_pad_mask

        return enc_output, non_pad_mask
예제 #5
0
    def decode(self, tgt_tensor, tgt_postion, tgt_max_len, src_tensor,
               enc_output):
        args = self.args

        with tf.compat.v1.variable_scope("decode",
                                         reuse=tf.compat.v1.AUTO_REUSE):

            dec_output = tf.nn.embedding_lookup(self.word_embedding,
                                                tgt_tensor)
            dec_output *= args.d_model**0.5

            dec_output += tf.nn.embedding_lookup(self.pos_embedding,
                                                 tgt_postion)

            dec_output = tf.nn.dropout(dec_output, keep_prob=self.dropout_rate)

            # decode mask
            non_pad_mask = common.get_non_pad_mask(tgt_tensor)
            slf_attn_mask_subseq = common.get_subsequent_mask(
                tgt_tensor, self.batch_size, tgt_max_len)
            slf_attn_mask_keypad = common.get_attn_key_pad_mask(
                tgt_tensor, tgt_tensor, tgt_max_len)
            slf_attn_mask = tf.math.greater(
                (slf_attn_mask_keypad + slf_attn_mask_subseq), 0)
            dec_enc_attn_mask = common.get_attn_key_pad_mask(
                src_tensor, tgt_tensor, tgt_max_len)

            for i in range(args.dec_stack_layers):
                with tf.compat.v1.variable_scope(
                        f"num_blocks_{i}", reuse=tf.compat.v1.AUTO_REUSE):
                    dec_output, dec_slf_attn = multi_head_attention(
                        dec_output,
                        dec_output,
                        dec_output,
                        slf_attn_mask,
                        args.n_head,
                        args.d_model,
                        args.d_k,
                        args.d_v,
                        self.dropout_rate,
                        self.initializer,
                        scope="self_attention")
                    dec_output *= non_pad_mask
                    m_dec_output = dec_output

                    dec_output, dec_enc_attn = multi_head_attention(
                        dec_output,
                        enc_output,
                        enc_output,
                        dec_enc_attn_mask,
                        args.n_head,
                        args.d_model,
                        args.d_k,
                        args.d_v,
                        self.dropout_rate,
                        self.initializer,
                        scope="vanilla_attention")
                    dec_output *= non_pad_mask

                    dec_output = position_wise(dec_output, args.d_model,
                                               args.d_ff, self.dropout_rate,
                                               self.initializer)
                    dec_output *= non_pad_mask

            dec_output = m_dec_output

        return dec_output, non_pad_mask
예제 #6
0
    def decode(self, tgt_tensor, tgt_max_len, src_tensor, enc_output):
        args = self.args

        with tf.compat.v1.variable_scope("decode",
                                         reuse=tf.compat.v1.AUTO_REUSE):
            # decode embedding
            dec_output = tf.nn.embedding_lookup(self.word_embedding,
                                                tgt_tensor)
            dec_output *= args.d_model**0.5
            dec_output = tf.nn.dropout(dec_output, keep_prob=self.dropout_rate)

            # decode mask
            non_pad_mask = common.get_non_pad_mask(tgt_tensor)
            slf_attn_mask_subseq = common.get_subsequent_mask(
                tgt_tensor, self.batch_size, tgt_max_len)
            slf_attn_mask_keypad = common.get_attn_key_pad_mask(
                tgt_tensor, tgt_tensor, tgt_max_len)
            slf_attn_mask = tf.math.greater(
                (slf_attn_mask_keypad + slf_attn_mask_subseq), 0)
            dec_enc_attn_mask = common.get_attn_key_pad_mask(
                src_tensor, tgt_tensor, tgt_max_len)

            # decode
            for i in range(args.dec_stack_layers):
                with tf.compat.v1.variable_scope(
                        "num_blocks_{}".format(i),
                        reuse=tf.compat.v1.AUTO_REUSE):
                    dec_output, dec_slf_attn = multi_head_attention(
                        dec_output,
                        dec_output,
                        dec_output,
                        slf_attn_mask,
                        args.n_head,
                        args.d_model,
                        args.d_k,
                        args.d_v,
                        self.dropout_rate,
                        self.initializer,
                        scope="self_attention")
                    dec_output *= non_pad_mask
                    m_dec_output = dec_output

                    dec_output, dec_enc_attn = multi_head_attention(
                        dec_output,
                        enc_output,
                        enc_output,
                        dec_enc_attn_mask,
                        args.n_head,
                        args.d_model,
                        args.d_k,
                        args.d_v,
                        self.dropout_rate,
                        self.initializer,
                        scope="vanilla_attention")
                    dec_output *= non_pad_mask

                    dec_output = position_wise(dec_output, args.d_model,
                                               args.d_ff, self.dropout_rate,
                                               self.initializer)
                    dec_output *= non_pad_mask

                    dec_output = m_dec_output

        with tf.compat.v1.variable_scope("pointer",
                                         reuse=tf.compat.v1.AUTO_REUSE):
            last_enc_output = tf.layers.dense(
                enc_output,
                args.d_model,
                use_bias=False,
                kernel_initializer=self.initializer)  # bsz slen dim
            last_enc_output = tf.expand_dims(last_enc_output,
                                             0)  # 1 bsz slen dim

            dec_output_trans = tf.transpose(dec_output,
                                            [1, 0, 2])  # tlen bsz dim
            dec_output_trans = tf.layers.dense(
                dec_output_trans,
                args.d_model,
                kernel_initializer=self.initializer,
                use_bias=False,
                name="pointer_decode",
                reuse=tf.compat.v1.AUTO_REUSE)  # tlen bsz dim
            dec_output_trans = tf.expand_dims(dec_output_trans,
                                              2)  # tlen bsz 1 dim

            attn_encode = tf.nn.tanh(dec_output_trans +
                                     last_enc_output)  # tlen bsz slen dim
            attn_encode = tf.layers.dense(
                attn_encode,
                1,
                kernel_initializer=self.initializer,
                use_bias=False,
                name="pointer_v",
                reuse=tf.compat.v1.AUTO_REUSE)  # tlen bsz slen 1
            attn_encode = tf.transpose(tf.squeeze(attn_encode, 3),
                                       [1, 0, 2])  # bsz tlen slen
            distributes = tf.nn.log_softmax(attn_encode, axis=-1) + 1e-9

        return distributes, dec_output