def test_hetero_graph_basics(self): G = generate_simple_hete_graph() hete = HeteroGraph(G) hete = HeteroGraph(node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True) self.assertEqual(hete.num_node_features('n1'), 10) self.assertEqual(hete.num_node_features('n2'), 12) self.assertEqual(hete.num_edge_features(('n1', 'e1', 'n1')), 8) self.assertEqual(hete.num_edge_features(('n1', 'e2', 'n2')), 12) self.assertEqual(hete.num_nodes('n1'), 4) self.assertEqual(hete.num_nodes('n2'), 5) self.assertEqual(len(hete.node_types), 2) self.assertEqual(len(hete.edge_types), 2) message_types = hete.message_types self.assertEqual(len(message_types), 7) self.assertEqual(hete.num_node_labels('n1'), 2) self.assertEqual(hete.num_node_labels('n2'), 2) self.assertEqual(hete.num_edge_labels(('n1', 'e1', 'n1')), 2) self.assertEqual(hete.num_edge_labels(('n1', 'e2', 'n2')), 2) self.assertEqual(hete.num_edges(message_types[0]), 3) self.assertEqual(len(hete.node_label_index), 2)
def test_hetero_graph_batch(self): G = generate_simple_hete_graph() hete = HeteroGraph(G) hete = HeteroGraph( node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True ) heteGraphDataset = [] for _ in range(30): heteGraphDataset.append(hete.clone()) dataloader = DataLoader( heteGraphDataset, collate_fn=Batch.collate(), batch_size=3, shuffle=True, ) self.assertEqual(len(dataloader), math.ceil(30 / 3)) for data in dataloader: self.assertEqual(data.num_graphs, 3)
def test_hetero_graph_none(self): G = generate_simple_hete_graph(add_edge_type=False) hete = HeteroGraph(G) hete = HeteroGraph(node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True) message_types = hete.message_types for message_type in message_types: self.assertEqual(message_type[1], None)
def test_hetero_graph_batch(self): G = generate_simple_hete_graph() hete = HeteroGraph(G) heteGraphDataset = [] for i in range(30): heteGraphDataset.append(hete.clone()) dataloader = DataLoader(heteGraphDataset, collate_fn=Batch.collate(), batch_size=3, shuffle=True) self.assertEqual(len(dataloader), math.ceil(30 / 3)) for data in dataloader: self.assertEqual(data.num_graphs, 3)
def test_hetero_graph_basics(self): G = generate_simple_hete_graph() hete = HeteroGraph(G) self.assertEqual(hete.get_num_node_features('n1'), 10) self.assertEqual(hete.get_num_node_features('n2'), 12) self.assertEqual(hete.get_num_edge_features('e1'), 8) self.assertEqual(hete.get_num_edge_features('e2'), 12) self.assertEqual(hete.get_num_nodes('n1'), 4) self.assertEqual(hete.get_num_nodes('n2'), 5) self.assertEqual(len(hete.node_types), 2) self.assertEqual(len(hete.edge_types), 2) message_types = hete.message_types self.assertEqual(len(message_types), 7) self.assertEqual(hete.get_num_node_labels('n1'), 2) self.assertEqual(hete.get_num_node_labels('n2'), 2) self.assertEqual(hete.get_num_edge_labels('e1'), 3) self.assertEqual(hete.get_num_edge_labels('e2'), 3) self.assertEqual(hete.get_num_edges(message_types[0]), 3) self.assertEqual(len(hete.node_label_index), 2)
def test_hetero_graph_none(self): G = generate_simple_hete_graph(no_edge_type=True) hete = HeteroGraph(G) message_types = hete.message_types for message_type in message_types: self.assertEqual(message_type[1], None)