Esempio n. 1
0
 def compute_position_ids(self, inputs):
     """T5的相对位置分桶(直接翻译自官方T5源码)
     i-i:   0 1 2 3 4 5 6 7 8 9 10 11 12 13 14...
     f(i-j):0 1 2 3 4 5 6 7 8 8 8  8  9   9  9 ...
     """
     q, v = inputs
     # 计算位置差
     q_idxs = K.arange(0, K.shape(q)[1], dtype='int32')
     q_idxs = K.expand_dims(q_idxs, 1)
     v_idxs = K.arange(0, K.shape(v)[1], dtype='int32')
     v_idxs = K.expand_dims(v_idxs, 0)
     pos_ids = v_idxs - q_idxs
     # 后处理操作
     num_buckets, max_distance = self.input_dim, self.max_distance
     ret = 0
     n = -pos_ids
     if self.bidirectional:
         num_buckets //= 2
         ret += K.cast(K.less(n, 0), 'int32') * num_buckets
         n = K.abs(n)
     else:
         n = K.maximum(n, 0)
     # now n is in the range [0, inf)
     max_exact = num_buckets // 2
     is_small = K.less(n, max_exact)
     val_if_large = max_exact + K.cast(
         K.log(K.cast(n, K.floatx()) / max_exact) /
         np.log(max_distance / max_exact) * (num_buckets - max_exact),
         'int32',
     )
     val_if_large = K.minimum(val_if_large, num_buckets - 1)
     ret += K.switch(is_small, n, val_if_large)
     return ret
Esempio n. 2
0
 def compute_position_idx(self, inputs):
     q, v = inputs
     q_idx = K.arange(0, K.shape(q)[1], dtype='int32')
     q_idx = K.expand_dims(q_idx, 1)
     v_idx = K.arange(0, K.shape(v)[1], dtype='int32')
     v_idx = K.expand_dims(v_idx, 0)
     # 相对位置差
     position_idx = v_idx - q_idx
     max_position = (self.input_dim - 1) // 2
     position_idx = K.clip(position_idx, -max_position, max_position)
     position_idx = position_idx + max_position
     return position_idx
 def get_labels_of_similarity(self, y_pred):
     idxs = K.arange(0, K.shape(y_pred)[0])
     idxs_1 = idxs[None, :]
     idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
     labels = K.equal(idxs_1, idxs_2)
     labels = K.cast(labels, K.floatx())
     return labels
Esempio n. 4
0
    def call(self, inputs):
        #     PE_2i(p) = sin(p/10000^(2i/d_pos))
        #     PE_2i+1(p) = cos(p/10000^(2i/d_pos))
        batch_size, seq_len, word_emb_dim = K.shape(inputs)[0], K.shape(
            inputs)[1], K.shape(inputs)[2]
        if not self.embedding_dim or self.method == 'add':
            self.embedding_dim = word_emb_dim
        t = 2 * K.arange(self.embedding_dim / 2, dtype='float32') / K.cast(
            self.embedding_dim, dtype='float32')
        embedding_wise_pos = 1. / K.pow(
            10000., t)  # 1/10000 ^(2i/d_pos) , shape = (p_dim/2, )
        embedding_wise_pos = K.expand_dims(embedding_wise_pos,
                                           0)  # (1, p_dim/2)
        word_wise_pos = K.cumsum(K.ones_like(inputs[:, :, 0]),
                                 axis=1)  # shape = [batch_size, seq_len]
        word_wise_pos = K.expand_dims(word_wise_pos,
                                      2)  # (batch_size, seq_len, 1)
        position_embedding = K.dot(
            word_wise_pos,
            embedding_wise_pos)  # (batch_size, seq_len, p_dim/2)

        position_embedding = K.expand_dims(position_embedding, 3)
        position_embedding = K.reshape(K.concatenate(
            [K.sin(position_embedding),
             K.cos(position_embedding)], axis=-1),
                                       shape=(batch_size, seq_len, -1))

        if self.method == 'add':
            return inputs + position_embedding

        return K.concatenate([inputs, position_embedding], axis=-1)
Esempio n. 5
0
 def get_labels_of_similarity(self, inputs):
     idx = K.arange(0, K.shape(inputs)[0])
     idx_1 = idx[None, :]
     idx_2 = (idx + 1 - idx % 2 * 2)[:, None]
     labels = K.equal(idx_1, idx_2)
     labels = K.cast(labels, K.floatx())
     return labels
Esempio n. 6
0
    def call(self, inputs, **kwargs):
        q, v = inputs
        seq_len = K.shape(q)[1]
        position_ids = K.arange(0, seq_len + 1, 'int32')  # 增加一个虚拟头结点来解耦cls
        x = K.gather(self.embeddings, position_ids)
        x = self.pos_ln(x)
        q = self.q_dense(x) * self.pos_scaling
        k = self.k_dense(x)
        q = K.reshape(q, (seq_len+1, self.num_attention_heads, -1))
        k = K.reshape(k, (seq_len+1, self.num_attention_heads, -1))

        abs_pos_bias = tf.einsum('jhd,khd->hjk', q, k)
        # p_0 \dot p_0 is  cla to others
        cls_2_other = abs_pos_bias[:, 0, 0]
        # p_1 \dot p_1 is others to cls
        other_2_cls = abs_pos_bias[:, 1, 1]
        # offset
        abs_pos_bias = abs_pos_bias[:, 1:, 1:]
        abs_pos_bias[:, :, 0] = K.reshape(other_2_cls, (-1, 1))
        abs_pos_bias[:, 0, :] = K.reshape(cls_2_other, (-1, 1))
        if self.relative_position_bias:
            rel_pos_bias = self.compute_rel_pos_bias(inputs)
            abs_pos_bias += rel_pos_bias
        return abs_pos_bias