Ejemplo n.º 1
0
 def compute_position_ids(self, inputs):
     """T5的相对位置分桶(直接翻译自官方T5源码)
     i-i:   0 1 2 3 4 5 6 7 8 9 10 11 12 13 14...
     f(i-j):0 1 2 3 4 5 6 7 8 8 8  8  9   9  9 ...
     """
     q, v = inputs
     # 计算位置差
     q_idxs = K.arange(0, K.shape(q)[1], dtype='int32')
     q_idxs = K.expand_dims(q_idxs, 1)
     v_idxs = K.arange(0, K.shape(v)[1], dtype='int32')
     v_idxs = K.expand_dims(v_idxs, 0)
     pos_ids = v_idxs - q_idxs
     # 后处理操作
     num_buckets, max_distance = self.input_dim, self.max_distance
     ret = 0
     n = -pos_ids
     if self.bidirectional:
         num_buckets //= 2
         ret += K.cast(K.less(n, 0), 'int32') * num_buckets
         n = K.abs(n)
     else:
         n = K.maximum(n, 0)
     # now n is in the range [0, inf)
     max_exact = num_buckets // 2
     is_small = K.less(n, max_exact)
     val_if_large = max_exact + K.cast(
         K.log(K.cast(n, K.floatx()) / max_exact) /
         np.log(max_distance / max_exact) * (num_buckets - max_exact),
         'int32',
     )
     val_if_large = K.minimum(val_if_large, num_buckets - 1)
     ret += K.switch(is_small, n, val_if_large)
     return ret
Ejemplo n.º 2
0
    def call(self, inputs):
        #     PE_2i(p) = sin(p/10000^(2i/d_pos))
        #     PE_2i+1(p) = cos(p/10000^(2i/d_pos))
        batch_size, seq_len, word_emb_dim = K.shape(inputs)[0], K.shape(
            inputs)[1], K.shape(inputs)[2]
        if not self.embedding_dim or self.method == 'add':
            self.embedding_dim = word_emb_dim
        t = 2 * K.arange(self.embedding_dim / 2, dtype='float32') / K.cast(
            self.embedding_dim, dtype='float32')
        embedding_wise_pos = 1. / K.pow(
            10000., t)  # 1/10000 ^(2i/d_pos) , shape = (p_dim/2, )
        embedding_wise_pos = K.expand_dims(embedding_wise_pos,
                                           0)  # (1, p_dim/2)
        word_wise_pos = K.cumsum(K.ones_like(inputs[:, :, 0]),
                                 axis=1)  # shape = [batch_size, seq_len]
        word_wise_pos = K.expand_dims(word_wise_pos,
                                      2)  # (batch_size, seq_len, 1)
        position_embedding = K.dot(
            word_wise_pos,
            embedding_wise_pos)  # (batch_size, seq_len, p_dim/2)

        position_embedding = K.expand_dims(position_embedding, 3)
        position_embedding = K.reshape(K.concatenate(
            [K.sin(position_embedding),
             K.cos(position_embedding)], axis=-1),
                                       shape=(batch_size, seq_len, -1))

        if self.method == 'add':
            return inputs + position_embedding

        return K.concatenate([inputs, position_embedding], axis=-1)
Ejemplo n.º 3
0
 def log_norm_step(self, inputs, states):
     """递归求解归一化因子"""
     inputs, mask = inputs[:, :-1], inputs[:, -1:]
     states = K.expand_dims(states[0], 2)  # batch_size,output_dim, 1
     trans = K.expand_dims(self.trans, 0)  # 1, output_dim, output_dim
     outputs = K.logsumexp(states + trans, 1)
     outputs += inputs
     outputs = mask * outputs + (1 - mask) * states[:, :, 0]
     return outputs, [outputs]
Ejemplo n.º 4
0
 def compute_position_idx(self, inputs):
     q, v = inputs
     q_idx = K.arange(0, K.shape(q)[1], dtype='int32')
     q_idx = K.expand_dims(q_idx, 1)
     v_idx = K.arange(0, K.shape(v)[1], dtype='int32')
     v_idx = K.expand_dims(v_idx, 0)
     # 相对位置差
     position_idx = v_idx - q_idx
     max_position = (self.input_dim - 1) // 2
     position_idx = K.clip(position_idx, -max_position, max_position)
     position_idx = position_idx + max_position
     return position_idx
