def test_fused_global_local_attention_shared_sublayers( self, share_kv_projections=False, share_qkv_projections=False, share_att_output_projection=False): hidden_size = 10 layer = etc_layers.FusedGlobalLocalAttention( long_hidden_size=hidden_size, global_hidden_size=hidden_size, num_heads=5, local_radius=7, relative_vocab_size=9, share_kv_projections=share_kv_projections, share_qkv_projections=share_qkv_projections, share_att_output_projection=share_att_output_projection) # Run layer to make sure all variables are built. layer(long_input=tf.ones([1, 1, hidden_size]), global_input=tf.ones([1, 1, hidden_size])) if share_qkv_projections: self.assertIs(layer.long_query_projection, layer.global_query_projection) self.assert_all_identical(layer.l2l_key_projection, layer.l2g_key_projection, layer.g2g_key_projection, layer.g2l_key_projection) self.assert_all_identical(layer.l2l_value_projection, layer.l2g_value_projection, layer.g2g_value_projection, layer.g2l_value_projection) elif share_kv_projections: self.assertIsNot(layer.long_query_projection, layer.global_query_projection) self.assertIs(layer.l2l_key_projection, layer.l2g_key_projection) self.assertIs(layer.g2g_key_projection, layer.g2l_key_projection) self.assertIsNot(layer.l2l_key_projection, layer.g2g_key_projection) self.assertIs(layer.l2l_value_projection, layer.l2g_value_projection) self.assertIs(layer.g2g_value_projection, layer.g2l_value_projection) self.assertIsNot(layer.l2l_value_projection, layer.g2g_value_projection) else: self.assertIsNot(layer.long_query_projection, layer.global_query_projection) self.assertIsNot(layer.l2l_key_projection, layer.l2g_key_projection) self.assertIsNot(layer.l2l_key_projection, layer.g2g_key_projection) self.assertIsNot(layer.l2l_value_projection, layer.l2g_value_projection) self.assertIsNot(layer.l2l_value_projection, layer.g2g_value_projection) self.assertIsNot(layer.long_query_projection, layer.l2l_key_projection) self.assertIsNot(layer.long_query_projection, layer.l2l_value_projection) self.assertIsNot(layer.l2l_key_projection, layer.l2l_value_projection) if share_att_output_projection: self.assertIs(layer.long_output_projection, layer.global_output_projection) else: self.assertIsNot(layer.long_output_projection, layer.global_output_projection)
def test_fused_global_local_attention_custom_total_att_size(self): tf.compat.v1.random.set_random_seed(1234) np.random.seed(1234) batch_size = 3 long_seq_len = 12 global_seq_len = 6 hidden_size = 11 num_heads = 5 local_radius = 2 total_att_size = 10 relative_vocab_size = 9 long_input = tf.constant( np.random.normal(size=[batch_size, long_seq_len, hidden_size])) global_input = tf.constant( np.random.normal(size=[batch_size, global_seq_len, hidden_size])) l2l_att_mask = tf.constant( np.random.binomial( n=1, p=0.9, size=[batch_size, long_seq_len, 2 * local_radius + 1])) g2g_att_mask = tf.constant( np.random.binomial( n=1, p=0.9, size=[batch_size, global_seq_len, global_seq_len])) l2g_att_mask = tf.constant( np.random.binomial(n=1, p=0.9, size=[batch_size, long_seq_len, global_seq_len])) g2l_att_mask = tf.constant( np.random.binomial(n=1, p=0.9, size=[batch_size, global_seq_len, long_seq_len])) l2l_relative_att_ids = tf.constant( np.random.randint( relative_vocab_size, size=[batch_size, long_seq_len, 2 * local_radius + 1])) g2g_relative_att_ids = tf.constant( np.random.randint( relative_vocab_size, size=[batch_size, global_seq_len, global_seq_len])) l2g_relative_att_ids = tf.constant( np.random.randint(relative_vocab_size, size=[batch_size, long_seq_len, global_seq_len])) g2l_relative_att_ids = tf.constant( np.random.randint(relative_vocab_size, size=[batch_size, global_seq_len, long_seq_len])) fused_att_layer = etc_layers.FusedGlobalLocalAttention( long_hidden_size=hidden_size, global_hidden_size=hidden_size, num_heads=num_heads, local_radius=local_radius, long_total_att_size=total_att_size, global_total_att_size=total_att_size, relative_vocab_size=relative_vocab_size, share_qkv_projections=True, share_att_output_projection=True) long_output, global_output = fused_att_layer( long_input, global_input, l2l_att_mask=l2l_att_mask, g2g_att_mask=g2g_att_mask, l2g_att_mask=l2g_att_mask, g2l_att_mask=g2l_att_mask, l2l_relative_att_ids=l2l_relative_att_ids, g2g_relative_att_ids=g2g_relative_att_ids, l2g_relative_att_ids=l2g_relative_att_ids, g2l_relative_att_ids=g2l_relative_att_ids) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllEqual([batch_size, long_seq_len, hidden_size], long_output.shape) self.assertAllEqual([batch_size, global_seq_len, hidden_size], global_output.shape)
def test_fused_global_local_attention_special_case_equivalence(self): # To test for correctness, we make sure the output is equivalent to # standard attention in the special case where `local_radius` covers the # entire long sequence length and projection weights are shared. # For simplicity, we don't use attention masks or relative attention ids # in this test. tf.compat.v1.random.set_random_seed(1234) np.random.seed(1234) batch_size = 3 long_seq_len = 12 global_seq_len = 6 hidden_size = 10 num_heads = 5 local_radius = 15 # Must be >= `long_seq_len - 1` to remove sparsity. # relative_vocab_size = 9 long_input = tf.constant( np.random.normal(size=[batch_size, long_seq_len, hidden_size])) global_input = tf.constant( np.random.normal(size=[batch_size, global_seq_len, hidden_size])) fused_att_layer = etc_layers.FusedGlobalLocalAttention( long_hidden_size=hidden_size, global_hidden_size=hidden_size, num_heads=num_heads, local_radius=local_radius, share_qkv_projections=True, share_att_output_projection=True) long_output, global_output = fused_att_layer( long_input, global_input, att_implementation='sparse') # [batch_size, long_seq_len + global_seq_len, hidden_size] fused_output = tf.concat([long_output, global_output], axis=1) # Create concatenated input for standard attention. # [batch_size, long_seq_len + global_seq_len, hidden_size] concat_input = tf.concat([long_input, global_input], axis=1) standard_att_layer = etc_layers.RelativeAttention( hidden_size=hidden_size, num_heads=num_heads, query_projection=fused_att_layer.long_query_projection, key_projection=fused_att_layer.l2l_key_projection, value_projection=fused_att_layer.l2l_value_projection, output_projection=fused_att_layer.long_output_projection) expected_output = standard_att_layer(concat_input) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(expected_output, fused_output) # Make sure 'full' att_implementation gives the same output. long_output_full_att, global_output_full_att = fused_att_layer( long_input, global_input, att_implementation='full') self.assertAllClose(long_output, long_output_full_att) self.assertAllClose(global_output, global_output_full_att)