def call(self, inputs, mask=None, a_mask=None, p_bias=None): """实现多头注意力 q_mask: 对输入的query序列的mask。 主要是将输出结果的padding部分置0。 v_mask: 对输入的value序列的mask。 主要是防止attention读取到padding信息。 a_mask: 对attention矩阵的mask。 不同的attention mask对应不同的应用。 p_bias: 在attention里的位置偏置。 一般用来指定相对位置编码的种类。 """ q, k, v = inputs[:3] q_mask, v_mask, n = None, None, 3 if mask is not None: if mask[0] is not None: q_mask = K.cast(mask[0], K.floatx()) if mask[2] is not None: v_mask = K.cast(mask[2], K.floatx()) if a_mask: a_mask = inputs[n] n += 1 # 线性变换 qw = self.q_dense(q) kw = self.k_dense(k) vw = self.v_dense(v) # 形状变换 qw = K.reshape(qw, (-1, K.shape(q)[1], self.heads, self.key_size)) kw = K.reshape(kw, (-1, K.shape(k)[1], self.heads, self.key_size)) vw = K.reshape(vw, (-1, K.shape(v)[1], self.heads, self.head_size)) # Attention a = tf.einsum('bjhd,bkhd->bhjk', qw, kw) # 处理位置编码 if p_bias == 'typical_relative': pos_embeddings = inputs[n] a = a + tf.einsum('bjhd,jkd->bhjk', qw, pos_embeddings) elif p_bias == 't5_relative': pos_embeddings = K.permute_dimensions(inputs[n], (2, 0, 1)) a = a + K.expand_dims(pos_embeddings, 0) # Attention(续) if self.attention_scale: a = a / self.key_size**0.5 a = sequence_masking(a, v_mask, 1, -1) if a_mask is not None: a = a - (1 - a_mask) * 1e12 a = K.softmax(a) # 完成输出 o = tf.einsum('bhjk,bkhd->bjhd', a, vw) if p_bias == 'typical_relative': o = o + tf.einsum('bhjk,jkd->bjhd', a, pos_embeddings) o = K.reshape(o, (-1, K.shape(o)[1], self.out_dim)) o = self.o_dense(o) # 返回结果 o = sequence_masking(o, q_mask, 0) return o
def call(self, inputs, mask=None): if mask is not None: mask = K.cast(mask, K.floatx()) return sequence_masking(inputs, mask, 1, 1)