def call(self, inputs, mask=None, **kwargs): if isinstance(inputs, list): query, key, value = inputs else: query = key = value = inputs if isinstance(mask, list): mask = mask[1] feature_dim = K.shape(query)[-1] e = K.batch_dot(query, key, axes=2) / K.sqrt( K.cast(feature_dim, dtype=K.floatx())) e = K.exp(e - K.max(e, axis=-1, keepdims=True)) if self.history_only: query_len, key_len = K.shape(query)[1], K.shape(key)[1] indices = K.expand_dims(K.arange(0, key_len), axis=0) upper = K.expand_dims(K.arange(0, query_len), axis=-1) e *= K.expand_dims(K.cast(indices <= upper, K.floatx()), axis=0) if mask is not None: e *= K.cast(K.expand_dims(mask, axis=-2), K.floatx()) s = K.sum(e, axis=-1, keepdims=True) s += K.cast(K.equal(s, 0.0), K.floatx()) a = e / s v = K.batch_dot(a, value) if self.return_attention: return [v, a] return v
def call(self, inputs, training=None, **kwargs): inputs, memory = inputs batch_size = K.shape(inputs)[0] seq_len = K.shape(inputs)[1] mem_mask = K.tile(K.ones_like(memory[:, :, :1], dtype=K.floatx()), [1, 1, seq_len]) # Build content mask with random permutation ranges = K.tile(K.expand_dims(K.arange(0, seq_len), axis=-1), [1, batch_size]) if self.enabled: shuffle = random_shuffle(ranges) else: shuffle = ranges if self.directional: shuffled = K.in_train_phase(shuffle, ranges, training) else: if self.enabled: shuffled = K.in_train_phase(shuffle, ranges + seq_len, training) else: shuffled = ranges + seq_len ranges = K.expand_dims(K.permute_dimensions(ranges, [1, 0]), axis=-1) shuffled = K.expand_dims(K.permute_dimensions(shuffled, [1, 0]), axis=1) content_mask = K.cast(ranges <= shuffled, dtype=K.floatx()) # Build query mask based on content mask ranges = K.arange(0, seq_len) eye = K.equal(K.expand_dims(ranges, axis=0), K.expand_dims(ranges, axis=-1)) eye = K.expand_dims(K.cast(eye, dtype=K.floatx()), axis=0) query_mask = content_mask * (1.0 - eye) content_mask = K.concatenate([mem_mask, content_mask], axis=1) query_mask = K.concatenate([mem_mask, query_mask], axis=1) return [ K.permute_dimensions(content_mask, [0, 2, 1]), K.permute_dimensions(query_mask, [0, 2, 1]), ]
def call(self, inputs, **kwargs): inputs, memory_length = inputs memory_length = K.cast(memory_length[0][0], 'int32') batch_size = K.cast(K.shape(inputs)[0], 'int32') seq_len = K.cast(K.shape(inputs)[1], 'int32') # Build new memory pad = K.tile(inputs[0:1, ...], (self.batch_size - batch_size, 1, 1)) padded = K.concatenate([inputs, pad], axis=0) # (self.batch_size, seq_len, output_dim) new_memory = K.concatenate([self.memory, padded], axis=1) # (self.batch_size, self.memory_len + seq_len, ...) new_memory = tf.slice( # (self.batch_size, self.memory_len, output_dim) new_memory, (0, seq_len, 0), (self.batch_size, self.memory_len + self.target_len, self.output_dim), ) self.add_update(K.update(self.memory, new_memory), inputs) # Build output old_memory = tf.slice( # (batch_size, memory_length, output_dim) new_memory, (0, K.maximum(0, self.memory_len + self.target_len - seq_len - memory_length), 0), (batch_size, K.minimum(self.memory_len, memory_length), self.output_dim), ) return old_memory
def _attention_regularizer(self, attention): batch_size = K.cast(K.shape(attention)[0], K.floatx()) input_len = K.shape(attention)[-1] indices = K.expand_dims(K.arange(0, input_len), axis=0) diagonal = K.expand_dims(K.arange(0, input_len), axis=-1) eye = K.cast(K.equal(indices, diagonal), K.floatx()) return self.attention_regularizer_weight * K.sum(K.square(K.batch_dot( attention, K.permute_dimensions(attention, (0, 2, 1))) - eye)) / batch_size
def _relative_shift(x, key_len_expected=-1): batch_size, q_len, k_len = K.shape(x)[0], K.shape(x)[1], K.shape(x)[2] x = K.reshape( x, (batch_size, k_len, q_len)) # (batch * n_head, prev_len + seq_len + 1, seq_len) x = x[:, 1:, :] # (batch * n_head, prev_len + seq_len, seq_len) x = K.reshape(x, (batch_size, q_len, k_len - 1)) # (batch * n_head, seq_len, prev_len + seq_len) x = tf.slice( x, (0, 0, 0), (-1, -1, key_len_expected)) # (batch * n_head, seq_len, key_len_expected) return x
def call(self, inputs, **kwargs): length = K.shape(inputs[0])[1] + K.shape(inputs[1])[1] inputs = K.tile( K.expand_dims(K.arange(length - 1, -1, -1, dtype=K.floatx()), axis=0), [K.shape(inputs[0])[0], 1], ) if self.clamp_len is not None: inputs = K.clip(inputs, min_value=0, max_value=self.clamp_len) inputs = K.expand_dims(inputs, axis=-1) output_dim = K.cast(self.output_dim, K.floatx()) ranges = K.expand_dims(K.arange(0.0, self.output_dim, 2.0), axis=0) / output_dim inverse = 1.0 / K.pow(10000.0, ranges) positions = inputs * inverse return K.concatenate([K.sin(positions), K.cos(positions)], axis=-1)
def _reshape_mask(mask, head_num): if mask is None: return mask seq_len = K.shape(mask)[1] mask = K.expand_dims(mask, axis=1) mask = K.tile(mask, [1, head_num, 1]) return K.reshape(mask, (-1, seq_len))
def _reshape_to_batches(self, x): input_shape = K.shape(x) batch_size, seq_len, feature_dim = input_shape[0], input_shape[ 1], input_shape[2] x = K.reshape(x, (batch_size, seq_len, self.num_head, self.units_head)) x = K.permute_dimensions(x, [0, 2, 1, 3]) return K.reshape( x, (batch_size * self.num_head, seq_len, self.units_head))
def _reshape_to_batches(x, head_num): input_shape = K.shape(x) batch_size, seq_len, feature_dim = input_shape[0], input_shape[ 1], input_shape[2] head_dim = feature_dim // head_num x = K.reshape(x, (batch_size, seq_len, head_num, head_dim)) x = K.permute_dimensions(x, [0, 2, 1, 3]) return K.reshape(x, (batch_size * head_num, seq_len, head_dim))
def call(self, inputs, **kwargs): q_len, m_len = K.shape(inputs[0])[1], K.shape(inputs[1])[1] k_len = q_len + m_len start, stop = k_len, -1 if not self.directional: stop = -q_len inputs = K.tile( K.expand_dims(K.arange(start, stop, -1, dtype=K.floatx()), axis=0), [K.shape(inputs[0])[0], 1], ) if self.clamp_len is not None: inputs = K.clip(inputs, min_value=0, max_value=self.clamp_len) inputs = K.expand_dims(inputs, axis=-1) output_dim = K.cast(self.output_dim, K.floatx()) ranges = K.expand_dims(K.arange(0.0, self.output_dim, 2.0), axis=0) / output_dim inverse = 1.0 / K.pow(10000.0, ranges) positions = inputs * inverse return K.concatenate([K.sin(positions), K.cos(positions)], axis=-1)
def call(self, inputs, mask=None): input_shape = K.shape(inputs) if self.mode == self.MODE_ADD: batch_size, seq_len, output_dim = input_shape[0], input_shape[ 1], input_shape[2] pos_input = K.tile(K.expand_dims(K.arange(0, seq_len), axis=0), [batch_size, 1]) elif self.mode == self.MODE_CONCAT: batch_size, seq_len, output_dim = input_shape[0], input_shape[ 1], self.output_dim pos_input = K.tile(K.expand_dims(K.arange(0, seq_len), axis=0), [batch_size, 1]) else: output_dim = self.output_dim pos_input = inputs if K.dtype(pos_input) != K.floatx(): pos_input = K.cast(pos_input, K.floatx()) evens = K.arange(0, output_dim // 2) * 2 odds = K.arange(0, output_dim // 2) * 2 + 1 even_embd = K.sin( K.dot( K.expand_dims(pos_input, -1), K.expand_dims( 1.0 / K.pow( 10000.0, K.cast(evens, K.floatx()) / K.cast(output_dim, K.floatx())), 0))) odd_embd = K.cos( K.dot( K.expand_dims(pos_input, -1), K.expand_dims( 1.0 / K.pow( 10000.0, K.cast((odds - 1), K.floatx()) / K.cast(output_dim, K.floatx())), 0))) embd = K.stack([even_embd, odd_embd], axis=-1) output = K.reshape(embd, [-1, K.shape(inputs)[1], output_dim]) if self.mode == self.MODE_CONCAT: output = K.concatenate([inputs, output], axis=-1) if self.mode == self.MODE_ADD: output += inputs return output
def call(self, x, mask=None): logits = K.dot(x, self.W) if self.use_bias: logits += self.b x_shape = K.shape(x) logits = K.reshape(logits, (x_shape[0], x_shape[1])) ai = K.exp(logits - K.max(logits, axis=-1, keepdims=True)) if mask is not None: mask = K.cast(mask, K.floatx()) ai = ai * mask att_weights = ai / (K.sum(ai, axis=1, keepdims=True) + K.epsilon()) weighted_input = x * K.expand_dims(att_weights) result = K.sum(weighted_input, axis=1) if self.return_attention: return [result, att_weights] return result
def _call_additive_emission(self, inputs): input_shape = K.shape(inputs) batch_size, input_len = input_shape[0], input_shape[1] # h_{t, t'} = \tanh(x_t^T W_t + x_{t'}^T W_x + b_h) q = K.expand_dims(K.dot(inputs, self.Wt), 2) k = K.expand_dims(K.dot(inputs, self.Wx), 1) if self.use_additive_bias: h = K.tanh(q + k + self.bh) else: h = K.tanh(q + k) # e_{t, t'} = W_a h_{t, t'} + b_a if self.use_attention_bias: e = K.reshape(K.dot(h, self.Wa) + self.ba, (batch_size, input_len, input_len)) else: e = K.reshape(K.dot(h, self.Wa), (batch_size, input_len, input_len)) return e
def call(self, inputs, mask=None, **kwargs): input_len = K.shape(inputs)[1] if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD: e = self._call_additive_emission(inputs) elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL: e = self._call_multiplicative_emission(inputs) if self.attention_activation is not None: e = self.attention_activation(e) e = K.exp(e - K.max(e, axis=-1, keepdims=True)) if self.attention_width is not None: if self.history_only: lower = K.arange(0, input_len) - (self.attention_width - 1) else: lower = K.arange(0, input_len) - self.attention_width // 2 lower = K.expand_dims(lower, axis=-1) upper = lower + self.attention_width indices = K.expand_dims(K.arange(0, input_len), axis=0) e = e * K.cast(lower <= indices, K.floatx()) * K.cast(indices < upper, K.floatx()) if mask is not None: mask = K.cast(mask, K.floatx()) mask = K.expand_dims(mask) e = K.permute_dimensions(K.permute_dimensions(e * mask, (0, 2, 1)) * mask, (0, 2, 1)) # a_{t} = \text{softmax}(e_t) s = K.sum(e, axis=-1, keepdims=True) a = e / (s + K.epsilon()) # l_t = \sum_{t'} a_{t, t'} x_{t'} v = K.batch_dot(a, inputs) if self.attention_regularizer_weight > 0.0: self.add_loss(self._attention_regularizer(a)) if self.return_attention: return [v, a] return v
def call(self, inputs, mask=None, training=None): (inputs, content, memories, segment_mat, segment_embed, relatives, bias_context, bias_relative, bias_segment, permutation) = inputs full = K.concatenate([memories, content], axis=1) # (batch, prev_len + seq_len, units) kernel_q = self.kernel[:, :self.units] kernel_kv = self.kernel[:, self.units:self.units * 3] kernel_r = self.kernel[:, self.units * 3:self.units * 4] kernel_o = self.kernel[:, self.units * 4:self.units * 5] bias_q, bias_kv, bias_r, bias_o = (None, ) * 4 if self.use_bias: bias_q = self.bias[:self.units] bias_kv = self.bias[self.units:self.units * 3] bias_r = self.bias[self.units * 3:self.units * 4] bias_o = self.bias[self.units * 4:self.units * 5] w_q = K.dot(inputs, kernel_q) # (batch, seq_len, units) w_kv = K.dot(full, kernel_kv) # (batch, prev_len + seq_len, units * 2) w_r = K.dot(relatives, kernel_r) # (batch, prev_len + seq_len, units) if self.use_bias: w_q = K.bias_add(w_q, bias_q) w_kv = K.bias_add(w_kv, bias_kv) w_r = K.bias_add(w_r, bias_r) if self.activation is not None: w_q = self.activation(w_q) w_kv = self.activation(w_kv) w_r = self.activation(w_r) w_k = w_kv[:, :, :self.units] # (batch, prev_len + seq_len, units) w_v = w_kv[:, :, self.units:] # (batch, prev_len + seq_len, units) batch_size, q_len, k_len = K.shape(inputs)[0], K.shape( w_q)[1], K.shape(w_k)[1] w_qc = K.bias_add(w_q, bias_context) w_qc = self._reshape_to_batches( w_qc) # (batch * n_head, seq_len, units_head) w_k = self._reshape_to_batches( w_k) # (batch * n_head, prev_len + seq_len, units_head) a_context = K.batch_dot( w_qc, w_k, axes=2) # (batch * n_head, seq_len, prev_len + seq_len) w_qr = K.bias_add(w_q, bias_relative) w_qr = self._reshape_to_batches( w_qr) # (batch * n_head, seq_len, units_head) w_r = self._reshape_to_batches( w_r) # (batch * n_head, prev_len + seq_len, units_head) a_relative = K.batch_dot( w_qr, w_r, axes=2) # (batch * n_head, seq_len, prev_len + seq_len) a_relative = self._relative_shift( # (batch * n_head, seq_len, prev_len + seq_len) a_relative, key_len_expected=K.shape(a_context)[-1], ) w_qs = K.bias_add(w_q, bias_segment) w_qs = K.reshape(w_qs, (-1, q_len, self.num_head, self.units_head)) w_qs = K.permute_dimensions( w_qs, (2, 0, 1, 3)) # (n_head, batch, seq_len, units_head) segment_embed = K.reshape(K.transpose(segment_embed), (self.num_head, 1, self.units_head, 2)) segment_embed = K.tile(segment_embed, (1, batch_size, 1, 1)) a_segment = K.batch_dot(w_qs, segment_embed, axes=(3, 2)) # (n_head, batch, seq_len, 2) a_segment = K.permute_dimensions( a_segment, (1, 2, 3, 0)) # (batch, seq_len, 2, n_head) a_segment = K.batch_dot( segment_mat, a_segment, axes=(3, 2)) # (batch, seq_len, prev_len + seq_len, n_head) a_segment = K.reshape(K.permute_dimensions(a_segment, (0, 3, 1, 2)), (-1, q_len, k_len)) att = (a_context + a_relative + a_segment) / K.sqrt( K.constant(self.units_head, dtype=K.floatx())) exp = K.exp(att - K.max(att, axis=-1, keepdims=True)) permutation = K.tile(K.expand_dims(permutation, axis=1), [1, self.num_head, 1, 1]) permutation = K.reshape(permutation, (-1, q_len, k_len)) exp *= permutation if mask is not None and mask[0] is not None: mask = K.cast(mask[0], K.floatx()) mask = K.concatenate([K.ones_like(memories[:, :, 0]), mask], axis=1) exp *= K.expand_dims(self._reshape_mask(mask), axis=1) att = exp / (K.sum(exp, axis=-1, keepdims=True) + K.epsilon()) if self.att_drop_layer is not None: att = self.att_drop_layer(att, training=training) w_v = self._reshape_to_batches( w_v) # (batch * n_head, prev_len + seq_len, units_head) w_o = K.batch_dot(att, w_v) # (batch * n_head, seq_len, units_head) w_o = self._reshape_from_batches(w_o) # (batch, seq_len, units) w_o = K.dot(w_o, kernel_o) # (batch, seq_len, units) if self.use_bias: w_o = K.bias_add(w_o, bias_o) if self.activation is not None: w_o = self.activation(w_o) if TF_KERAS: # Add shape information to tensor when using `tf.keras` input_shape = K.int_shape(inputs) if input_shape[1] is not None: w_o = K.reshape(w_o, (-1, ) + input_shape[1:]) return w_o
def _reshape_mask(self, mask): seq_len = K.shape(mask)[1] mask = K.expand_dims(mask, axis=1) mask = K.tile(mask, [1, self.num_head, 1]) return K.reshape(mask, (-1, seq_len))