def build_key(self):
        with tf.compat.v1.variable_scope("embeddings"):
            input_tensor = self.get_embeddings(self.input_ids,
                                               self.segment_ids)
            self.input_shape = bc.get_shape_list(input_tensor, expected_rank=3)

        with tf.compat.v1.variable_scope("encoder"):
            self.attention_mask = bc.create_attention_mask_from_input_mask(
                input_tensor, self.input_mask)
            prev_output = bc.reshape_to_matrix(input_tensor)

            for layer_idx in range(self.layers_before_key_pooling):
                with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                    intermediate_output, prev_output = self.forward_layer(
                        prev_output)
                    intermediate_output = tf.reshape(intermediate_output, [
                        self.batch_size * self.seq_length,
                        self.config.intermediate_size
                    ])
                    final_output = bc.reshape_from_matrix(
                        prev_output, self.input_shape)
                    self.all_layer_outputs.append(final_output)

        self.last_intermediate_output = intermediate_output

        self.last_key_layer = prev_output
        with tf.compat.v1.variable_scope("mr_key"):
            key_vectors = bc.dense(self.key_dimension,
                                   self.initializer)(intermediate_output)
            self.debug1 = key_vectors
            key_vectors = tf.reshape(
                key_vectors,
                [self.batch_size, self.seq_length, self.key_dimension])
            key_output = self.key_pooling(key_vectors)
        return key_output
Ejemplo n.º 2
0
    def call(self, input_ids, input_mask, segment_ids):
        with tf.compat.v1.variable_scope("embeddings"):
            self.embedding_layer = Embedding2()
            input_tensor = self.embedding_layer.apply(
                input_ids, segment_ids, self.config.initializer_range,
                self.config.vocab_size, self.config.embedding_size,
                self.config.type_vocab_size,
                self.config.max_position_embeddings,
                self.config.hidden_dropout_prob, self.use_one_hot_embeddings)
            input_tensor = self.embedding_projection(input_tensor)
            self.embedding_output = input_tensor
            input_shape = bc.get_shape_list2(input_tensor)
            batch_size, seq_length, _ = input_shape
        with tf.compat.v1.variable_scope("encoder"):
            self.attention_mask = bc.create_attention_mask_from_input_mask2(
                input_tensor, input_mask)
            prev_output = bc.reshape_to_matrix(input_tensor)
            with tf.compat.v1.variable_scope("layer"):
                intermediate_output, prev_output = self.layer.apply(
                    prev_output, batch_size, seq_length, self.attention_mask)
                final_output = bc.reshape_from_matrix2(prev_output,
                                                       input_shape)
                self.all_layer_outputs.append(final_output)

            for layer_idx in range(1, self.config.num_hidden_layers):
                with tf.compat.v1.variable_scope("layer", reuse=True):
                    intermediate_output, prev_output = self.layer.apply(
                        prev_output, batch_size, seq_length,
                        self.attention_mask)
                    final_output = bc.reshape_from_matrix2(
                        prev_output, input_shape)
                    self.all_layer_outputs.append(final_output)

        return prev_output
Ejemplo n.º 3
0
    def call(self, input_vectors, use_context):
        # input_vectors : [num_window, hidden_size]
        batch_size, seq_length, hidden_dim = bc.get_shape_list2(input_vectors)
        # Add position embedding
        input_vectors = bc.embedding_postprocessor2(
            input_tensor=input_vectors,
            token_type_table=self.token_type_table,
            full_position_embeddings=self.full_position_embeddings,
            use_token_type=False,
            token_type_ids=None,
            token_type_vocab_size=1,
            use_position_embeddings=True,
            max_position_embeddings=self.config.max_num_window,
            dropout_prob=self.config.hidden_dropout_prob)

        input_shape = [batch_size, seq_length]

        attention_mask = tf.ones([batch_size, seq_length, seq_length],
                                 tf.int32) * tf.expand_dims(use_context, 2)
        with tf.compat.v1.variable_scope("mid"):
            prev_output = bc.reshape_to_matrix(input_vectors)
            for layer_idx in range(self.n_layers):
                with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                    intermediate_output, prev_output = self.layer_list[
                        layer_idx].apply(prev_output, batch_size, seq_length,
                                         attention_mask)
                    final_output = bc.reshape_from_matrix2(
                        prev_output, input_shape)
                    self.all_layer_outputs.append(final_output)

        return prev_output
