示例#1
0
 def test_cat_scalar_different(self):
     self.assertEqual(
         EdgeList.cat([
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([3, 4], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([0, 2], dtype=torch.long)),
                 torch.tensor(0, dtype=torch.long),
             ),
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([1, 0], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([1, 3], dtype=torch.long)),
                 torch.tensor(1, dtype=torch.long),
             ),
         ]),
         EdgeList(
             EntityList.from_tensor(
                 torch.tensor([3, 4, 1, 0], dtype=torch.long)),
             EntityList.from_tensor(
                 torch.tensor([0, 2, 1, 3], dtype=torch.long)),
             torch.tensor([0, 0, 1, 1], dtype=torch.long),
         ),
     )
示例#2
0
 def test_cat_vector(self):
     self.assertEqual(
         EdgeList.cat([
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([3, 4], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([0, 2], dtype=torch.long)),
                 torch.tensor([2, 1], dtype=torch.long),
             ),
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([1, 0], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([1, 3], dtype=torch.long)),
                 torch.tensor([3, 0], dtype=torch.long),
             ),
         ]),
         EdgeList(
             EntityList.from_tensor(
                 torch.tensor([3, 4, 1, 0], dtype=torch.long)),
             EntityList.from_tensor(
                 torch.tensor([0, 2, 1, 3], dtype=torch.long)),
             torch.tensor([2, 1, 3, 0], dtype=torch.long),
         ),
     )
示例#3
0
 def test_basic(self):
     edges = EdgeList(
         EntityList.from_tensor(
             torch.tensor([93, 24, 13, 31, 70, 66, 77, 38, 5, 5],
                          dtype=torch.long)),
         EntityList.from_tensor(
             torch.tensor([90, 75, 9, 25, 23, 31, 49, 64, 42, 50],
                          dtype=torch.long)),
         torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long),
     )
     edges_by_type = defaultdict(list)
     for batch_edges in batch_edges_group_by_relation_type(edges,
                                                           batch_size=3):
         self.assertIsInstance(batch_edges, EdgeList)
         self.assertLessEqual(len(batch_edges), 3)
         self.assertTrue(batch_edges.has_scalar_relation_type())
         edges_by_type[batch_edges.get_relation_type_as_scalar()].append(
             batch_edges)
     self.assertEqual(
         {k: EdgeList.cat(v)
          for k, v in edges_by_type.items()},
         {
             0:
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([24, 13, 77, 38], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([75, 9, 49, 64], dtype=torch.long)),
                 torch.tensor(0, dtype=torch.long),
             ),
             1:
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([93, 31], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([90, 25], dtype=torch.long)),
                 torch.tensor(1, dtype=torch.long),
             ),
             2:
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([70, 66, 5, 5], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([23, 31, 42, 50], dtype=torch.long)),
                 torch.tensor(2, dtype=torch.long),
             ),
         },
     )