def __init__(self, num_units, num_heads, ffn_inner_dim, num_sources=1, dropout=0.1, attention_dropout=0.1, ffn_dropout=0.1, ffn_activation=tf.nn.relu, **kwargs): """Initializes the layer. Args: num_units: The number of hidden units. num_heads: The number of heads in the multi-head attention. ffn_inner_dim: The number of units of the inner linear transformation in the feed forward layer. num_sources: The number of source contexts. dropout: The probability to drop units from the outputs. attention_dropout: The probability to drop units from the attention. ffn_dropout: The probability to drop units from the activation output in the feed forward layer. ffn_activation: The activation function to apply between the two linear transformations of the feed forward layer. **kwargs: Additional layer arguments. """ super(_SelfAttentionDecoderLayer, self).__init__(**kwargs) self.self_attention = transformer.MultiHeadAttention( num_heads, num_units, dropout=attention_dropout, name="masked_multi_head_attention") self.self_attention = transformer.TransformerLayerWrapper( self.self_attention, dropout, name="sub_layer_0") self.attention = [] for i in range(num_sources): attention = transformer.MultiHeadAttention( num_heads, num_units, dropout=attention_dropout, return_attention=num_sources == 1, name="multi_head_attention") attention = transformer.TransformerLayerWrapper( attention, dropout, name="sub_layer_%d" % (i + 1)) self.attention.append(attention) self.ffn = transformer.FeedForwardNetwork(ffn_inner_dim, num_units, dropout=ffn_dropout, activation=ffn_activation, name="feed_forward") self.ffn = transformer.TransformerLayerWrapper(self.ffn, dropout, name="sub_layer_%d" % (num_sources + 1))
def testMultiHeadAttention(self): attention = transformer.MultiHeadAttention(4, 20) queries = tf.random.uniform([4, 5, 10]) memory = tf.random.uniform([4, 3, 10]) mask = tf.sequence_mask([1, 3, 2, 2]) context, _ = attention(queries, memory=memory, mask=mask) self.assertListEqual(context.shape.as_list(), [4, 5, 20])
def testMultiHeadSelfAttentionRelativePositionsWithCache(self): attention = transformer.MultiHeadAttention(4, 20, maximum_relative_position=6) x = tf.random.uniform([4, 1, 10]) cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5])) _, cache = attention(x, cache=cache)
def testMultiHeadSelfAttentionRelativePositions(self): attention = transformer.MultiHeadAttention(4, 20, maximum_relative_position=6) x = tf.random.uniform([2, 9, 10]) mask = tf.sequence_mask([9, 7]) y = attention(x, mask=mask)
def __init__(self, num_layers, num_units, num_heads, dropout=0.3, cell_class=None, **kwargs): """Initializes the decoder parameters. Args: num_layers: The number of layers. num_units: The number of units in each layer. num_heads: The number of attention heads. dropout: The probability to drop units from the decoder input and in each layer output. cell_class: The inner cell class or a callable taking :obj:`num_units` as argument and returning a cell. Defaults to a layer normalized LSTM cell. **kwargs: Additional layer arguments. """ super(RNMTPlusDecoder, self).__init__(**kwargs) if cell_class is None: cell_class = tfa.rnn.LayerNormLSTMCell self.num_heads = num_heads self.num_units = num_units self.dropout = dropout self.cells = [cell_class(num_units) for _ in range(num_layers)] self.multi_head_attention = transformer.MultiHeadAttention( num_heads, num_units, dropout=dropout, return_attention=True)
def __init__(self, num_units, num_heads, ffn_inner_dim, dropout=0.1, attention_dropout=0.1, relu_dropout=0.1, **kwargs): """Initializes the layer. Args: num_units: The number of hidden units. num_heads: The number of heads in the multi-head attention. ffn_inner_dim: The number of units of the inner linear transformation in the feed forward layer. dropout: The probability to drop units from the outputs. attention_dropout: The probability to drop units from the attention. relu_dropout: The probability to drop units from the ReLU activation in the feed forward layer. kwargs: Additional layer arguments. """ super(_SelfAttentionEncoderLayer, self).__init__(**kwargs) self.self_attention = transformer.MultiHeadAttention( num_heads, num_units, dropout=attention_dropout) self.self_attention = common.LayerWrapper(self.self_attention, normalize_input=True, output_dropout=dropout, residual_connection=True) self.ffn = transformer.FeedForwardNetwork(ffn_inner_dim, num_units, dropout=relu_dropout) self.ffn = common.LayerWrapper(self.ffn, normalize_input=True, output_dropout=dropout, residual_connection=True)
def testMultiHeadSelfAttentionRelativePositionsEmpty(self): attention = transformer.MultiHeadAttention(4, 20, maximum_relative_position=6) x = tf.random.uniform([1, 0, 10]) mask = tf.sequence_mask([0]) y, _ = attention(x, mask=mask) self.assertListEqual(y.shape.as_list(), [1, 0, 20])
def testMultiHeadSelfAttentionSpan(self): attention = transformer.MultiHeadAttention(4, 20, attention_span=1, num_attended_heads=3) queries = tf.random.uniform([4, 5, 10]) mask = tf.sequence_mask([4, 3, 5, 2]) context, _ = attention(queries, mask=mask) self.assertListEqual(context.shape.as_list(), [4, 5, 20])
def testMultiHeadSelfAttentionWithCache(self): cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5])) attention = transformer.MultiHeadAttention(4, 20) x = tf.random.uniform([4, 1, 10]) _, cache = attention(x, cache=cache) self.assertEqual(cache[0].shape[2], 1) self.assertEqual(cache[1].shape[2], 1) _, cache = attention(x, cache=cache) self.assertEqual(cache[0].shape[2], 2) self.assertEqual(cache[1].shape[2], 2)
def testMultiHeadAttentionWithCache(self): cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5])) attention = transformer.MultiHeadAttention(4, 20) memory = tf.random.uniform([4, 3, 10]) mask = tf.sequence_mask([1, 3, 2, 2]) x = tf.random.uniform([4, 1, 10]) y1, cache = attention(x, memory=memory, mask=mask, cache=cache) self.assertEqual(cache[0].shape[2], 3) self.assertEqual(cache[1].shape[2], 3) y2, cache = attention(x, memory=memory, mask=mask, cache=cache) self.assertAllEqual(y1, y2)
def testMultiHeadAttentionMask(self): attention = transformer.MultiHeadAttention(4, 20, return_attention=True) queries = tf.random.uniform([4, 5, 10]) memory = tf.random.uniform([4, 3, 10]) mask = tf.sequence_mask([1, 3, 2, 2]) _, _, attention = attention(queries, memory=memory, mask=mask) attention = tf.reshape(attention, [4, -1, 3]) mask = tf.broadcast_to(tf.expand_dims(mask, 1), attention.shape) padding = tf.boolean_mask(attention, tf.logical_not(mask)) self.assertAllEqual(tf.reduce_sum(padding), 0)
def testMultiHeadSelfAttentionRelativeGradients(self): attention = transformer.MultiHeadAttention(4, 20, maximum_relative_position=6) @tf.function def _compute_gradients_in_function(x): with tf.GradientTape() as tape: y, _ = attention(x) loss = tf.math.reduce_sum(y) gradients = tape.gradient(loss, attention.weights) for gradient in gradients: self.assertTrue(gradient.shape.is_fully_defined()) _compute_gradients_in_function(tf.random.uniform([4, 1, 10]))
def testMultiHeadSelfAttention(self): attention = transformer.MultiHeadAttention(4, 20) queries = tf.random.uniform([4, 5, 10]) mask = tf.expand_dims(tf.sequence_mask([4, 3, 5, 2]), 1) context, _ = attention(queries, mask=mask) self.assertListEqual(context.shape.as_list(), [4, 5, 20])