Exemplo n.º 1
0
  def _create_gates(self, inputs, memory):
    """Create input and forget gates for this step using `inputs` and `memory`.

    Args:
      inputs: Tensor input.
      memory: The current state of memory.

    Returns:
      input_gate: A LSTM-like insert gate.
      forget_gate: A LSTM-like forget gate.
    """
    # We'll create the input and forget gates at once. Hence, calculate double
    # the gate size.
    num_gates = 2 * self._calculate_gate_size()

    memory = tf.tanh(memory)
    inputs = basic.BatchFlatten()(inputs)
    gate_inputs = basic.BatchApply(basic.Linear(num_gates), n_dims=1)(inputs)
    gate_inputs = tf.expand_dims(gate_inputs, axis=1)
    gate_memory = basic.BatchApply(basic.Linear(num_gates))(memory)
    gates = tf.split(gate_memory + gate_inputs, num_or_size_splits=2, axis=2)
    input_gate, forget_gate = gates

    input_gate = tf.sigmoid(input_gate + self._input_bias)
    forget_gate = tf.sigmoid(forget_gate + self._forget_bias)

    return input_gate, forget_gate
Exemplo n.º 2
0
    def _build(self, inputs, memory, treat_input_as_matrix=False):
        """Adds relational memory to the TensorFlow graph.
    Args:
      inputs: Tensor input.
      memory: Memory output from the previous time step.
      treat_input_as_matrix: Optional, whether to treat `input` as a sequence
        of matrices. Defaulta to False, in which case the input is flattened
        into a vector.
    Returns:
      output: This time step's output.
      next_memory: The next version of memory to use.
    """
        if treat_input_as_matrix:
            inputs = basic.BatchFlatten(preserve_dims=2)(inputs)
            inputs_reshape = basic.BatchApply(basic.Linear(self._mem_size),
                                              n_dims=2)(inputs)
        else:
            inputs = basic.BatchFlatten()(inputs)
            inputs = basic.Linear(self._mem_size)(inputs)
            inputs_reshape = tf.expand_dims(inputs, 1)

        memory_plus_input = tf.concat([memory, inputs_reshape], axis=1)
        next_memory = self._attend_over_memory(memory_plus_input)

        n = inputs_reshape.get_shape().as_list()[1]
        next_memory = next_memory[:, :-n, :]

        if self._gate_style == 'unit' or self._gate_style == 'memory':
            self._input_gate, self._forget_gate = self._create_gates(
                inputs_reshape, memory)
            next_memory = self._input_gate * tf.tanh(next_memory)
            next_memory += self._forget_gate * memory

        output = basic.BatchFlatten()(next_memory)
        return output, next_memory
Exemplo n.º 3
0
  def _multihead_attention(self, memory):
    """Perform multi-head attention from 'Attention is All You Need'.

    Implementation of the attention mechanism from
    https://arxiv.org/abs/1706.03762.

    Args:
      memory: Memory tensor to perform attention on.

    Returns:
      new_memory: New memory tensor.
    """
    key_size = self._key_size
    value_size = self._head_size

    qkv_size = 2 * key_size + value_size
    total_size = qkv_size * self._num_heads  # Denote as F.
    qkv = basic.BatchApply(basic.Linear(total_size))(memory)
    qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv)

    mem_slots = memory.get_shape().as_list()[1]  # Denoted as N.

    # [B, N, F] -> [B, N, H, F/H]
    qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads,
                                      qkv_size])(qkv)

    # [B, N, H, F/H] -> [B, H, N, F/H]
    qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3])
    q, k, v = tf.split(qkv_transpose, [key_size, key_size, value_size], -1)

    q *= qkv_size ** -0.5
    dot_product = tf.matmul(q, k, transpose_b=True)  # [B, H, N, N]
    weights = tf.nn.softmax(dot_product)

    output = tf.matmul(weights, v)  # [B, H, N, V]

    # [B, H, N, V] -> [B, N, H, V]
    output_transpose = tf.transpose(output, [0, 2, 1, 3])

    # [B, N, H, V] -> [B, N, H * V]
    new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose)
    return new_memory
