def test_shapes_single(self, seq_len, embed_size, d_key, num_heads, d_value, d_out): query = key = value = jnp.zeros((seq_len, embed_size)) mha = attention.MultiHeadAttention(key_size=d_key, num_heads=num_heads, value_size=d_value, model_size=d_out, w_init_scale=1.0)(query, key, value) self.assertEqual(mha.shape, (seq_len, d_out))
def test_different_seq_lengths(self): query = jnp.zeros((2, 3)) key = value = jnp.zeros((5, 3)) mha = attention.MultiHeadAttention(key_size=7, num_heads=11, value_size=13, model_size=15, w_init_scale=1.0)(query, key, value) self.assertEqual(mha.shape, (2, 15))
def test_mask_arg(self): seq_len = 3 embed_size = 2 model_size = 15 query = key = value = jnp.zeros((seq_len, embed_size)) causal_mask = jnp.tril(jnp.ones((seq_len, seq_len))) causal_mask = causal_mask[None, :, :] mha = attention.MultiHeadAttention(key_size=7, num_heads=11, value_size=13, model_size=model_size, w_init_scale=1.0)(query, key, value, mask=causal_mask) self.assertEqual(mha.shape, (seq_len, model_size))
def f(query, key, value): return attention.MultiHeadAttention(key_size=3, num_heads=5, w_init_scale=1.0)(query, key, value)
def test_default_sizes(self): mha = attention.MultiHeadAttention(key_size=3, num_heads=5, w_init_scale=1.0) self.assertEqual(mha.value_size, mha.key_size) self.assertEqual(mha.model_size, mha.key_size * mha.num_heads)
def test_shapes(self, batch_size, seq_len, embed_size, d_key, num_heads): query = key = value = np.zeros((batch_size, seq_len, embed_size)) mha = attention.MultiHeadAttention(d_key, num_heads, 1.0)(query, key, value) self.assertEqual(mha.shape, (batch_size, seq_len, d_key * num_heads))