def cat(cls, edge_lists: Sequence["EdgeList"]) -> "EdgeList": cat_lhs = EntityList.cat([el.lhs for el in edge_lists]) cat_rhs = EntityList.cat([el.rhs for el in edge_lists]) if any(el.has_weight() for el in edge_lists): if not all(el.has_weight() for el in edge_lists): raise RuntimeError( "Can't concatenate edgelists with and without weight field." ) cat_weight = torch.cat( [el.weight.expand((len(el), )) for el in edge_lists]) else: cat_weight = None if all(el.has_scalar_relation_type() for el in edge_lists): rel_types = {el.get_relation_type_as_scalar() for el in edge_lists} if len(rel_types) == 1: (rel_type, ) = rel_types return cls( cat_lhs, cat_rhs, torch.tensor(rel_type, dtype=torch.long), cat_weight, ) cat_rel = torch.cat([el.rel.expand((len(el), )) for el in edge_lists]) return cls(cat_lhs, cat_rhs, cat_rel, cat_weight)
def cat(cls, edge_lists: Sequence["EdgeList"]) -> "EdgeList": cat_lhs = EntityList.cat([el.lhs for el in edge_lists]) cat_rhs = EntityList.cat([el.rhs for el in edge_lists]) if all(el.has_scalar_relation_type() for el in edge_lists): rel_types = {el.get_relation_type_as_scalar() for el in edge_lists} if len(rel_types) == 1: (rel_type,) = rel_types return cls(cat_lhs, cat_rhs, torch.tensor(rel_type, dtype=torch.long)) cat_rel = torch.cat([el.rel.expand((len(el),)) for el in edge_lists]) return EdgeList(cat_lhs, cat_rhs, cat_rel)
def test_cat(self): tensor_1 = torch.tensor([2, 3], dtype=torch.long) tensor_2 = torch.tensor([0, 1], dtype=torch.long) tensor_sum = torch.tensor([2, 3, 0, 1], dtype=torch.long) tensor_list_1 = tensor_list_from_lists([[3, 4], [0]]) tensor_list_2 = tensor_list_from_lists([[1, 2, 0], []]) tensor_list_sum = tensor_list_from_lists([[3, 4], [0], [1, 2, 0], []]) self.assertEqual( EntityList.cat([ EntityList(tensor_1, tensor_list_1), EntityList(tensor_2, tensor_list_2), ]), EntityList(tensor_sum, tensor_list_sum), )
def test_basic(self): lhs = EntityList.from_tensor( torch.tensor([93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long)) rhs = EntityList.from_tensor( torch.tensor([90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long)) rel = torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long) lhs_by_type = defaultdict(list) rhs_by_type = defaultdict(list) for batch_lhs, batch_rhs, rel_type in batch_edges_group_by_relation_type( lhs, rhs, rel, batch_size=3): self.assertIsInstance(batch_lhs, EntityList) self.assertLessEqual(batch_lhs.size(0), 3) lhs_by_type[rel_type].append(batch_lhs) self.assertIsInstance(batch_rhs, EntityList) self.assertLessEqual(batch_rhs.size(0), 3) rhs_by_type[rel_type].append(batch_rhs) self.assertCountEqual(lhs_by_type.keys(), [0, 1, 2]) self.assertCountEqual(rhs_by_type.keys(), [0, 1, 2]) self.assertEqual( EntityList.cat(lhs_by_type[0]), EntityList.from_tensor( torch.tensor([24, 13, 77, 38], dtype=torch.long))) self.assertEqual( EntityList.cat(rhs_by_type[0]), EntityList.from_tensor( torch.tensor([75, 9, 49, 64], dtype=torch.long))) self.assertEqual( EntityList.cat(lhs_by_type[1]), EntityList.from_tensor(torch.tensor([93, 31], dtype=torch.long))) self.assertEqual( EntityList.cat(rhs_by_type[1]), EntityList.from_tensor(torch.tensor([90, 25], dtype=torch.long))) self.assertEqual( EntityList.cat(lhs_by_type[2]), EntityList.from_tensor( torch.tensor([70, 66, 5, 5], dtype=torch.long))) self.assertEqual( EntityList.cat(rhs_by_type[2]), EntityList.from_tensor( torch.tensor([23, 31, 42, 50], dtype=torch.long)))