def default_mlp(hidden_sizes, activate_final=False, init_std=2., **kwargs):
    """Standard batch-applied MLP for transformer modules."""
    init = {
        'w': tf.variance_scaling_initializer(init_std, distribution='normal')
    }
    mlp = snt_mlp.MLP(hidden_sizes,
                      activate_final=activate_final,
                      use_dropout=True,
                      initializers=init,
                      **kwargs)
    return basic.BatchApply(mlp)
Exemplo n.º 5
0
    def _attend_over_memory(self, memory):
        """Perform multiheaded attention over `memory`.
    Args:
      memory: Current relational memory.
    Returns:
      The attended-over memory.
    """
        attention_mlp = basic.BatchApply(
            mlp.MLP([self._mem_size] * self._attention_mlp_layers))
        for _ in range(self._num_blocks):
            attended_memory = self._multihead_attention(memory)

            # Add a skip connection to the multiheaded attention's input.
            memory = basic.BatchApply(layer_norm.LayerNorm())(memory +
                                                              attended_memory)

            # Add a skip connection to the attention_mlp's input.
            memory = basic.BatchApply(
                layer_norm.LayerNorm())(attention_mlp(memory) + memory)

        return memory
    def _build(self,
               inputs,
               state=None,
               condition=None,
               is_training=True,
               final_layer_key_value_inputs=None):
        """Calculates multi-layer self attention and mlp transformation.

    Args:
      inputs: Tensor of shape [batch_size, num_steps, dim_size].
      state: optional list of length num_layers of tensors of shape
        [batch_size, memory_size, dim_size].
      condition: optional tensor to condition on. The shape is shape
        [batch_size, dim_size].
      is_training: If true, dropout is applied.
      final_layer_key_value_inputs: optional Tensor to be used as the key and
        value for the final multi-head attention layer of shape
        [batch_size, num_steps, dim_size]. Useful when the tower is a Seq2Seq
        decoder and it can attend to encoder outputs.

    Returns:
      output: tensor of shape [batch_size, num_steps, output_dim_size].
      state: list of length `num_layers` containing AttentionState tuples.
    """
        # inputs: [B, N, F]
        if final_layer_key_value_inputs is not None and state is not None and len(
                state) == (self._num_layers - 1):
            raise ValueError(
                'When the final_layer_key_value_input is set, exclude'
                'the state of the last layer.')

        if condition is not None:
            condition_tile = tf.tile(tf.expand_dims(condition, 1),
                                     [1, tf.shape(inputs)[1], 1])
            inputs = tf.concat([inputs, condition_tile], -1)

        # Map inputs to be of `embedding_size` dimension.
        if inputs.get_shape().as_list()[-1] != self._embedding_size:
            inputs = default_mlp([self._embedding_size], activate_final=True)(
                inputs,
                is_training=is_training,
                dropout_keep_prob=1 - self._dropout_rate)

        if state is None:
            memory_sizes = [0]
        elif isinstance(state[0], CompressedMemoryState):
            cm_mem_size = max(_memory_size(s.compressed_memory) for s in state)
            em_mem_size = max(_memory_size(s.episodic_memory) for s in state)
            memory_sizes = [cm_mem_size, em_mem_size]
        else:
            memory_sizes = [max([_memory_size(s) for s in state])]
        chunk_size = inputs.get_shape().as_list()[1]
        self._positional_encodings = []
        # Creates positional encodings for different memory types.
        for i, memory_size in enumerate(memory_sizes):
            seq_len = chunk_size + memory_size
            key_positions = get_position_encodings(
                sequence_length=seq_len,
                hidden_size=inputs.get_shape().as_list()[2],
                clamp_value=self._clamp_time_range,
            )
            if is_training:
                key_positions = tf.nn.dropout(key_positions,
                                              rate=self._dropout_rate)
            key_positions = tf.cast(key_positions, dtype=inputs.dtype)
            query_positions = key_positions[:, -chunk_size:, :]
            self._positional_encodings.append((key_positions, query_positions))

        if self._causal:
            self._mask = create_mask(inputs, state,
                                     self._same_attention_length)

        layer_i_inputs = inputs
        attention_states = []
        key_value_inputs = None

        for i in range(self._num_layers):
            with tf.variable_scope('layer_%d' % i, reuse=tf.AUTO_REUSE):
                multihead_attention, object_mlp = self.get_sublayers(
                    is_training)
                # Multihead attention with residuals.
                state_i = None if state is None else state[i]
                if i == (self._num_layers -
                         1) and final_layer_key_value_inputs is not None:
                    # When the final_layer_key_value_inputs is set, the finaly layer
                    # of attention will use it as the key & value, thus no need for state.
                    key_value_inputs = final_layer_key_value_inputs
                    state_i = None

                attention_outputs, attention_state = multihead_attention(
                    layer_i_inputs,
                    state=state_i,
                    is_training=is_training,
                    dropout_keep_prob=1. - self._dropout_rate,
                    key_value_inputs=key_value_inputs)
                attention_states.append(attention_state)
                # Feed-forward with residuals.
                output = object_mlp(attention_outputs,
                                    is_training=is_training,
                                    dropout_keep_prob=1 - self._dropout_rate)
                layer_i_inputs = output

        if self._output_size is not None:
            output = basic.BatchApply(
                basic.Linear(self._output_size, use_bias=False))(output)

        return output, attention_states
    def _build(self,
               inputs,
               query_inputs=None,
               state=None,
               is_training=False,
               dropout_keep_prob=0.5,
               key_value_inputs=None):
        """Calculates multi-layer self attention.

    Args:
      inputs: Tensor of shape [batch_size, num_steps, output_dim_size]. Inputs
        used as the query, key, and value to the attention layer.
      query_inputs: optional Tensor of shape [batch_size, num_steps,
        output_dim_size]. Query inputs to the attention layer. Set when
        query_inputs is different from the inputs argument.
      state: optional CompressedMemoryState or a Tensor of shape [batch_size,
        memory_size, dim_size] concatenated to the inputs. Set when attend to
        the memory from previous steps.
      is_training: if currently training.
      dropout_keep_prob: dropout rate applied to attention weights.
      key_value_inputs: optional Tensor of shape [batch_size, num_steps,
        output_dim_size]. It is used as the key and value of the multihead
        attention. Set when the key and value are different from the inputs
        argument.

    Returns:
      output: the result Tensor of shape
        [batch_size, num_steps, output_dim_size].
      attention_state: named tuple of AttentionState.
    """
        if key_value_inputs is not None and state is not None:
            raise ValueError(
                'Only one of the key_value_input and state is needed.')
        embedding_size = self._value_size * self._num_heads

        q_inputs = inputs if query_inputs is None else query_inputs
        # Denoted by L. If query_inputs is None, L = N.
        _, query_size = q_inputs.get_shape().as_list()[:2]

        if key_value_inputs is not None:
            k_inputs = key_value_inputs
            v_inputs = k_inputs
        elif state is not None:
            if isinstance(state, CompressedMemoryState):
                state_memory_list = [
                    state.compressed_memory, state.episodic_memory
                ]
            else:
                state_memory_list = [state]

            k_inputs = tf.concat(state_memory_list + [inputs], 1)
            v_inputs = k_inputs
        else:
            k_inputs = inputs
            v_inputs = inputs

        # Batch size denoted by B
        batch_size = tf.shape(inputs)[0]
        # Chunk_size denoted by N
        chunk_size = inputs.get_shape().as_list()[1]
        # Denoted by N + M
        att_size = k_inputs.get_shape().as_list()[1]

        if self._positional_encodings and not self._use_relative_positions:
            if len(self._positional_encodings) != 1:
                raise ValueError(
                    'Absolute positional encodings only supported for 1 memory. '
                    'Found %i.' % len(self._positional_encodings))
            key_positions, query_positions = self._positional_encodings[0]
            k_inputs += key_positions
            q_inputs += query_positions

        # [B, H, L, K]
        q = self.multihead_linear(q_inputs, 'query')
        # [B, H, N + M, K]
        k = self.multihead_linear(k_inputs, 'key')
        # [B, H, N + M, V]
        v = self.multihead_linear(v_inputs, 'value')

        # Scaling the dot-product
        if self._scaling:
            q *= self._key_size**-0.5

        # [B, H, L, N + M]
        if self._use_relative_positions:
            r_w_bias = tf.get_variable('r_w_bias',
                                       [1, self._num_heads, 1, self._key_size],
                                       dtype=inputs.dtype)
            content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True)
            all_relative_logits = []
            # Loop over multiple positional encodings, for the case of multiple
            # memory types.
            for i, positional_encodings in enumerate(
                    self._positional_encodings):
                key_positions, query_positions = positional_encodings
                if key_positions.get_shape().as_list()[-1] != att_size:
                    key_positions = key_positions[:,
                                                  -att_size:]  # Crop to layer mem size
                is_final = i == len(self._positional_encodings) - 1
                suffix = '' if is_final else '_%d' % i
                relative_keys = self.multihead_linear(key_positions,
                                                      name='relative_keys' +
                                                      suffix)
                # [B, H, N, D]
                r_r_bias = tf.get_variable(
                    'r_r_bias' + suffix,
                    [1, self._num_heads, 1, self._key_size],
                    dtype=inputs.dtype)
                relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1])
                relative_logits = tf.matmul(q + r_r_bias,
                                            relative_keys,
                                            transpose_b=True)
                relative_logits = rel_shift(relative_logits)
                if not is_final:  # Include relative positions for input sequence.
                    relative_logits = relative_logits[:, :, :, :-chunk_size]
                all_relative_logits.append(relative_logits)
            all_relative_logits = tf.concat(all_relative_logits, 3)
            logits = content_logits + all_relative_logits
        else:
            # [B, H, N, N + M]
            logits = tf.matmul(q, k, transpose_b=True)
            content_logits = logits

        if self._mask is not None:
            if self._mask.get_shape().as_list()[-1] != att_size:
                mask = self._mask[:, :, :, -att_size:]
            else:
                mask = self._mask
            logits += mask

        weights = tf.nn.softmax(logits)
        if is_training:
            weights = tf.nn.dropout(weights, dropout_keep_prob)
        # [B, L, H, V], where V is value_size
        output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v)

        # [B, L, H, V] -> [B, L, HV]
        attended_inputs = basic.BatchReshape([query_size, embedding_size
                                              ])(output_transpose)
        # Apply final mlp to mix information between heads.
        output = basic.BatchApply(
            basic.Linear(embedding_size))(attended_inputs)

        attention_state = AttentionState(queries=q,
                                         keys=k,
                                         values=v,
                                         weights=weights,
                                         logits=content_logits,
                                         embeddings=inputs,
                                         read_words=output)
        return output, attention_state
