Exemplo n.º 1
0
    def call(self, inputs, mask=None, a_mask=None, position_bias=None):
        """
        多头注意力
        :param inputs: [q, k, v, a_mask, position_bias]
        :param mask: [q_mask, v_mask],
            q_mask 对query序列进行mask,针对padding;v_mask对value序列进行mask,防止看到某些位置value,如padding
        :param a_mask: Boolean,是否对attention进行mask
        :param position_bias: type of position bias, 使用指定类型的位置编码对attention里的位置进行偏移
        :return:
        """
        q, k, v = inputs[:3]
        q_mask, v_mask, idx = None, None, 3
        if mask is not None:
            if mask[0] is not None:
                q_mask = K.cast(mask[0], K.floatx())
            if mask[2] is not None:
                v_mask = K.cast(mask[2], K.floatx())
        if a_mask is not None:
            a_mask = inputs[idx]
            idx += 1

        # 投影变换
        qw = self.q_dense(q)
        kw = self.k_dense(k)
        vw = self.v_dense(v)

        # 形状变换
        qw = K.reshape(qw, [-1, K.shape(q)[1], self.head_nums, self.key_size])
        kw = K.reshape(kw, [-1, K.shape(k)[1], self.head_nums, self.key_size])
        vw = K.reshape(vw, [-1, K.shape(v)[1], self.head_nums, self.head_size])
        # 计算attention
        att = tf.einsum('bjhd,bkhd->bhjk', qw, kw)
        # 处理位置编码
        if position_bias == 'relative':
            position_embeddings = inputs[idx]
            att = att + tf.einsum('bjhd,jkd->bhjk', qw, position_embeddings)

        if self.attention_scale:
            att = att / self.key_size**0.5

        # value mask
        att = sequence_masking(att, v_mask, 'add', -1)
        # attention mask
        if a_mask is not None:
            att = att - (1 - a_mask) * 1e12

        att = K.softmax(att)
        output = tf.einsum('bhjk,bkhd->bjhd', att, vw)
        # 继续处理位置编码
        if position_bias == 'relative':
            output = output + tf.einsum('bhjk,jkd->bjhd', att,
                                        position_embeddings)
        output = K.reshape(output, (-1, K.shape(output)[1], self.output_dim))
        output = self.combine_dense(output)
        # query mask
        output = sequence_masking(output, q_mask, 'mul')
        return output
Exemplo n.º 2
0
    def call(self, inputs, mask=None):
        # 只是计算loss,并不改变输入
        if mask is not None:
            mask = K.cast(mask, K.floatx())

        return sequence_masking(inputs, mask, 1, 1)