Ejemplo n.º 1
0
    def __init__(self,
                 num_units,
                 num_heads,
                 ffn_inner_dim,
                 num_sources=1,
                 dropout=0.1,
                 attention_dropout=0.1,
                 ffn_dropout=0.1,
                 ffn_activation=tf.nn.relu,
                 **kwargs):
        """Initializes the layer.

    Args:
      num_units: The number of hidden units.
      num_heads: The number of heads in the multi-head attention.
      ffn_inner_dim: The number of units of the inner linear transformation
        in the feed forward layer.
      num_sources: The number of source contexts.
      dropout: The probability to drop units from the outputs.
      attention_dropout: The probability to drop units from the attention.
      ffn_dropout: The probability to drop units from the activation output in
        the feed forward layer.
      ffn_activation: The activation function to apply between the two linear
        transformations of the feed forward layer.
      **kwargs: Additional layer arguments.
    """
        super(_SelfAttentionDecoderLayer, self).__init__(**kwargs)
        self.self_attention = transformer.MultiHeadAttention(
            num_heads,
            num_units,
            dropout=attention_dropout,
            name="masked_multi_head_attention")
        self.self_attention = transformer.TransformerLayerWrapper(
            self.self_attention, dropout, name="sub_layer_0")
        self.attention = []
        for i in range(num_sources):
            attention = transformer.MultiHeadAttention(
                num_heads,
                num_units,
                dropout=attention_dropout,
                return_attention=num_sources == 1,
                name="multi_head_attention")
            attention = transformer.TransformerLayerWrapper(
                attention, dropout, name="sub_layer_%d" % (i + 1))
            self.attention.append(attention)
        self.ffn = transformer.FeedForwardNetwork(ffn_inner_dim,
                                                  num_units,
                                                  dropout=ffn_dropout,
                                                  activation=ffn_activation,
                                                  name="feed_forward")
        self.ffn = transformer.TransformerLayerWrapper(self.ffn,
                                                       dropout,
                                                       name="sub_layer_%d" %
                                                       (num_sources + 1))
Ejemplo n.º 2
0
 def testMultiHeadAttention(self):
     attention = transformer.MultiHeadAttention(4, 20)
     queries = tf.random.uniform([4, 5, 10])
     memory = tf.random.uniform([4, 3, 10])
     mask = tf.sequence_mask([1, 3, 2, 2])
     context, _ = attention(queries, memory=memory, mask=mask)
     self.assertListEqual(context.shape.as_list(), [4, 5, 20])
Ejemplo n.º 3
0
 def testMultiHeadSelfAttentionRelativePositionsWithCache(self):
     attention = transformer.MultiHeadAttention(4,
                                                20,
                                                maximum_relative_position=6)
     x = tf.random.uniform([4, 1, 10])
     cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5]))
     _, cache = attention(x, cache=cache)
Ejemplo n.º 4
0
 def testMultiHeadSelfAttentionRelativePositions(self):
     attention = transformer.MultiHeadAttention(4,
                                                20,
                                                maximum_relative_position=6)
     x = tf.random.uniform([2, 9, 10])
     mask = tf.sequence_mask([9, 7])
     y = attention(x, mask=mask)
Ejemplo n.º 5
0
  def __init__(self,
               num_layers,
               num_units,
               num_heads,
               dropout=0.3,
               cell_class=None,
               **kwargs):
    """Initializes the decoder parameters.

    Args:
      num_layers: The number of layers.
      num_units: The number of units in each layer.
      num_heads: The number of attention heads.
      dropout: The probability to drop units from the decoder input and in each
        layer output.
      cell_class: The inner cell class or a callable taking :obj:`num_units` as
        argument and returning a cell. Defaults to a layer normalized LSTM cell.
      **kwargs: Additional layer arguments.
    """
    super(RNMTPlusDecoder, self).__init__(**kwargs)
    if cell_class is None:
      cell_class = tfa.rnn.LayerNormLSTMCell
    self.num_heads = num_heads
    self.num_units = num_units
    self.dropout = dropout
    self.cells = [cell_class(num_units) for _ in range(num_layers)]
    self.multi_head_attention = transformer.MultiHeadAttention(
        num_heads,
        num_units,
        dropout=dropout,
        return_attention=True)
