def test_hetero_graph_basics(self):
        G = generate_simple_hete_graph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)

        self.assertEqual(hete.num_node_features('n1'), 10)
        self.assertEqual(hete.num_node_features('n2'), 12)
        self.assertEqual(hete.num_edge_features(('n1', 'e1', 'n1')), 8)
        self.assertEqual(hete.num_edge_features(('n1', 'e2', 'n2')), 12)
        self.assertEqual(hete.num_nodes('n1'), 4)
        self.assertEqual(hete.num_nodes('n2'), 5)
        self.assertEqual(len(hete.node_types), 2)
        self.assertEqual(len(hete.edge_types), 2)

        message_types = hete.message_types
        self.assertEqual(len(message_types), 7)
        self.assertEqual(hete.num_node_labels('n1'), 2)
        self.assertEqual(hete.num_node_labels('n2'), 2)
        self.assertEqual(hete.num_edge_labels(('n1', 'e1', 'n1')), 2)
        self.assertEqual(hete.num_edge_labels(('n1', 'e2', 'n2')), 2)
        self.assertEqual(hete.num_edges(message_types[0]), 3)
        self.assertEqual(len(hete.node_label_index), 2)
Пример #2
0
    )

    # Node feature and node label to device
    for key in hetero_graph.node_feature:
        hetero_graph.node_feature[key] = hetero_graph.node_feature[key].to(
            args.device)
    for key in hetero_graph.node_label:
        hetero_graph.node_label[key] = hetero_graph.node_label[key].to(
            args.device)

    # Edge_index to sparse tensor and to device
    for key in hetero_graph.edge_index:
        edge_index = hetero_graph.edge_index[key]
        adj = SparseTensor(row=edge_index[0],
                           col=edge_index[1],
                           sparse_sizes=(hetero_graph.num_nodes('paper'),
                                         hetero_graph.num_nodes('paper')))
        hetero_graph.edge_index[key] = adj.t().to(args.device)
    print(hetero_graph.edge_index[message_type_1])
    print(hetero_graph.edge_index[message_type_2])

    # Mean aggregation
    best_model = None
    best_val = 0

    model = HeteroGNN(hetero_graph, args, aggr="mean").to(args.device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    for epoch in range(args.epochs):