Exemple #1
0
 def test_sample_entities(self):
     torch.manual_seed(42)
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0],
                                [3.0, 3.0, 3.0]])
     module = FeaturizedEmbedding(weight=embeddings)
     with self.assertRaises(NotImplementedError):
         module.sample_entities(2, 2)
 def test_get_all_entities(self):
     embeddings = torch.tensor([
         [1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.],
     ])
     module = FeaturizedEmbedding(weight=embeddings)
     with self.assertRaises(NotImplementedError):
         module.get_all_entities()
Exemple #3
0
 def test_forward(self):
     embeddings = torch.tensor(
         [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], requires_grad=True
     )
     module = FeaturizedEmbedding(weight=embeddings)
     result = module(
         EntityList.from_tensor_list(
             TensorList(
                 torch.tensor([0, 1, 3, 6, 6]), torch.tensor([0, 2, 1, 0, 1, 0])
             )
         )
     )
     self.assertTensorEqual(
         result,
         torch.tensor(
             [
                 [1.0000, 1.0000, 1.0000],
                 [2.5000, 2.5000, 2.5000],
                 [1.3333, 1.3333, 1.3333],
                 [0.0000, 0.0000, 0.0000],
             ]
         ),
     )
     result.sum().backward()
     self.assertTrue((embeddings.grad.to_dense() != 0).any())
 def test_empty(self):
     embeddings = torch.empty((0, 3))
     module = FeaturizedEmbedding(weight=embeddings)
     self.assertTensorEqual(
         module(
             EntityList.from_tensor_list(
                 TensorList(torch.zeros((1, ), dtype=torch.long),
                            torch.empty((0, ), dtype=torch.long)))),
         torch.empty((0, 3)))
Exemple #5
0
 def test_max_norm(self):
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0],
                                [3.0, 3.0, 3.0]])
     module = FeaturizedEmbedding(weight=embeddings, max_norm=2)
     self.assertTensorEqual(
         module(
             EntityList.from_tensor_list(
                 TensorList(torch.tensor([0, 1, 3, 6, 6]),
                            torch.tensor([0, 2, 1, 0, 1, 0])))),
         torch.tensor([
             [1.0000, 1.0000, 1.0000],
             [1.1547, 1.1547, 1.1547],
             [1.0516, 1.0516, 1.0516],
             [0.0000, 0.0000, 0.0000],
         ]),
     )