def test_gnn_explainer_explain_node(model, return_type, allow_edge_mask, feat_mask_type): explainer = GNNExplainer(model, log=False, return_type=return_type, allow_edge_mask=allow_edge_mask, feat_mask_type=feat_mask_type) assert explainer.__repr__() == 'GNNExplainer()' x = torch.randn(8, 3) y = torch.tensor([0, 1, 1, 0, 1, 0, 1, 0]) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index) if feat_mask_type == 'scalar': _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=0.8, node_alpha=node_feat_mask) else: edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), )) _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, edge_y=edge_y, threshold=0.8) if feat_mask_type == 'individual_feature': assert node_feat_mask.size() == x.size() elif feat_mask_type == 'scalar': assert node_feat_mask.size() == (x.size(0), ) else: assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.min() >= 0 and edge_mask.max() <= 1 if not allow_edge_mask: assert edge_mask[:8].tolist() == [1.] * 8 assert edge_mask[8:].tolist() == [0.] * 6
def test_gnn_explainer_explain_graph(model, return_type, allow_edge_mask, feat_mask_type): explainer = GNNExplainer(model, log=False, return_type=return_type, allow_edge_mask=allow_edge_mask, feat_mask_type=feat_mask_type) x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]]) node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index) if feat_mask_type == 'scalar': pass _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask, y=torch.tensor(2), threshold=0.8, node_alpha=node_feat_mask) else: edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), )) _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask, edge_y=edge_y, y=torch.tensor(2), threshold=0.8) if feat_mask_type == 'individual_feature': assert node_feat_mask.size() == x.size() elif feat_mask_type == 'scalar': assert node_feat_mask.size() == (x.size(0), ) else: assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer_explain_graph(model): explainer = GNNExplainer(model, log=False) x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]]) node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index) _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask, y=torch.tensor(2), threshold=0.8) assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer(model): explainer = GNNExplainer(model, log=False) assert explainer.__repr__() == 'GNNExplainer()' x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) y = torch.randint(0, 6, (8, ), dtype=torch.long) node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index) assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.min() >= 0 and edge_mask.max() <= 1 explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=None) explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=0.5) explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=None) explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=0.5)
captum_model = to_captum(model, mask_type='edge', output_idx=output_idx) edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device) ig = IntegratedGradients(captum_model) ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=target, additional_forward_args=(data.x, data.edge_index), internal_batch_size=1) # Scale attributions to [0, 1]: ig_attr_edge = ig_attr_edge.squeeze(0).abs() ig_attr_edge /= ig_attr_edge.max() # Visualize absolute values of attributions with GNNExplainer visualizer explainer = GNNExplainer(model) # TODO: Change to general Explainer visualizer ax, G = explainer.visualize_subgraph(output_idx, data.edge_index, ig_attr_edge) plt.show() # Node explainability # =================== captum_model = to_captum(model, mask_type='node', output_idx=output_idx) ig = IntegratedGradients(captum_model) ig_attr_node = ig.attribute(data.x.unsqueeze(0), target=target, additional_forward_args=(data.edge_index), internal_batch_size=1) # Scale attributions to [0, 1]: ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1)
super(Net, self).__init__() self.lin = Sequential(Linear(10, 10)) self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) x, edge_index = data.x, data.edge_index for epoch in range(1, 201): model.train() optimizer.zero_grad() log_logits = model(x, edge_index) loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() explainer = GNNExplainer(model, epochs=200) node_idx = 10 node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index) plt = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y) # plt.show()
best_loss = test_loss if test_loss < best_loss else best_loss if log and epoch % 50 == 1: print(best_loss) return model, x, data.y, edge_index def extract_subgraph(node_idx, num_hops, edge_index): nodes, new_edge_index, mapping, _ = k_hop_subgraph(node_idx, num_hops, edge_index) return new_edge_index, node_idx def run_model(edge_mask, edge_index, model, node_idx): edge_index_1 = edge_index[:, torch.tensor(edge_mask).to(device).bool()] out = model(x, edge_index_1).detach().cpu() return out[node_idx].numpy() if __name__ == '__main__': model, x, data.y, edge_index = train_model(log=True) explainer = GNNExplainer(model, epochs=1000, num_hops=1) node_idx = 549 node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index) scores = stats.zscore(edge_mask.cpu().numpy()) idxs = scores > 10 print(sum(idxs)) ax, G = explainer.visualize_subgraph(node_idx, edge_index.T[idxs].T, edge_mask[idxs], y=data.y) plt.savefig('explain/node%d.png' % node_idx, dpi=300)
# explain a node explainer = GNNExplainer(model, epochs=200) # need to do the union of all the nodes to get full graph explanation... feat_mask_all = torch.zeros(dataset.num_features) edge_mask_all = torch.zeros(single_graph.num_edges) node_idx = 0 node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index, batch=torch.zeros( single_graph.num_nodes, dtype=torch.int64)) ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, threshold=0.5) #, y=single_graph.y) plt.show() # # """ # Cora - node classification # """ # import os.path as osp # import torch # import torch.nn.functional as F # import matplotlib.pyplot as plt # from torch_geometric.datasets import Planetoid # import torch_geometric.transforms as T
loss.backward() optimizer.step() # Testing step model.eval() test_loss = F.nll_loss(log_logits[test_mask], data.y[test_mask]).item() best_loss = test_loss if test_loss < best_loss else best_loss if log and epoch % 50 == 1: print(best_loss) return model, x, data.y, edge_index def extract_subgraph(node_idx, num_hops, edge_index): nodes, new_edge_index, mapping, _ = k_hop_subgraph(node_idx, num_hops, edge_index) return new_edge_index, node_idx def run_model(edge_mask, edge_index, model, node_idx): edge_index_1 = edge_index[:, torch.tensor(edge_mask).to(device).bool()] out = model(x, edge_index_1).detach().cpu() return out[node_idx].numpy() if __name__ == '__main__': model, x, data.y, edge_index = train_model(log=True) explainer = GNNExplainer(model, epochs=1000, num_hops=1) node_idx = 1 node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index) node_feat_mask = node_feat_mask.to('cpu') edge_mask = edge_mask.to('cpu') y = data.y.to('cpu') ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=y) nx.readwrite.edgelist.write_edgelist(G, 'scripts/BI/random/explain/node%d.el' % node_idx, data = ['att']) plt.savefig('scripts/BI/random/explain/node%d.png' % node_idx, dpi=300)