Exemplo n.º 1
0
    def _test_lsh_self_attention_deterministic_given_seed(self, causal=False):
        # Once the initialization and the call seeds are pinned down we have
        # deterministic output.
        with math.use_backend('jax'):
            layer = efficient_attention.LSHSelfAttention(
                n_heads=5,
                d_qk=7,
                d_v=17,
                causal=causal,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=2,
                n_buckets=4,
                use_reference_code=True,
                attention_dropout=0.0,
                mode='train')
            x = np.ones((3, 32, 8)).astype(np.float32)

            def get_output():
                _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0))
                return layer(x, rng=jax.random.PRNGKey(1))

            ys = [get_output() for _ in range(10)]

            self.assertEqual(ys[0].shape, x.shape)

            for y in ys[1:]:
                np.testing.assert_array_almost_equal(ys[0], y, decimal=6)
Exemplo n.º 2
0
    def test_lsh_self_attention_masked_non_causal(self):
        # Test that when the input that is in the masked area changes the attention
        # for the un-masked outputs doesn't change, but the masked region does
        # change.
        with math.use_backend('jax'):
            layer = efficient_attention.LSHSelfAttention(
                n_heads=5,
                d_qk=7,
                d_v=17,
                causal=False,
                masked=True,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=2,
                n_buckets=4,
                use_reference_code=True,
                attention_dropout=0.0,
                mode='train')

            batch = 5
            max_len = 32
            hidden = 8

            x = np.random.uniform(size=(batch, max_len, hidden))
            mask = np.ones((batch, max_len)).astype(np.bool)
            rngs = jax.random.randint(jax.random.PRNGKey(0), (batch, ),
                                      minval=1,
                                      maxval=max_len - 1)

            # Set some suffix of each mask[b] to 0.
            for i in range(batch):
                mask[i, rngs[i]:] = 0

            # Fix rngs and get the output for the LSH layer.
            def get_output(x, mask):
                xs = [x, mask]
                _, _ = layer.init(shapes.signature(xs), jax.random.PRNGKey(0))
                return layer(xs, rng=jax.random.PRNGKey(1))

            # Get the attention output for masked x.
            y = get_output(x, mask)

            # Change x, but only in the masked regions.
            for i in range(batch):
                x[i, rngs[i]:] = np.random.uniform(size=(max_len - rngs[i],
                                                         hidden))

            y2 = get_output(x, mask)

            for i in range(batch):
                # y and y2 should be identical in the non-masked part.
                np.testing.assert_array_almost_equal(y[i, :rngs[i]],
                                                     y2[i, :rngs[i]],
                                                     decimal=6)

                # In the masked out part, they should be different.
                self.assertGreater(
                    np.mean(np.abs(y[i, rngs[i]:] - y2[i, rngs[i]:])), 1e-5)
Exemplo n.º 3
0
 def test_lsh_self_attention_tf(self):
   with fastmath.use_backend(fastmath.Backend.TFNP):
     layer = efficient_attention.LSHSelfAttention(
         n_heads=5, d_qk=7, d_v=17, causal=True,
         chunk_len=8, n_chunks_before=1, n_chunks_after=0,
         n_hashes=2, n_buckets=4,
         use_reference_code=True, attention_dropout=0.0, mode='train')
     x = np.ones((3, 32, 8)).astype(np.float32)
     _, _ = layer.init(shapes.signature(x))
     y = layer(x)
     self.assertEqual(y.shape, x.shape)
Exemplo n.º 4
0
 def test_lsh_self_attention(self):
     with math.use_backend('jax'):
         input_signature = ShapeDtype((3, 32, 8))
         layer = efficient_attention.LSHSelfAttention(
             n_heads=5,
             d_qk=7,
             d_v=17,
             causal=True,
             chunk_len=8,
             n_chunks_before=1,
             n_chunks_after=0,
             n_hashes=2,
             n_buckets=4,
             use_reference_code=True,
             attention_dropout=0.0,
             mode='train')
         final_shape = base.check_shape_agreement(layer, input_signature)
         self.assertEqual((3, 32, 8), final_shape)
Exemplo n.º 5
0
    def test_lsh_and_pure_lsh_self_attention_equivalence(self):
        # Given the same weight matrices and random numbers, do these produce the
        # same output.
        with fastmath.use_backend(fastmath.Backend.JAX):
            n_heads = 4
            d_head = 4
            d_model = n_heads * d_head
            pure_lsh_layer = efficient_attention.PureLSHSelfAttention(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=True,
                masked=False,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=4,
                n_buckets=8,
                use_reference_code=False,
                attention_dropout=0.0,
                use_python_loop=True,
                bias=False,
                mode='train')
            lsh_layer = efficient_attention.LSHSelfAttention(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=True,
                masked=False,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=4,
                n_buckets=8,
                use_reference_code=False,
                attention_dropout=0.0,
                use_python_loop=True,
                mode='train')

            batch, seqlen = 3, 32
            input_shape = (batch, seqlen, d_model)

            x = jax.random.uniform(jax.random.PRNGKey(0),
                                   input_shape,
                                   dtype=jnp.float32)
            lsh_layer_input = x

            call_rng = jax.random.PRNGKey(42)

            lsh_layer_weights, lsh_layer_state = lsh_layer.init(
                shapes.signature(lsh_layer_input))
            lsh_layer.rng = call_rng
            lsh_layer_output = lsh_layer(lsh_layer_input)

            # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head),
            # (n_heads, d_head, d_model)
            # Abbreviated as - hmn, hmn, hnm
            w_qk, w_v, w_o = lsh_layer_weights

            qk = jnp.einsum('blm,hmn->bhln', x, w_qk)
            qk = qk.reshape((-1, qk.shape[2], qk.shape[3]))

            v = jnp.einsum('blm,hmn->bhln', x, w_v)
            v = v.reshape((-1, v.shape[2], v.shape[3]))

            pure_lsh_layer_input = (qk, v)
            _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input))
            pure_lsh_layer.rng = call_rng
            pure_lsh_layer.state = lsh_layer_state
            pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input)

            # b*h,l,n
            pure_lsh_layer_output = pure_lsh_layer_output.reshape(
                (batch, -1) + pure_lsh_layer_output.shape[1:])
            pure_lsh_layer_output_projected = (jnp.einsum(
                'bhld,hdm->blm', pure_lsh_layer_output, w_o))

            diff = pure_lsh_layer_output_projected - lsh_layer_output
            avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff))

            self.assertLess(avg_diff, 1e-5)