def test_sample_entities(self): torch.manual_seed(42) embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = FeaturizedEmbedding(weight=embeddings) with self.assertRaises(NotImplementedError): module.sample_entities(2, 2)
def test_get_all_entities(self): embeddings = torch.tensor([ [1., 1., 1.], [2., 2., 2.], [3., 3., 3.], ]) module = FeaturizedEmbedding(weight=embeddings) with self.assertRaises(NotImplementedError): module.get_all_entities()
def test_forward(self): embeddings = torch.tensor( [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], requires_grad=True ) module = FeaturizedEmbedding(weight=embeddings) result = module( EntityList.from_tensor_list( TensorList( torch.tensor([0, 1, 3, 6, 6]), torch.tensor([0, 2, 1, 0, 1, 0]) ) ) ) self.assertTensorEqual( result, torch.tensor( [ [1.0000, 1.0000, 1.0000], [2.5000, 2.5000, 2.5000], [1.3333, 1.3333, 1.3333], [0.0000, 0.0000, 0.0000], ] ), ) result.sum().backward() self.assertTrue((embeddings.grad.to_dense() != 0).any())
def test_empty(self): embeddings = torch.empty((0, 3)) module = FeaturizedEmbedding(weight=embeddings) self.assertTensorEqual( module( EntityList.from_tensor_list( TensorList(torch.zeros((1, ), dtype=torch.long), torch.empty((0, ), dtype=torch.long)))), torch.empty((0, 3)))
def test_max_norm(self): embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = FeaturizedEmbedding(weight=embeddings, max_norm=2) self.assertTensorEqual( module( EntityList.from_tensor_list( TensorList(torch.tensor([0, 1, 3, 6, 6]), torch.tensor([0, 2, 1, 0, 1, 0])))), torch.tensor([ [1.0000, 1.0000, 1.0000], [1.1547, 1.1547, 1.1547], [1.0516, 1.0516, 1.0516], [0.0000, 0.0000, 0.0000], ]), )