Example #1
0
    def _compute_qkv(self, query, memory):
        """ Computes linear transformations of query, key
         and value.

        Args:
            query: Attention query tensor with shape
              [batch_size, length_q, channels_query].
            memory: Attention values tensor with shape
              [batch_size, length_m, channels_value]
        Returns: A tuple `(query_transformed, key_transformed,
          memory_transformed)`.
        """
        if query is None:
            # indicating self-attention, query and key are both the same as memory
            combined = conv1d(memory,
                              self._attention_key_depth * 2 +
                              self._attention_value_depth,
                              kernel_size=1,
                              name="qkv_transform")
            q, k, v = tf.split(combined, [
                self._attention_key_depth, self._attention_key_depth,
                self._attention_value_depth
            ],
                               axis=2)
            return q, k, v
        else:
            # encoder-decoder attention
            q = conv1d(query,
                       self._attention_key_depth,
                       kernel_size=1,
                       name="q_transform",
                       padding="VALID")
            kv_combined = conv1d(memory,
                                 self._attention_key_depth +
                                 self._attention_value_depth,
                                 kernel_size=1,
                                 name="kv_transform",
                                 padding="VALID")
            k, v = tf.split(
                kv_combined,
                [self._attention_key_depth, self._attention_value_depth],
                axis=2)
            return q, k, v
Example #2
0
    def _compute_qkv(self, query, memory, cache):
        """ Computes linear transformations of query, keys and values.

        Args:
            query: Attention query tensor with shape [batch_size, length_q, channels_query].
              If None, it indicates self-attention.
            memory: Attention values tensor with shape
              [batch_size, length_m, channels_value]
            cache: A dictionary containing pre-projected keys and values.

        Returns: A tuple `(query_transformed, key_transformed, memory_transformed)`.
        """
        if query is None:
            # indicates self-attention
            q, k, v = self.compute_qkv(memory)
            if cache is not None:
                # for self-attention in transformer decoder when mode=INFER
                k = tf.concat([cache["keys"], k], axis=1)
                v = tf.concat([cache["values"], v], axis=1)
                cache["keys"] = k
                cache["values"] = v
        else:
            q = conv1d(
                query,
                self._attention_key_depth,
                kernel_size=1,
                name="q_transform",
                padding="VALID")
            # indicates encoder-decoder attention
            if cache is not None and "attention_keys" in cache:
                k = cache["attention_keys"]
                if "attention_values" in cache:
                    v = cache["attention_values"]
                else:
                    v = conv1d(memory,
                               self._attention_value_depth,
                               kernel_size=1,
                               name="v_transform",
                               padding="VALID")
            else:
                k, v = self.compute_kv(memory)
        return q, k, v
    def _compute_qkv(self, query, memory, cache):
        """ Computes linear transformations of query, keys and values.

        Args:
            query: Attention query tensor with shape [batch_size, length_q, channels_query].
              If None, it indicates self-attention.
            memory: Attention values tensor with shape
              [batch_size, length_m, channels_value]
            cache: A dictionary containing pre-projected keys and values.

        Returns: A tuple `(query_transformed, key_transformed, memory_transformed)`.
        """
        if query is None:
            # indicates self-attention
            q, k, v = self.compute_qkv(memory)
            if cache is not None:
                # for self-attention in transformer decoder when mode=INFER
                k = tf.concat([cache["keys"], k], axis=1)
                v = tf.concat([cache["values"], v], axis=1)
                cache["keys"] = k
                cache["values"] = v
        else:
            q = conv1d(query,
                       self._attention_key_depth,
                       kernel_size=1,
                       name="q_transform",
                       padding="VALID")
            # indicates encoder-decoder attention
            if cache is not None and "attention_keys" in cache:
                k = cache["attention_keys"]
                if "attention_values" in cache:
                    v = cache["attention_values"]
                else:
                    v = conv1d(memory,
                               self._attention_value_depth,
                               kernel_size=1,
                               name="v_transform",
                               padding="VALID")
            else:
                k, v = self.compute_kv(memory)
        return q, k, v
