Esempio n. 1
0
 def _reshape_to_batches(x, head_num):
     input_shape = K.shape(x)
     batch_size, seq_len, feature_dim = input_shape[0], input_shape[
         1], input_shape[2]
     head_dim = feature_dim // head_num
     x = K.reshape(x, (batch_size, seq_len, head_num, head_dim))
     x = K.permute_dimensions(x, [0, 2, 1, 3])
     return K.reshape(x, (batch_size * head_num, seq_len, head_dim))
Esempio n. 2
0
 def _reshape_mask(mask, head_num):
     if mask is None:
         return mask
     seq_len = K.shape(mask)[1]
     mask = K.expand_dims(mask, axis=1)
     mask = K.tile(mask, K.stack([1, head_num, 1]))
     return K.reshape(mask, (-1, seq_len))
Esempio n. 3
0
 def call(self, inputs, mask=None):
     if isinstance(inputs, list):
         q, k, v = inputs
     else:
         q = k = v = inputs
     if isinstance(mask, list):
         q_mask, k_mask, v_mask = mask
     else:
         q_mask = k_mask = v_mask = mask
     q = K.dot(q, self.Wq)
     k = K.dot(k, self.Wk)
     v = K.dot(v, self.Wv)
     if self.use_bias:
         q += self.bq
         k += self.bk
         v += self.bv
     if self.activation is not None:
         q = self.activation(q)
         k = self.activation(k)
         v = self.activation(v)
     y = ScaledDotProductAttention(
         history_only=self.history_only,
         name='%s-Attention' % self.name,
     )(
         inputs=[
             self._reshape_to_batches(q, self.head_num),
             self._reshape_to_batches(k, self.head_num),
             self._reshape_to_batches(v, self.head_num),
         ],
         mask=[
             self._reshape_mask(q_mask, self.head_num),
             self._reshape_mask(k_mask, self.head_num),
             self._reshape_mask(v_mask, self.head_num),
         ],
     )
     y = self._reshape_from_batches(y, self.head_num)
     y = K.dot(y, self.Wo)
     if self.use_bias:
         y += self.bo
     if self.activation is not None:
         y = self.activation(y)
     y = K.reshape(y, (-1, 512, 768))
     return y
Esempio n. 4
0
                                          INIT,
                                          seq_len=MAX_SEQ_LEN,
                                          trainable=True)

x = bert.output
x = tf.keras.layers.Lambda(lambda x: x[:, 0], name='Pooler')(x)
x = tf.keras.layers.Dense(units=config['hidden_size'],
                          activation='tanh',
                          name='Pooler-Dense')(x)

output = tf.keras.layers.Dropout(rate=0.1)(x)
output = tf.keras.layers.Dense(units=1)(output)
model = tf.keras.models.Model(bert.inputs, output)

output = tf.keras.layers.Lambda(lambda x: x[:, 0], name='Squeeze')(output)
toutput = tf.keras.layers.Lambda(lambda x: K.reshape(x, [-1, NUM_TRAIN_CANDS]),
                                 name='Reshape')(output)
tprobs = tf.keras.layers.Softmax(name='Softmax')(toutput)
train_model = tf.keras.models.Model(bert.inputs, tprobs)

poutput = tf.keras.layers.Lambda(lambda x: K.reshape(x, [-1, NUM_CANDS]),
                                 name='Reshape')(output)
pprobs = tf.keras.layers.Softmax(name='Softmax')(poutput)
predict_model = tf.keras.models.Model(bert.inputs, pprobs)

train_model.compile(
    loss='categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(1e-5),  # 用足够小的学习率
    # optimizer=PiecewiseLinearLearningRate(Adam(5e-5), {10000: 1, 30000: 0.1}),
    metrics=['accuracy'],
)