def stacked_multihead_attention(x,
                                num_blocks,
                                num_heads,
                                use_residual,
                                is_training,
                                dropout_rate,
                                reuse=False):
    num_hiddens = x.get_shape().as_list()[-1]
    with tf.variable_scope('stacked_multihead_attention', reuse=reuse):
        for i in range(num_blocks):
            with tf.variable_scope('multihead_block_{}'.format(i),
                                   reuse=reuse):
                x, attentions = multihead_attention(x,
                                                    x,
                                                    x,
                                                    use_residual,
                                                    is_training,
                                                    dropout_rate,
                                                    num_heads=num_heads,
                                                    reuse=reuse)
                x = feed_forward(x,
                                 num_hiddens=num_hiddens,
                                 activation=tf.nn.relu,
                                 reuse=reuse)
    return x, attentions
def multihead_attention(queries,
                        keys,
                        values,
                        use_residual,
                        is_training,
                        dropout_rate,
                        num_units=None,
                        num_heads=8,
                        reuse=False):
    with tf.variable_scope('multihead-attention', reuse=reuse):
        if num_units is None:
            num_units = queries.get_shape().as_list()[-1]
        Q = linear(queries)
        K = linear(keys)
        V = linear(values)

        Q = tf.concat(tf.split(Q, num_heads, axis=2), axis=0)
        K = tf.concat(tf.split(K, num_heads, axis=2), axis=0)
        V = tf.concat(tf.split(V, num_heads, axis=2), axis=0)

        Q_K_V, attentions = scaled_dot_product_attention(Q, K, V)
        Q_K_V = dropout(Q_K_V, is_training, rate=dropout_rate)
        Q_K_V_ = tf.concat(tf.split(Q_K_V, num_heads, axis=0), axis=2)

        output = feed_forward(Q_K_V_, num_units, reuse=reuse)

        if use_residual:
            output = residual(output, queries, reuse=reuse)
        # output = normalization(output)

    return output, attentions