Example #4
0
    def compute_qkv(self, memory):
        """ Computes linear transformations of query, keys and values, especially
        for self-attention in transformer encoder.

        Args:
            memory: Attention values tensor with shape
              [batch_size, length_m, channels_value]

        Returns: A tuple `(query_transformed, key_transformed, memory_transformed)`.
        """
        combined = conv1d(
            memory,
            self._attention_key_depth * 2 + self._attention_value_depth,
            kernel_size=1, name="qkv_transform")
        q, k, v = tf.split(
            combined,
            [self._attention_key_depth, self._attention_key_depth,
             self._attention_value_depth],
            axis=2)
        return q, k, v
    def compute_kv(self, memory):
        """ Computes linear transformations of keys and values, especially
        for encoder decoder attention.

        Args:
            memory: Attention values tensor with shape
              [batch_size, length_m, channels_value]

        Returns: A tuple `(key_transformed, memory_transformed)`.
        """
        kv_combined = conv1d(memory,
                             self._attention_key_depth +
                             self._attention_value_depth,
                             kernel_size=1,
                             name="kv_transform",
                             padding="VALID")
        k, v = tf.split(
            kv_combined,
            [self._attention_key_depth, self._attention_value_depth],
            axis=2)
        return k, v
    def compute_qkv(self, memory):
        """ Computes linear transformations of query, keys and values, especially
        for self-attention in transformer encoder.

        Args:
            memory: Attention values tensor with shape
              [batch_size, length_m, channels_value]

        Returns: A tuple `(query_transformed, key_transformed, memory_transformed)`.
        """
        combined = conv1d(memory,
                          self._attention_key_depth * 2 +
                          self._attention_value_depth,
                          kernel_size=1,
                          name="qkv_transform")
        q, k, v = tf.split(combined, [
            self._attention_key_depth, self._attention_key_depth,
            self._attention_value_depth
        ],
                           axis=2)
        return q, k, v
Example #7
0
    def compute_kv(self, memory):
        """ Computes linear transformations of keys and values, especially
        for encoder decoder attention.

        Args:
            memory: Attention values tensor with shape
              [batch_size, length_m, channels_value]

        Returns: A tuple `(key_transformed, memory_transformed)`.
        """
        kv_combined = conv1d(
            memory,
            self._attention_key_depth + self._attention_value_depth,
            kernel_size=1,
            name="kv_transform",
            padding="VALID")
        k, v = tf.split(
            kv_combined,
            [self._attention_key_depth, self._attention_value_depth],
            axis=2)
        return k, v
    def build(self,
              query,
              memory,
              memory_length=None,
              memory_bias=None,
              cache=None):
        """ Builds attention context.

        Args:
            query: Attention query tensor with shape [batch_size, length_q, channels_query].
              If None, it indicates self-attention.
            memory: Attention values tensor with shape [batch_size, length_m, channels_value].
            memory_length: The number of attention values, [batch_size,].
            memory_bias: The bias tensor for attention values with shape [batch_size, 1, 1, timesteps].
            cache: A dictionary containing pre-projected keys and values.

        Returns: The result of the attention transformation. A tuple
        `(attention_scores, attention_context)`. The `attention_scores`
        has shape [batch_size, num_heads, length_q, length_k]. The
        `attention_context` has shape [batch_size, length_q, output_depth].
        """
        with tf.variable_scope(self.name, values=[query, memory]):
            query_is_2d = False
            if query is not None and query.get_shape().ndims == 2:
                # for using MultiHeadAttention in RNN-based decoders
                query_is_2d = True
                query = tf.expand_dims(query, axis=1)
            # compute q, k, v
            q, k, v = self._compute_qkv(query, memory, cache)

            # after split_last_dimension: [batch_size, length, depth]
            #           ==> [batch_size, length, num_heads, depth/num_heads]
            # after split_head: ==> [batch_size, num_heads, length, depth/num_heads]
            split_head = lambda _x, _nh: tf.transpose(
                split_last_dimension(_x, _nh), [0, 2, 1, 3])
            # [batch_size, num_heads, length, depth/num_heads] ==> [batch_size, length, depth]
            combine_head = lambda _x: combine_last_two_dimensions(
                tf.transpose(_x, [0, 2, 1, 3]))

            # [batch_size, num_heads, length_q/k/v, depth/num_heads]
            q = split_head(q, self._num_heads)
            k = split_head(k, self._num_heads)
            v = split_head(v, self._num_heads)
            # last dim of q, k, v after split_head
            key_depth_per_head = self._attention_key_depth // self._num_heads
            q *= key_depth_per_head**(-0.5)  # scale the query

            if memory_bias is None:
                if memory_length is not None:
                    assert query is not None, "Unseen error may occur. Please CHECK."
                    memory_bias = MultiHeadAttention.attention_length_to_bias(
                        tf.shape(memory)[1], memory_length)

            # compute attention weight, [batch_size, num_heads, length_q, length_k]
            attention_weight = self.att_fn(q, k, memory_bias)
            # sum over attention values, [batch_size, num_heads, length_q, depth/num_heads]
            attention_context = tf.matmul(attention_weight, v)

            # combined: [batch_size, length_q, depth_value]
            attention_context = combine_head(attention_context)
            # linear transform
            attention_context = conv1d(attention_context,
                                       self._output_depth,
                                       kernel_size=1,
                                       name="output_transform")
            if query_is_2d:
                # attention context: [batch_size, depth_value]
                attention_context = tf.squeeze(attention_context, axis=1)
                # attention weight: [batch_size, num_heads, length_k]
                attention_weight = tf.squeeze(attention_weight, axis=2)
            return attention_weight, attention_context
