def test_residual_block_with_relative_attention(self, use_pre_activation_order): np.random.seed(1234) batch_size = 2 seq_len = 4 hidden_size = 10 inputs = tf.constant( np.random.normal(size=[batch_size, seq_len, hidden_size]), tf.float32) att_mask = tf.stack([ # Force each element in the first example to only attend to itself. tf.eye(seq_len, dtype=tf.int32), # The second example can attend everywhere. tf.ones([seq_len, seq_len], dtype=tf.int32) ]) inner_layer = etc_layers.RelativeAttention( hidden_size=hidden_size, num_heads=2, relative_vocab_size=2, initializer=tf.keras.initializers.Identity()) residual_block = etc_layers.ResidualBlock( inner_layer=inner_layer, normalization_layer=tf.keras.layers.Lambda(lambda x: x), dropout_probability=0.0, use_pre_activation_order=use_pre_activation_order) relative_att_ids1 = tf.zeros([batch_size, seq_len, seq_len], dtype=tf.int32) result1 = residual_block(inputs, att_mask=att_mask, relative_att_ids=relative_att_ids1) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(2 * inputs[0], result1[0]) relative_att_ids2 = tf.tile([[[0, 1, 0, 1]]], [batch_size, seq_len, 1]) result2 = residual_block(inputs, att_mask=att_mask, relative_att_ids=relative_att_ids2) self.assertAllClose(result1[0], result2[0]) self.assertNotAllClose(result1[1], result2[1])
def test_relative_attention_self_attention(self, use_one_hot_lookup): tf.compat.v1.random.set_random_seed(1234) np.random.seed(1234) batch_size = 3 seq_len = 16 num_heads = 5 input_hidden_size = 11 output_hidden_size = 12 total_key_size = 10 total_value_size = 15 relative_vocab_size = 21 inputs = tf.constant( np.random.normal(size=[batch_size, seq_len, input_hidden_size]), tf.float32) att_mask = tf.constant( np.random.binomial(n=1, p=0.9, size=[batch_size, seq_len, seq_len])) relative_att_ids = tf.constant( np.random.randint(relative_vocab_size, size=[batch_size, seq_len, seq_len])) layer = etc_layers.RelativeAttention( hidden_size=output_hidden_size, num_heads=num_heads, total_key_size=total_key_size, total_value_size=total_value_size, relative_vocab_size=relative_vocab_size, use_one_hot_lookup=use_one_hot_lookup) result1 = layer(inputs, att_mask=att_mask, relative_att_ids=relative_att_ids) self.assertAllEqual([batch_size, seq_len, output_hidden_size], result1.shape) result2 = layer(from_seq=inputs, to_seq=inputs, att_mask=att_mask, relative_att_ids=relative_att_ids) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllEqual(result1, result2)
def test_relative_attention(self, use_one_hot_lookup): tf.compat.v1.random.set_random_seed(1234) np.random.seed(1234) batch_size = 3 from_seq_len = 16 to_seq_len = 17 num_heads = 5 from_hidden_size = 11 to_hidden_size = 12 output_hidden_size = 13 total_key_size = 10 total_value_size = 15 relative_vocab_size = 21 from_seq = tf.random.normal( [batch_size, from_seq_len, from_hidden_size]) to_seq = tf.random.normal([batch_size, to_seq_len, to_hidden_size]) att_mask = tf.constant( np.random.binomial(n=1, p=0.9, size=[batch_size, from_seq_len, to_seq_len])) relative_att_ids = tf.random.uniform( [batch_size, from_seq_len, to_seq_len], maxval=relative_vocab_size, dtype=tf.int32) layer = etc_layers.RelativeAttention( hidden_size=output_hidden_size, num_heads=num_heads, total_key_size=total_key_size, total_value_size=total_value_size, relative_vocab_size=relative_vocab_size, use_one_hot_lookup=use_one_hot_lookup) result = layer(from_seq=from_seq, to_seq=to_seq, att_mask=att_mask, relative_att_ids=relative_att_ids) self.assertAllEqual([batch_size, from_seq_len, output_hidden_size], result.shape)
def test_fused_global_local_attention_special_case_equivalence(self): # To test for correctness, we make sure the output is equivalent to # standard attention in the special case where `local_radius` covers the # entire long sequence length and projection weights are shared. # For simplicity, we don't use attention masks or relative attention ids # in this test. tf.compat.v1.random.set_random_seed(1234) np.random.seed(1234) batch_size = 3 long_seq_len = 12 global_seq_len = 6 hidden_size = 10 num_heads = 5 local_radius = 15 # Must be >= `long_seq_len - 1` to remove sparsity. # relative_vocab_size = 9 long_input = tf.constant( np.random.normal(size=[batch_size, long_seq_len, hidden_size])) global_input = tf.constant( np.random.normal(size=[batch_size, global_seq_len, hidden_size])) fused_att_layer = etc_layers.FusedGlobalLocalAttention( long_hidden_size=hidden_size, global_hidden_size=hidden_size, num_heads=num_heads, local_radius=local_radius, share_qkv_projections=True, share_att_output_projection=True) long_output, global_output = fused_att_layer( long_input, global_input, att_implementation='sparse') # [batch_size, long_seq_len + global_seq_len, hidden_size] fused_output = tf.concat([long_output, global_output], axis=1) # Create concatenated input for standard attention. # [batch_size, long_seq_len + global_seq_len, hidden_size] concat_input = tf.concat([long_input, global_input], axis=1) standard_att_layer = etc_layers.RelativeAttention( hidden_size=hidden_size, num_heads=num_heads, query_projection=fused_att_layer.long_query_projection, key_projection=fused_att_layer.l2l_key_projection, value_projection=fused_att_layer.l2l_value_projection, output_projection=fused_att_layer.long_output_projection) expected_output = standard_att_layer(concat_input) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(expected_output, fused_output) # Make sure 'full' att_implementation gives the same output. long_output_full_att, global_output_full_att = fused_att_layer( long_input, global_input, att_implementation='full') self.assertAllClose(long_output, long_output_full_att) self.assertAllClose(global_output, global_output_full_att)
def test_relative_attention_shared_sublayers(self, use_one_hot_lookup): tf.compat.v1.random.set_random_seed(1234) np.random.seed(1234) batch_size = 3 from_seq_len = 16 to_seq_len = 17 num_heads = 5 from_hidden_size = 11 to_hidden_size = 12 output_hidden_size = 13 total_key_size = 10 total_value_size = 15 relative_vocab_size = 9 from_seq = tf.constant( np.random.random( size=[batch_size, from_seq_len, from_hidden_size])) to_seq = tf.constant( np.random.random(size=[batch_size, to_seq_len, to_hidden_size])) att_mask = tf.constant( np.random.binomial(n=1, p=0.9, size=[batch_size, from_seq_len, to_seq_len])) layer = etc_layers.RelativeAttention( hidden_size=output_hidden_size, num_heads=num_heads, total_key_size=total_key_size, total_value_size=total_value_size, relative_vocab_size=relative_vocab_size, use_one_hot_lookup=use_one_hot_lookup) sharing_layer = etc_layers.RelativeAttention( hidden_size=output_hidden_size, num_heads=num_heads, total_key_size=total_key_size, total_value_size=total_value_size, query_projection=layer.query_projection, key_projection=layer.key_projection, value_projection=layer.value_projection, qkv_relative_attention=layer.qkv_relative_attention, output_projection=layer.output_projection) different_layer = etc_layers.RelativeAttention( hidden_size=output_hidden_size, num_heads=num_heads, total_key_size=total_key_size, total_value_size=total_value_size, relative_vocab_size=relative_vocab_size, use_one_hot_lookup=use_one_hot_lookup) result1 = layer(from_seq=from_seq, to_seq=to_seq, att_mask=att_mask, relative_att_ids=None) result2 = sharing_layer(from_seq=from_seq, to_seq=to_seq, att_mask=att_mask, relative_att_ids=None) result3 = different_layer(from_seq=from_seq, to_seq=to_seq, att_mask=att_mask, relative_att_ids=None) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllEqual(result1, result2) self.assertNotAllClose(result1, result3)