def test_get_all_entities(self): embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = SimpleEmbedding(weight=embeddings) self.assertTensorEqual( module.get_all_entities(), torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]), )
def test_get_all_entities_max_norm(self): embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = SimpleEmbedding(weight=embeddings, max_norm=2) self.assertTensorEqual( module.get_all_entities(), torch.tensor([ [1.0000, 1.0000, 1.0000], [1.1547, 1.1547, 1.1547], [1.1547, 1.1547, 1.1547], ]), )