def _reshape_mask(mask, head_num): if mask is None: return mask seq_len = K.shape(mask)[1] mask = K.expand_dims(mask, axis=1) mask = K.tile(mask, K.stack([1, head_num, 1])) return K.reshape(mask, (-1, seq_len))
def _reshape_to_batches(x, head_num): input_shape = K.shape(x) batch_size, seq_len, feature_dim = input_shape[0], input_shape[ 1], input_shape[2] head_dim = feature_dim // head_num x = K.reshape(x, (batch_size, seq_len, head_num, head_dim)) x = K.permute_dimensions(x, [0, 2, 1, 3]) return K.reshape(x, (batch_size * head_num, seq_len, head_dim))
def call(self, inputs, **kwargs): inputs, tasks = inputs if K.dtype(tasks) != 'int32': tasks = K.cast(tasks, 'int32') task_embed = K.gather(self.embeddings, tasks) if self.mask_zero: task_embed = task_embed * K.expand_dims( K.cast(K.not_equal(tasks, 0), K.floatx()), axis=-1) if K.backend() == 'theano': task_embed = K.tile(task_embed, (1, K.shape(inputs)[1], 1)) return inputs + task_embed
def call(self, inputs, mask=None, **kwargs): if isinstance(inputs, list): query, key, value = inputs else: query = key = value = inputs if isinstance(mask, list): mask = mask[1] feature_dim = K.shape(query)[-1] e = K.batch_dot(query, key, axes=2) / K.sqrt( K.cast(feature_dim, dtype=K.floatx())) if self.history_only: query_len, key_len = K.shape(query)[1], K.shape(key)[1] ones = tf.ones((query_len, key_len)) e -= (ones - tf.matrix_band_part(ones, -1, 0)) * 1e9 if mask is not None: e -= (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx())) * 1e9 a = keras.activations.softmax(e) v = K.batch_dot(a, value) if self.return_attention: return [v, a] return v
def call(self, inputs, **kwargs): if self.mode == self.MODE_EXPAND: if K.dtype(inputs) != 'int32': inputs = K.cast(inputs, 'int32') return K.gather( self.embeddings, K.minimum(K.maximum(inputs, -self.input_dim), self.input_dim) + self.input_dim, ) input_shape = K.shape(inputs) if self.mode == self.MODE_ADD: batch_size, seq_len, output_dim = input_shape[0], input_shape[ 1], input_shape[2] else: batch_size, seq_len, output_dim = input_shape[0], input_shape[ 1], self.output_dim pos_embeddings = K.tile( K.expand_dims(self.embeddings[:seq_len, :self.output_dim], axis=0), K.stack([batch_size, 1, 1]), ) if self.mode == self.MODE_ADD: return inputs + pos_embeddings return K.concatenate([inputs, pos_embeddings], axis=-1)