예제 #1
0
    def attention_core(self, q, k, v, attn_mask):
        """ Core math operations of multihead attention layer """
        q = self._split_heads(
            q)  # [batch_size * n_heads * n_q * (k_dim/n_heads)]
        k = self._split_heads(
            k)  # [batch_size * n_heads * n_kv * (k_dim/n_heads)]
        v = self._split_heads(
            v)  # [batch_size * n_heads * n_kv * (v_dim/n_heads)]

        key_depth_per_head = self.key_depth / self.num_heads
        q = q / math.sqrt(key_depth_per_head)

        # Dot-product attention
        # logits: (batch_size * n_heads * n_q * n_kv)
        attn_bias = MultiHeadAttn.ATTN_BIAS_VALUE * (1 - attn_mask)
        logits = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) + attn_bias
        weights = tf.nn.softmax(logits)

        tf.add_to_collection("AttnWeights", weights)
        tf.add_to_collection(
            lib.meta.ATTENTIONS,
            lib.meta.Attention(self.scope, weights, logits, attn_mask))

        if is_dropout_enabled():
            weights = dropout(weights, 1.0 - self.attn_dropout)
        x = tf.matmul(
            weights,  # [batch_size * n_heads * n_q * n_kv]
            v  # [batch_size * n_heads * n_kv * (v_deph/n_heads)]
        )
        combined_x = self._combine_heads(x)

        if is_dropout_enabled():
            combined_x = dropout(combined_x, 1.0 - self.attn_value_dropout)
        return combined_x
예제 #2
0
    def __call__(self, query_inp, attn_mask, kv_inp=None, kv=None):
        """
        query_inp: [batch_size * n_q * inp_dim]
        attn_mask: [batch_size * 1 * n_q * n_kv]
        kv_inp: [batch_size * n_kv * inp_dim]
        -----------------------------------------------
        results: [batch_size * n_q * output_depth]
        """
        assert kv is None or kv_inp is None, "please only feed one of kv or kv_inp"
        with tf.name_scope(self.name) as scope:
            if kv_inp is not None or kv is not None:
                q = self.query_conv(query_inp)
                if kv is None:
                    kv = self.kv_conv(kv_inp)
                k, v = tf.split(kv, [self.key_depth, self.value_depth], axis=2)
            else:
                combined = self.combined_conv(query_inp)
                q, k, v = tf.split(
                    combined,
                    [self.key_depth, self.key_depth, self.value_depth],
                    axis=2)
            q = self._split_heads(
                q)  # [batch_size * n_heads * n_q * (k_dim/n_heads)]
            k = self._split_heads(
                k)  # [batch_size * n_heads * n_kv * (k_dim/n_heads)]
            v = self._split_heads(
                v)  # [batch_size * n_heads * n_kv * (v_dim/n_heads)]

            key_depth_per_head = self.key_depth / self.num_heads
            q = q / math.sqrt(key_depth_per_head)

            # Dot-product attention
            # logits: (batch_size * n_heads * n_q * n_kv)
            attn_bias = MultiHeadAttn.ATTN_BIAS_VALUE * (1 - attn_mask)
            logits = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2
                                                        ])) + attn_bias
            weights = tf.nn.softmax(logits)

            tf.add_to_collection("AttnWeights", weights)

            tf.add_to_collection(
                lib.meta.ATTENTIONS,
                lib.meta.Attention(scope, weights, logits, attn_mask))

            if is_dropout_enabled():
                weights = dropout(weights, 1.0 - self.attn_dropout)
            x = tf.matmul(
                weights,  # [batch_size * n_heads * n_q * n_kv]
                v  # [batch_size * n_heads * n_kv * (v_deph/n_heads)]
            )
            # x: [batch, n_heads, n_q, (v_deph/n_heads)]

            # ========================  apply the gate  ========================
            gated_x = self.gate * x
            # ==================================================================

            combined_x = self._combine_heads(gated_x)

            if is_dropout_enabled():
                combined_x = dropout(combined_x, 1.0 - self.attn_value_dropout)

            outputs = self.out_conv(combined_x)

            return outputs