示例#1
0
 def test_get_all_entities(self):
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])
     module = SimpleEmbedding(weight=embeddings)
     self.assertTensorEqual(
         module.get_all_entities(),
         torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]),
     )
示例#2
0
 def test_get_all_entities_max_norm(self):
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0],
                                [3.0, 3.0, 3.0]])
     module = SimpleEmbedding(weight=embeddings, max_norm=2)
     self.assertTensorEqual(
         module.get_all_entities(),
         torch.tensor([
             [1.0000, 1.0000, 1.0000],
             [1.1547, 1.1547, 1.1547],
             [1.1547, 1.1547, 1.1547],
         ]),
     )