Esempio n. 1
0
    def test_odd_embed_dim(self):
        with self.assertRaises(NotImplementedError):
            SinusoidalPositionalEmbedding(num_positions=4,
                                          embedding_dim=5,
                                          padding_idx=0).to(torch_device)

        # odd num_positions is allowed
        SinusoidalPositionalEmbedding(num_positions=5,
                                      embedding_dim=4,
                                      padding_idx=0).to(torch_device)
Esempio n. 2
0
    def __init__(self, config: BartConfig, embed_tokens):
        super().__init__(config, embed_tokens)

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop
        self.visual = None

        embed_dim = embed_tokens.embedding_dim
        self.embed_scale = math.sqrt(
            embed_dim) if config.scale_embedding else 1.0
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = config.max_position_embeddings

        self.embed_tokens = embed_tokens
        if config.static_position_embeddings:
            self.embed_positions = SinusoidalPositionalEmbedding(
                config.max_position_embeddings, embed_dim, self.padding_idx)
        else:
            self.embed_positions = LearnedPositionalEmbedding(
                config.max_position_embeddings,
                embed_dim,
                self.padding_idx,
                config.extra_pos_embeddings,
            )
        self.layers = nn.ModuleList(
            [EncoderLayer(config) for _ in range(config.encoder_layers)])
        self.layernorm_embedding = LayerNorm(
            embed_dim) if config.normalize_embedding else nn.Identity()
        # mbart has one extra layer_norm
        self.layer_norm = LayerNorm(
            config.d_model) if config.normalize_before else None
Esempio n. 3
0
 def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
     super().__init__()
     self.output_attentions = config.output_attentions
     self.output_hidden_states = config.output_hidden_states
     self.dropout = config.dropout
     self.layerdrop = config.decoder_layerdrop
     self.padding_idx = embed_tokens.padding_idx
     self.max_target_positions = config.max_position_embeddings
     self.embed_scale = math.sqrt(
         config.d_model) if config.scale_embedding else 1.0
     self.embed_tokens = embed_tokens
     if config.static_position_embeddings:
         self.embed_positions = SinusoidalPositionalEmbedding(
             config.max_position_embeddings, config.d_model,
             config.pad_token_id)
     else:
         self.embed_positions = LearnedPositionalEmbedding(
             config.max_position_embeddings,
             config.d_model,
             self.padding_idx,
         )
     self.layers = nn.ModuleList([
         DecoderLayer(config) for _ in range(config.decoder_layers)
     ])  # type: List[DecoderLayer]
     self.layernorm_embedding = LayerNorm(
         config.d_model) if config.normalize_embedding else nn.Identity()
     self.layer_norm = LayerNorm(
         config.d_model) if config.add_final_layer_norm else None
Esempio n. 4
0
 def test_positional_emb_cache_logic(self):
     pad = 1
     input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device)
     emb1 = SinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6, padding_idx=pad).to(torch_device)
     no_cache = emb1(input_ids, use_cache=False)
     yes_cache = emb1(input_ids, use_cache=True)
     self.assertEqual((1, 1, 6), yes_cache.shape)  # extra dim to allow broadcasting, feel free to delete!
     self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist())
Esempio n. 5
0
    def test_positional_emb_weights_against_marian(self):
        pad = 1
        emb1 = SinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=pad).to(torch_device)
        weights = emb1.weight.data[:3, :5].tolist()
        for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)):
            for j in range(5):
                self.assertAlmostEqual(expected_weight[j], actual_weight[j], places=3)

        # test that forward pass is just a lookup, there is no ignore padding logic
        input_ids = torch.tensor([[4, 10, pad, pad, pad]], dtype=torch.long, device=torch_device)
        no_cache_pad_zero = emb1(input_ids)
        self.assertTrue(
            torch.allclose(
                torch.tensor(self.desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3
            )
        )