def test_self_attention_tf(self): with fastmath.use_backend(fastmath.Backend.TFNP): layer = efficient_attention.SelfAttention( n_heads=5, d_qk=7, d_v=17, share_qk=False, causal=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, use_reference_code=True, attention_dropout=0.0, mode='train') x = np.ones((3, 32, 8)).astype(np.float32) _, _ = layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, x.shape)
def test_self_attention(self): with math.use_backend('jax'): input_signature = ShapeDtype((3, 32, 8)) layer = efficient_attention.SelfAttention(n_heads=5, d_qk=7, d_v=17, share_qk=False, causal=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, use_reference_code=True, attention_dropout=0.0, mode='train') final_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual((3, 32, 8), final_shape)