def cat(cls, edge_lists: Sequence["EdgeList"]) -> "EdgeList":
        cat_lhs = EntityList.cat([el.lhs for el in edge_lists])
        cat_rhs = EntityList.cat([el.rhs for el in edge_lists])

        if any(el.has_weight() for el in edge_lists):
            if not all(el.has_weight() for el in edge_lists):
                raise RuntimeError(
                    "Can't concatenate edgelists with and without weight field."
                )
            cat_weight = torch.cat(
                [el.weight.expand((len(el), )) for el in edge_lists])
        else:
            cat_weight = None

        if all(el.has_scalar_relation_type() for el in edge_lists):
            rel_types = {el.get_relation_type_as_scalar() for el in edge_lists}
            if len(rel_types) == 1:
                (rel_type, ) = rel_types
                return cls(
                    cat_lhs,
                    cat_rhs,
                    torch.tensor(rel_type, dtype=torch.long),
                    cat_weight,
                )
        cat_rel = torch.cat([el.rel.expand((len(el), )) for el in edge_lists])

        return cls(cat_lhs, cat_rhs, cat_rel, cat_weight)
Beispiel #2
0
 def cat(cls, edge_lists: Sequence["EdgeList"]) -> "EdgeList":
     cat_lhs = EntityList.cat([el.lhs for el in edge_lists])
     cat_rhs = EntityList.cat([el.rhs for el in edge_lists])
     if all(el.has_scalar_relation_type() for el in edge_lists):
         rel_types = {el.get_relation_type_as_scalar() for el in edge_lists}
         if len(rel_types) == 1:
             (rel_type,) = rel_types
             return cls(cat_lhs, cat_rhs, torch.tensor(rel_type, dtype=torch.long))
     cat_rel = torch.cat([el.rel.expand((len(el),)) for el in edge_lists])
     return EdgeList(cat_lhs, cat_rhs, cat_rel)
 def test_cat(self):
     tensor_1 = torch.tensor([2, 3], dtype=torch.long)
     tensor_2 = torch.tensor([0, 1], dtype=torch.long)
     tensor_sum = torch.tensor([2, 3, 0, 1], dtype=torch.long)
     tensor_list_1 = tensor_list_from_lists([[3, 4], [0]])
     tensor_list_2 = tensor_list_from_lists([[1, 2, 0], []])
     tensor_list_sum = tensor_list_from_lists([[3, 4], [0], [1, 2, 0], []])
     self.assertEqual(
         EntityList.cat([
             EntityList(tensor_1, tensor_list_1),
             EntityList(tensor_2, tensor_list_2),
         ]),
         EntityList(tensor_sum, tensor_list_sum),
     )
 def test_basic(self):
     lhs = EntityList.from_tensor(
         torch.tensor([93, 24, 13, 31, 70, 66, 77, 38, 5, 5],
                      dtype=torch.long))
     rhs = EntityList.from_tensor(
         torch.tensor([90, 75, 9, 25, 23, 31, 49, 64, 42, 50],
                      dtype=torch.long))
     rel = torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long)
     lhs_by_type = defaultdict(list)
     rhs_by_type = defaultdict(list)
     for batch_lhs, batch_rhs, rel_type in batch_edges_group_by_relation_type(
             lhs, rhs, rel, batch_size=3):
         self.assertIsInstance(batch_lhs, EntityList)
         self.assertLessEqual(batch_lhs.size(0), 3)
         lhs_by_type[rel_type].append(batch_lhs)
         self.assertIsInstance(batch_rhs, EntityList)
         self.assertLessEqual(batch_rhs.size(0), 3)
         rhs_by_type[rel_type].append(batch_rhs)
     self.assertCountEqual(lhs_by_type.keys(), [0, 1, 2])
     self.assertCountEqual(rhs_by_type.keys(), [0, 1, 2])
     self.assertEqual(
         EntityList.cat(lhs_by_type[0]),
         EntityList.from_tensor(
             torch.tensor([24, 13, 77, 38], dtype=torch.long)))
     self.assertEqual(
         EntityList.cat(rhs_by_type[0]),
         EntityList.from_tensor(
             torch.tensor([75, 9, 49, 64], dtype=torch.long)))
     self.assertEqual(
         EntityList.cat(lhs_by_type[1]),
         EntityList.from_tensor(torch.tensor([93, 31], dtype=torch.long)))
     self.assertEqual(
         EntityList.cat(rhs_by_type[1]),
         EntityList.from_tensor(torch.tensor([90, 25], dtype=torch.long)))
     self.assertEqual(
         EntityList.cat(lhs_by_type[2]),
         EntityList.from_tensor(
             torch.tensor([70, 66, 5, 5], dtype=torch.long)))
     self.assertEqual(
         EntityList.cat(rhs_by_type[2]),
         EntityList.from_tensor(
             torch.tensor([23, 31, 42, 50], dtype=torch.long)))