Ejemplo n.º 1
0
 def test_get_all_entities(self):
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])
     module = SimpleEmbedding(weight=embeddings)
     self.assertTensorEqual(
         module.get_all_entities(),
         torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]),
     )
Ejemplo n.º 2
0
 def test_sample_entities(self):
     torch.manual_seed(42)
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])
     module = SimpleEmbedding(weight=embeddings)
     self.assertTensorEqual(
         module.sample_entities(2, 2),
         torch.tensor(
             [[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]], [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]
         ),
     )
Ejemplo n.º 3
0
 def test_sample_entities_max_norm(self):
     torch.manual_seed(42)
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0],
                                [3.0, 3.0, 3.0]])
     module = SimpleEmbedding(weight=embeddings, max_norm=2)
     self.assertTensorEqual(
         module.sample_entities(2, 2),
         torch.tensor([
             [[1.0000, 1.0000, 1.0000], [1.1547, 1.1547, 1.1547]],
             [[1.1547, 1.1547, 1.1547], [1.1547, 1.1547, 1.1547]],
         ]),
     )
Ejemplo n.º 4
0
 def test_get_all_entities_max_norm(self):
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0],
                                [3.0, 3.0, 3.0]])
     module = SimpleEmbedding(weight=embeddings, max_norm=2)
     self.assertTensorEqual(
         module.get_all_entities(),
         torch.tensor([
             [1.0000, 1.0000, 1.0000],
             [1.1547, 1.1547, 1.1547],
             [1.1547, 1.1547, 1.1547],
         ]),
     )
Ejemplo n.º 5
0
 def test_empty(self):
     embeddings = torch.empty((0, 3))
     module = SimpleEmbedding(weight=embeddings)
     self.assertTensorEqual(
         module(EntityList.from_tensor(torch.empty((0, ),
                                                   dtype=torch.long))),
         torch.empty((0, 3)))
Ejemplo n.º 6
0
 def test_forward(self):
     embeddings = torch.tensor(
         [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], requires_grad=True
     )
     module = SimpleEmbedding(weight=embeddings)
     result = module(EntityList.from_tensor(torch.tensor([2, 0, 0])))
     self.assertTensorEqual(
         result, torch.tensor([[3.0, 3.0, 3.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
     )
     result.sum().backward()
     self.assertTrue((embeddings.grad.to_dense() != 0).any())
Ejemplo n.º 7
0
 def test_max_norm(self):
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0],
                                [3.0, 3.0, 3.0]])
     module = SimpleEmbedding(weight=embeddings, max_norm=2)
     self.assertTensorEqual(
         module(EntityList.from_tensor(torch.tensor([2, 0, 0]))),
         torch.tensor([
             [1.1547, 1.1547, 1.1547],
             [1.0000, 1.0000, 1.0000],
             [1.0000, 1.0000, 1.0000],
         ]),
     )