def test_chunked_causal_attention(self):
        batch_size = 1
        length = 128
        num_heads = 1
        dim = 16
        query = tf.random.normal([batch_size, length, num_heads, dim])
        key = tf.random.normal([batch_size, length, num_heads, dim])
        value = tf.random.normal([batch_size, length, num_heads, dim])
        query = tf.cast(query, tf.float64)
        key = tf.cast(key, tf.float64)
        value = tf.cast(value, tf.float64)
        attention_block_output = favor.favor_attention(
            query, key, value, None, favor.relu_kernel_transformation, True)

        chunked_attention_block_output = favor.favor_attention(
            query,
            key,
            value,
            None,
            favor.relu_kernel_transformation,
            True,
            use_chunked_causal=True)

        max_error = 0.0001
        with self.session(use_gpu=False) as sess:
            chunked_output, groundtruth_output = sess.run(
                [chunked_attention_block_output, attention_block_output])
            error = np.max(np.abs(groundtruth_output - chunked_output))
            self.assertLess(error, max_error)
    def test_softmax_noncausal_attention_block_output(self):
        batch_size = 1
        length = 2
        num_heads = 1
        dim = 8
        num_random_features = 1000
        query = tf.random.normal([batch_size, length, num_heads, dim])
        key = tf.random.normal([batch_size, length, num_heads, dim])
        value = tf.random.normal([batch_size, length, num_heads, dim])
        kernel_transformation = favor.softmax_kernel_transformation
        projection_matrix = favor.create_projection_matrix(
            num_random_features, dim)
        query = tf.cast(query, tf.float64)
        key = tf.cast(key, tf.float64)
        value = tf.cast(value, tf.float64)
        projection_matrix = tf.cast(projection_matrix, tf.float64)
        attention_block_output = favor.favor_attention(query, key, value, None,
                                                       kernel_transformation,
                                                       False,
                                                       projection_matrix)

        query = tf.multiply(query, 1.0 / math.sqrt(float(dim)))
        attention_scores = tf.einsum("BXHD,BYHD->BXYH", query, key)
        attention_scores = tf.nn.softmax(attention_scores, axis=2)
        exact_attention_block_output = tf.einsum("BXYH,BYHD->BXHD",
                                                 attention_scores, value)
        max_error = 0.5
        with self.session(use_gpu=False) as sess:
            favor_output, groundtruth_output = sess.run(
                [exact_attention_block_output, attention_block_output])
            error = np.max(
                np.abs(
                    (groundtruth_output - favor_output) / groundtruth_output))
            self.assertLess(error, max_error)