Example #9
0
    def _compute_qkv(self,
                     query,
                     keys,
                     memory,
                     query_is_projected=False,
                     key_is_projected=False):
        """ Computes linear transformations of query, key
         and value.

        Args:
            query: Attention query tensor with shape [batch_size, length_q, channels_query].
              If None, it indicates self-attention.
            keys: Attention keys tensor with shape [batch_size, length_k, channels_key].
            memory: Attention values tensor with shape
              [batch_size, length_m, channels_value]
            query_is_projected: Whether the `query` is already projected.
            key_is_projected: Whether the `keys` is already projected.
        Returns: A tuple `(query_transformed, key_transformed,
          memory_transformed)`.
        """
        if query is None:
            # indicating self-attention, query and key are both the same as memory
            _ = keys
            combined = conv1d(memory,
                              self._attention_key_depth * 2 +
                              self._attention_value_depth,
                              kernel_size=1,
                              name="qkv_transform")
            q, k, v = tf.split(combined, [
                self._attention_key_depth, self._attention_key_depth,
                self._attention_value_depth
            ],
                               axis=2)
            return q, k, v
        else:
            # encoder-decoder attention
            if query_is_projected:
                q = query
            else:
                q = conv1d(query,
                           self._attention_key_depth,
                           kernel_size=1,
                           name="q_transform",
                           padding="VALID")
            if key_is_projected:
                k = keys
                v = conv1d(memory,
                           self._attention_value_depth,
                           kernel_size=1,
                           name="v_transform",
                           padding="VALID")
            else:
                kv_combined = conv1d(memory,
                                     self._attention_key_depth +
                                     self._attention_value_depth,
                                     kernel_size=1,
                                     name="kv_transform",
                                     padding="VALID")
                k, v = tf.split(
                    kv_combined,
                    [self._attention_key_depth, self._attention_value_depth],
                    axis=2)
            return q, k, v
