def test_output_last_dim_is_part_of_config_1(self): some_num_attention_heads = 2 some_inner_dim = 32 some_inner_activation = 'relu' encoder = TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation) self.assertIn('output_last_dim', encoder.get_config()) self.assertIsNone(encoder.get_config()['output_last_dim'])
def test_use_query_residual_is_part_of_config_1(self): some_num_attention_heads = 2 some_inner_dim = 32 some_inner_activation = 'relu' encoder = TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation) self.assertIn('use_query_residual', encoder.get_config()) self.assertTrue(encoder.get_config()['use_query_residual'])
def test_diff_q_kv_att_layer_norm_is_part_of_config_1(self): some_num_attention_heads = 2 some_inner_dim = 32 some_inner_activation = 'relu' encoder = TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation, norm_first=False) self.assertIn('diff_q_kv_att_layer_norm', encoder.get_config()) self.assertFalse(encoder.get_config()['diff_q_kv_att_layer_norm'])
def test_value_dim_is_part_of_config_2(self): some_num_attention_heads = 2 some_inner_dim = 32 some_inner_activation = 'relu' value_dim = 10 encoder = TransformerEncoderBlock( num_attention_heads=some_num_attention_heads, inner_dim=some_inner_dim, inner_activation=some_inner_activation, value_dim=value_dim) self.assertIn('value_dim', encoder.get_config()) self.assertEqual(value_dim, encoder.get_config()['value_dim'])
def test_get_config(self): num_attention_heads = 2 encoder_block = TransformerEncoderBlock( num_attention_heads=num_attention_heads, inner_dim=32, inner_activation='relu', output_dropout=0.1, attention_dropout=0.1, use_bias=False, norm_first=True, norm_epsilon=1e-6, inner_dropout=0.1, attention_initializer=tf.keras.initializers.RandomUniform( minval=0., maxval=1.)) encoder_block_config = encoder_block.get_config() new_encoder_block = TransformerEncoderBlock.from_config( encoder_block_config) self.assertEqual(encoder_block_config, new_encoder_block.get_config())