def _layer_norm(inputs):
    if inputs.get_shape().ndims > 2:
        return basic.BatchApply(snt_ln.LayerNorm())(inputs)
    else:
        return snt_ln.LayerNorm()(inputs)
Exemplo n.º 9
0
    def _build(self, memory, query, memory_mask=None):
        """Perform a differentiable read.

    Args:
      memory: [batch_size, memory_size, memory_word_size]-shaped Tensor of
        dtype float32. This represents, for each example and memory slot, a
        single embedding to attend over.
      query: [batch_size, query_word_size]-shaped Tensor of dtype float32.
        Represents, for each example, a single embedding representing a query.
      memory_mask: None or [batch_size, memory_size]-shaped Tensor of dtype
        bool. An entry of False indicates that a memory slot should not enter
        the resulting weighted sum. If None, all memory is used.

    Returns:
      An AttentionOutput instance containing:
        read: [batch_size, memory_word_size]-shaped Tensor of dtype float32.
          This represents, for each example, a weighted sum of the contents of
          the memory.
        weights: [batch_size, memory_size]-shaped Tensor of dtype float32. This
          represents, for each example and memory slot, the attention weights
          used to compute the read.
        weight_logits: [batch_size, memory_size]-shaped Tensor of dtype float32.
          This represents, for each example and memory slot, the logits of the
          attention weights, that is, `weights` is calculated by taking the
          softmax of the weight logits.

    Raises:
      UnderspecifiedError: if memory_word_size or query_word_size can not be
        inferred.
      IncompatibleShapeError: if memory, query, memory_mask, or output of
        attention_logit_mod do not match expected shapes.
    """
        if len(memory.get_shape()) != 3:
            raise base.IncompatibleShapeError(
                "memory must have shape [batch_size, memory_size, memory_word_size]."
            )

        if len(query.get_shape()) != 2:
            raise base.IncompatibleShapeError(
                "query must have shape [batch_size, query_word_size].")

        if memory_mask is not None and len(memory_mask.get_shape()) != 2:
            raise base.IncompatibleShapeError(
                "memory_mask must have shape [batch_size, memory_size].")

        # Ensure final dimensions are defined, else the attention logit module will
        # be unable to infer input size when constructing variables.
        inferred_memory_word_size = memory.get_shape()[2].value
        inferred_query_word_size = query.get_shape()[1].value
        if inferred_memory_word_size is None or inferred_query_word_size is None:
            raise base.UnderspecifiedError(
                "memory_word_size and query_word_size must be known at graph "
                "construction time.")

        memory_shape = tf.shape(memory)
        batch_size = memory_shape[0]
        memory_size = memory_shape[1]

        query_shape = tf.shape(query)
        query_batch_size = query_shape[0]

        # Transform query to have same number of words as memory.
        #
        # expanded_query: [batch_size, memory_size, query_word_size].
        expanded_query = tf.tile(tf.expand_dims(query, dim=1),
                                 [1, memory_size, 1])

        # Compute attention weights for each memory slot.
        #
        # attention_weight_logits: [batch_size, memory_size]
        with tf.control_dependencies(
            [tf.assert_equal(batch_size, query_batch_size)]):
            concatenated_embeddings = tf.concat(
                values=[memory, expanded_query], axis=2)

        batch_apply_attention_logit = basic.BatchApply(
            self._attention_logit_mod,
            n_dims=2,
            name="batch_apply_attention_logit")
        attention_weight_logits = batch_apply_attention_logit(
            concatenated_embeddings)

        # Note: basic.BatchApply() will automatically reshape the [batch_size *
        # memory_size, 1]-shaped result of self._attention_logit_mod(...) into a
        # [batch_size, memory_size, 1]-shaped Tensor. If
        # self._attention_logit_mod(...) returns something with more dimensions,
        # then attention_weight_logits will have extra dimensions, too.
        if len(attention_weight_logits.get_shape()) != 3:
            raise base.IncompatibleShapeError(
                "attention_weight_logits must be a rank-3 Tensor. Are you sure that "
                "attention_logit_mod() returned [batch_size * memory_size, 1]-shaped"
                " Tensor?")

        # Remove final length-1 dimension.
        attention_weight_logits = tf.squeeze(attention_weight_logits, [2])

        # Mask out ignored memory slots by assigning them very small logits. Ensures
        # that every example has at least one valid memory slot, else we'd end up
        # averaging all memory slots equally.
        if memory_mask is not None:
            num_remaining_memory_slots = tf.reduce_sum(tf.cast(memory_mask,
                                                               dtype=tf.int32),
                                                       axis=[1])
            with tf.control_dependencies(
                [tf.assert_positive(num_remaining_memory_slots)]):
                finfo = np.finfo(np.float32)
                kept_indices = tf.cast(memory_mask, dtype=tf.float32)
                ignored_indices = tf.cast(tf.logical_not(memory_mask),
                                          dtype=tf.float32)
                lower_bound = finfo.max * kept_indices + finfo.min * ignored_indices
                attention_weight_logits = tf.minimum(attention_weight_logits,
                                                     lower_bound)

        # attended_memory: [batch_size, memory_word_size].
        attention_weight = tf.reshape(tf.nn.softmax(attention_weight_logits),
                                      shape=[batch_size, memory_size, 1])
        # The multiplication is elementwise and relies on broadcasting the weights
        # across memory_word_size. Then we sum across the memory slots.
        attended_memory = tf.reduce_sum(memory * attention_weight, axis=[1])

        # Infer shape of result as much as possible.
        inferred_batch_size, _, inferred_memory_word_size = (
            memory.get_shape().as_list())
        attended_memory.set_shape(
            [inferred_batch_size, inferred_memory_word_size])

        return AttentionOutput(read=attended_memory,
                               weights=tf.squeeze(attention_weight, [2]),
                               weight_logits=attention_weight_logits)