Example #10
0
    def build(self,
              query,
              keys,
              memory,
              memory_length=None,
              memory_bias=None,
              query_is_projected=False,
              key_is_projected=False):
        """ Builds attention context.

        Args:
            query: Attention query tensor with shape [batch_size, length_q, channels_query].
              If None, it indicates self-attention.
            keys: Attention keys tensor with shape [batch_size, length_k, channels_key].
            memory: Attention values tensor with shape [batch_size, length_m, channels_value].
            memory_length: The number of attention values, [batch_size,].
            memory_bias: The bias tensor for attention values with shape [batch_size, 1, 1, timesteps].
            query_is_projected: Whether the `query` is already projected.
            key_is_projected: Whether the `keys` is already projected.

        Returns: The result of the attention transformation. A tuple
        `(attention_scores, attention_context)`. The `attention_scores`
        has shape [batch_size, num_heads, length_q, length_k]. The
        `attention_context` has shape [batch_size, length_q, output_depth].
        """
        with tf.variable_scope(self.name, values=[query, memory]):
            query_is_2d = False
            if query is not None and query.get_shape().ndims == 2:
                # for using MultiHeadAttention in RNN-based decoders
                query_is_2d = True
                query = tf.expand_dims(query, axis=1)
            # compute q, k, v
            q, k, v = self._compute_qkv(query,
                                        keys,
                                        memory,
                                        query_is_projected=query_is_projected,
                                        key_is_projected=key_is_projected)

            # after split_last_dimension: [batch_size, length, depth]
            #           ==> [batch_size, length, num_heads, depth/num_heads]
            # after split_head: ==> [batch_size, num_heads, length, depth/num_heads]
            split_head = lambda _x, _nh: tf.transpose(
                split_last_dimension(_x, _nh), [0, 2, 1, 3])
            # [batch_size, num_heads, length, depth/num_heads] ==> [batch_size, length, depth]
            combine_head = lambda _x: combine_last_two_dimensions(
                tf.transpose(_x, [0, 2, 1, 3]))

            # [batch_size, num_heads, length_q/k/v, depth/num_heads]
            q = split_head(q, self._num_heads)
            k = split_head(k, self._num_heads)
            v = split_head(v, self._num_heads)
            # last dim of q, k, v after split_head
            key_depth_per_head = self._attention_key_depth // self._num_heads
            q *= key_depth_per_head**(-0.5)  # scale the query

            if memory_bias is None:
                if memory_length is not None:
                    assert query is not None, "Unseen error may occur. Please CHECK."
                    input_padding = embedding_to_padding(memory, memory_length)
                    # [batch_size, 1, 1, timesteps], FLOAT_MIN for padding, 0.0 for non-padding
                    memory_bias = attention_bias_ignore_padding(input_padding)

            # compute attention weight, [batch_size, num_heads, length_q, length_k]
            attention_weight = self.att_fn(q, k, memory_bias)
            # sum over attention values, [batch_size, num_heads, length_q, depth/num_heads]
            attention_context = tf.matmul(attention_weight, v)

            # combined: [batch_size, length_q, depth_value]
            attention_context = combine_head(attention_context)
            # linear transform
            attention_context = conv1d(attention_context,
                                       self._output_depth,
                                       kernel_size=1,
                                       name="output_transform")
            # TODO here the dimension of attention_weight is not checked for output_attention
            if query_is_2d:
                attention_context = tf.squeeze(attention_context, axis=1)
            return attention_weight, attention_context
Example #11
0
    def build(self,
              query,
              keys,
              memory,
              memory_length=None,
              memory_bias=None,
              query_is_projected=True,
              key_is_projected=True):
        """ Builds attention context.

        Args:
            query: Attention query tensor with shape
              [batch_size, length_q, channels_query].
            keys: Attention keys tensor with shape
              [batch_size, length_k, channels_key]. Not used here.
              but use `memory` as the `keys`.
            memory: Attention values tensor with shape
              [batch_size, length_m, channels_value]
            memory_length: The number of attention values, [batch_size,].
            memory_bias: The bias tensor for attention values with
              shape [batch_size, 1, 1, timesteps].
            query_is_projected: Whether the `query` is already projected.
              Not used here.
            key_is_projected: Whether the `keys` is already projected.
              Not used here.

        Returns: The result of the attention transformation. A tuple
        `(attention_scores, attention_context)`. The `attention_scores`
        has shape [batch_size, num_heads, length_q, length_k]. The
        `attention_context` has shape [batch_size, length_q, output_depth].
        """
        _ = keys
        _ = query_is_projected
        _ = key_is_projected
        with tf.variable_scope(self.name, values=[query, memory]):
            # compute q, k, v
            q, k, v = self._compute_qkv(query, memory)

            # after split_last_dimension: [batch_size, length, depth]
            #           ==> [batch_size, length, num_heads, depth/num_heads]
            # after split_head: ==> [batch_size, num_heads, length, depth/num_heads]
            split_head = lambda _x, _nh: tf.transpose(
                algebra_ops.split_last_dimension(_x, _nh), [0, 2, 1, 3])
            # [batch_size, num_heads, length, depth/num_heads] ==> [batch_size, length, depth]
            combine_head = lambda _x: algebra_ops.combine_last_two_dimensions(
                tf.transpose(_x, [0, 2, 1, 3]))

            # [batch_size, num_heads, length_q/k/v, depth/num_heads]
            q = split_head(q, self._num_heads)
            k = split_head(k, self._num_heads)
            v = split_head(v, self._num_heads)
            # last dim of q, k, v after split_head
            key_depth_per_head = self._attention_key_depth // self._num_heads
            q *= key_depth_per_head**(-0.5)  # scale the query

            # compute attention weight, [batch_size, num_heads, length_q, length_k]
            attention_weight = self.att_fn(q, k, memory_bias)
            # sum over attention values, [batch_size, num_heads, length_q, depth/num_heads]
            attention_context = tf.matmul(attention_weight, v)

            # combined: [batch_size, length_q, depth_value]
            attention_context = combine_head(attention_context)
            # linear transform
            attention_context = conv1d(attention_context,
                                       self._output_depth,
                                       kernel_size=1,
                                       name="output_transform")
            return attention_weight, attention_context
