Exemplo n.º 1
0
    def test_odd_embed_dim(self):
        # odd embedding_dim  is allowed
        SinusoidalPositionalEmbedding(
            num_positions=4, embedding_dim=5,
            padding_idx=self.padding_idx).to(torch_device)

        # odd num_embeddings is allowed
        SinusoidalPositionalEmbedding(
            num_positions=5, embedding_dim=4,
            padding_idx=self.padding_idx).to(torch_device)
Exemplo n.º 2
0
    def test_positional_emb_weights_against_marian(self):

        desired_weights = torch.tensor([
            [0, 0, 0, 0, 0],
            [0.84147096, 0.82177866, 0.80180490, 0.78165019, 0.76140374],
            [0.90929741, 0.93651021, 0.95829457, 0.97505713, 0.98720258],
        ])
        emb1 = SinusoidalPositionalEmbedding(
            num_positions=512, embedding_dim=512,
            padding_idx=self.padding_idx).to(torch_device)
        weights = emb1.weights.data[:3, :5]
        # XXX: only the 1st and 3rd lines match - this is testing against
        # verbatim copy of SinusoidalPositionalEmbedding from fairseq
        self.assertTrue(
            torch.allclose(weights, desired_weights, atol=self.tolerance),
            msg=f"\nexp:\n{desired_weights}\ngot:\n{weights}\n",
        )

        # test that forward pass is just a lookup, there is no ignore padding logic
        input_ids = torch.tensor(
            [[4, 10, self.padding_idx, self.padding_idx, self.padding_idx]],
            dtype=torch.long,
            device=torch_device)
        no_cache_pad_zero = emb1(input_ids)[0]
        # XXX: only the 1st line matches the 3rd
        self.assertTrue(
            torch.allclose(torch.tensor(desired_weights, device=torch_device),
                           no_cache_pad_zero[:3, :5],
                           atol=1e-3))
Exemplo n.º 3
0
 def test_basic(self):
     input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device)
     emb1 = SinusoidalPositionalEmbedding(num_positions=6, embedding_dim=6, padding_idx=self.padding_idx).to(
         torch_device
     )
     emb = emb1(input_ids)
     desired_weights = torch.tensor(
         [
             [9.0930e-01, 1.9999e-02, 2.0000e-04, -4.1615e-01, 9.9980e-01, 1.0000e00],
             [1.4112e-01, 2.9995e-02, 3.0000e-04, -9.8999e-01, 9.9955e-01, 1.0000e00],
         ]
     ).to(torch_device)
     self.assertTrue(
         torch.allclose(emb[0], desired_weights, atol=self.tolerance),
         msg=f"\nexp:\n{desired_weights}\ngot:\n{emb[0]}\n",
     )