Ejemplo n.º 5
0
    def call(self, inputs):
        """
        conditional 时, condition 放在inputs后面,[inputs, condition]
        """
        if self.conditional:
            inputs, cond = inputs
            if self.condition_hidden_units is not None:
                cond = self.condition_hidden_dense(cond)
            # 适配cond维度,与inputs保持一致
            for _ in range(K.ndim(inputs) - K.ndim(cond)):
                cond = K.expand_dims(cond, 1)
            if self.center:
                beta = self.beta_dense(cond) + self.beta
            if self.scale:
                gamma = self.gamma_dense(cond) + self.gamma
        else:
            beta = self.beta
            gamma = self.gamma

        output = inputs
        if self.center:
            mean = K.mean(inputs, axis=-1, keepdims=True)
            output = output - mean

        if self.scale:
            var = K.mean(K.square(output), axis=-1, keepdims=True)
            std = K.sqrt(var + self.epsilon)
            output = output / std
            output = output * gamma

        if self.center:
            output = output + beta

        return output
Ejemplo n.º 6
0
    def compute_mask(self, inputs, mask=None):
        if self.conditional:
            masks = mask if mask is not None else []
            masks = [K.expand_dims(m, 0) for m in masks if m is not None]
            if len(masks) == 0:
                return None
            else:
                return K.all(K.concatenate(masks, axis=0), axis=0)

        return mask
Ejemplo n.º 7
0
 def call(self, x, mask=None):
     x0 = x
     x = self.k_dense(x0)
     x = self.o_dense(x)
     if mask is not None:
         mask = K.cast(mask, K.floatx())
         mask = K.expand_dims(mask, 2)
         x = x - (1 - mask) * 1e12
     x = K.softmax(x, 1)
     x = K.sum(x0 * x, 1)
     return x
Ejemplo n.º 8
0
    def call(self, inputs):
        input_shape = K.shape(inputs)
        batch_size, seq_length = input_shape[0], input_shape[1]
        pos_embedding = self.embeddings[:seq_length]
        pos_embedding = K.expand_dims(pos_embedding, 0)
        if self.merge_mode != 'add':
            pos_embedding = K.tile(pos_embedding, [batch_size, 1, 1])

        if self.merge_mode == 'add':
            return inputs + pos_embedding
        return K.concatenate([inputs, pos_embedding], axis=-1)
Ejemplo n.º 9
0
    def call(self, x, mask=None):
        x0 = x
        if mask is not None:
            mask = K.cast(mask, K.floatx())
            mask = K.expand_dims(mask, 2)
        #         x = x0 * mask if mask is not None else x0
        x0 = Lambda(lambda x_: x_, output_shape=lambda s: s)(x0)  # drop mask so do not put mask to conv1d
        x = self.conv1d(x0)
        x, g = x[:, :, :self.o_dim], x[:, :, self.o_dim:]
        if self.dropout_rate is not None:
            g = K.in_train_phase(K.dropout(g, self.dropout_rate), g)
        g = K.sigmoid(g)
        # mask is none
        mask = mask if mask is not None else K.ones_like(x)

        if self.skip_connection:
            if K.int_shape(x0)[-1] != self.o_dim:
                x0 = self.conv1d_1x1(x0)
            return (x0 * (1 - g) + x * g) * mask
        return x * g * mask
Ejemplo n.º 10
0
    def call(self, inputs, **kwargs):
        logits, token_seq = inputs[:2]
        seq_shape = K.shape(token_seq)
        batch_size, seq_length = seq_shape[0], seq_shape[1]
        if self.pad_token_id is None:
            sequence_lengths = [seq_length - 1] * batch_size
        else:
            sequence_lengths = (
                    K.sum(
                        K.cast(
                            K.not_equal(token_seq, self.pad_token_id),
                            dtype='int32',
                        ),
                        -1,
                        keepdims=False,
                    )
                    - 1
            )
        # only tf2
        # return tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)

        indices = K.expand_dims(sequence_lengths, -1)
        return tf.gather_nd(logits, indices, batch_dims=1)
Ejemplo n.º 11
0
 def call(self, inputs, mask=None):
     if mask is not None:
         mask = K.cast(mask, K.floatx())
         mask = K.expand_dims(mask, 2)
         inputs = inputs - (1.0 - mask) * 1e12
     return K.softmax(inputs, 1)
Ejemplo n.º 12
0
 def call(self, x):
     seq, vec = x
     vec = K.expand_dims(vec, 1)
     vec = K.tile(vec, [1, K.shape(seq)[1], 1])
     return K.concatenate([seq, vec], 2)