Exemplo n.º 1
0
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Forwards the lengths and scores."""
        check.NotNone(stride, 'MstSolverNetwork requires stride')

        lengths = network_units.lookup_named_tensor('lengths',
                                                    linked_embeddings)
        lengths_b = tf.to_int32(tf.squeeze(lengths.tensor, [1]))

        scores = network_units.lookup_named_tensor('scores', linked_embeddings)
        scores_bnxn = scores.tensor
        max_length = tf.shape(scores_bnxn)[1]
        scores_bxnxn = tf.reshape(scores_bnxn,
                                  [stride, max_length, max_length])

        _, argmax_sources_bxn = mst_ops.maximum_spanning_tree(
            forest=self._attrs['forest'],
            num_nodes=lengths_b,
            scores=scores_bxnxn)
        argmax_sources_bn = tf.reshape(argmax_sources_bxn, [-1])
        arcs_bnxn = tf.one_hot(argmax_sources_bn, max_length, dtype=tf.float32)

        return [lengths_b, scores_bxnxn, scores_bnxn, arcs_bnxn]
Exemplo n.º 2
0
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Requires |stride|; otherwise see base class."""
        del context_tensor_arrays, attention_tensor
        if stride is None:
            raise RuntimeError(
                "PairwiseBilinearLabelNetwork needs 'stride' and must "
                "be called in a bulk component.")

        sources = network_units.lookup_named_tensor('sources',
                                                    linked_embeddings)
        sources_tensor = tf.reshape(sources.tensor,
                                    [stride, -1, self._source_dim])

        targets = network_units.lookup_named_tensor('targets',
                                                    linked_embeddings)
        targets_tensor = tf.reshape(targets.tensor,
                                    [stride, -1, self._target_dim])

        # Dimensions: source_dim x num_labels x target_dim
        bilinear_params = self._component.get_variable('bilinear')

        # Ensures that num_steps is the same for both inputs
        num_steps = tf.shape(sources_tensor)[1]
        with tf.control_dependencies([
                tf.assert_equal(num_steps,
                                tf.shape(targets_tensor)[1],
                                name='num_steps_mismatch')
        ]):
            # Dimensions:
            # (batch_size*num_steps x source_dim) *
            #   (source_dim x num_labels*target_dim)
            #     = (batch_size*num_steps x num_labels*target_dim)
            lin = tf.matmul(
                tf.reshape(sources_tensor, [-1, self._source_dim]),
                tf.reshape(bilinear_params, [self._source_dim, -1]))

            # (batch_size x num_steps*num_labels x target_dim) *
            #   (batch_size x num_steps x target_dim)^T
            #     = (batch_size x num_steps*num_labels x num_steps)
            bilin = tf.matmul(tf.reshape(
                lin, [-1, num_steps * self._num_labels, self._target_dim]),
                              targets_tensor,
                              transpose_b=True)

        # (batch_size x num_steps*num_labels x num_steps) ->
        #   (batch_size x num_steps x num_steps*num_labels)
        scores = tf.transpose(bilin, [0, 2, 1])

        return [
            tf.reshape(scores, [-1, num_steps * self._num_labels],
                       name='reshape_activations')
        ]
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Requires |stride|; otherwise see base class."""
        check.NotNone(
            stride,
            'BiaffineDigraphNetwork requires "stride" and must be called '
            'in the bulk feature extractor component.')

        # TODO(googleuser): Add dropout during training.
        del during_training

        # Retrieve (possibly averaged) weights.
        weights_arc = self._component.get_variable('weights_arc')
        weights_source = self._component.get_variable('weights_source')
        root = self._component.get_variable('root')

        # Extract the source and target token activations.  Use |stride| to collapse
        # batch and beam into a single dimension.
        sources = network_units.lookup_named_tensor('sources',
                                                    linked_embeddings)
        targets = network_units.lookup_named_tensor('targets',
                                                    linked_embeddings)
        source_tokens_bxnxs = tf.reshape(sources.tensor,
                                         [stride, -1, self._source_dim])
        target_tokens_bxnxt = tf.reshape(targets.tensor,
                                         [stride, -1, self._target_dim])
        num_tokens = tf.shape(source_tokens_bxnxs)[1]

        # Compute the arc, source, and root potentials.
        arcs_bxnxn = digraph_ops.ArcPotentialsFromTokens(
            source_tokens_bxnxs, target_tokens_bxnxt, weights_arc)
        sources_bxnxn = digraph_ops.ArcSourcePotentialsFromTokens(
            source_tokens_bxnxs, weights_source)
        roots_bxn = digraph_ops.RootPotentialsFromTokens(
            root, target_tokens_bxnxt, weights_arc, weights_source)

        # Combine them into a single matrix with the roots on the diagonal.
        adjacency_bxnxn = digraph_ops.CombineArcAndRootPotentials(
            arcs_bxnxn + sources_bxnxn, roots_bxn)

        # The adjacency matrix currently has sources on rows and targets on columns,
        # but we want targets on rows so that maximizing within a row corresponds to
        # selecting sources for a given target.
        adjacency_bxnxn = tf.matrix_transpose(adjacency_bxnxn)

        return [tf.reshape(adjacency_bxnxn, [-1, num_tokens])]
Exemplo n.º 4
0
  def create(self,
             fixed_embeddings,
             linked_embeddings,
             context_tensor_arrays,
             attention_tensor,
             during_training,
             stride=None):
    """Requires |stride|; otherwise see base class."""
    check.NotNone(stride,
                  'BiaffineDigraphNetwork requires "stride" and must be called '
                  'in the bulk feature extractor component.')

    # TODO(googleuser): Add dropout during training.
    del during_training

    # Retrieve (possibly averaged) weights.
    weights_arc = self._component.get_variable('weights_arc')
    weights_source = self._component.get_variable('weights_source')
    root = self._component.get_variable('root')

    # Extract the source and target token activations.  Use |stride| to collapse
    # batch and beam into a single dimension.
    sources = network_units.lookup_named_tensor('sources', linked_embeddings)
    targets = network_units.lookup_named_tensor('targets', linked_embeddings)
    source_tokens_bxnxs = tf.reshape(sources.tensor,
                                     [stride, -1, self._source_dim])
    target_tokens_bxnxt = tf.reshape(targets.tensor,
                                     [stride, -1, self._target_dim])
    num_tokens = tf.shape(source_tokens_bxnxs)[1]

    # Compute the arc, source, and root potentials.
    arcs_bxnxn = digraph_ops.ArcPotentialsFromTokens(
        source_tokens_bxnxs, target_tokens_bxnxt, weights_arc)
    sources_bxnxn = digraph_ops.ArcSourcePotentialsFromTokens(
        source_tokens_bxnxs, weights_source)
    roots_bxn = digraph_ops.RootPotentialsFromTokens(
        root, target_tokens_bxnxt, weights_arc, weights_source)

    # Combine them into a single matrix with the roots on the diagonal.
    adjacency_bxnxn = digraph_ops.CombineArcAndRootPotentials(
        arcs_bxnxn + sources_bxnxn, roots_bxn)

    # The adjacency matrix currently has sources on rows and targets on columns,
    # but we want targets on rows so that maximizing within a row corresponds to
    # selecting sources for a given target.
    adjacency_bxnxn = tf.matrix_transpose(adjacency_bxnxn)

    return [tf.reshape(adjacency_bxnxn, [-1, num_tokens])]
Exemplo n.º 5
0
  def create(self,
             fixed_embeddings,
             linked_embeddings,
             context_tensor_arrays,
             attention_tensor,
             during_training,
             stride=None):
    """Requires |stride|; otherwise see base class."""
    del context_tensor_arrays, attention_tensor
    if stride is None:
      raise RuntimeError("PairwiseBilinearLabelNetwork needs 'stride' and must "
                         "be called in a bulk component.")

    sources = network_units.lookup_named_tensor('sources', linked_embeddings)
    sources_tensor = tf.reshape(sources.tensor, [stride, -1, self._source_dim])

    targets = network_units.lookup_named_tensor('targets', linked_embeddings)
    targets_tensor = tf.reshape(targets.tensor, [stride, -1, self._target_dim])

    # Dimensions: source_dim x num_labels x target_dim
    bilinear_params = self._component.get_variable('bilinear')

    # Ensures that num_steps is the same for both inputs
    num_steps = tf.shape(sources_tensor)[1]
    with tf.control_dependencies([tf.assert_equal(num_steps,
                                                  tf.shape(targets_tensor)[1],
                                                  name='num_steps_mismatch')]):
      # Dimensions:
      # (batch_size*num_steps x source_dim) *
      #   (source_dim x num_labels*target_dim)
      #     = (batch_size*num_steps x num_labels*target_dim)
      lin = tf.matmul(tf.reshape(sources_tensor, [-1, self._source_dim]),
                      tf.reshape(bilinear_params, [self._source_dim, -1]))

      # (batch_size x num_steps*num_labels x target_dim) *
      #   (batch_size x num_steps x target_dim)^T
      #     = (batch_size x num_steps*num_labels x num_steps)
      bilin = tf.matmul(
          tf.reshape(lin, [-1, num_steps*self._num_labels, self._target_dim]),
          targets_tensor, transpose_b=True)

    # (batch_size x num_steps*num_labels x num_steps) ->
    #   (batch_size x num_steps x num_steps*num_labels)
    scores = tf.transpose(bilin, [0, 2, 1])

    return [tf.reshape(scores, [-1, num_steps*self._num_labels],
                       name='reshape_activations')]
Exemplo n.º 6
0
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Requires |stride|; otherwise see base class."""
        check.NotNone(
            stride, 'BulkBiLSTMNetwork requires "stride" and must be called '
            'in the bulk feature extractor component.')

        # Flatten the lengths into a vector.
        lengths = dragnn.lookup_named_tensor('lengths', linked_embeddings)
        lengths_s = tf.squeeze(lengths.tensor, [1])

        # Collect all other inputs into a batched tensor.
        linked_embeddings = [
            named_tensor for named_tensor in linked_embeddings
            if named_tensor.name != 'lengths'
        ]
        inputs_sxnxd = dragnn.get_input_tensor_with_stride(
            fixed_embeddings, linked_embeddings, stride)

        # Since get_input_tensor_with_stride() concatenates the input embeddings, it
        # obscures the static activation dimension, which the RNN library requires.
        # Restore it using set_shape().  Note that set_shape() merges into the known
        # shape, so only specify the activation dimension.
        inputs_sxnxd.set_shape(
            [tf.Dimension(None),
             tf.Dimension(None), self._input_dim])

        initial_states_forward, initial_states_backward = (
            self._create_initial_states(stride))

        if during_training:
            cells_forward = self._train_cells_forward
            cells_backward = self._train_cells_backward
        else:
            cells_forward = self._inference_cells_forward
            cells_backward = self._inference_cells_backward

        def _bilstm_closure(scope):
            """Applies the bi-LSTM to the current inputs."""
            outputs_sxnxd, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
                cells_forward,
                cells_backward,
                inputs_sxnxd,
                initial_states_fw=initial_states_forward,
                initial_states_bw=initial_states_backward,
                sequence_length=lengths_s,
                parallel_iterations=self._attrs['parallel_iterations'],
                scope=scope)
            return outputs_sxnxd

        # Layer outputs are not batched; flatten out the batch dimension.
        outputs_sxnxd = self._apply_with_captured_variables(_bilstm_closure)
        outputs_snxd = tf.reshape(outputs_sxnxd, [-1, self._output_dim])
        return self._append_base_layers([outputs_snxd])
