def test_forward(self): embeddings = torch.tensor( [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], requires_grad=True ) module = FeaturizedEmbedding(weight=embeddings) result = module( EntityList.from_tensor_list( TensorList( torch.tensor([0, 1, 3, 6, 6]), torch.tensor([0, 2, 1, 0, 1, 0]) ) ) ) self.assertTensorEqual( result, torch.tensor( [ [1.0000, 1.0000, 1.0000], [2.5000, 2.5000, 2.5000], [1.3333, 1.3333, 1.3333], [0.0000, 0.0000, 0.0000], ] ), ) result.sum().backward() self.assertTrue((embeddings.grad.to_dense() != 0).any())
def test_empty(self): embeddings = torch.empty((0, 3)) module = FeaturizedEmbedding(weight=embeddings) self.assertTensorEqual( module( EntityList.from_tensor_list( TensorList(torch.zeros((1, ), dtype=torch.long), torch.empty((0, ), dtype=torch.long)))), torch.empty((0, 3)))
def test_max_norm(self): embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = FeaturizedEmbedding(weight=embeddings, max_norm=2) self.assertTensorEqual( module( EntityList.from_tensor_list( TensorList(torch.tensor([0, 1, 3, 6, 6]), torch.tensor([0, 2, 1, 0, 1, 0])))), torch.tensor([ [1.0000, 1.0000, 1.0000], [1.1547, 1.1547, 1.1547], [1.0516, 1.0516, 1.0516], [0.0000, 0.0000, 0.0000], ]), )
def test_from_tensor_list(self): tensor_list = tensor_list_from_lists([[3, 4], [0, 2]]) self.assertEqual( EntityList.from_tensor_list(tensor_list), EntityList(torch.full((2, ), -1, dtype=torch.long), tensor_list), )