Exemplo n.º 10
0
    def _build(self, inputs, state=None, condition=None, is_training=True):
        """Calculates multi-layer self attention and mlp transformation.

    Args:
      inputs: Tensor of shape [batch_size, num_steps, dim_size].
      state: optional tensor of shape [batch_size, memory_size, dim_size].
      condition: optional tensor to condition on. The shape is shape
        [batch_size, dim_size].
      is_training: If true, dropout is applied.

    Returns:
      output: tensor of shape [batch_size, num_steps, output_dim_size].
      state: list of length `num_layers` containing AttentionState tuples.
    """

        # inputs: [B, N, F]
        if condition is not None:
            condition_tile = tf.tile(tf.expand_dims(condition, 1),
                                     [1, tf.shape(inputs)[1], 1])
            inputs = tf.concat([inputs, condition_tile], -1)

        if state is None:
            memory_sizes = [0]
        elif isinstance(state[0], CompressedMemoryState):
            cm_mem_size = max(_memory_size(s.compressed_memory) for s in state)
            em_mem_size = max(_memory_size(s.episodic_memory) for s in state)
            memory_sizes = [cm_mem_size, em_mem_size]
        else:
            memory_sizes = [max([_memory_size(s) for s in state])]
        chunk_size = inputs.get_shape().as_list()[1]
        self._positional_encodings = []
        # Creates positional encodings for different memory types.
        for i, memory_size in enumerate(memory_sizes):
            seq_len = chunk_size + memory_size
            key_positions = get_position_encodings(
                sequence_length=seq_len,
                hidden_size=inputs.get_shape().as_list()[2],
                clamp_value=self._clamp_time_range,
            )
            if is_training:
                key_positions = tf.nn.dropout(key_positions,
                                              rate=self._dropout_rate)
            key_positions = tf.cast(key_positions, dtype=inputs.dtype)
            query_positions = key_positions[:, -chunk_size:, :]
            self._positional_encodings.append((key_positions, query_positions))

        if self._causal:
            self._mask = create_mask(inputs, state,
                                     self._same_attention_length)

        layer_i_inputs = inputs
        attention_states = []
        for i in range(self._num_layers):
            with tf.variable_scope('layer_%d' % i, reuse=tf.AUTO_REUSE):
                multihead_attention, object_mlp = self.get_sublayers(
                    is_training)
                # Multihead attention with residuals.
                state_i = None if state is None else state[i]
                attention_outputs, attention_state = multihead_attention(
                    layer_i_inputs,
                    state=state_i,
                    is_training=is_training,
                    dropout_keep_prob=1. - self._dropout_rate)
                attention_states.append(attention_state)
                # Feed-forward with residuals.
                output = object_mlp(attention_outputs,
                                    is_training=is_training,
                                    dropout_keep_prob=1 - self._dropout_rate)
                layer_i_inputs = output

        if self._output_size is not None:
            output = basic.BatchApply(
                basic.Linear(self._output_size, use_bias=False))(output)

        return output, attention_states