Exemplo n.º 7
0
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Requires |stride|; otherwise see base class."""
        check.NotNone(
            stride,
            'BiaffineLabelNetwork requires "stride" and must be called '
            'in the bulk feature extractor component.')

        # TODO(googleuser): Add dropout during training.
        del during_training

        # Retrieve (possibly averaged) weights.
        weights_pair = self._component.get_variable('weights_pair')
        weights_source = self._component.get_variable('weights_source')
        weights_target = self._component.get_variable('weights_target')
        biases = self._component.get_variable('biases')

        # Extract and shape the source and target token activations.  Use |stride|
        # to collapse batch and beam into a single dimension.
        sources = network_units.lookup_named_tensor('sources',
                                                    linked_embeddings)
        targets = network_units.lookup_named_tensor('targets',
                                                    linked_embeddings)
        sources_bxnxs = tf.reshape(sources.tensor,
                                   [stride, -1, self._source_dim])
        targets_bxnxt = tf.reshape(targets.tensor,
                                   [stride, -1, self._target_dim])

        # Compute the pair, source, and target potentials.
        pairs_bxnxl = digraph_ops.LabelPotentialsFromTokenPairs(
            sources_bxnxs, targets_bxnxt, weights_pair)
        sources_bxnxl = digraph_ops.LabelPotentialsFromTokens(
            sources_bxnxs, weights_source)
        targets_bxnxl = digraph_ops.LabelPotentialsFromTokens(
            targets_bxnxt, weights_target)

        # Combine them with the biases.
        labels_bxnxl = pairs_bxnxl + sources_bxnxl + targets_bxnxl + biases

        # Flatten out the batch dimension.
        return [tf.reshape(labels_bxnxl, [-1, self._num_labels])]
Exemplo n.º 8
0
  def create(self,
             fixed_embeddings,
             linked_embeddings,
             context_tensor_arrays,
             attention_tensor,
             during_training,
             stride=None):
    """Requires |stride|; otherwise see base class."""
    check.NotNone(stride,
                  'BulkBiLSTMNetwork requires "stride" and must be called '
                  'in the bulk feature extractor component.')

    # Flatten the lengths into a vector.
    lengths = dragnn.lookup_named_tensor('lengths', linked_embeddings)
    lengths_s = tf.squeeze(lengths.tensor, [1])

    # Collect all other inputs into a batched tensor.
    linked_embeddings = [
        named_tensor for named_tensor in linked_embeddings
        if named_tensor.name != 'lengths'
    ]
    inputs_sxnxd = dragnn.get_input_tensor_with_stride(
        fixed_embeddings, linked_embeddings, stride)

    # Since get_input_tensor_with_stride() concatenates the input embeddings, it
    # obscures the static activation dimension, which the RNN library requires.
    # Restore it using set_shape().  Note that set_shape() merges into the known
    # shape, so only specify the activation dimension.
    inputs_sxnxd.set_shape(
        [tf.Dimension(None), tf.Dimension(None), self._input_dim])

    initial_states_forward, initial_states_backward = (
        self._create_initial_states(stride))

    if during_training:
      cells_forward = self._train_cells_forward
      cells_backward = self._train_cells_backward
    else:
      cells_forward = self._inference_cells_forward
      cells_backward = self._inference_cells_backward

    def _bilstm_closure(scope):
      """Applies the bi-LSTM to the current inputs."""
      outputs_sxnxd, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
          cells_forward,
          cells_backward,
          inputs_sxnxd,
          initial_states_fw=initial_states_forward,
          initial_states_bw=initial_states_backward,
          sequence_length=lengths_s,
          parallel_iterations=self._attrs['parallel_iterations'],
          scope=scope)
      return outputs_sxnxd

    # Layer outputs are not batched; flatten out the batch dimension.
    outputs_sxnxd = self._apply_with_captured_variables(_bilstm_closure)
    outputs_snxd = tf.reshape(outputs_sxnxd, [-1, self._output_dim])
    return self._append_base_layers([outputs_snxd])
Exemplo n.º 9
0
  def create(self,
             fixed_embeddings,
             linked_embeddings,
             context_tensor_arrays,
             attention_tensor,
             during_training,
             stride=None):
    """Requires |stride|; otherwise see base class."""
    check.NotNone(stride,
                  'BiaffineLabelNetwork requires "stride" and must be called '
                  'in the bulk feature extractor component.')

    # TODO(googleuser): Add dropout during training.
    del during_training

    # Retrieve (possibly averaged) weights.
    weights_pair = self._component.get_variable('weights_pair')
    weights_source = self._component.get_variable('weights_source')
    weights_target = self._component.get_variable('weights_target')
    biases = self._component.get_variable('biases')

    # Extract and shape the source and target token activations.  Use |stride|
    # to collapse batch and beam into a single dimension.
    sources = network_units.lookup_named_tensor('sources', linked_embeddings)
    targets = network_units.lookup_named_tensor('targets', linked_embeddings)
    sources_bxnxs = tf.reshape(sources.tensor, [stride, -1, self._source_dim])
    targets_bxnxt = tf.reshape(targets.tensor, [stride, -1, self._target_dim])

    # Compute the pair, source, and target potentials.
    pairs_bxnxl = digraph_ops.LabelPotentialsFromTokenPairs(sources_bxnxs,
                                                            targets_bxnxt,
                                                            weights_pair)
    sources_bxnxl = digraph_ops.LabelPotentialsFromTokens(sources_bxnxs,
                                                          weights_source)
    targets_bxnxl = digraph_ops.LabelPotentialsFromTokens(targets_bxnxt,
                                                          weights_target)

    # Combine them with the biases.
    labels_bxnxl = pairs_bxnxl + sources_bxnxl + targets_bxnxl + biases

    # Flatten out the batch dimension.
    return [tf.reshape(labels_bxnxl, [-1, self._num_labels])]
Exemplo n.º 10
0
  def create(self,
             fixed_embeddings,
             linked_embeddings,
             context_tensor_arrays,
             attention_tensor,
             during_training,
             stride=None):
    """Forwards the lengths and scores."""
    check.NotNone(stride, 'MstSolverNetwork requires stride')

    lengths = network_units.lookup_named_tensor('lengths', linked_embeddings)
    lengths_b = tf.to_int32(tf.squeeze(lengths.tensor, [1]))

    scores = network_units.lookup_named_tensor('scores', linked_embeddings)
    scores_bnxn = scores.tensor
    max_length = tf.shape(scores_bnxn)[1]
    scores_bxnxn = tf.reshape(scores_bnxn, [stride, max_length, max_length])

    _, argmax_sources_bxn = mst_ops.maximum_spanning_tree(
        forest=self._attrs['forest'], num_nodes=lengths_b, scores=scores_bxnxn)
    argmax_sources_bn = tf.reshape(argmax_sources_bxn, [-1])
    arcs_bnxn = tf.one_hot(argmax_sources_bn, max_length, dtype=tf.float32)

    return [lengths_b, scores_bxnxn, scores_bnxn, arcs_bnxn]
Exemplo n.º 11
0
  def create(self,
             fixed_embeddings,
             linked_embeddings,
             context_tensor_arrays,
             attention_tensor,
             during_training,
             stride=None):
    """Requires |stride|; otherwise see base class."""
    del context_tensor_arrays, attention_tensor
    if stride is None:
      raise RuntimeError("TransformerEncoderNetwork needs 'stride' and must be "
                         "called in the bulk feature extractor component.")

    lengths = network_units.lookup_named_tensor('lengths', linked_embeddings)
    lengths_s = tf.to_int32(tf.squeeze(lengths.tensor, [1]))
    num_steps = tf.reduce_max(lengths_s)

    in_tensor = network_units.lookup_named_tensor('features', linked_embeddings)
    input_tensor = tf.reshape(in_tensor.tensor, [stride, num_steps, -1])

    if self._timing_signal:
      input_tensor = add_timing_signal_1d(input_tensor)

    # Adds a dimension for conv2d
    input_tensor = tf.expand_dims(input_tensor, 1)

    # For masking padding in attention
    mask = compute_padding_mask(lengths_s)

    conv = tf.nn.conv2d(input_tensor,
                        self._component.get_variable('init_proj'),
                        [1, 1, 1, 1], padding='SAME')
    conv = tf.nn.bias_add(conv, self._component.get_variable('init_bias'))

    for i in range(self._num_layers):
      with tf.variable_scope('transform_%d' % i, reuse=True):
        attn_weights = self._component.get_variable('attn_weights')
        attn_combined = tf.nn.conv2d(conv,
                                     attn_weights,
                                     [1, 1, 1, 1],
                                     padding='SAME')
        attn_combined = tf.squeeze(attn_combined, 1)

        # Splits combined projection into queries, keys, and values
        queries, keys, values = tf.split(attn_combined,
                                         [self._combined_filters]*3,
                                         axis=2)

        # Splits each of queries, keys, values into attention heads
        queries = split_heads(queries, self._num_heads)
        keys = split_heads(keys, self._num_heads)
        values = split_heads(values, self._num_heads)
        if self._scale_attn:
          queries *= self._filter_size**-0.5

        # Performs dot product attention and concatenates the resulting heads
        attended = dot_product_attention(queries, keys, values,
                                         self._attention_dropout, mask)
        attended = combine_heads(attended)

        # Projects combined heads
        attended = tf.expand_dims(attended, 1)
        proj = tf.nn.conv2d(attended,
                            self._component.get_variable('proj_weights'),
                            [1, 1, 1, 1],
                            padding='SAME')

        # Residual connection between input and attended input
        attn_layer_norm_params = None
        if self._layer_norm_res:
          attn_layer_norm_params = self._layer_norms['attn_layer_norm_%d' % i]
        proj_res = residual(conv, proj, self._residual_dropout,
                            attn_layer_norm_params)

        # Feed forward
        with tf.variable_scope('mlp'):
          ff = mlp(self._component, proj_res, self._mlp_dropout,
                   self._mlp_depth)

        # Residual connection between attended input and feed forward layers
        ff_layer_norm_params = None
        if self._layer_norm_res:
          ff_layer_norm_params = self._layer_norms['ff_layer_norm_%d' % i]
        conv = residual(proj_res, ff, self._residual_dropout,
                        ff_layer_norm_params)

    return [tf.reshape(conv, [-1, self._combined_filters],
                       name='reshape_activations')]
Exemplo n.º 12
0
  def create(self,
             fixed_embeddings,
             linked_embeddings,
             context_tensor_arrays,
             attention_tensor,
             during_training,
             stride=None):
    """Requires |stride|; otherwise see base class."""
    del context_tensor_arrays, attention_tensor
    if stride is None:
      raise RuntimeError("TransformerEncoderNetwork needs 'stride' and must be "
                         "called in the bulk feature extractor component.")

    lengths = network_units.lookup_named_tensor('lengths', linked_embeddings)
    lengths_s = tf.to_int32(tf.squeeze(lengths.tensor, [1]))
    num_steps = tf.reduce_max(lengths_s)

    in_tensor = network_units.lookup_named_tensor('features', linked_embeddings)
    input_tensor = tf.reshape(in_tensor.tensor, [stride, num_steps, -1])

    if self._timing_signal:
      input_tensor = add_timing_signal_1d(input_tensor)

    # Adds a dimension for conv2d
    input_tensor = tf.expand_dims(input_tensor, 1)

    # For masking padding in attention
    mask = compute_padding_mask(lengths_s)

    conv = tf.nn.conv2d(input_tensor,
                        self._component.get_variable('init_proj'),
                        [1, 1, 1, 1], padding='SAME')
    conv = tf.nn.bias_add(conv, self._component.get_variable('init_bias'))

    for i in range(self._num_layers):
      with tf.variable_scope('transform_%d' % i, reuse=True):
        attn_weights = self._component.get_variable('attn_weights')
        attn_combined = tf.nn.conv2d(conv,
                                     attn_weights,
                                     [1, 1, 1, 1],
                                     padding='SAME')
        attn_combined = tf.squeeze(attn_combined, 1)

        # Splits combined projection into queries, keys, and values
        queries, keys, values = tf.split(attn_combined,
                                         [self._combined_filters]*3,
                                         axis=2)

        # Splits each of queries, keys, values into attention heads
        queries = split_heads(queries, self._num_heads)
        keys = split_heads(keys, self._num_heads)
        values = split_heads(values, self._num_heads)
        if self._scale_attn:
          queries *= self._filter_size**-0.5

        # Performs dot product attention and concatenates the resulting heads
        attended = dot_product_attention(queries, keys, values,
                                         self._attention_dropout, mask)
        attended = combine_heads(attended)

        # Projects combined heads
        attended = tf.expand_dims(attended, 1)
        proj = tf.nn.conv2d(attended,
                            self._component.get_variable('proj_weights'),
                            [1, 1, 1, 1],
                            padding='SAME')

        # Residual connection between input and attended input
        attn_layer_norm_params = None
        if self._layer_norm_res:
          attn_layer_norm_params = self._layer_norms['attn_layer_norm_%d' % i]
        proj_res = residual(conv, proj, self._residual_dropout,
                            attn_layer_norm_params)

        # Feed forward
        with tf.variable_scope('mlp'):
          ff = mlp(self._component, proj_res, self._mlp_dropout,
                   self._mlp_depth)

        # Residual connection between attended input and feed forward layers
        ff_layer_norm_params = None
        if self._layer_norm_res:
          ff_layer_norm_params = self._layer_norms['ff_layer_norm_%d' % i]
        conv = residual(proj_res, ff, self._residual_dropout,
                        ff_layer_norm_params)

    return [tf.reshape(conv, [-1, self._combined_filters],
                       name='reshape_activations')]