def call(self, x): x, mask = x mask = K.squeeze(mask, axis=2) # linear key = K.bias_add(K.dot(x, self.weight), self.bias) # compute attention outputs = K.squeeze(K.dot(key, self.query), axis=2) outputs -= 1e32 * (1 - mask) attn_scores = K.softmax(outputs) attn_scores *= mask attn_scores = K.reshape(attn_scores, shape=(-1, 1, attn_scores.shape[-1])) outputs = K.squeeze(K.batch_dot(attn_scores, key), axis=1) return outputs
def call(self, x): x, mask = x # 因为 self.weight只有两个维度,所以这里要进行维度处理 mask = K.squeeze(mask, axis=2) # 维度压缩 去掉一个维度 axis=2,但是数据还是不变的 # linear 线性变化 # K.dot()进行 点乘,然后加了self.bias key = K.bias_add(K.dot(x, self.weight), self.bias) # compute attention outputs = K.squeeze(K.dot(key, self.query), axis=2) # 计算注意力 outputs -= 1e32 * (1 - mask) attn_scores = K.softmax(outputs) # 使用 softmax 计算得分 attn_scores *= mask attn_scores = K.reshape(attn_scores, shape=(-1, 1, attn_scores.shape[-1])) outputs = K.squeeze(K.batch_dot(attn_scores, key), axis=1) return outputs
def call(self, inputs, q_mask=False, v_mask=False, a_mask=False): """实现多头注意力 q_mask: 对输入的query序列的mask。 主要是将输出结果的padding部分置0。 v_mask: 对输入的value序列的mask。 主要是防止attention读取到padding信息。 a_mask: 对attention矩阵的mask。 不同的attention mask对应不同的应用。 """ q, k, v = inputs[:3] # 处理mask idx = 3 if q_mask: q_mask = inputs[idx] idx += 1 else: q_mask = None if v_mask: v_mask = inputs[idx] idx += 1 else: v_mask = None if a_mask: if len(inputs) > idx: a_mask = inputs[idx] else: a_mask = 'history_only' else: a_mask = None # 线性变换 qw = self.q_dense(q) kw = self.k_dense(k) vw = self.v_dense(v) input_shape = K.shape(q) batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2] # 形状变换 qw = K.reshape(qw, (-1, K.shape(q)[1], self.heads, self.key_size)) kw = K.reshape(kw, (-1, K.shape(k)[1], self.heads, self.key_size)) vw = K.reshape(vw, (-1, K.shape(v)[1], self.heads, self.head_size)) qw = K.permute_dimensions(qw, [0, 2, 1, 3]) kw = K.permute_dimensions(kw, [0, 2, 1, 3]) vw = K.permute_dimensions(vw, [0, 2, 1, 3]) qw = K.reshape(qw, (batch_size * self.heads, seq_len, self.key_size)) kw = K.reshape(kw, (batch_size * self.heads, seq_len, self.key_size)) vw = K.reshape(vw, (batch_size * self.heads, seq_len, self.head_size)) a = K.batch_dot(qw, kw, axes=2) / K.sqrt(K.cast(self.key_size, dtype=K.floatx())) a = K.softmax(a) o = K.batch_dot(a, vw) o = K.reshape(o, (batch_size, self.heads, seq_len, self.head_size)) o = K.permute_dimensions(o, [0, 2, 1, 3]) o = K.reshape(o, (batch_size, seq_len, self.out_dim)) o = self.o_dense(o) o = sequence_masking(o, q_mask, 0) # # Attention # a = tf.einsum('bjhd,bkhd->bhjk', qw, kw) / self.key_size**0.5 # a = sequence_masking(a, v_mask, 1, -1) # if a_mask is not None: # if is_string(a_mask) and a_mask == 'history_only': # ones = K.ones_like(a[:1, :1]) # a_mask = (ones - tf.linalg.band_part(ones, -1, 0)) * 1e12 # a = a - a_mask # else: # a = a - (1 - a_mask) * 1e12 # a = K.softmax(a) # 完成输出 # o = tf.einsum('bhjk,bkhd->bjhd', a, vw) # o = K.reshape(o, (-1, K.shape(o)[1], self.out_dim)) # o = self.o_dense(o) # o = sequence_masking(o, q_mask, 0) return o
def call(self, inputs, q_mask=False, v_mask=False, a_mask=False): """实现多头注意力 q_mask: 对输入的query序列的mask。 主要是将输出结果的padding部分置0。 v_mask: 对输入的value序列的mask。 主要是防止attention读取到padding信息。 a_mask: 对attention矩阵的mask。 不同的attention mask对应不同的应用。 """ q, k, v = inputs[:3] # 处理mask idx = 3 if q_mask: q_mask = inputs[idx] idx += 1 else: q_mask = None if v_mask: v_mask = inputs[idx] idx += 1 else: v_mask = None if a_mask: if len(inputs) > idx: a_mask = inputs[idx] else: a_mask = 'history_only' else: a_mask = None # 线性变换 qw = self.q_dense(q) kw = self.k_dense(k) vw = self.v_dense(v) # 形状变换 qw = K.reshape(qw, (-1, K.shape(q)[1], self.heads, self.key_size)) kw = K.reshape(kw, (-1, K.shape(k)[1], self.heads, self.key_size)) vw = K.reshape(vw, (-1, K.shape(v)[1], self.heads, self.head_size)) # 维度置换 qw = K.permute_dimensions(qw, (0, 2, 1, 3)) kw = K.permute_dimensions(kw, (0, 2, 1, 3)) vw = K.permute_dimensions(vw, (0, 2, 1, 3)) # 转为三阶张量 qw = K.reshape(qw, (-1, K.shape(q)[1], self.key_size)) kw = K.reshape(kw, (-1, K.shape(k)[1], self.key_size)) vw = K.reshape(vw, (-1, K.shape(v)[1], self.head_size)) # Attention a = K.batch_dot(qw, kw, [2, 2]) / self.key_size**0.5 a = sequence_masking(a, v_mask, 1, -1, self.heads) if a_mask is not None: if is_string(a_mask) and a_mask == 'history_only': ones = K.ones_like(a[:1]) a_mask = (ones - tf.linalg.band_part(ones, -1, 0)) * 1e12 a = a - a_mask else: a = a - (1 - a_mask) * 1e12 a = K.softmax(a) # 完成输出 o = K.batch_dot(a, vw, [2, 1]) o = K.reshape(o, (-1, self.heads, K.shape(q)[1], self.head_size)) o = K.permute_dimensions(o, (0, 2, 1, 3)) o = K.reshape(o, (-1, K.shape(o)[1], self.out_dim)) o = self.o_dense(o) o = sequence_masking(o, q_mask, 0) return o