Ejemplo n.º 4
0
    def apply_3d(self, input_tensor, batch_size, seq_length, attention_mask):
        input_shape = bc.get_shape_list2(input_tensor)
        input_tensor = bc.reshape_to_matrix(input_tensor)
        intermediate_output, layer_output = self.apply(input_tensor,
                                                       batch_size, seq_length,
                                                       attention_mask)

        return bc.reshape_from_matrix2(layer_output, input_shape)
    def build_by_attention(self, key):
        hidden_size = self.config.hidden_size
        with tf.compat.v1.variable_scope("embeddings"):
            lexical_tensor = self.get_lexical_lookup()
            self.embedding_output = self.embedding_postprocessor(
                d_input_ids=self.input_ids,
                input_tensor=lexical_tensor,
                use_token_type=True,
                token_type_ids=self.segment_ids,
                token_type_vocab_size=self.config.type_vocab_size,
                token_type_embedding_name="token_type_embeddings",
                use_position_embeddings=True,
                position_embedding_name="position_embeddings",
                initializer_range=self.config.initializer_range,
                max_position_embeddings=self.config.max_position_embeddings,
                dropout_prob=self.config.hidden_dropout_prob)
            input_tensor = self.embedding_output
            #[ def_per_batch, seq_length, hidden_size]

        with tf.compat.v1.variable_scope("encoder"):
            num_key_tokens = self.ssdr_config.num_key_tokens
            project_dim = hidden_size * num_key_tokens
            raw_key = bc.dense(project_dim, self.initializer)(key)
            key_tokens = tf.reshape(
                raw_key, [self.batch_size, num_key_tokens, hidden_size])

            input_tensor = tf.concat([key_tokens, input_tensor], axis=1)
            input_shape = bc.get_shape_list(input_tensor, expected_rank=3)

            mask_for_key = tf.ones([self.batch_size, num_key_tokens],
                                   dtype=tf.int64)
            self.input_mask = tf.cast(self.input_mask, tf.int64)
            self.input_mask = tf.concat([mask_for_key, self.input_mask],
                                        axis=1)
            self.seq_length = self.seq_length + num_key_tokens

            self.attention_mask = bc.create_attention_mask_from_input_mask(
                input_tensor, self.input_mask)
            prev_output = bc.reshape_to_matrix(input_tensor)
            for layer_idx in range(self.ssdr_config.num_hidden_layers):
                with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                    intermediate_output, prev_output = self.forward_layer(
                        prev_output)
                    self.all_layer_outputs.append(prev_output)

            final_output = bc.reshape_from_matrix(prev_output, input_shape)
            self.scores = bc.dense(1, self.initializer)(final_output[:, 0, :])

            if self.ssdr_config.info_pooling_method == "first_tokens":
                self.info_output = final_output[:, :num_key_tokens, :]
            elif self.ssdr_config.info_pooling_method == "max_pooling":
                self.info_output = tf.reduce_max(final_output, axis=1)

        return self.scores, self.info_output
Ejemplo n.º 6
0
 def apply_topic_vector(self, input_tensor, topic_ids, layer_idx):
     if layer_idx == 0:
         return self.add_topic_vector(input_tensor, topic_ids)
     else:
         input_tensor = bc.reshape_from_matrix2(input_tensor,
                                                self.input_shape)
         input_tensor = input_tensor[:, -self.topic_emb_len, :]
         input_tensor = tf.concat(
             [input_tensor, self.topic_tensor[layer_idx]], axis=1)
         input_tensor = bc.reshape_to_matrix(input_tensor)
         return input_tensor
Ejemplo n.º 7
0
    def call(self, input_vectors, attention_mask):
        prev_output = input_vectors
        input_shape = bc.get_shape_list2(input_vectors)
        batch_size, seq_length, _ = input_shape
        prev_output = bc.reshape_to_matrix(prev_output)
        for layer_idx in range(self.n_layers):
            with tf.compat.v1.variable_scope(
                    "layer_%d" % (layer_idx + self.layer_idx_base)):
                layer = self.layer_list[layer_idx]
                intermediate_output, prev_output = layer.apply(
                    prev_output, batch_size, seq_length, attention_mask)
                final_output = bc.reshape_from_matrix2(prev_output,
                                                       input_shape)
                self.all_layer_outputs.append(final_output)

        return prev_output
