コード例 #1
0
 def test_sample_entities(self):
     torch.manual_seed(42)
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])
     module = SimpleEmbedding(weight=embeddings)
     self.assertTensorEqual(
         module.sample_entities(2, 2),
         torch.tensor(
             [[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]], [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]
         ),
     )
コード例 #2
0
 def test_sample_entities_max_norm(self):
     torch.manual_seed(42)
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0],
                                [3.0, 3.0, 3.0]])
     module = SimpleEmbedding(weight=embeddings, max_norm=2)
     self.assertTensorEqual(
         module.sample_entities(2, 2),
         torch.tensor([
             [[1.0000, 1.0000, 1.0000], [1.1547, 1.1547, 1.1547]],
             [[1.1547, 1.1547, 1.1547], [1.1547, 1.1547, 1.1547]],
         ]),
     )