def call(self, inputs, **kwargs): q, v = inputs seq_len = K.shape(q)[1] position_ids = K.arange(0, seq_len + 1, 'int32') # 增加一个虚拟头结点来解耦cls x = K.gather(self.embeddings, position_ids) x = self.pos_ln(x) q = self.q_dense(x) * self.pos_scaling k = self.k_dense(x) q = K.reshape(q, (seq_len+1, self.num_attention_heads, -1)) k = K.reshape(k, (seq_len+1, self.num_attention_heads, -1)) abs_pos_bias = tf.einsum('jhd,khd->hjk', q, k) # p_0 \dot p_0 is cla to others cls_2_other = abs_pos_bias[:, 0, 0] # p_1 \dot p_1 is others to cls other_2_cls = abs_pos_bias[:, 1, 1] # offset abs_pos_bias = abs_pos_bias[:, 1:, 1:] abs_pos_bias[:, :, 0] = K.reshape(other_2_cls, (-1, 1)) abs_pos_bias[:, 0, :] = K.reshape(cls_2_other, (-1, 1)) if self.relative_position_bias: rel_pos_bias = self.compute_rel_pos_bias(inputs) abs_pos_bias += rel_pos_bias return abs_pos_bias
def call(self, inputs): relative_position_idx = self.compute_position_idx(inputs) return K.gather(self.embeddings, relative_position_idx)