Ejemplo n.º 8
0
    def call(self, input_ids, input_mask, segment_ids, topic_ids):
        with tf.compat.v1.variable_scope("embeddings"):
            self.embedding_layer = Embedding(self.config,
                                             self.use_one_hot_embeddings)
            input_tensor = self.embedding_layer.apply(input_ids, segment_ids)
            self.embedding_output = input_tensor

        input_mask = self.extend_input_mask(input_mask)
        topic_tensor, _ = bc.embedding_lookup2(topic_ids, self.n_topics,
                                               self.topic_embedding,
                                               self.topic_embedding_size,
                                               self.use_one_hot_embeddings)
        self.topic_tensor = tf.reshape(
            topic_tensor, [-1, self.topic_emb_len, self.hidden_size])

        input_tensor = tf.concat([input_tensor, self.topic_tensor], axis=1)
        input_shape = bc.get_shape_list2(input_tensor)
        batch_size, seq_length, _ = input_shape

        with tf.compat.v1.variable_scope("encoder"):
            self.attention_mask = bc.create_attention_mask_from_input_mask2(
                input_tensor, input_mask)
            prev_output = bc.reshape_to_matrix(input_tensor)
            for layer_idx in range(self.n_layers):
                with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                    layer = self.layer_list[layer_idx]
                    intermediate_output, prev_output = layer.apply(
                        prev_output, batch_size, seq_length,
                        self.attention_mask)
                    final_output = bc.reshape_from_matrix2(
                        prev_output, input_shape)
                    self.all_layer_outputs.append(final_output)

        self.embedding_table = self.embedding_layer.embedding_table
        self.sequence_output = final_output[:, :-self.topic_emb_len]
        self.pooled_output = mimic_pooling(self.sequence_output,
                                           self.config.hidden_size,
                                           self.config.initializer_range)
        return self.sequence_output
Ejemplo n.º 9
0
    def build(self, value_out, locations):
        with tf.compat.v1.variable_scope("embeddings"):
            input_tensor = self.get_embeddings(self.input_ids,
                                               self.segment_ids)
            self.input_shape = bc.get_shape_list(input_tensor, expected_rank=3)

        with tf.compat.v1.variable_scope("encoder"):
            self.attention_mask = bc.create_attention_mask_from_input_mask(
                input_tensor, self.input_mask)
            prev_output = bc.reshape_to_matrix(input_tensor)
            prev_output = tf.tensor_scatter_nd_update(prev_output, locations,
                                                      value_out)

            for layer_idx in range(self.config.num_hidden_layers):
                with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                    intermediate_output, prev_output = self.forward_layer(
                        prev_output)
                    final_output = bc.reshape_from_matrix(
                        prev_output, self.input_shape)
                    self.all_layer_outputs.append(final_output)

        return self.all_layer_outputs
Ejemplo n.º 10
0
    def call(self, input_ids, input_mask, segment_ids):
        with tf.compat.v1.variable_scope("embeddings"):
            self.embedding_layer = Embedding(self.config,
                                             self.use_one_hot_embeddings)
            input_tensor = self.embedding_layer.apply(input_ids, segment_ids)
            self.embedding_output = input_tensor
            input_shape = bc.get_shape_list2(input_tensor)
            batch_size, seq_length, _ = input_shape
        with tf.compat.v1.variable_scope("lower"):
            self.attention_mask = bc.create_attention_mask_from_input_mask2(
                input_tensor, input_mask)
            prev_output = bc.reshape_to_matrix(input_tensor)
            for layer_idx in range(self.n_layers):
                with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                    layer = self.layer_list[layer_idx]
                    intermediate_output, prev_output = layer.apply(
                        prev_output, batch_size, seq_length,
                        self.attention_mask)
                    final_output = bc.reshape_from_matrix2(
                        prev_output, input_shape)
                    self.all_layer_outputs.append(final_output)

        return prev_output
