def test_use_query_residual_false_removes_add_op(self, norm_first): graph_with_res = tf.Graph() with graph_with_res.as_default(): layer = TransformerEncoderBlock( num_attention_heads=2, inner_dim=128, inner_activation='relu', norm_first=norm_first) inputs = tf.keras.Input(shape=(None, None, 2)) outputs = layer(inputs) tf.keras.Model(inputs=inputs, outputs=outputs) graph_without_res = tf.Graph() with graph_without_res.as_default(): layer = TransformerEncoderBlock( num_attention_heads=2, inner_dim=128, inner_activation='relu', norm_first=norm_first, use_query_residual=False) inputs = tf.keras.Input(shape=(None, None, 2)) outputs = layer(inputs) tf.keras.Model(inputs=inputs, outputs=outputs) graph_with_res_names = {x.name for x in graph_with_res.get_operations()} graph_without_res_names = { x.name for x in graph_without_res.get_operations() } self.assertIn('transformer_encoder_block/add', list(graph_with_res_names - graph_without_res_names)[0]) self.assertEmpty(graph_without_res_names - graph_with_res_names)
def test_needs_diff_q_kv_att_layer_norm_to_be_true_for_diff_q_and_kv_dims( self, output_range): test_layer = TransformerEncoderBlock( num_attention_heads=2, inner_dim=128, inner_activation='relu', output_range=output_range, norm_first=True) # Forward path. q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32) kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32) inputs = [q_tensor, kv_tensor, dummy_mask] with self.assertRaises(tf.errors.InvalidArgumentError): test_layer(inputs) test_layer = TransformerEncoderBlock( num_attention_heads=2, inner_dim=128, inner_activation='relu', diff_q_kv_att_layer_norm=True, norm_first=True) # Forward path. test_layer(inputs)
def test_output_last_dim_is_part_of_config_1(self): some_num_attention_heads = 2 some_inner_dim = 32 some_inner_activation = 'relu' encoder = TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation) self.assertIn('output_last_dim', encoder.get_config()) self.assertIsNone(encoder.get_config()['output_last_dim'])
def test_use_query_residual_is_part_of_config_1(self): some_num_attention_heads = 2 some_inner_dim = 32 some_inner_activation = 'relu' encoder = TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation) self.assertIn('use_query_residual', encoder.get_config()) self.assertTrue(encoder.get_config()['use_query_residual'])
def test_diff_q_kv_att_layer_norm_is_part_of_config_1(self): some_num_attention_heads = 2 some_inner_dim = 32 some_inner_activation = 'relu' encoder = TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation, norm_first=False) self.assertIn('diff_q_kv_att_layer_norm', encoder.get_config()) self.assertFalse(encoder.get_config()['diff_q_kv_att_layer_norm'])
def test_value_dim_is_part_of_config_2(self): some_num_attention_heads = 2 some_inner_dim = 32 some_inner_activation = 'relu' value_dim = 10 encoder = TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation, value_dim=value_dim) self.assertIn('value_dim', encoder.get_config()) self.assertEqual(value_dim, encoder.get_config()['value_dim'])
def test_get_config(self): num_attention_heads = 2 encoder_block = TransformerEncoderBlock( num_attention_heads=num_attention_heads, inner_dim=32, inner_activation='relu', output_dropout=0.1, attention_dropout=0.1, use_bias=False, norm_first=True, norm_epsilon=1e-6, inner_dropout=0.1, attention_initializer=tf.keras.initializers.RandomUniform( minval=0., maxval=1.)) encoder_block_config = encoder_block.get_config() new_encoder_block = TransformerEncoderBlock.from_config( encoder_block_config) self.assertEqual(encoder_block_config, new_encoder_block.get_config())
def test_norm_first_false_and_diff_q_kv_att_layer_norm_true_raises(self): some_num_attention_heads = 2 some_inner_dim = 32 some_inner_activation = 'relu' with self.assertRaises(ValueError): TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation, norm_first=False, diff_q_kv_att_layer_norm=True)
def test_raises_invalid_arg_error_when_q_kv_dims_are_different(self): test_layer = TransformerEncoderBlock(num_attention_heads=2, inner_dim=128, inner_activation='relu', norm_first=True) # Forward path. q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32) kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32) inputs = [q_tensor, kv_tensor, dummy_mask] with self.assertRaises(tf.errors.InvalidArgumentError): test_layer(inputs)
def test_value_dim(self, value_dim, q_tensor_last_dim, some_num_attention_heads, expected): some_inner_dim = 32 some_inner_activation = 'relu' test_layer = TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation, value_dim=value_dim) q_tensor = tf.zeros([2, 4, q_tensor_last_dim], dtype=tf.float32) kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32) test_layer([q_tensor, kv_tensor, dummy_mask]) self.assertEqual(expected, test_layer._attention_layer.get_config()['value_dim'])
def test_several_attention_axes(self, attention_axes): test_layer = TransformerEncoderBlock(inner_dim=32, inner_activation='relu', output_dropout=0.1, attention_dropout=0.1, use_bias=False, norm_first=True, norm_epsilon=1e-6, inner_dropout=0.1, num_attention_heads=10, attention_axes=attention_axes) num_rows = 21 num_cols = 13 width = 80 # Create a 3-dimensional input (the first dimension is implicit). data_tensor = tf.keras.Input(shape=(num_rows, num_cols, width)) output_tensor = test_layer(data_tensor) # The default output of a transformer layer should be the same as the input. self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_output_last_dim(self, use_query_residual, output_last_dim, q_tensor_last_dim, expected): some_num_attention_heads = 2 some_inner_dim = 32 some_inner_activation = 'relu' test_layer = TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation, # Must be false for multi-head output to be different from # first input's last dim use_query_residual=use_query_residual, output_last_dim=output_last_dim) q_tensor = tf.zeros([2, 4, q_tensor_last_dim], dtype=tf.float32) kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32) output = test_layer([q_tensor, kv_tensor, dummy_mask]) self.assertEqual(output.numpy().shape[-1], expected)
def test_use_bias_norm_first(self): num_attention_heads = 2 hidden_size = 16 encoder_block = TransformerEncoderBlock( num_attention_heads=num_attention_heads, inner_dim=32, inner_activation='relu', output_dropout=0.1, attention_dropout=0.1, use_bias=False, norm_first=True, norm_epsilon=1e-6, inner_dropout=0.1, attention_initializer=tf.keras.initializers.RandomUniform( minval=0., maxval=1.)) # Forward path. dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32) inputs = [dummy_tensor, dummy_mask] output = encoder_block(inputs) self.assertEqual(output.shape, (2, 4, hidden_size))
def test_dropout_config(self, output_dropout, attention_dropout, inner_dropout): test_layer = TransformerEncoderBlock( num_attention_heads=2, inner_dim=32, inner_activation='relu', output_dropout=output_dropout, attention_dropout=attention_dropout, inner_dropout=inner_dropout) seq_len = 21 hidden_size = 512 input_tensor = tf.keras.Input(shape=(seq_len, hidden_size)) _ = test_layer(input_tensor) true_output_dropout = test_layer._output_dropout.get_config()['rate'] true_attention_dropout = test_layer._attention_dropout.get_config()['rate'] true_inner_dropout = test_layer._inner_dropout_layer.get_config()['rate'] self.assertEqual(true_output_dropout, output_dropout) self.assertEqual(true_attention_dropout, attention_dropout) self.assertEqual(true_inner_dropout, inner_dropout)