def test_hetero_graph_basics(self): G = generate_simple_hete_graph() hete = HeteroGraph(G) hete = HeteroGraph(node_feature=hete.node_feature, node_label=hete.node_label, edge_feature=hete.edge_feature, edge_label=hete.edge_label, edge_index=hete.edge_index, directed=True) self.assertEqual(hete.num_node_features('n1'), 10) self.assertEqual(hete.num_node_features('n2'), 12) self.assertEqual(hete.num_edge_features(('n1', 'e1', 'n1')), 8) self.assertEqual(hete.num_edge_features(('n1', 'e2', 'n2')), 12) self.assertEqual(hete.num_nodes('n1'), 4) self.assertEqual(hete.num_nodes('n2'), 5) self.assertEqual(len(hete.node_types), 2) self.assertEqual(len(hete.edge_types), 2) message_types = hete.message_types self.assertEqual(len(message_types), 7) self.assertEqual(hete.num_node_labels('n1'), 2) self.assertEqual(hete.num_node_labels('n2'), 2) self.assertEqual(hete.num_edge_labels(('n1', 'e1', 'n1')), 2) self.assertEqual(hete.num_edge_labels(('n1', 'e2', 'n2')), 2) self.assertEqual(hete.num_edges(message_types[0]), 3) self.assertEqual(len(hete.node_label_index), 2)
) # Node feature and node label to device for key in hetero_graph.node_feature: hetero_graph.node_feature[key] = hetero_graph.node_feature[key].to( args.device) for key in hetero_graph.node_label: hetero_graph.node_label[key] = hetero_graph.node_label[key].to( args.device) # Edge_index to sparse tensor and to device for key in hetero_graph.edge_index: edge_index = hetero_graph.edge_index[key] adj = SparseTensor(row=edge_index[0], col=edge_index[1], sparse_sizes=(hetero_graph.num_nodes('paper'), hetero_graph.num_nodes('paper'))) hetero_graph.edge_index[key] = adj.t().to(args.device) print(hetero_graph.edge_index[message_type_1]) print(hetero_graph.edge_index[message_type_2]) # Mean aggregation best_model = None best_val = 0 model = HeteroGNN(hetero_graph, args, aggr="mean").to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) for epoch in range(args.epochs):