Ejemplo n.º 11
0
    def build(self):
        with tf.compat.v1.variable_scope("dict"):
            with tf.compat.v1.variable_scope("embeddings"):
                input_tensor = self.get_embeddings(self.input_ids,
                                                   self.segment_ids)

            with tf.compat.v1.variable_scope("encoder"):
                num_key_tokens = self.ssdr_config.num_key_tokens
                input_shape = bc.get_shape_list(input_tensor, expected_rank=3)

                mask_for_key = tf.ones([self.batch_size, num_key_tokens],
                                       dtype=tf.int64)
                self.input_mask = tf.cast(self.input_mask, tf.int64)
                self.input_mask = tf.concat([mask_for_key, self.input_mask],
                                            axis=1)
                self.seq_length = self.seq_length + num_key_tokens

                self.attention_mask = bc.create_attention_mask_from_input_mask(
                    input_tensor, self.input_mask)
                prev_output = bc.reshape_to_matrix(input_tensor)
                for layer_idx in range(self.ssdr_config.num_hidden_layers):
                    with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                        intermediate_output, prev_output = self.forward_layer(
                            prev_output)
                        self.all_layer_outputs.append(prev_output)

                final_output = bc.reshape_from_matrix(prev_output, input_shape)
                self.scores = bc.dense(1, self.initializer)(final_output[:,
                                                                         0, :])

                if self.ssdr_config.info_pooling_method == "first_tokens":
                    self.info_output = final_output[:, :num_key_tokens, :]
                elif self.ssdr_config.info_pooling_method == "max_pooling":
                    self.info_output = tf.reduce_max(final_output, axis=1)

            return self.scores, self.info_output
