def test_unbatch_nested(self): dims = [2, 3] G_sizes = [10, 5] G_list = [] for i, size in enumerate(G_sizes): G = Graph() G.G = nx.complete_graph(i + 1) G.node_property = { "node_prop0": torch.ones(size, dims[0]) * i, "node_prop1": torch.ones(size, dims[1]) * i, } G_list.append(G) batch = Batch.from_data_list(G_list) # reconstruct graph list G_list_recon = batch.to_data_list() self.assertEqual( G_list_recon[0].node_property["node_prop0"].size(0), 10, ) self.assertEqual( G_list_recon[0].node_property["node_prop0"].size(1), 2, ) self.assertEqual( G_list_recon[1].node_property["node_prop1"].size(0), 5, ) self.assertEqual( G_list_recon[1].node_property["node_prop1"].size(1), 3, )
def test_collate_batch_nested(self): dims = [2, 3] G_sizes = [10, 5] G_list = [] for i, size in enumerate(G_sizes): G = Graph() G.G = nx.complete_graph(i + 1) G.node_property = { 'node_prop0': torch.ones(size, dims[0]) * i, 'node_prop1': torch.ones(size, dims[1]) * i } G_list.append(G) batch = Batch.from_data_list(G_list) self.assertEqual(batch.num_graphs, 2) self.assertEqual(batch.node_property['node_prop0'].size(0), sum(G_sizes))