Example #12
0
    def build(self,
              query,
              memory,
              memory_length=None,
              memory_bias=None,
              cache=None):
        """ Builds attention context.

        Args:
            query: Attention query tensor with shape [batch_size, length_q, channels_query].
              If None, it indicates self-attention.
            memory: Attention values tensor with shape [batch_size, length_m, channels_value].
            memory_length: The number of attention values, [batch_size,].
            memory_bias: The bias tensor for attention values with shape [batch_size, 1, 1, timesteps].
            cache: A dictionary containing pre-projected keys and values.

        Returns: The result of the attention transformation. A tuple
        `(attention_scores, attention_context)`. The `attention_scores`
        has shape [batch_size, num_heads, length_q, length_k]. The
        `attention_context` has shape [batch_size, length_q, output_depth].
        """
        with tf.variable_scope(self.name, values=[query, memory]):
            query_is_2d = False
            if query is not None and query.get_shape().ndims == 2:
                # for using MultiHeadAttention in RNN-based decoders
                query_is_2d = True
                query = tf.expand_dims(query, axis=1)
            # compute q, k, v
            q, k, v = self._compute_qkv(query, memory, cache)

            # after split_last_dimension: [batch_size, length, depth]
            #           ==> [batch_size, length, num_heads, depth/num_heads]
            # after split_head: ==> [batch_size, num_heads, length, depth/num_heads]
            split_head = lambda _x, _nh: tf.transpose(split_last_dimension(_x, _nh),
                                                      [0, 2, 1, 3])
            # [batch_size, num_heads, length, depth/num_heads] ==> [batch_size, length, depth]
            combine_head = lambda _x: combine_last_two_dimensions(
                tf.transpose(_x, [0, 2, 1, 3]))

            # [batch_size, num_heads, length_q/k/v, depth/num_heads]
            q = split_head(q, self._num_heads)
            k = split_head(k, self._num_heads)
            v = split_head(v, self._num_heads)
            # last dim of q, k, v after split_head
            key_depth_per_head = self._attention_key_depth // self._num_heads
            q *= key_depth_per_head ** (-0.5)  # scale the query

            if memory_bias is None:
                if memory_length is not None:
                    assert query is not None, "Unseen error may occur. Please CHECK."
                    memory_bias = MultiHeadAttention.attention_length_to_bias(tf.shape(memory)[1], memory_length)

            # compute attention weight, [batch_size, num_heads, length_q, length_k]
            attention_weight = self.att_fn(q, k, memory_bias)
            # sum over attention values, [batch_size, num_heads, length_q, depth/num_heads]
            attention_context = tf.matmul(attention_weight, v)

            # combined: [batch_size, length_q, depth_value]
            attention_context = combine_head(attention_context)
            # linear transform
            attention_context = conv1d(attention_context, self._output_depth, kernel_size=1,
                                       name="output_transform")
            if query_is_2d:
                # attention context: [batch_size, depth_value]
                attention_context = tf.squeeze(attention_context, axis=1)
                # attention weight: [batch_size, num_heads, length_k]
                attention_weight = tf.squeeze(attention_weight, axis=2)
            return attention_weight, attention_context