Ejemplo n.º 12
0
def transformer_model(input_tensor,
                    attention_mask=None,
                    input_mask=None,
                    hidden_size=768,
                    num_hidden_layers=12,
                    num_attention_heads=12,
                    mr_num_route=10,
                    intermediate_size=3072,
                    intermediate_act_fn=gelu,
                    hidden_dropout_prob=0.1,
                    attention_probs_dropout_prob=0.1,
                    initializer_range=0.02,
                    is_training=True,
                    do_return_all_layers=False):
    """Multi-headed, multi-layer Transformer from "Attention is All You Need".

    This is almost an exact implementation of the original Transformer encoder.

    See the original paper:
    https://arxiv.org/abs/1706.03762

    Also see:
    https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py

    Args:
        input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
        attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
            seq_length], with 1 for positions that can be attended to and 0 in
            positions that should not be.
        hidden_size: int. Hidden size of the Transformer.
        num_hidden_layers: int. Number of layers (blocks) in the Transformer.
        num_attention_heads: int. Number of attention heads in the Transformer.
        intermediate_size: int. The size of the "intermediate" (a.k.a., feed
            forward) layer.
        intermediate_act_fn: function. The non-linear activation function to apply
            to the output of the intermediate/feed-forward layer.
        hidden_dropout_prob: float. Dropout probability for the hidden layers.
        attention_probs_dropout_prob: float. Dropout probability of the attention
            probabilities.
        initializer_range: float. Range of the initializer (stddev of truncated
            normal).
        do_return_all_layers: Whether to also return all layers or just the final
            layer.

    Returns:
        float Tensor of shape [batch_size, seq_length, hidden_size], the final
        hidden layer of the Transformer.

    Raises:
        ValueError: A Tensor shape or parameter is invalid.
    """
    if hidden_size % num_attention_heads != 0:
        raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))

    attention_head_size = int(hidden_size / num_attention_heads)
    input_shape = get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape[0]
    seq_length = input_shape[1]
    input_width = input_shape[2]

    initializer = create_initializer(initializer_range)

    ext_tensor = tf.compat.v1.get_variable("ext_tensor",
                                 shape=[num_hidden_layers, mr_num_route, EXT_SIZE ,hidden_size],
                                 initializer=initializer,
                                 )
    ext_tensor_inter = tf.compat.v1.get_variable("ext_tensor_inter",
                                       shape=[num_hidden_layers, mr_num_route, intermediate_size],
                                       initializer=initializer,
                                           )
    # The Transformer performs sum residuals on all layers so the input needs
    # to be the same as the hidden size.
    if input_width != hidden_size:
        raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
                                         (input_width, hidden_size))

    # We keep the representation as a 2D tensor to avoid re-shaping it back and
    # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
    # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
    # help the optimizer.
    prev_output = reshape_to_matrix(input_tensor)

    def is_mr_layer(layer_idx):
        if layer_idx > 1:
            return True
        else:
            return False

    all_layer_outputs = []
    for layer_idx in range(num_hidden_layers):
        if not is_mr_layer(layer_idx):
            with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                layer_input = prev_output

                with tf.compat.v1.variable_scope("attention"):
                    attention_heads = []
                    with tf.compat.v1.variable_scope("self"):
                        attention_head = attention_layer(
                                from_tensor=layer_input,
                                to_tensor=layer_input,
                                attention_mask=attention_mask,
                                num_attention_heads=num_attention_heads,
                                size_per_head=attention_head_size,
                                attention_probs_dropout_prob=attention_probs_dropout_prob,
                                initializer_range=initializer_range,
                                do_return_2d_tensor=True,
                                batch_size=batch_size,
                                from_seq_length=seq_length,
                                to_seq_length=seq_length)
                        attention_heads.append(attention_head)

                    attention_output = None
                    if len(attention_heads) == 1:
                        attention_output = attention_heads[0]
                    else:
                        # In the case where we have other sequences, we just concatenate
                        # them to the self-attention head before the projection.
                        attention_output = tf.concat(attention_heads, axis=-1)

                    # Run a linear projection of `hidden_size` then add a residual
                    # with `layer_input`.
                    with tf.compat.v1.variable_scope("output"):
                        attention_output = dense(hidden_size, initializer)(attention_output)
                        attention_output = dropout(attention_output, hidden_dropout_prob)
                        attention_output = layer_norm(attention_output + layer_input)

                # The activation is only applied to the "intermediate" hidden layer.
                with tf.compat.v1.variable_scope("intermediate"):
                    intermediate_output = dense(intermediate_size, initializer,
                                                activation=intermediate_act_fn)(attention_output)

                # Down-project back to `hidden_size` then add the residual.
                with tf.compat.v1.variable_scope("output"):
                    layer_output = dense(hidden_size, initializer)(intermediate_output)
                    layer_output = dropout(layer_output, hidden_dropout_prob)
                    layer_output = layer_norm(layer_output + attention_output)
                    prev_output = layer_output
                    all_layer_outputs.append(layer_output)

                with tf.compat.v1.variable_scope("mr_key"):
                    key_output = tf.keras.layers.Dense(
                        mr_num_route,
                        kernel_initializer=create_initializer(initializer_range))(intermediate_output)
                    key_output = dropout(key_output, hidden_dropout_prob)

                    if is_training:
                        key = tf.random.categorical(key_output, 1) # [batch_size, 1]
                        key = tf.reshape(key, [-1])
                    else:
                        key = tf.math.argmax(input=key_output, axis=1)

        else: # Case MR layer
            with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                layer_input = prev_output
                ext_slice = tf.gather(ext_tensor[layer_idx], key)
                ext_interm_slice = tf.gather(ext_tensor_inter[layer_idx], key)
                print("ext_slice (batch*seq, ", ext_slice.shape)
                with tf.compat.v1.variable_scope("attention"):
                    attention_heads = []
                    with tf.compat.v1.variable_scope("self"):
                        attention_head = attention_layer_w_ext(
                            from_tensor=layer_input,
                            to_tensor=layer_input,
                            attention_mask=attention_mask,
                            ext_slice=ext_slice,
                            num_attention_heads=num_attention_heads,
                            size_per_head=attention_head_size,
                            attention_probs_dropout_prob=attention_probs_dropout_prob,
                            initializer_range=initializer_range,
                            do_return_2d_tensor=True,
                            batch_size=batch_size,
                            from_seq_length=seq_length,
                            to_seq_length=seq_length)
                        attention_head = attention_head + ext_slice[:,EXT_ATT_OUT,:]
                        attention_heads.append(attention_head)

                    attention_output = None
                    if len(attention_heads) == 1:
                        attention_output = attention_heads[0]
                    else:
                        # In the case where we have other sequences, we just concatenate
                        # them to the self-attention head before the projection.
                        attention_output = tf.concat(attention_heads, axis=-1)

                    # Run a linear projection of `hidden_size` then add a residual
                    # with `layer_input`.
                    with tf.compat.v1.variable_scope("output"):
                        attention_output = dense(hidden_size, initializer)(attention_output)
                        attention_output = dropout(attention_output, hidden_dropout_prob)
                        attention_output = attention_output + ext_slice[:,EXT_ATT_PROJ,:]
                        attention_output = layer_norm(attention_output + layer_input)

                # The activation is only applied to the "intermediate" hidden layer.
                with tf.compat.v1.variable_scope("intermediate"):
                    intermediate_output = dense(intermediate_size, initializer,
                                                activation=intermediate_act_fn)(attention_output)
                    intermediate_output = ext_interm_slice + intermediate_output
                # Down-project back to `hidden_size` then add the residual.
                with tf.compat.v1.variable_scope("output"):
                    layer_output = dense(hidden_size, initializer)(intermediate_output)
                    layer_output = layer_output + ext_slice[:, EXT_LAYER_OUT,:]
                    layer_output = dropout(layer_output, hidden_dropout_prob)
                    layer_output = layer_norm(layer_output + attention_output)
                    prev_output = layer_output
                    all_layer_outputs.append(layer_output)

    if do_return_all_layers:
        final_outputs = []
        for layer_output in all_layer_outputs:
            final_output = reshape_from_matrix(layer_output, input_shape)
            final_outputs.append(final_output)
        return final_outputs, key
    else:
        final_output = reshape_from_matrix(prev_output, input_shape)
        return final_output, key
