def attention(
        self, query: tf.Tensor, decoder_prev_state: tf.Tensor,
        decoder_input: tf.Tensor, loop_state: MultiHeadLoopState
    ) -> Tuple[tf.Tensor, MultiHeadLoopState]:
        """Run a multi-head attention getting context vector for a given query.

        This method is an API-wrapper for the global function 'attention'
        defined in this module. Transforms a query of shape(batch, query_size)
        to shape(batch, 1, query_size) and applies the attention function.
        Output context has shape(batch, 1, value_size) and weights
        have shape(batch, n_heads, 1, time(k)). The output is then processed
        to produce output vector of contexts and the following attention
        loop state.

        Arguments:
            query: Input query for the current decoding step
                of shape(batch, query_size).
            decoder_prev_state: Previous state of the decoder.
            decoder_input: Input to the RNN cell of the decoder.
            loop_state: Attention loop state.

        Returns:
            Vector of contexts and the following attention loop state.
        """

        context_3d, weights_4d = attention(
            queries=tf.expand_dims(query, 1),
            keys=self.attention_keys,
            values=self.attention_values,
            keys_mask=self.attention_mask,
            num_heads=self.n_heads,
            dropout_callback=lambda x: dropout(x, self.dropout_keep_prob, self.
                                               train_mode))

        # head_weights_3d is HEAD-wise list of (batch, 1, 1, time(keys))
        head_weights_3d = tf.split(weights_4d, self.n_heads, axis=1)

        context = tf.squeeze(context_3d, axis=1)
        head_weights = [tf.squeeze(w, axis=[1, 2]) for w in head_weights_3d]

        next_contexts = tf.concat(
            [loop_state.contexts,
             tf.expand_dims(context, 0)], axis=0)
        next_head_weights = [
            tf.concat([
                loop_state.head_weights[i],
                tf.expand_dims(head_weights[i], 0)
            ],
                      axis=0) for i in range(self.n_heads)
        ]

        next_loop_state = MultiHeadLoopState(contexts=next_contexts,
                                             head_weights=next_head_weights)

        return context, next_loop_state
def empty_multi_head_loop_state(
        batch_size: Union[int, tf.Tensor], num_heads: Union[int, tf.Tensor],
        length: Union[int, tf.Tensor],
        dimension: Union[int, tf.Tensor]) -> MultiHeadLoopState:

    return MultiHeadLoopState(
        contexts=tf.zeros(shape=[0, batch_size, dimension],
                          dtype=tf.float32,
                          name="contexts"),
        head_weights=[
            tf.zeros(shape=[0, batch_size, length],
                     dtype=tf.float32,
                     name="distributions_head{}".format(i))
            for i in range(num_heads)
        ])