Example #1
0
 def test_self_attention_tf(self):
   with fastmath.use_backend(fastmath.Backend.TFNP):
     layer = efficient_attention.SelfAttention(
         n_heads=5, d_qk=7, d_v=17, share_qk=False, causal=True,
         chunk_len=8, n_chunks_before=1, n_chunks_after=0,
         use_reference_code=True, attention_dropout=0.0, mode='train')
     x = np.ones((3, 32, 8)).astype(np.float32)
     _, _ = layer.init(shapes.signature(x))
     y = layer(x)
     self.assertEqual(y.shape, x.shape)
 def test_self_attention(self):
     with math.use_backend('jax'):
         input_signature = ShapeDtype((3, 32, 8))
         layer = efficient_attention.SelfAttention(n_heads=5,
                                                   d_qk=7,
                                                   d_v=17,
                                                   share_qk=False,
                                                   causal=True,
                                                   chunk_len=8,
                                                   n_chunks_before=1,
                                                   n_chunks_after=0,
                                                   use_reference_code=True,
                                                   attention_dropout=0.0,
                                                   mode='train')
         final_shape = base.check_shape_agreement(layer, input_signature)
         self.assertEqual((3, 32, 8), final_shape)