Beispiel #1
0
 def _reshape_mask(mask, head_num):
     if mask is None:
         return mask
     seq_len = K.shape(mask)[1]
     mask = K.expand_dims(mask, axis=1)
     mask = K.tile(mask, K.stack([1, head_num, 1]))
     return K.reshape(mask, (-1, seq_len))
Beispiel #2
0
 def _reshape_to_batches(x, head_num):
     input_shape = K.shape(x)
     batch_size, seq_len, feature_dim = input_shape[0], input_shape[
         1], input_shape[2]
     head_dim = feature_dim // head_num
     x = K.reshape(x, (batch_size, seq_len, head_num, head_dim))
     x = K.permute_dimensions(x, [0, 2, 1, 3])
     return K.reshape(x, (batch_size * head_num, seq_len, head_dim))
Beispiel #3
0
 def call(self, inputs, **kwargs):
     inputs, tasks = inputs
     if K.dtype(tasks) != 'int32':
         tasks = K.cast(tasks, 'int32')
     task_embed = K.gather(self.embeddings, tasks)
     if self.mask_zero:
         task_embed = task_embed * K.expand_dims(
             K.cast(K.not_equal(tasks, 0), K.floatx()), axis=-1)
     if K.backend() == 'theano':
         task_embed = K.tile(task_embed, (1, K.shape(inputs)[1], 1))
     return inputs + task_embed
Beispiel #4
0
 def call(self, inputs, mask=None, **kwargs):
     if isinstance(inputs, list):
         query, key, value = inputs
     else:
         query = key = value = inputs
     if isinstance(mask, list):
         mask = mask[1]
     feature_dim = K.shape(query)[-1]
     e = K.batch_dot(query, key, axes=2) / K.sqrt(
         K.cast(feature_dim, dtype=K.floatx()))
     if self.history_only:
         query_len, key_len = K.shape(query)[1], K.shape(key)[1]
         ones = tf.ones((query_len, key_len))
         e -= (ones - tf.matrix_band_part(ones, -1, 0)) * 1e9
     if mask is not None:
         e -= (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx())) * 1e9
     a = keras.activations.softmax(e)
     v = K.batch_dot(a, value)
     if self.return_attention:
         return [v, a]
     return v
Beispiel #5
0
 def call(self, inputs, **kwargs):
     if self.mode == self.MODE_EXPAND:
         if K.dtype(inputs) != 'int32':
             inputs = K.cast(inputs, 'int32')
         return K.gather(
             self.embeddings,
             K.minimum(K.maximum(inputs, -self.input_dim), self.input_dim) +
             self.input_dim,
         )
     input_shape = K.shape(inputs)
     if self.mode == self.MODE_ADD:
         batch_size, seq_len, output_dim = input_shape[0], input_shape[
             1], input_shape[2]
     else:
         batch_size, seq_len, output_dim = input_shape[0], input_shape[
             1], self.output_dim
     pos_embeddings = K.tile(
         K.expand_dims(self.embeddings[:seq_len, :self.output_dim], axis=0),
         K.stack([batch_size, 1, 1]),
     )
     if self.mode == self.MODE_ADD:
         return inputs + pos_embeddings
     return K.concatenate([inputs, pos_embeddings], axis=-1)