def test_cat_scalar_different(self): self.assertEqual( EdgeList.cat([ EdgeList( EntityList.from_tensor( torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2], dtype=torch.long)), torch.tensor(0, dtype=torch.long), ), EdgeList( EntityList.from_tensor( torch.tensor([1, 0], dtype=torch.long)), EntityList.from_tensor( torch.tensor([1, 3], dtype=torch.long)), torch.tensor(1, dtype=torch.long), ), ]), EdgeList( EntityList.from_tensor( torch.tensor([3, 4, 1, 0], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2, 1, 3], dtype=torch.long)), torch.tensor([0, 0, 1, 1], dtype=torch.long), ), )
def test_cat_vector(self): self.assertEqual( EdgeList.cat([ EdgeList( EntityList.from_tensor( torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2], dtype=torch.long)), torch.tensor([2, 1], dtype=torch.long), ), EdgeList( EntityList.from_tensor( torch.tensor([1, 0], dtype=torch.long)), EntityList.from_tensor( torch.tensor([1, 3], dtype=torch.long)), torch.tensor([3, 0], dtype=torch.long), ), ]), EdgeList( EntityList.from_tensor( torch.tensor([3, 4, 1, 0], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2, 1, 3], dtype=torch.long)), torch.tensor([2, 1, 3, 0], dtype=torch.long), ), )
def test_basic(self): edges = EdgeList( EntityList.from_tensor( torch.tensor([93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long)), EntityList.from_tensor( torch.tensor([90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long)), torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long), ) edges_by_type = defaultdict(list) for batch_edges in batch_edges_group_by_relation_type(edges, batch_size=3): self.assertIsInstance(batch_edges, EdgeList) self.assertLessEqual(len(batch_edges), 3) self.assertTrue(batch_edges.has_scalar_relation_type()) edges_by_type[batch_edges.get_relation_type_as_scalar()].append( batch_edges) self.assertEqual( {k: EdgeList.cat(v) for k, v in edges_by_type.items()}, { 0: EdgeList( EntityList.from_tensor( torch.tensor([24, 13, 77, 38], dtype=torch.long)), EntityList.from_tensor( torch.tensor([75, 9, 49, 64], dtype=torch.long)), torch.tensor(0, dtype=torch.long), ), 1: EdgeList( EntityList.from_tensor( torch.tensor([93, 31], dtype=torch.long)), EntityList.from_tensor( torch.tensor([90, 25], dtype=torch.long)), torch.tensor(1, dtype=torch.long), ), 2: EdgeList( EntityList.from_tensor( torch.tensor([70, 66, 5, 5], dtype=torch.long)), EntityList.from_tensor( torch.tensor([23, 31, 42, 50], dtype=torch.long)), torch.tensor(2, dtype=torch.long), ), }, )