def test_get_all_entities(self): embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = SimpleEmbedding(weight=embeddings) self.assertTensorEqual( module.get_all_entities(), torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]), )
def test_sample_entities(self): torch.manual_seed(42) embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = SimpleEmbedding(weight=embeddings) self.assertTensorEqual( module.sample_entities(2, 2), torch.tensor( [[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]], [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]] ), )
def test_sample_entities_max_norm(self): torch.manual_seed(42) embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = SimpleEmbedding(weight=embeddings, max_norm=2) self.assertTensorEqual( module.sample_entities(2, 2), torch.tensor([ [[1.0000, 1.0000, 1.0000], [1.1547, 1.1547, 1.1547]], [[1.1547, 1.1547, 1.1547], [1.1547, 1.1547, 1.1547]], ]), )
def test_get_all_entities_max_norm(self): embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = SimpleEmbedding(weight=embeddings, max_norm=2) self.assertTensorEqual( module.get_all_entities(), torch.tensor([ [1.0000, 1.0000, 1.0000], [1.1547, 1.1547, 1.1547], [1.1547, 1.1547, 1.1547], ]), )
def test_empty(self): embeddings = torch.empty((0, 3)) module = SimpleEmbedding(weight=embeddings) self.assertTensorEqual( module(EntityList.from_tensor(torch.empty((0, ), dtype=torch.long))), torch.empty((0, 3)))
def test_forward(self): embeddings = torch.tensor( [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], requires_grad=True ) module = SimpleEmbedding(weight=embeddings) result = module(EntityList.from_tensor(torch.tensor([2, 0, 0]))) self.assertTensorEqual( result, torch.tensor([[3.0, 3.0, 3.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) ) result.sum().backward() self.assertTrue((embeddings.grad.to_dense() != 0).any())
def test_max_norm(self): embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = SimpleEmbedding(weight=embeddings, max_norm=2) self.assertTensorEqual( module(EntityList.from_tensor(torch.tensor([2, 0, 0]))), torch.tensor([ [1.1547, 1.1547, 1.1547], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], ]), )