def test_small_example_evaluate_unidirectional(self): # query -> [batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels] # key -> [batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels] # value -> [batch_size, dim1, dim2, ..., dimN, num_heads, value_channels] qk_dim = 1 batch_size = 1 dim = 3 num_heads = 1 nb_random_features = 10000 shape_query = (batch_size, dim, num_heads, qk_dim) shape_key = (batch_size, dim, num_heads, qk_dim) query = jnp.ones(shape_query) key = jnp.ones(shape_key) value = onp.zeros((1, 3, 1, 1)) value[0][0][0][0] = 1.0 value[0][1][0][0] = 0.0 value[0][2][0][0] = 0.0 value = jnp.array(value) groundtruth = onp.array([[[[1.0]], [[0.5]], [[1.0 / 3.0]]]]) renormalize_attention = True numerical_stabilizer = 0.0 redraw_features = False unidirectional = True unstructured_random_matrix_creator = functools.partial( fast_attention.GaussianUnstructuredRandomMatrix, nb_random_features, qk_dim) ortho_random_matrix_creator = functools.partial( fast_attention.GaussianOrthogonalRandomMatrix, nb_random_features, qk_dim) fast_unstruct_rfm_dot_product_attention = fast_attention.FastAttentionviaLowRankDecomposition( unstructured_random_matrix_creator, kernel_feature_creator, renormalize_attention, numerical_stabilizer, redraw_features, unidirectional) fast_ortho_rfm_dot_product_attention = fast_attention.FastAttentionviaLowRankDecomposition( ortho_random_matrix_creator, kernel_feature_creator, renormalize_attention, numerical_stabilizer, redraw_features, unidirectional) unidirectional_unstruct_rfm_attention_result = fast_unstruct_rfm_dot_product_attention.dot_product_attention( query, key, value) unidirectional_ortho_rfm_attention_result = fast_ortho_rfm_dot_product_attention.dot_product_attention( query, key, value) max_error = 0.02 unstruct_error = jnp.abs(unidirectional_unstruct_rfm_attention_result - groundtruth) ortho_error = jnp.abs(unidirectional_ortho_rfm_attention_result - groundtruth) self.assertLess(jnp.max(jnp.abs(unstruct_error)), max_error) self.assertLess(jnp.max(jnp.abs(ortho_error)), max_error)
def test_evaluate_parameter(self): # query -> [batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels] # key -> [batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels] # value -> [batch_size, dim1, dim2, ..., dimN, num_heads, value_channels] qk_dim = 8 v_dim = 10 batch_size = 1 dim1 = 2 dim2 = 1 dim3 = 1 num_heads = 1 nb_random_features = 10000 shape_query = (batch_size, dim1, dim2, dim3, num_heads, qk_dim) shape_key = (batch_size, dim1, dim2, dim3, num_heads, qk_dim) shape_value = (batch_size, dim1, dim2, dim3, num_heads, v_dim) query = random.normal(random.PRNGKey(0), shape_query) key = random.normal(random.PRNGKey(0), shape_key) value = random.normal(random.PRNGKey(0), shape_value) renormalize_attention = True numerical_stabilizer = 0.0 redraw_features = False unidirectional = False unstructured_random_matrix_creator = functools.partial( fast_attention.GaussianUnstructuredRandomMatrix, nb_random_features, qk_dim) ortho_random_matrix_creator = functools.partial( fast_attention.GaussianOrthogonalRandomMatrix, nb_random_features, qk_dim) fast_unstruct_rfm_dot_product_attention = fast_attention.FastAttentionviaLowRankDecomposition( unstructured_random_matrix_creator, kernel_feature_creator, renormalize_attention, numerical_stabilizer, redraw_features, unidirectional) fast_ortho_rfm_dot_product_attention = fast_attention.FastAttentionviaLowRankDecomposition( ortho_random_matrix_creator, kernel_feature_creator, renormalize_attention, numerical_stabilizer, redraw_features, unidirectional) standard_attention_result = attention.dot_product_attention( query, key, value) unstruct_rfm_attention_result = fast_unstruct_rfm_dot_product_attention.dot_product_attention( query, key, value) ortho_rfm_attention_result = fast_ortho_rfm_dot_product_attention.dot_product_attention( query, key, value) max_error = 0.33 unstruct_error = jnp.abs( (standard_attention_result - unstruct_rfm_attention_result) / standard_attention_result) ortho_error = jnp.abs( (standard_attention_result - ortho_rfm_attention_result) / standard_attention_result) self.assertLess(jnp.max(jnp.abs(unstruct_error)), max_error) max_ortho_error = 2.0 self.assertLess(jnp.max(jnp.abs(ortho_error)), max_ortho_error)