def _compute_qkv(self, query, memory): """ Computes linear transformations of query, key and value. Args: query: Attention query tensor with shape [batch_size, length_q, channels_query]. memory: Attention values tensor with shape [batch_size, length_m, channels_value] Returns: A tuple `(query_transformed, key_transformed, memory_transformed)`. """ if query is None: # indicating self-attention, query and key are both the same as memory combined = conv1d(memory, self._attention_key_depth * 2 + self._attention_value_depth, kernel_size=1, name="qkv_transform") q, k, v = tf.split(combined, [ self._attention_key_depth, self._attention_key_depth, self._attention_value_depth ], axis=2) return q, k, v else: # encoder-decoder attention q = conv1d(query, self._attention_key_depth, kernel_size=1, name="q_transform", padding="VALID") kv_combined = conv1d(memory, self._attention_key_depth + self._attention_value_depth, kernel_size=1, name="kv_transform", padding="VALID") k, v = tf.split( kv_combined, [self._attention_key_depth, self._attention_value_depth], axis=2) return q, k, v
def _compute_qkv(self, query, memory, cache): """ Computes linear transformations of query, keys and values. Args: query: Attention query tensor with shape [batch_size, length_q, channels_query]. If None, it indicates self-attention. memory: Attention values tensor with shape [batch_size, length_m, channels_value] cache: A dictionary containing pre-projected keys and values. Returns: A tuple `(query_transformed, key_transformed, memory_transformed)`. """ if query is None: # indicates self-attention q, k, v = self.compute_qkv(memory) if cache is not None: # for self-attention in transformer decoder when mode=INFER k = tf.concat([cache["keys"], k], axis=1) v = tf.concat([cache["values"], v], axis=1) cache["keys"] = k cache["values"] = v else: q = conv1d( query, self._attention_key_depth, kernel_size=1, name="q_transform", padding="VALID") # indicates encoder-decoder attention if cache is not None and "attention_keys" in cache: k = cache["attention_keys"] if "attention_values" in cache: v = cache["attention_values"] else: v = conv1d(memory, self._attention_value_depth, kernel_size=1, name="v_transform", padding="VALID") else: k, v = self.compute_kv(memory) return q, k, v
def _compute_qkv(self, query, memory, cache): """ Computes linear transformations of query, keys and values. Args: query: Attention query tensor with shape [batch_size, length_q, channels_query]. If None, it indicates self-attention. memory: Attention values tensor with shape [batch_size, length_m, channels_value] cache: A dictionary containing pre-projected keys and values. Returns: A tuple `(query_transformed, key_transformed, memory_transformed)`. """ if query is None: # indicates self-attention q, k, v = self.compute_qkv(memory) if cache is not None: # for self-attention in transformer decoder when mode=INFER k = tf.concat([cache["keys"], k], axis=1) v = tf.concat([cache["values"], v], axis=1) cache["keys"] = k cache["values"] = v else: q = conv1d(query, self._attention_key_depth, kernel_size=1, name="q_transform", padding="VALID") # indicates encoder-decoder attention if cache is not None and "attention_keys" in cache: k = cache["attention_keys"] if "attention_values" in cache: v = cache["attention_values"] else: v = conv1d(memory, self._attention_value_depth, kernel_size=1, name="v_transform", padding="VALID") else: k, v = self.compute_kv(memory) return q, k, v
def compute_qkv(self, memory): """ Computes linear transformations of query, keys and values, especially for self-attention in transformer encoder. Args: memory: Attention values tensor with shape [batch_size, length_m, channels_value] Returns: A tuple `(query_transformed, key_transformed, memory_transformed)`. """ combined = conv1d( memory, self._attention_key_depth * 2 + self._attention_value_depth, kernel_size=1, name="qkv_transform") q, k, v = tf.split( combined, [self._attention_key_depth, self._attention_key_depth, self._attention_value_depth], axis=2) return q, k, v
def compute_kv(self, memory): """ Computes linear transformations of keys and values, especially for encoder decoder attention. Args: memory: Attention values tensor with shape [batch_size, length_m, channels_value] Returns: A tuple `(key_transformed, memory_transformed)`. """ kv_combined = conv1d(memory, self._attention_key_depth + self._attention_value_depth, kernel_size=1, name="kv_transform", padding="VALID") k, v = tf.split( kv_combined, [self._attention_key_depth, self._attention_value_depth], axis=2) return k, v
def compute_qkv(self, memory): """ Computes linear transformations of query, keys and values, especially for self-attention in transformer encoder. Args: memory: Attention values tensor with shape [batch_size, length_m, channels_value] Returns: A tuple `(query_transformed, key_transformed, memory_transformed)`. """ combined = conv1d(memory, self._attention_key_depth * 2 + self._attention_value_depth, kernel_size=1, name="qkv_transform") q, k, v = tf.split(combined, [ self._attention_key_depth, self._attention_key_depth, self._attention_value_depth ], axis=2) return q, k, v
def compute_kv(self, memory): """ Computes linear transformations of keys and values, especially for encoder decoder attention. Args: memory: Attention values tensor with shape [batch_size, length_m, channels_value] Returns: A tuple `(key_transformed, memory_transformed)`. """ kv_combined = conv1d( memory, self._attention_key_depth + self._attention_value_depth, kernel_size=1, name="kv_transform", padding="VALID") k, v = tf.split( kv_combined, [self._attention_key_depth, self._attention_value_depth], axis=2) return k, v
def build(self, query, memory, memory_length=None, memory_bias=None, cache=None): """ Builds attention context. Args: query: Attention query tensor with shape [batch_size, length_q, channels_query]. If None, it indicates self-attention. memory: Attention values tensor with shape [batch_size, length_m, channels_value]. memory_length: The number of attention values, [batch_size,]. memory_bias: The bias tensor for attention values with shape [batch_size, 1, 1, timesteps]. cache: A dictionary containing pre-projected keys and values. Returns: The result of the attention transformation. A tuple `(attention_scores, attention_context)`. The `attention_scores` has shape [batch_size, num_heads, length_q, length_k]. The `attention_context` has shape [batch_size, length_q, output_depth]. """ with tf.variable_scope(self.name, values=[query, memory]): query_is_2d = False if query is not None and query.get_shape().ndims == 2: # for using MultiHeadAttention in RNN-based decoders query_is_2d = True query = tf.expand_dims(query, axis=1) # compute q, k, v q, k, v = self._compute_qkv(query, memory, cache) # after split_last_dimension: [batch_size, length, depth] # ==> [batch_size, length, num_heads, depth/num_heads] # after split_head: ==> [batch_size, num_heads, length, depth/num_heads] split_head = lambda _x, _nh: tf.transpose( split_last_dimension(_x, _nh), [0, 2, 1, 3]) # [batch_size, num_heads, length, depth/num_heads] ==> [batch_size, length, depth] combine_head = lambda _x: combine_last_two_dimensions( tf.transpose(_x, [0, 2, 1, 3])) # [batch_size, num_heads, length_q/k/v, depth/num_heads] q = split_head(q, self._num_heads) k = split_head(k, self._num_heads) v = split_head(v, self._num_heads) # last dim of q, k, v after split_head key_depth_per_head = self._attention_key_depth // self._num_heads q *= key_depth_per_head**(-0.5) # scale the query if memory_bias is None: if memory_length is not None: assert query is not None, "Unseen error may occur. Please CHECK." memory_bias = MultiHeadAttention.attention_length_to_bias( tf.shape(memory)[1], memory_length) # compute attention weight, [batch_size, num_heads, length_q, length_k] attention_weight = self.att_fn(q, k, memory_bias) # sum over attention values, [batch_size, num_heads, length_q, depth/num_heads] attention_context = tf.matmul(attention_weight, v) # combined: [batch_size, length_q, depth_value] attention_context = combine_head(attention_context) # linear transform attention_context = conv1d(attention_context, self._output_depth, kernel_size=1, name="output_transform") if query_is_2d: # attention context: [batch_size, depth_value] attention_context = tf.squeeze(attention_context, axis=1) # attention weight: [batch_size, num_heads, length_k] attention_weight = tf.squeeze(attention_weight, axis=2) return attention_weight, attention_context
def _compute_qkv(self, query, keys, memory, query_is_projected=False, key_is_projected=False): """ Computes linear transformations of query, key and value. Args: query: Attention query tensor with shape [batch_size, length_q, channels_query]. If None, it indicates self-attention. keys: Attention keys tensor with shape [batch_size, length_k, channels_key]. memory: Attention values tensor with shape [batch_size, length_m, channels_value] query_is_projected: Whether the `query` is already projected. key_is_projected: Whether the `keys` is already projected. Returns: A tuple `(query_transformed, key_transformed, memory_transformed)`. """ if query is None: # indicating self-attention, query and key are both the same as memory _ = keys combined = conv1d(memory, self._attention_key_depth * 2 + self._attention_value_depth, kernel_size=1, name="qkv_transform") q, k, v = tf.split(combined, [ self._attention_key_depth, self._attention_key_depth, self._attention_value_depth ], axis=2) return q, k, v else: # encoder-decoder attention if query_is_projected: q = query else: q = conv1d(query, self._attention_key_depth, kernel_size=1, name="q_transform", padding="VALID") if key_is_projected: k = keys v = conv1d(memory, self._attention_value_depth, kernel_size=1, name="v_transform", padding="VALID") else: kv_combined = conv1d(memory, self._attention_key_depth + self._attention_value_depth, kernel_size=1, name="kv_transform", padding="VALID") k, v = tf.split( kv_combined, [self._attention_key_depth, self._attention_value_depth], axis=2) return q, k, v
def build(self, query, keys, memory, memory_length=None, memory_bias=None, query_is_projected=False, key_is_projected=False): """ Builds attention context. Args: query: Attention query tensor with shape [batch_size, length_q, channels_query]. If None, it indicates self-attention. keys: Attention keys tensor with shape [batch_size, length_k, channels_key]. memory: Attention values tensor with shape [batch_size, length_m, channels_value]. memory_length: The number of attention values, [batch_size,]. memory_bias: The bias tensor for attention values with shape [batch_size, 1, 1, timesteps]. query_is_projected: Whether the `query` is already projected. key_is_projected: Whether the `keys` is already projected. Returns: The result of the attention transformation. A tuple `(attention_scores, attention_context)`. The `attention_scores` has shape [batch_size, num_heads, length_q, length_k]. The `attention_context` has shape [batch_size, length_q, output_depth]. """ with tf.variable_scope(self.name, values=[query, memory]): query_is_2d = False if query is not None and query.get_shape().ndims == 2: # for using MultiHeadAttention in RNN-based decoders query_is_2d = True query = tf.expand_dims(query, axis=1) # compute q, k, v q, k, v = self._compute_qkv(query, keys, memory, query_is_projected=query_is_projected, key_is_projected=key_is_projected) # after split_last_dimension: [batch_size, length, depth] # ==> [batch_size, length, num_heads, depth/num_heads] # after split_head: ==> [batch_size, num_heads, length, depth/num_heads] split_head = lambda _x, _nh: tf.transpose( split_last_dimension(_x, _nh), [0, 2, 1, 3]) # [batch_size, num_heads, length, depth/num_heads] ==> [batch_size, length, depth] combine_head = lambda _x: combine_last_two_dimensions( tf.transpose(_x, [0, 2, 1, 3])) # [batch_size, num_heads, length_q/k/v, depth/num_heads] q = split_head(q, self._num_heads) k = split_head(k, self._num_heads) v = split_head(v, self._num_heads) # last dim of q, k, v after split_head key_depth_per_head = self._attention_key_depth // self._num_heads q *= key_depth_per_head**(-0.5) # scale the query if memory_bias is None: if memory_length is not None: assert query is not None, "Unseen error may occur. Please CHECK." input_padding = embedding_to_padding(memory, memory_length) # [batch_size, 1, 1, timesteps], FLOAT_MIN for padding, 0.0 for non-padding memory_bias = attention_bias_ignore_padding(input_padding) # compute attention weight, [batch_size, num_heads, length_q, length_k] attention_weight = self.att_fn(q, k, memory_bias) # sum over attention values, [batch_size, num_heads, length_q, depth/num_heads] attention_context = tf.matmul(attention_weight, v) # combined: [batch_size, length_q, depth_value] attention_context = combine_head(attention_context) # linear transform attention_context = conv1d(attention_context, self._output_depth, kernel_size=1, name="output_transform") # TODO here the dimension of attention_weight is not checked for output_attention if query_is_2d: attention_context = tf.squeeze(attention_context, axis=1) return attention_weight, attention_context
def build(self, query, keys, memory, memory_length=None, memory_bias=None, query_is_projected=True, key_is_projected=True): """ Builds attention context. Args: query: Attention query tensor with shape [batch_size, length_q, channels_query]. keys: Attention keys tensor with shape [batch_size, length_k, channels_key]. Not used here. but use `memory` as the `keys`. memory: Attention values tensor with shape [batch_size, length_m, channels_value] memory_length: The number of attention values, [batch_size,]. memory_bias: The bias tensor for attention values with shape [batch_size, 1, 1, timesteps]. query_is_projected: Whether the `query` is already projected. Not used here. key_is_projected: Whether the `keys` is already projected. Not used here. Returns: The result of the attention transformation. A tuple `(attention_scores, attention_context)`. The `attention_scores` has shape [batch_size, num_heads, length_q, length_k]. The `attention_context` has shape [batch_size, length_q, output_depth]. """ _ = keys _ = query_is_projected _ = key_is_projected with tf.variable_scope(self.name, values=[query, memory]): # compute q, k, v q, k, v = self._compute_qkv(query, memory) # after split_last_dimension: [batch_size, length, depth] # ==> [batch_size, length, num_heads, depth/num_heads] # after split_head: ==> [batch_size, num_heads, length, depth/num_heads] split_head = lambda _x, _nh: tf.transpose( algebra_ops.split_last_dimension(_x, _nh), [0, 2, 1, 3]) # [batch_size, num_heads, length, depth/num_heads] ==> [batch_size, length, depth] combine_head = lambda _x: algebra_ops.combine_last_two_dimensions( tf.transpose(_x, [0, 2, 1, 3])) # [batch_size, num_heads, length_q/k/v, depth/num_heads] q = split_head(q, self._num_heads) k = split_head(k, self._num_heads) v = split_head(v, self._num_heads) # last dim of q, k, v after split_head key_depth_per_head = self._attention_key_depth // self._num_heads q *= key_depth_per_head**(-0.5) # scale the query # compute attention weight, [batch_size, num_heads, length_q, length_k] attention_weight = self.att_fn(q, k, memory_bias) # sum over attention values, [batch_size, num_heads, length_q, depth/num_heads] attention_context = tf.matmul(attention_weight, v) # combined: [batch_size, length_q, depth_value] attention_context = combine_head(attention_context) # linear transform attention_context = conv1d(attention_context, self._output_depth, kernel_size=1, name="output_transform") return attention_weight, attention_context
def build(self, query, memory, memory_length=None, memory_bias=None, cache=None): """ Builds attention context. Args: query: Attention query tensor with shape [batch_size, length_q, channels_query]. If None, it indicates self-attention. memory: Attention values tensor with shape [batch_size, length_m, channels_value]. memory_length: The number of attention values, [batch_size,]. memory_bias: The bias tensor for attention values with shape [batch_size, 1, 1, timesteps]. cache: A dictionary containing pre-projected keys and values. Returns: The result of the attention transformation. A tuple `(attention_scores, attention_context)`. The `attention_scores` has shape [batch_size, num_heads, length_q, length_k]. The `attention_context` has shape [batch_size, length_q, output_depth]. """ with tf.variable_scope(self.name, values=[query, memory]): query_is_2d = False if query is not None and query.get_shape().ndims == 2: # for using MultiHeadAttention in RNN-based decoders query_is_2d = True query = tf.expand_dims(query, axis=1) # compute q, k, v q, k, v = self._compute_qkv(query, memory, cache) # after split_last_dimension: [batch_size, length, depth] # ==> [batch_size, length, num_heads, depth/num_heads] # after split_head: ==> [batch_size, num_heads, length, depth/num_heads] split_head = lambda _x, _nh: tf.transpose(split_last_dimension(_x, _nh), [0, 2, 1, 3]) # [batch_size, num_heads, length, depth/num_heads] ==> [batch_size, length, depth] combine_head = lambda _x: combine_last_two_dimensions( tf.transpose(_x, [0, 2, 1, 3])) # [batch_size, num_heads, length_q/k/v, depth/num_heads] q = split_head(q, self._num_heads) k = split_head(k, self._num_heads) v = split_head(v, self._num_heads) # last dim of q, k, v after split_head key_depth_per_head = self._attention_key_depth // self._num_heads q *= key_depth_per_head ** (-0.5) # scale the query if memory_bias is None: if memory_length is not None: assert query is not None, "Unseen error may occur. Please CHECK." memory_bias = MultiHeadAttention.attention_length_to_bias(tf.shape(memory)[1], memory_length) # compute attention weight, [batch_size, num_heads, length_q, length_k] attention_weight = self.att_fn(q, k, memory_bias) # sum over attention values, [batch_size, num_heads, length_q, depth/num_heads] attention_context = tf.matmul(attention_weight, v) # combined: [batch_size, length_q, depth_value] attention_context = combine_head(attention_context) # linear transform attention_context = conv1d(attention_context, self._output_depth, kernel_size=1, name="output_transform") if query_is_2d: # attention context: [batch_size, depth_value] attention_context = tf.squeeze(attention_context, axis=1) # attention weight: [batch_size, num_heads, length_k] attention_weight = tf.squeeze(attention_weight, axis=2) return attention_weight, attention_context