Ejemplo n.º 1
0
 def _reshape_mask(mask, head_num):
     if mask is None:
         return mask
     seq_len = K.shape(mask)[1]
     mask = K.expand_dims(mask, axis=1)
     mask = K.tile(mask, K.stack([1, head_num, 1]))
     return K.reshape(mask, (-1, seq_len))
 def call(self, inputs, **kwargs):
     inputs, tasks = inputs
     if K.dtype(tasks) != 'int32':
         tasks = K.cast(tasks, 'int32')
     task_embed = K.gather(self.embeddings, tasks)
     if self.mask_zero:
         task_embed = task_embed * K.expand_dims(
             K.cast(K.not_equal(tasks, 0), K.floatx()), axis=-1)
     return inputs + task_embed
Ejemplo n.º 3
0
 def call(self, inputs, **kwargs):
     inputs, tasks = inputs
     if K.dtype(tasks) != 'int32':
         tasks = K.cast(tasks, 'int32')
     task_embed = K.gather(self.embeddings, tasks)
     if self.mask_zero:
         task_embed = task_embed * K.expand_dims(
             K.cast(K.not_equal(tasks, 0), K.floatx()), axis=-1)
     if K.backend() == 'theano':
         task_embed = K.tile(task_embed, (1, K.shape(inputs)[1], 1))
     return inputs + task_embed
Ejemplo n.º 4
0
 def call(self, inputs, mask=None, **kwargs):
     if isinstance(inputs, list):
         query, key, value = inputs
     else:
         query = key = value = inputs
     if isinstance(mask, list):
         mask = mask[1]
     feature_dim = K.shape(query)[-1]
     e = K.batch_dot(query, key, axes=2) / K.sqrt(
         K.cast(feature_dim, dtype=K.floatx()))
     if self.history_only:
         query_len, key_len = K.shape(query)[1], K.shape(key)[1]
         ones = tf.ones((query_len, key_len))
         e -= (ones - tf.matrix_band_part(ones, -1, 0)) * 1e9
     if mask is not None:
         e -= (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx())) * 1e9
     a = keras.activations.softmax(e)
     v = K.batch_dot(a, value)
     if self.return_attention:
         return [v, a]
     return v
Ejemplo n.º 5
0
 def call(self, inputs, **kwargs):
     if self.mode == self.MODE_EXPAND:
         if K.dtype(inputs) != 'int32':
             inputs = K.cast(inputs, 'int32')
         return K.gather(
             self.embeddings,
             K.minimum(K.maximum(inputs, -self.input_dim), self.input_dim) +
             self.input_dim,
         )
     input_shape = K.shape(inputs)
     if self.mode == self.MODE_ADD:
         batch_size, seq_len, output_dim = input_shape[0], input_shape[
             1], input_shape[2]
     else:
         batch_size, seq_len, output_dim = input_shape[0], input_shape[
             1], self.output_dim
     pos_embeddings = K.tile(
         K.expand_dims(self.embeddings[:seq_len, :self.output_dim], axis=0),
         K.stack([batch_size, 1, 1]),
     )
     if self.mode == self.MODE_ADD:
         return inputs + pos_embeddings
     return K.concatenate([inputs, pos_embeddings], axis=-1)
 def call(self, inputs, mask=None):
     if mask is not None:
         mask = K.cast(mask, K.floatx())
         inputs -= K.expand_dims((1.0 - mask) * 1e6, axis=-1)
     return K.max(inputs, axis=-2)
Ejemplo n.º 7
0
 def call(self, inputs, mask=None):
     if mask is not None:
         mask = K.cast(mask, K.floatx())
         inputs *= K.expand_dims(mask, axis=-1)
     return super(MaskedConv1D, self).call(inputs)