def _test_lsh_self_attention_deterministic_given_seed(self, causal=False): # Once the initialization and the call seeds are pinned down we have # deterministic output. with math.use_backend('jax'): layer = efficient_attention.LSHSelfAttention( n_heads=5, d_qk=7, d_v=17, causal=causal, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=2, n_buckets=4, use_reference_code=True, attention_dropout=0.0, mode='train') x = np.ones((3, 32, 8)).astype(np.float32) def get_output(): _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0)) return layer(x, rng=jax.random.PRNGKey(1)) ys = [get_output() for _ in range(10)] self.assertEqual(ys[0].shape, x.shape) for y in ys[1:]: np.testing.assert_array_almost_equal(ys[0], y, decimal=6)
def test_lsh_self_attention_masked_non_causal(self): # Test that when the input that is in the masked area changes the attention # for the un-masked outputs doesn't change, but the masked region does # change. with math.use_backend('jax'): layer = efficient_attention.LSHSelfAttention( n_heads=5, d_qk=7, d_v=17, causal=False, masked=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=2, n_buckets=4, use_reference_code=True, attention_dropout=0.0, mode='train') batch = 5 max_len = 32 hidden = 8 x = np.random.uniform(size=(batch, max_len, hidden)) mask = np.ones((batch, max_len)).astype(np.bool) rngs = jax.random.randint(jax.random.PRNGKey(0), (batch, ), minval=1, maxval=max_len - 1) # Set some suffix of each mask[b] to 0. for i in range(batch): mask[i, rngs[i]:] = 0 # Fix rngs and get the output for the LSH layer. def get_output(x, mask): xs = [x, mask] _, _ = layer.init(shapes.signature(xs), jax.random.PRNGKey(0)) return layer(xs, rng=jax.random.PRNGKey(1)) # Get the attention output for masked x. y = get_output(x, mask) # Change x, but only in the masked regions. for i in range(batch): x[i, rngs[i]:] = np.random.uniform(size=(max_len - rngs[i], hidden)) y2 = get_output(x, mask) for i in range(batch): # y and y2 should be identical in the non-masked part. np.testing.assert_array_almost_equal(y[i, :rngs[i]], y2[i, :rngs[i]], decimal=6) # In the masked out part, they should be different. self.assertGreater( np.mean(np.abs(y[i, rngs[i]:] - y2[i, rngs[i]:])), 1e-5)
def test_lsh_self_attention_tf(self): with fastmath.use_backend(fastmath.Backend.TFNP): layer = efficient_attention.LSHSelfAttention( n_heads=5, d_qk=7, d_v=17, causal=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=2, n_buckets=4, use_reference_code=True, attention_dropout=0.0, mode='train') x = np.ones((3, 32, 8)).astype(np.float32) _, _ = layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, x.shape)
def test_lsh_self_attention(self): with math.use_backend('jax'): input_signature = ShapeDtype((3, 32, 8)) layer = efficient_attention.LSHSelfAttention( n_heads=5, d_qk=7, d_v=17, causal=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=2, n_buckets=4, use_reference_code=True, attention_dropout=0.0, mode='train') final_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual((3, 32, 8), final_shape)
def test_lsh_and_pure_lsh_self_attention_equivalence(self): # Given the same weight matrices and random numbers, do these produce the # same output. with fastmath.use_backend(fastmath.Backend.JAX): n_heads = 4 d_head = 4 d_model = n_heads * d_head pure_lsh_layer = efficient_attention.PureLSHSelfAttention( n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=4, n_buckets=8, use_reference_code=False, attention_dropout=0.0, use_python_loop=True, bias=False, mode='train') lsh_layer = efficient_attention.LSHSelfAttention( n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=4, n_buckets=8, use_reference_code=False, attention_dropout=0.0, use_python_loop=True, mode='train') batch, seqlen = 3, 32 input_shape = (batch, seqlen, d_model) x = jax.random.uniform(jax.random.PRNGKey(0), input_shape, dtype=jnp.float32) lsh_layer_input = x call_rng = jax.random.PRNGKey(42) lsh_layer_weights, lsh_layer_state = lsh_layer.init( shapes.signature(lsh_layer_input)) lsh_layer.rng = call_rng lsh_layer_output = lsh_layer(lsh_layer_input) # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head), # (n_heads, d_head, d_model) # Abbreviated as - hmn, hmn, hnm w_qk, w_v, w_o = lsh_layer_weights qk = jnp.einsum('blm,hmn->bhln', x, w_qk) qk = qk.reshape((-1, qk.shape[2], qk.shape[3])) v = jnp.einsum('blm,hmn->bhln', x, w_v) v = v.reshape((-1, v.shape[2], v.shape[3])) pure_lsh_layer_input = (qk, v) _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input)) pure_lsh_layer.rng = call_rng pure_lsh_layer.state = lsh_layer_state pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input) # b*h,l,n pure_lsh_layer_output = pure_lsh_layer_output.reshape( (batch, -1) + pure_lsh_layer_output.shape[1:]) pure_lsh_layer_output_projected = (jnp.einsum( 'bhld,hdm->blm', pure_lsh_layer_output, w_o)) diff = pure_lsh_layer_output_projected - lsh_layer_output avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff)) self.assertLess(avg_diff, 1e-5)