예제 #1
0
    def test_get_param_groups_for_optimizer(self):
        word_embedding = WordEmbedding(
            num_embeddings=5,
            embedding_dim=4,
            embeddings_weight=None,
            init_range=(-1, 1),
            unk_token_idx=4,
            pad_token_idx=3,
            mlp_layer_dims=[],
        )
        char_embedding = CharacterEmbedding(num_embeddings=5,
                                            embed_dim=4,
                                            out_channels=2,
                                            kernel_sizes=[1, 2])
        embedding_list = EmbeddingList([word_embedding, char_embedding],
                                       concat=True)
        param_groups = embedding_list.get_param_groups_for_optimizer()
        self.assertEqual(len(param_groups), 1)

        param_names = set(param_groups[0].keys())
        expected_param_names = {
            name
            for name, _ in embedding_list.named_parameters()
        }
        self.assertSetEqual(param_names, expected_param_names)
예제 #2
0
 def test_empty_mlp_layer_dims(self):
     num_embeddings = 5
     embedding_dim = 4
     embedding_module = WordEmbedding(
         num_embeddings=num_embeddings,
         embedding_dim=embedding_dim,
         embeddings_weight=None,
         init_range=[-1, 1],
         unk_token_idx=4,
         mlp_layer_dims=[],
     )
     self.assertEqual(embedding_module.embedding_dim, embedding_dim)
예제 #3
0
    def test_basic(self):
        # Setup embedding
        num_embeddings = 5
        output_dim = 6
        embedding_module = WordEmbedding(
            num_embeddings=num_embeddings,
            embedding_dim=4,
            embeddings_weight=None,
            init_range=[-1, 1],
            unk_token_idx=4,
            mlp_layer_dims=[3, output_dim],
        )
        self.assertEqual(embedding_module.embedding_dim, output_dim)

        # Check output shape
        input_batch_size, input_len = 4, 6
        token_ids = torch.randint(
            low=0, high=num_embeddings, size=[input_batch_size, input_len]
        )
        output_embedding = embedding_module(token_ids)
        expected_output_dims = [input_batch_size, input_len, output_dim]
        self.assertEqual(list(output_embedding.size()), expected_output_dims)