Ejemplo n.º 13
0
def attention_layer_w_ext(from_tensor,
                                        to_tensor,
                                        attention_mask=None,
                                        num_attention_heads=1,
                                        size_per_head=512,
                                        ext_slice=None, # [Num_tokens, n_items, hidden_dim]
                                        query_act=None,
                                        key_act=None,
                                        value_act=None,
                                        attention_probs_dropout_prob=0.0,
                                        initializer_range=0.02,
                                        do_return_2d_tensor=False,
                                        batch_size=None,
                                        from_seq_length=None,
                                        to_seq_length=None):
    """Performs multi-headed attention from `from_tensor` to `to_tensor`.

    This is an implementation of multi-headed attention based on "Attention
    is all you Need". If `from_tensor` and `to_tensor` are the same, then
    this is self-attention. Each timestep in `from_tensor` attends to the
    corresponding sequence in `to_tensor`, and returns a fixed-with vector.

    This function first projects `from_tensor` into a "query" tensor and
    `to_tensor` into "key" and "value" tensors. These are (effectively) a list
    of tensors of length `num_attention_heads`, where each tensor is of shape
    [batch_size, seq_length, size_per_head].

    Then, the query and key tensors are dot-producted and scaled. These are
    softmaxed to obtain attention probabilities. The value tensors are then
    interpolated by these probabilities, then concatenated back to a single
    tensor and returned.

    In practice, the multi-headed attention are done with transposes and
    reshapes rather than actual separate tensors.

    Args:
        from_tensor: float Tensor of shape [batch_size, from_seq_length,
            from_width].
        to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
        attention_mask: (optional) int32 Tensor of shape [batch_size,
            from_seq_length, to_seq_length]. The values should be 1 or 0. The
            attention scores will effectively be set to -infinity for any positions in
            the mask that are 0, and will be unchanged for positions that are 1.
        num_attention_heads: int. Number of attention heads.
        size_per_head: int. Size of each attention head.
        query_act: (optional) Activation function for the query transform.
        key_act: (optional) Activation function for the key transform.
        value_act: (optional) Activation function for the value transform.
        attention_probs_dropout_prob: (optional) float. Dropout probability of the
            attention probabilities.
        initializer_range: float. Range of the weight initializer.
        do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
            * from_seq_length, num_attention_heads * size_per_head]. If False, the
            output will be of shape [batch_size, from_seq_length, num_attention_heads
            * size_per_head].
        batch_size: (Optional) int. If the input is 2D, this might be the batch size
            of the 3D version of the `from_tensor` and `to_tensor`.
        from_seq_length: (Optional) If the input is 2D, this might be the seq length
            of the 3D version of the `from_tensor`.
        to_seq_length: (Optional) If the input is 2D, this might be the seq length
            of the 3D version of the `to_tensor`.

    Returns:
        float Tensor of shape [batch_size, from_seq_length,
            num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
            true, this will be of shape [batch_size * from_seq_length,
            num_attention_heads * size_per_head]).

    Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
    """

    def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
                                                     seq_length, width):
        output_tensor = tf.reshape(
                input_tensor, [batch_size, seq_length, num_attention_heads, width])

        output_tensor = tf.transpose(a=output_tensor, perm=[0, 2, 1, 3])
        return output_tensor

    from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
    to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])

    if len(from_shape) != len(to_shape):
        raise ValueError(
                "The rank of `from_tensor` must match the rank of `to_tensor`.")

    if len(from_shape) == 3:
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]
        to_seq_length = to_shape[1]
    elif len(from_shape) == 2:
        if (batch_size is None or from_seq_length is None or to_seq_length is None):
            raise ValueError(
                    "When passing in rank 2 tensors to attention_layer, the values "
                    "for `batch_size`, `from_seq_length`, and `to_seq_length` "
                    "must all be specified.")

    # Scalar dimensions referenced here:
    #     B = batch size (number of sequences)
    #     F = `from_tensor` sequence length
    #     T = `to_tensor` sequence length
    #     N = `num_attention_heads`
    #     H = `size_per_head`

    from_tensor_2d = reshape_to_matrix(from_tensor)
    to_tensor_2d = reshape_to_matrix(to_tensor)

    def get_ext_slice(idx):
        return ext_slice[:, idx, :]

    print("from_tensor_2d ", from_tensor_2d.shape)

    query_in = from_tensor_2d + get_ext_slice(EXT_QUERY_IN)
    query_in = from_tensor_2d

    # `query_layer` = [B*F, N*H]
    query_layer = tf.keras.layers.Dense(
            num_attention_heads * size_per_head,
            activation=query_act,
            name="query",
            kernel_initializer=create_initializer(initializer_range))(query_in)

    query_layer = query_layer + get_ext_slice(EXT_QUERY_OUT)

    key_in = to_tensor_2d
    key_in = to_tensor_2d + get_ext_slice(EXT_KEY_IN)
    # `key_layer` = [B*T, N*H]
    key_layer = tf.keras.layers.Dense(
            num_attention_heads * size_per_head,
            activation=key_act,
            name="key",
            kernel_initializer=create_initializer(initializer_range))(key_in)

    key_layer = key_layer + get_ext_slice(EXT_KEY_OUT)

    value_in = to_tensor_2d
    value_in = to_tensor_2d + get_ext_slice(EXT_VALUE_IN)
    # `value_layer` = [B*T, N*H]
    value_layer = tf.keras.layers.Dense(
            num_attention_heads * size_per_head,
            activation=value_act,
            name="value",
            kernel_initializer=create_initializer(initializer_range))(value_in)

    value_layer = value_layer + get_ext_slice(EXT_VALUE_OUT)

    # `query_layer` = [B, N, F, H]
    query_layer = transpose_for_scores(query_layer, batch_size,
                                                     num_attention_heads, from_seq_length,
                                                     size_per_head)

    # `key_layer` = [B, N, T, H]
    key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
                                                                     to_seq_length, size_per_head)

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
    # `attention_scores` = [B, N, F, T]
    attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
    attention_scores = tf.multiply(attention_scores,
                                                                 1.0 / math.sqrt(float(size_per_head)))

    if attention_mask is not None:
        # `attention_mask` = [B, 1, F, T]
        attention_mask = tf.expand_dims(attention_mask, axis=[1])

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0

        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        attention_scores += adder


    # Normalize the attention scores to probabilities.
    # `attention_probs` = [B, N, F, T]
    attention_probs = tf.nn.softmax(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = dropout(attention_probs, attention_probs_dropout_prob)

    # `value_layer` = [B, T, N, H]
    value_layer = tf.reshape(
            value_layer,
            [batch_size, to_seq_length, num_attention_heads, size_per_head])

    # `value_layer` = [B, N, T, H]
    value_layer = tf.transpose(a=value_layer, perm=[0, 2, 1, 3])

    # `context_layer` = [B, N, F, H]
    context_layer = tf.matmul(attention_probs, value_layer)

    # `context_layer` = [B, F, N, H]
    context_layer = tf.transpose(a=context_layer, perm=[0, 2, 1, 3])

    if do_return_2d_tensor:
        # `context_layer` = [B*F, N*V]
        context_layer = tf.reshape(
                context_layer,
                [batch_size * from_seq_length, num_attention_heads * size_per_head])
    else:
        # `context_layer` = [B, F, N*V]
        context_layer = tf.reshape(
                context_layer,
                [batch_size, from_seq_length, num_attention_heads * size_per_head])

    return context_layer