Esempio n. 1
0
 def test_shapes_single(self, seq_len, embed_size, d_key, num_heads,
                        d_value, d_out):
     query = key = value = jnp.zeros((seq_len, embed_size))
     mha = attention.MultiHeadAttention(key_size=d_key,
                                        num_heads=num_heads,
                                        value_size=d_value,
                                        model_size=d_out,
                                        w_init_scale=1.0)(query, key, value)
     self.assertEqual(mha.shape, (seq_len, d_out))
Esempio n. 2
0
 def test_different_seq_lengths(self):
     query = jnp.zeros((2, 3))
     key = value = jnp.zeros((5, 3))
     mha = attention.MultiHeadAttention(key_size=7,
                                        num_heads=11,
                                        value_size=13,
                                        model_size=15,
                                        w_init_scale=1.0)(query, key, value)
     self.assertEqual(mha.shape, (2, 15))
Esempio n. 3
0
    def test_mask_arg(self):
        seq_len = 3
        embed_size = 2
        model_size = 15
        query = key = value = jnp.zeros((seq_len, embed_size))
        causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))
        causal_mask = causal_mask[None, :, :]

        mha = attention.MultiHeadAttention(key_size=7,
                                           num_heads=11,
                                           value_size=13,
                                           model_size=model_size,
                                           w_init_scale=1.0)(query,
                                                             key,
                                                             value,
                                                             mask=causal_mask)
        self.assertEqual(mha.shape, (seq_len, model_size))
Esempio n. 4
0
 def f(query, key, value):
     return attention.MultiHeadAttention(key_size=3,
                                         num_heads=5,
                                         w_init_scale=1.0)(query, key,
                                                           value)
Esempio n. 5
0
 def test_default_sizes(self):
     mha = attention.MultiHeadAttention(key_size=3,
                                        num_heads=5,
                                        w_init_scale=1.0)
     self.assertEqual(mha.value_size, mha.key_size)
     self.assertEqual(mha.model_size, mha.key_size * mha.num_heads)
Esempio n. 6
0
 def test_shapes(self, batch_size, seq_len, embed_size, d_key, num_heads):
     query = key = value = np.zeros((batch_size, seq_len, embed_size))
     mha = attention.MultiHeadAttention(d_key, num_heads, 1.0)(query, key,
                                                               value)
     self.assertEqual(mha.shape, (batch_size, seq_len, d_key * num_heads))