def test_softmax_noncausal_attention_block_output(self): batch_size = 1 length = 2 num_heads = 1 dim = 8 num_random_features = 30000 query = tf.random.normal([batch_size, length, num_heads, dim]) key = tf.random.normal([batch_size, length, num_heads, dim]) value = tf.random.normal([batch_size, length, num_heads, dim]) kernel_transformation = fast_attention.softmax_kernel_transformation projection_matrix = fast_attention.create_projection_matrix( num_random_features, dim) attention_block_output = fast_attention.favor_attention( query, key, value, kernel_transformation, False, projection_matrix) query = tf.multiply(query, 1.0 / math.sqrt(float(dim))) attention_scores = tf.einsum("BXHD,BYHD->BXYH", query, key) attention_scores = tf.nn.softmax(attention_scores, axis=2) exact_attention_block_output = tf.einsum("BXYH,BYHD->BXHD", attention_scores, value) max_error = 2.0 error = tf.math.abs( (exact_attention_block_output - attention_block_output) / exact_attention_block_output) self.assertLess(tf.math.reduce_max(tf.math.abs(error)), max_error)
def test_relu_causal_attention_block_output_shape(self): batch_size = 1 length = 10 num_heads = 1 dim = 4 query = tf.ones([batch_size, length, num_heads, dim]) key = tf.ones([batch_size, length, num_heads, dim]) value = tf.ones([batch_size, length, num_heads, dim]) kernel_transformation = fast_attention.relu_kernel_transformation attention_block_output = fast_attention.favor_attention( query, key, value, kernel_transformation, True) self.assertListEqual(attention_block_output.get_shape().as_list(), [batch_size, length, num_heads, dim])
def test_softmax_noncausal_attention_block_output_shape(self): batch_size = 1 length = 10 num_heads = 1 dim = 4 num_random_features = 350 query = tf.ones([batch_size, length, num_heads, dim]) key = tf.ones([batch_size, length, num_heads, dim]) value = tf.ones([batch_size, length, num_heads, dim]) kernel_transformation = fast_attention.softmax_kernel_transformation projection_matrix = fast_attention.create_projection_matrix( num_random_features, dim) attention_block_output = fast_attention.favor_attention( query, key, value, kernel_transformation, False, projection_matrix) self.assertListEqual(attention_block_output.get_shape().as_list(), [batch_size, length, num_heads, dim])