Ejemplo n.º 6
0
    def __init__(self,
                 num_units,
                 num_heads,
                 ffn_inner_dim,
                 dropout=0.1,
                 attention_dropout=0.1,
                 relu_dropout=0.1,
                 **kwargs):
        """Initializes the layer.

    Args:
      num_units: The number of hidden units.
      num_heads: The number of heads in the multi-head attention.
      ffn_inner_dim: The number of units of the inner linear transformation
        in the feed forward layer.
      dropout: The probability to drop units from the outputs.
      attention_dropout: The probability to drop units from the attention.
      relu_dropout: The probability to drop units from the ReLU activation in
        the feed forward layer.
      kwargs: Additional layer arguments.
    """
        super(_SelfAttentionEncoderLayer, self).__init__(**kwargs)
        self.self_attention = transformer.MultiHeadAttention(
            num_heads, num_units, dropout=attention_dropout)
        self.self_attention = common.LayerWrapper(self.self_attention,
                                                  normalize_input=True,
                                                  output_dropout=dropout,
                                                  residual_connection=True)
        self.ffn = transformer.FeedForwardNetwork(ffn_inner_dim,
                                                  num_units,
                                                  dropout=relu_dropout)
        self.ffn = common.LayerWrapper(self.ffn,
                                       normalize_input=True,
                                       output_dropout=dropout,
                                       residual_connection=True)
Ejemplo n.º 7
0
 def testMultiHeadSelfAttentionRelativePositionsEmpty(self):
     attention = transformer.MultiHeadAttention(4,
                                                20,
                                                maximum_relative_position=6)
     x = tf.random.uniform([1, 0, 10])
     mask = tf.sequence_mask([0])
     y, _ = attention(x, mask=mask)
     self.assertListEqual(y.shape.as_list(), [1, 0, 20])
 def testMultiHeadSelfAttentionSpan(self):
     attention = transformer.MultiHeadAttention(4,
                                                20,
                                                attention_span=1,
                                                num_attended_heads=3)
     queries = tf.random.uniform([4, 5, 10])
     mask = tf.sequence_mask([4, 3, 5, 2])
     context, _ = attention(queries, mask=mask)
     self.assertListEqual(context.shape.as_list(), [4, 5, 20])
Ejemplo n.º 9
0
 def testMultiHeadSelfAttentionWithCache(self):
     cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5]))
     attention = transformer.MultiHeadAttention(4, 20)
     x = tf.random.uniform([4, 1, 10])
     _, cache = attention(x, cache=cache)
     self.assertEqual(cache[0].shape[2], 1)
     self.assertEqual(cache[1].shape[2], 1)
     _, cache = attention(x, cache=cache)
     self.assertEqual(cache[0].shape[2], 2)
     self.assertEqual(cache[1].shape[2], 2)
Ejemplo n.º 10
0
 def testMultiHeadAttentionWithCache(self):
     cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5]))
     attention = transformer.MultiHeadAttention(4, 20)
     memory = tf.random.uniform([4, 3, 10])
     mask = tf.sequence_mask([1, 3, 2, 2])
     x = tf.random.uniform([4, 1, 10])
     y1, cache = attention(x, memory=memory, mask=mask, cache=cache)
     self.assertEqual(cache[0].shape[2], 3)
     self.assertEqual(cache[1].shape[2], 3)
     y2, cache = attention(x, memory=memory, mask=mask, cache=cache)
     self.assertAllEqual(y1, y2)
Ejemplo n.º 11
0
 def testMultiHeadAttentionMask(self):
     attention = transformer.MultiHeadAttention(4,
                                                20,
                                                return_attention=True)
     queries = tf.random.uniform([4, 5, 10])
     memory = tf.random.uniform([4, 3, 10])
     mask = tf.sequence_mask([1, 3, 2, 2])
     _, _, attention = attention(queries, memory=memory, mask=mask)
     attention = tf.reshape(attention, [4, -1, 3])
     mask = tf.broadcast_to(tf.expand_dims(mask, 1), attention.shape)
     padding = tf.boolean_mask(attention, tf.logical_not(mask))
     self.assertAllEqual(tf.reduce_sum(padding), 0)
Ejemplo n.º 12
0
    def testMultiHeadSelfAttentionRelativeGradients(self):
        attention = transformer.MultiHeadAttention(4,
                                                   20,
                                                   maximum_relative_position=6)

        @tf.function
        def _compute_gradients_in_function(x):
            with tf.GradientTape() as tape:
                y, _ = attention(x)
                loss = tf.math.reduce_sum(y)
            gradients = tape.gradient(loss, attention.weights)
            for gradient in gradients:
                self.assertTrue(gradient.shape.is_fully_defined())

        _compute_gradients_in_function(tf.random.uniform([4, 1, 10]))
Ejemplo n.º 13
0
 def testMultiHeadSelfAttention(self):
     attention = transformer.MultiHeadAttention(4, 20)
     queries = tf.random.uniform([4, 5, 10])
     mask = tf.expand_dims(tf.sequence_mask([4, 3, 5, 2]), 1)
     context, _ = attention(queries, mask=mask)
     self.assertListEqual(context.shape.as_list(), [4, 5, 20])