def test_hetero_graph_basics(self):
        G = generate_simple_hete_graph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)

        self.assertEqual(hete.num_node_features('n1'), 10)
        self.assertEqual(hete.num_node_features('n2'), 12)
        self.assertEqual(hete.num_edge_features(('n1', 'e1', 'n1')), 8)
        self.assertEqual(hete.num_edge_features(('n1', 'e2', 'n2')), 12)
        self.assertEqual(hete.num_nodes('n1'), 4)
        self.assertEqual(hete.num_nodes('n2'), 5)
        self.assertEqual(len(hete.node_types), 2)
        self.assertEqual(len(hete.edge_types), 2)

        message_types = hete.message_types
        self.assertEqual(len(message_types), 7)
        self.assertEqual(hete.num_node_labels('n1'), 2)
        self.assertEqual(hete.num_node_labels('n2'), 2)
        self.assertEqual(hete.num_edge_labels(('n1', 'e1', 'n1')), 2)
        self.assertEqual(hete.num_edge_labels(('n1', 'e2', 'n2')), 2)
        self.assertEqual(hete.num_edges(message_types[0]), 3)
        self.assertEqual(len(hete.node_label_index), 2)
Пример #2
0
            f"test micro {round(accs[2][0] * 100, 2)}%, test macro {round(accs[2][1] * 100, 2)}%"
        )
    best_accs, _, _ = test(best_model, hetero_graph,
                           [train_idx, val_idx, test_idx])
    print(
        f"Best model: "
        f"train micro {round(best_accs[0][0] * 100, 2)}%, train macro {round(best_accs[0][1] * 100, 2)}%, "
        f"valid micro {round(best_accs[1][0] * 100, 2)}%, valid macro {round(best_accs[1][1] * 100, 2)}%, "
        f"test micro {round(best_accs[2][0] * 100, 2)}%, test macro {round(best_accs[2][1] * 100, 2)}%"
    )

    # Attention aggregation
    best_model = None
    best_val = 0

    output_size = hetero_graph.num_node_labels('paper')
    model = HeteroGNN(hetero_graph, args, aggr="attn").to(args.device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    for epoch in range(args.epochs):
        loss = train(model, optimizer, hetero_graph, train_idx)
        accs, best_model, best_val = test(model, hetero_graph,
                                          [train_idx, val_idx, test_idx],
                                          best_model, best_val)
        print(
            f"Epoch {epoch + 1}: loss {round(loss, 5)}, "
            f"train micro {round(accs[0][0] * 100, 2)}%, train macro {round(accs[0][1] * 100, 2)}%, "
            f"valid micro {round(accs[1][0] * 100, 2)}%, valid macro {round(accs[1][1] * 100, 2)}%, "
            f"test micro {round(accs[2][0] * 100, 2)}%, test macro {round(accs[2][1] * 100, 2)}%"