def test_gnn_explainer_explain_graph(model, return_type, allow_edge_mask, feat_mask_type): explainer = GNNExplainer(model, log=False, return_type=return_type, allow_edge_mask=allow_edge_mask, feat_mask_type=feat_mask_type) x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]]) node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index) if feat_mask_type == 'scalar': pass _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask, y=torch.tensor(2), threshold=0.8, node_alpha=node_feat_mask) else: edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), )) _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask, edge_y=edge_y, y=torch.tensor(2), threshold=0.8) if feat_mask_type == 'individual_feature': assert node_feat_mask.size() == x.size() elif feat_mask_type == 'scalar': assert node_feat_mask.size() == (x.size(0), ) else: assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_graph_explainer(model): x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]]) explainer = GNNExplainer(model, log=False) node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index) assert_edgemask_clear(model) assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.shape[0] == edge_index.shape[1] assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer_explain_graph(model): explainer = GNNExplainer(model, log=False) x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]]) node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index) _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask, y=torch.tensor(2), threshold=0.8) assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.max() <= 1 and edge_mask.min() >= 0