def test_get_param_groups_for_optimizer(self): word_embedding = WordEmbedding( num_embeddings=5, embedding_dim=4, embeddings_weight=None, init_range=(-1, 1), unk_token_idx=4, pad_token_idx=3, mlp_layer_dims=[], ) char_embedding = CharacterEmbedding(num_embeddings=5, embed_dim=4, out_channels=2, kernel_sizes=[1, 2]) embedding_list = EmbeddingList([word_embedding, char_embedding], concat=True) param_groups = embedding_list.get_param_groups_for_optimizer() self.assertEqual(len(param_groups), 1) param_names = set(param_groups[0].keys()) expected_param_names = { name for name, _ in embedding_list.named_parameters() } self.assertSetEqual(param_names, expected_param_names)
def test_empty_mlp_layer_dims(self): num_embeddings = 5 embedding_dim = 4 embedding_module = WordEmbedding( num_embeddings=num_embeddings, embedding_dim=embedding_dim, embeddings_weight=None, init_range=[-1, 1], unk_token_idx=4, mlp_layer_dims=[], ) self.assertEqual(embedding_module.embedding_dim, embedding_dim)
def test_basic(self): # Setup embedding num_embeddings = 5 output_dim = 6 embedding_module = WordEmbedding( num_embeddings=num_embeddings, embedding_dim=4, embeddings_weight=None, init_range=[-1, 1], unk_token_idx=4, mlp_layer_dims=[3, output_dim], ) self.assertEqual(embedding_module.embedding_dim, output_dim) # Check output shape input_batch_size, input_len = 4, 6 token_ids = torch.randint( low=0, high=num_embeddings, size=[input_batch_size, input_len] ) output_embedding = embedding_module(token_ids) expected_output_dims = [input_batch_size, input_len, output_dim] self.assertEqual(list(output_embedding.size()), expected_output_dims)