def test_softmax_noncausal_attention_block_output(self):
    batch_size = 1
    length = 2
    num_heads = 1
    dim = 8
    num_random_features = 30000
    query = tf.random.normal([batch_size, length, num_heads, dim])
    key = tf.random.normal([batch_size, length, num_heads, dim])
    value = tf.random.normal([batch_size, length, num_heads, dim])
    kernel_transformation = fast_attention.softmax_kernel_transformation
    projection_matrix = fast_attention.create_projection_matrix(
        num_random_features, dim)
    attention_block_output = fast_attention.favor_attention(
        query, key, value, kernel_transformation, False, projection_matrix)

    query = tf.multiply(query, 1.0 / math.sqrt(float(dim)))
    attention_scores = tf.einsum("BXHD,BYHD->BXYH", query, key)
    attention_scores = tf.nn.softmax(attention_scores, axis=2)
    exact_attention_block_output = tf.einsum("BXYH,BYHD->BXHD",
                                             attention_scores, value)
    max_error = 2.0
    error = tf.math.abs(
        (exact_attention_block_output - attention_block_output) /
        exact_attention_block_output)
    self.assertLess(tf.math.reduce_max(tf.math.abs(error)), max_error)
 def test_relu_causal_attention_block_output_shape(self):
   batch_size = 1
   length = 10
   num_heads = 1
   dim = 4
   query = tf.ones([batch_size, length, num_heads, dim])
   key = tf.ones([batch_size, length, num_heads, dim])
   value = tf.ones([batch_size, length, num_heads, dim])
   kernel_transformation = fast_attention.relu_kernel_transformation
   attention_block_output = fast_attention.favor_attention(
       query, key, value, kernel_transformation, True)
   self.assertListEqual(attention_block_output.get_shape().as_list(),
                        [batch_size, length, num_heads, dim])
 def test_softmax_noncausal_attention_block_output_shape(self):
   batch_size = 1
   length = 10
   num_heads = 1
   dim = 4
   num_random_features = 350
   query = tf.ones([batch_size, length, num_heads, dim])
   key = tf.ones([batch_size, length, num_heads, dim])
   value = tf.ones([batch_size, length, num_heads, dim])
   kernel_transformation = fast_attention.softmax_kernel_transformation
   projection_matrix = fast_attention.create_projection_matrix(
       num_random_features, dim)
   attention_block_output = fast_attention.favor_attention(
       query, key, value, kernel_transformation, False, projection_matrix)
   self.assertListEqual(attention_block_output.get_shape().as_list(),
                        [batch_size, length, num_heads, dim])