コード例 #1
0
 def test_output_last_dim_is_part_of_config_1(self):
   some_num_attention_heads = 2
   some_inner_dim = 32
   some_inner_activation = 'relu'
   encoder = TransformerEncoderBlock(
       num_attention_heads=some_num_attention_heads,
       inner_dim=some_inner_dim,
       inner_activation=some_inner_activation)
   self.assertIn('output_last_dim', encoder.get_config())
   self.assertIsNone(encoder.get_config()['output_last_dim'])
コード例 #2
0
 def test_use_query_residual_is_part_of_config_1(self):
   some_num_attention_heads = 2
   some_inner_dim = 32
   some_inner_activation = 'relu'
   encoder = TransformerEncoderBlock(
       num_attention_heads=some_num_attention_heads,
       inner_dim=some_inner_dim,
       inner_activation=some_inner_activation)
   self.assertIn('use_query_residual', encoder.get_config())
   self.assertTrue(encoder.get_config()['use_query_residual'])
コード例 #3
0
 def test_diff_q_kv_att_layer_norm_is_part_of_config_1(self):
   some_num_attention_heads = 2
   some_inner_dim = 32
   some_inner_activation = 'relu'
   encoder = TransformerEncoderBlock(
       num_attention_heads=some_num_attention_heads,
       inner_dim=some_inner_dim,
       inner_activation=some_inner_activation,
       norm_first=False)
   self.assertIn('diff_q_kv_att_layer_norm', encoder.get_config())
   self.assertFalse(encoder.get_config()['diff_q_kv_att_layer_norm'])
コード例 #4
0
 def test_value_dim_is_part_of_config_2(self):
   some_num_attention_heads = 2
   some_inner_dim = 32
   some_inner_activation = 'relu'
   value_dim = 10
   encoder = TransformerEncoderBlock(
       num_attention_heads=some_num_attention_heads,
       inner_dim=some_inner_dim,
       inner_activation=some_inner_activation,
       value_dim=value_dim)
   self.assertIn('value_dim', encoder.get_config())
   self.assertEqual(value_dim, encoder.get_config()['value_dim'])
コード例 #5
0
 def test_get_config(self):
     num_attention_heads = 2
     encoder_block = TransformerEncoderBlock(
         num_attention_heads=num_attention_heads,
         inner_dim=32,
         inner_activation='relu',
         output_dropout=0.1,
         attention_dropout=0.1,
         use_bias=False,
         norm_first=True,
         norm_epsilon=1e-6,
         inner_dropout=0.1,
         attention_initializer=tf.keras.initializers.RandomUniform(
             minval=0., maxval=1.))
     encoder_block_config = encoder_block.get_config()
     new_encoder_block = TransformerEncoderBlock.from_config(
         encoder_block_config)
     self.assertEqual(encoder_block_config, new_encoder_block.get_config())