def test_hetero_graph_basics(self):
        G = generate_simple_hete_graph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)

        self.assertEqual(hete.num_node_features('n1'), 10)
        self.assertEqual(hete.num_node_features('n2'), 12)
        self.assertEqual(hete.num_edge_features(('n1', 'e1', 'n1')), 8)
        self.assertEqual(hete.num_edge_features(('n1', 'e2', 'n2')), 12)
        self.assertEqual(hete.num_nodes('n1'), 4)
        self.assertEqual(hete.num_nodes('n2'), 5)
        self.assertEqual(len(hete.node_types), 2)
        self.assertEqual(len(hete.edge_types), 2)

        message_types = hete.message_types
        self.assertEqual(len(message_types), 7)
        self.assertEqual(hete.num_node_labels('n1'), 2)
        self.assertEqual(hete.num_node_labels('n2'), 2)
        self.assertEqual(hete.num_edge_labels(('n1', 'e1', 'n1')), 2)
        self.assertEqual(hete.num_edge_labels(('n1', 'e2', 'n2')), 2)
        self.assertEqual(hete.num_edges(message_types[0]), 3)
        self.assertEqual(len(hete.node_label_index), 2)
    def test_hetero_graph_batch(self):
        G = generate_simple_hete_graph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(
            node_feature=hete.node_feature,
            node_label=hete.node_label,
            edge_feature=hete.edge_feature,
            edge_label=hete.edge_label,
            edge_index=hete.edge_index,
            directed=True
        )

        heteGraphDataset = []
        for _ in range(30):
            heteGraphDataset.append(hete.clone())
        dataloader = DataLoader(
            heteGraphDataset,
            collate_fn=Batch.collate(),
            batch_size=3,
            shuffle=True,
        )

        self.assertEqual(len(dataloader), math.ceil(30 / 3))
        for data in dataloader:
            self.assertEqual(data.num_graphs, 3)
 def test_hetero_graph_none(self):
     G = generate_simple_hete_graph(add_edge_type=False)
     hete = HeteroGraph(G)
     hete = HeteroGraph(node_feature=hete.node_feature,
                        node_label=hete.node_label,
                        edge_feature=hete.edge_feature,
                        edge_label=hete.edge_label,
                        edge_index=hete.edge_index,
                        directed=True)
     message_types = hete.message_types
     for message_type in message_types:
         self.assertEqual(message_type[1], None)
Esempio n. 4
0
    def test_hetero_graph_batch(self):
        G = generate_simple_hete_graph()
        hete = HeteroGraph(G)

        heteGraphDataset = []
        for i in range(30):
            heteGraphDataset.append(hete.clone())
        dataloader = DataLoader(heteGraphDataset,
                                collate_fn=Batch.collate(),
                                batch_size=3,
                                shuffle=True)

        self.assertEqual(len(dataloader), math.ceil(30 / 3))
        for data in dataloader:
            self.assertEqual(data.num_graphs, 3)
Esempio n. 5
0
    def test_hetero_graph_basics(self):
        G = generate_simple_hete_graph()
        hete = HeteroGraph(G)

        self.assertEqual(hete.get_num_node_features('n1'), 10)
        self.assertEqual(hete.get_num_node_features('n2'), 12)
        self.assertEqual(hete.get_num_edge_features('e1'), 8)
        self.assertEqual(hete.get_num_edge_features('e2'), 12)
        self.assertEqual(hete.get_num_nodes('n1'), 4)
        self.assertEqual(hete.get_num_nodes('n2'), 5)
        self.assertEqual(len(hete.node_types), 2)
        self.assertEqual(len(hete.edge_types), 2)

        message_types = hete.message_types
        self.assertEqual(len(message_types), 7)
        self.assertEqual(hete.get_num_node_labels('n1'), 2)
        self.assertEqual(hete.get_num_node_labels('n2'), 2)
        self.assertEqual(hete.get_num_edge_labels('e1'), 3)
        self.assertEqual(hete.get_num_edge_labels('e2'), 3)
        self.assertEqual(hete.get_num_edges(message_types[0]), 3)
        self.assertEqual(len(hete.node_label_index), 2)
Esempio n. 6
0
 def test_hetero_graph_none(self):
     G = generate_simple_hete_graph(no_edge_type=True)
     hete = HeteroGraph(G)
     message_types = hete.message_types
     for message_type in message_types:
         self.assertEqual(message_type[1], None)