Example #1
0
    def test_small_example_evaluate_unidirectional(self):

        # query -> [batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]
        # key -> [batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]
        # value -> [batch_size, dim1, dim2, ..., dimN, num_heads, value_channels]

        qk_dim = 1
        batch_size = 1
        dim = 3
        num_heads = 1
        nb_random_features = 10000

        shape_query = (batch_size, dim, num_heads, qk_dim)
        shape_key = (batch_size, dim, num_heads, qk_dim)

        query = jnp.ones(shape_query)
        key = jnp.ones(shape_key)

        value = onp.zeros((1, 3, 1, 1))
        value[0][0][0][0] = 1.0
        value[0][1][0][0] = 0.0
        value[0][2][0][0] = 0.0
        value = jnp.array(value)

        groundtruth = onp.array([[[[1.0]], [[0.5]], [[1.0 / 3.0]]]])

        renormalize_attention = True
        numerical_stabilizer = 0.0
        redraw_features = False
        unidirectional = True

        unstructured_random_matrix_creator = functools.partial(
            fast_attention.GaussianUnstructuredRandomMatrix,
            nb_random_features, qk_dim)
        ortho_random_matrix_creator = functools.partial(
            fast_attention.GaussianOrthogonalRandomMatrix, nb_random_features,
            qk_dim)
        fast_unstruct_rfm_dot_product_attention = fast_attention.FastAttentionviaLowRankDecomposition(
            unstructured_random_matrix_creator, kernel_feature_creator,
            renormalize_attention, numerical_stabilizer, redraw_features,
            unidirectional)
        fast_ortho_rfm_dot_product_attention = fast_attention.FastAttentionviaLowRankDecomposition(
            ortho_random_matrix_creator, kernel_feature_creator,
            renormalize_attention, numerical_stabilizer, redraw_features,
            unidirectional)

        unidirectional_unstruct_rfm_attention_result = fast_unstruct_rfm_dot_product_attention.dot_product_attention(
            query, key, value)

        unidirectional_ortho_rfm_attention_result = fast_ortho_rfm_dot_product_attention.dot_product_attention(
            query, key, value)

        max_error = 0.02
        unstruct_error = jnp.abs(unidirectional_unstruct_rfm_attention_result -
                                 groundtruth)
        ortho_error = jnp.abs(unidirectional_ortho_rfm_attention_result -
                              groundtruth)

        self.assertLess(jnp.max(jnp.abs(unstruct_error)), max_error)
        self.assertLess(jnp.max(jnp.abs(ortho_error)), max_error)
Example #2
0
    def test_evaluate_parameter(self):

        # query -> [batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]
        # key -> [batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]
        # value -> [batch_size, dim1, dim2, ..., dimN, num_heads, value_channels]

        qk_dim = 8
        v_dim = 10
        batch_size = 1
        dim1 = 2
        dim2 = 1
        dim3 = 1
        num_heads = 1
        nb_random_features = 10000
        shape_query = (batch_size, dim1, dim2, dim3, num_heads, qk_dim)
        shape_key = (batch_size, dim1, dim2, dim3, num_heads, qk_dim)
        shape_value = (batch_size, dim1, dim2, dim3, num_heads, v_dim)
        query = random.normal(random.PRNGKey(0), shape_query)
        key = random.normal(random.PRNGKey(0), shape_key)
        value = random.normal(random.PRNGKey(0), shape_value)

        renormalize_attention = True
        numerical_stabilizer = 0.0
        redraw_features = False
        unidirectional = False

        unstructured_random_matrix_creator = functools.partial(
            fast_attention.GaussianUnstructuredRandomMatrix,
            nb_random_features, qk_dim)
        ortho_random_matrix_creator = functools.partial(
            fast_attention.GaussianOrthogonalRandomMatrix, nb_random_features,
            qk_dim)
        fast_unstruct_rfm_dot_product_attention = fast_attention.FastAttentionviaLowRankDecomposition(
            unstructured_random_matrix_creator, kernel_feature_creator,
            renormalize_attention, numerical_stabilizer, redraw_features,
            unidirectional)
        fast_ortho_rfm_dot_product_attention = fast_attention.FastAttentionviaLowRankDecomposition(
            ortho_random_matrix_creator, kernel_feature_creator,
            renormalize_attention, numerical_stabilizer, redraw_features,
            unidirectional)

        standard_attention_result = attention.dot_product_attention(
            query, key, value)
        unstruct_rfm_attention_result = fast_unstruct_rfm_dot_product_attention.dot_product_attention(
            query, key, value)
        ortho_rfm_attention_result = fast_ortho_rfm_dot_product_attention.dot_product_attention(
            query, key, value)

        max_error = 0.33
        unstruct_error = jnp.abs(
            (standard_attention_result - unstruct_rfm_attention_result) /
            standard_attention_result)
        ortho_error = jnp.abs(
            (standard_attention_result - ortho_rfm_attention_result) /
            standard_attention_result)
        self.assertLess(jnp.max(jnp.abs(unstruct_error)), max_error)

        max_ortho_error = 2.0
        self.assertLess(jnp.max(jnp.abs(ortho_error)), max_ortho_error)