Exemplo n.º 11
0
    def _build(self,
               inputs,
               query_inputs=None,
               state=None,
               is_training=False,
               dropout_keep_prob=0.5):
        embedding_size = self._value_size * self._num_heads

        q_inputs = inputs if query_inputs is None else query_inputs
        # Denoted by L. If query_inputs is None, L = N.
        _, query_size = q_inputs.get_shape().as_list()[:2]

        if state is not None:
            if isinstance(state, CompressedMemoryState):
                state_memory_list = [
                    state.compressed_memory, state.episodic_memory
                ]
            else:
                state_memory_list = [state]

            k_inputs = tf.concat(state_memory_list + [inputs], 1)
            v_inputs = k_inputs
        else:
            k_inputs = inputs
            v_inputs = inputs

        # Batch size denoted by B
        batch_size = tf.shape(inputs)[0]
        # Chunk_size denoted by N
        chunk_size = inputs.get_shape().as_list()[1]
        # Denoted by N + M
        att_size = k_inputs.get_shape().as_list()[1]

        if self._positional_encodings and not self._use_relative_positions:
            key_positions, query_positions = self._positional_encodings
            k_inputs += key_positions
            q_inputs += query_positions

        # [B, H, L, K]
        q = self.multihead_linear(q_inputs, 'query')
        # [B, H, N + M, K]
        k = self.multihead_linear(k_inputs, 'key')
        # [B, H, N + M, V]
        v = self.multihead_linear(v_inputs, 'value')

        # Scaling the dot-product
        if self._scaling:
            q *= self._key_size**-0.5

        # [B, H, L, N + M]
        if self._use_relative_positions:
            r_w_bias = tf.get_variable('r_w_bias',
                                       [1, self._num_heads, 1, self._key_size],
                                       dtype=inputs.dtype)
            content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True)
            all_relative_logits = []
            # Loop over multiple positional encodings, for the case of multiple
            # memory types.
            for i, positional_encodings in enumerate(
                    self._positional_encodings):
                key_positions, query_positions = positional_encodings
                if key_positions.get_shape().as_list()[-1] != att_size:
                    key_positions = key_positions[:,
                                                  -att_size:]  # Crop to layer mem size
                is_final = i == len(self._positional_encodings) - 1
                suffix = '' if is_final else '_%d' % i
                relative_keys = self.multihead_linear(key_positions,
                                                      name='relative_keys' +
                                                      suffix)
                # [B, H, N, D]
                r_r_bias = tf.get_variable(
                    'r_r_bias' + suffix,
                    [1, self._num_heads, 1, self._key_size],
                    dtype=inputs.dtype)
                relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1])
                relative_logits = tf.matmul(q + r_r_bias,
                                            relative_keys,
                                            transpose_b=True)
                relative_logits = rel_shift(relative_logits)
                if not is_final:  # Include relative positions for input sequence.
                    relative_logits = relative_logits[:, :, :, :-chunk_size]
                all_relative_logits.append(relative_logits)
            all_relative_logits = tf.concat(all_relative_logits, 3)
            logits = content_logits + all_relative_logits
        else:
            # [B, H, N, N + M]
            logits = tf.matmul(q, k, transpose_b=True)
            content_logits = logits

        if self._mask is not None:
            if self._mask.get_shape().as_list()[-1] != att_size:
                mask = self._mask[:, :, :, -att_size:]
            else:
                mask = self._mask
            logits += mask

        weights = tf.nn.softmax(logits)
        if is_training:
            weights = tf.nn.dropout(weights, dropout_keep_prob)
        # [B, L, H, V], where V is value_size
        output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v)

        # [B, L, H, V] -> [B, L, HV]
        attended_inputs = basic.BatchReshape([query_size, embedding_size
                                              ])(output_transpose)
        # Apply final mlp to mix information between heads.
        output = basic.BatchApply(
            basic.Linear(embedding_size))(attended_inputs)

        attention_state = AttentionState(queries=q,
                                         keys=k,
                                         values=v,
                                         weights=weights,
                                         logits=content_logits,
                                         embeddings=inputs,
                                         read_words=output)
        return output, attention_state