def test_config(self):
   num_heads = 12
   key_dim = 64
   test_layer = attention.KernelAttention(
       num_heads=num_heads,
       key_dim=key_dim,
       feature_transform='exp',
       num_random_features=128,
       is_short_seq=True)
   new_layer = attention.KernelAttention.from_config(
       test_layer.get_config())
   # If the serialization was successful, the new config should match the old.
   self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
Beispiel #2
0
  def test_attention_scale_by_length(self, seq_length):
    num_heads = 12
    key_dim = 64
    batch_size = 2
    test_layer = attention.KernelAttention(
        num_heads=num_heads,
        key_dim=key_dim,
        num_random_features=0,
        scale_by_length=True)
    query = tf.random.normal(
        shape=(batch_size, seq_length, key_dim))
    value = query
    encoder_inputs_mask = tf.ones((batch_size, seq_length), dtype=tf.int32)
    masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
    output_scale_by_length = test_layer(
        query=query, value=value, attention_mask=masks)

    test_layer._scale_by_length = False
    output_no_scale_by_length = test_layer(
        query=query, value=value, attention_mask=masks)
    if seq_length == 512:  # Equals because log(seq_length, base=512) = 1.0
      self.assertAllClose(output_scale_by_length, output_no_scale_by_length)
    else:
      self.assertNotAllClose(output_scale_by_length, output_no_scale_by_length)
 def test_attention_projection(self, feature_transform, num_random_features,
                               training, redraw, is_short, begin_kernel):
     num_heads = 12
     key_dim = 64
     seq_length = 1024
     batch_size = 2
     test_layer = attention.KernelAttention(
         num_heads=num_heads,
         key_dim=key_dim,
         feature_transform=feature_transform,
         num_random_features=num_random_features,
         redraw=redraw,
         is_short_seq=is_short,
         begin_kernel=begin_kernel)
     query = tf.random.normal(shape=(batch_size, seq_length, key_dim))
     value = query
     encoder_inputs_mask = tf.zeros((batch_size, seq_length),
                                    dtype=tf.int32)
     masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
     output = test_layer(query=query,
                         value=value,
                         attention_mask=masks,
                         training=training)
     self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
Beispiel #4
0
 def test_redraw_true_no_projection(self):
   with self.assertRaisesRegex(
       ValueError, "There is nothing to redraw when num_random_features.*"):
     _ = attention.KernelAttention(
         num_heads=2, key_dim=64, feature_transform="elu",
         num_random_features=0, redraw=True)
Beispiel #5
0
 def test_unsupported_feature_transform(self):
   with self.assertRaisesRegex(ValueError, "Unsupported feature_transform.*"):
     _ = attention.KernelAttention(feature_transform="test")