Beispiel #1
0
    def test_lsh_and_pure_lsh_self_attention_equivalence(self):
        # Given the same weight matrices and random numbers, do these produce the
        # same output.
        with fastmath.use_backend(fastmath.Backend.JAX):
            n_heads = 4
            d_head = 4
            d_model = n_heads * d_head
            pure_lsh_layer = efficient_attention.PureLSHSelfAttention(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=True,
                masked=False,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=4,
                n_buckets=8,
                use_reference_code=False,
                attention_dropout=0.0,
                use_python_loop=True,
                bias=False,
                mode='train')
            lsh_layer = efficient_attention.LSHSelfAttention(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=True,
                masked=False,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=4,
                n_buckets=8,
                use_reference_code=False,
                attention_dropout=0.0,
                use_python_loop=True,
                mode='train')

            batch, seqlen = 3, 32
            input_shape = (batch, seqlen, d_model)

            x = jax.random.uniform(jax.random.PRNGKey(0),
                                   input_shape,
                                   dtype=jnp.float32)
            lsh_layer_input = x

            call_rng = jax.random.PRNGKey(42)

            lsh_layer_weights, lsh_layer_state = lsh_layer.init(
                shapes.signature(lsh_layer_input))
            lsh_layer.rng = call_rng
            lsh_layer_output = lsh_layer(lsh_layer_input)

            # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head),
            # (n_heads, d_head, d_model)
            # Abbreviated as - hmn, hmn, hnm
            w_qk, w_v, w_o = lsh_layer_weights

            qk = jnp.einsum('blm,hmn->bhln', x, w_qk)
            qk = qk.reshape((-1, qk.shape[2], qk.shape[3]))

            v = jnp.einsum('blm,hmn->bhln', x, w_v)
            v = v.reshape((-1, v.shape[2], v.shape[3]))

            pure_lsh_layer_input = (qk, v)
            _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input))
            pure_lsh_layer.rng = call_rng
            pure_lsh_layer.state = lsh_layer_state
            pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input)

            # b*h,l,n
            pure_lsh_layer_output = pure_lsh_layer_output.reshape(
                (batch, -1) + pure_lsh_layer_output.shape[1:])
            pure_lsh_layer_output_projected = (jnp.einsum(
                'bhld,hdm->blm', pure_lsh_layer_output, w_o))

            diff = pure_lsh_layer_output_projected - lsh_layer_output
            avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff))

            self.assertLess(avg_diff, 1e-5)
Beispiel #2
0
 def init(self, weights):
     return jnp.ones_like(weights)