def test_gnn_explainer_explain_graph(model, return_type, allow_edge_mask, feat_mask_type): explainer = GNNExplainer(model, log=False, return_type=return_type, allow_edge_mask=allow_edge_mask, feat_mask_type=feat_mask_type) x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]]) node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index) if feat_mask_type == 'scalar': pass _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask, y=torch.tensor(2), threshold=0.8, node_alpha=node_feat_mask) else: edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), )) _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask, edge_y=edge_y, y=torch.tensor(2), threshold=0.8) if feat_mask_type == 'individual_feature': assert node_feat_mask.size() == x.size() elif feat_mask_type == 'scalar': assert node_feat_mask.size() == (x.size(0), ) else: assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def find_noise_feats_by_GNNExplainer(model, data, args): explainer = GNNExplainer(model, epochs=args.masks_epochs, lr=args.masks_lr, num_hops=args.hop, log=False) node_indices = extract_test_nodes(data, args.test_samples) num_noise_feats = [] for node_idx in tqdm(node_indices, desc='explain node', leave=False): node_feat_mask, edge_mask = explainer.explain_node( node_idx, data.x, data.edge_index) node_feat_mask = node_feat_mask.detach().cpu().numpy() feat_indices = node_feat_mask.argsort()[-args.K:] feat_indices = [ idx for idx in feat_indices if node_feat_mask[idx] > args.masks_threshold ] num_noise_feat = sum(idx >= INPUT_DIM[args.dataset] for idx in feat_indices) num_noise_feats.append(num_noise_feat) return num_noise_feats
def test_gnn_explainer_to_log_prob(model): raw_to_log = GNNExplainer(model, return_type='raw').__to_log_prob__ prob_to_log = GNNExplainer(model, return_type='prob').__to_log_prob__ log_to_log = GNNExplainer(model, return_type='log_prob').__to_log_prob__ raw = torch.tensor([[1, 3.2, 6.1], [9, 9, 0.1]]) prob = raw.softmax(dim=-1) log_prob = raw.log_softmax(dim=-1) assert torch.allclose(raw_to_log(raw), prob_to_log(prob)) assert torch.allclose(prob_to_log(prob), log_to_log(log_prob))
def test_gnn_explainer_with_existing_self_loops(model, return_type): explainer = GNNExplainer(model, log=False, return_type=return_type) x = torch.randn(8, 3) edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [0, 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index) assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer(model): explainer = GNNExplainer(model, log=False) assert explainer.__repr__() == 'GNNExplainer()' x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index) assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.min() >= 0 and edge_mask.max() <= 1
def test_graph_explainer(model): x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]]) explainer = GNNExplainer(model, log=False) node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index) assert_edgemask_clear(model) assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.shape[0] == edge_index.shape[1] assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer_explain_graph(model): explainer = GNNExplainer(model, log=False) x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]]) node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index) _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask, y=torch.tensor(2), threshold=0.8) assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer_explain_node(model, return_type, allow_edge_mask, feat_mask_type): explainer = GNNExplainer(model, log=False, return_type=return_type, allow_edge_mask=allow_edge_mask, feat_mask_type=feat_mask_type) assert explainer.__repr__() == 'GNNExplainer()' x = torch.randn(8, 3) y = torch.tensor([0, 1, 1, 0, 1, 0, 1, 0]) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index) if feat_mask_type == 'scalar': _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=0.8, node_alpha=node_feat_mask) else: edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), )) _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, edge_y=edge_y, threshold=0.8) if feat_mask_type == 'individual_feature': assert node_feat_mask.size() == x.size() elif feat_mask_type == 'scalar': assert node_feat_mask.size() == (x.size(0), ) else: assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.min() >= 0 and edge_mask.max() <= 1 if not allow_edge_mask: assert edge_mask[:8].tolist() == [1.] * 8 assert edge_mask[8:].tolist() == [0.] * 6
else: for epoch in range(1, 201): model.train() optimizer.zero_grad() log_logits = model(x, edge_index) loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() torch.save(model, model_path) log_logists = model(x, edge_index) y_pred_test = torch.argmax(log_logists, dim=1)[np.arange(len(U_train + U_dev), len(U_train + U_dev + U_test))] y_pred_test = y_pred_test.detach().numpy() mean, median, acc, _, _, _ = geo_eval(Y_test, y_pred_test, U_test, classLatMedian, classLonMedian, userLocation) print(f"mean:{mean} median: {median} acc: {acc}") explainer = GNNExplainer(model, epochs=200) node_idx = 9000 print(userLocation[U[node_idx]]) node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index) print("top features:\n", node_feat_mask.argsort()[-10:]) print("top edges:\n", edge_index[:, edge_mask.argsort()[-10:]]) only_topk_edges = 20 all_distances = torch.FloatTensor([get_distance(U[node_idx], u, userLocation) for u in U]) #ax, G = visualize_subgraph2(explainer, node_idx, edge_index, edge_mask, y=all_distances, threshold=None, only_topk_edges=only_topk_edges, cmap=plt.cm.cool) #plt.savefig(f"{node_idx}-{U[node_idx]}-{only_topk_edges}.pdf") #plt.close() #print(data.y) ax, G = visualize_subgraph3(explainer, node_idx, edge_index, edge_mask, y=all_distances, threshold=None, only_topk_edges=only_topk_edges, cmap=plt.cm.cool) plt.savefig(f"{node_idx}-{U[node_idx]}-{only_topk_edges} - original.pdf")
# Gather some statistics about the first graph. print(f'Number of nodes: {single_graph.num_nodes}') print(f'Number of edges: {single_graph.num_edges}') print( f'Average node degree: {single_graph.num_edges / single_graph.num_nodes:.2f}' ) print(f'Contains isolated nodes: {single_graph.contains_isolated_nodes()}') print(f'Contains self-loops: {single_graph.contains_self_loops()}') print(f'Is undirected: {single_graph.is_undirected()}') single_graph = single_graph.to(device) x, edge_index = single_graph.x, single_graph.edge_index # explain a node explainer = GNNExplainer(model, epochs=200) # need to do the union of all the nodes to get full graph explanation... feat_mask_all = torch.zeros(dataset.num_features) edge_mask_all = torch.zeros(single_graph.num_edges) node_idx = 0 node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index, batch=torch.zeros( single_graph.num_nodes, dtype=torch.int64)) ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask,
self.conv1 = GCNConv(dataset.num_features, 16, normalize=False) self.conv2 = GCNConv(16, dataset.num_classes, normalize=False) def forward(self, x, edge_index, edge_weight): x = F.relu(self.conv1(x, edge_index, edge_weight)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index, edge_weight) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight for epoch in range(1, 201): model.train() optimizer.zero_grad() log_logits = model(x, edge_index, edge_weight) loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() explainer = GNNExplainer(model, epochs=200, return_type='log_prob') node_idx = 10 node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index, edge_weight=edge_weight) ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y) plt.show()
best_loss = test_loss if test_loss < best_loss else best_loss if log and epoch % 50 == 1: print(best_loss) return model, x, data.y, edge_index def extract_subgraph(node_idx, num_hops, edge_index): nodes, new_edge_index, mapping, _ = k_hop_subgraph(node_idx, num_hops, edge_index) return new_edge_index, node_idx def run_model(edge_mask, edge_index, model, node_idx): edge_index_1 = edge_index[:, torch.tensor(edge_mask).to(device).bool()] out = model(x, edge_index_1).detach().cpu() return out[node_idx].numpy() if __name__ == '__main__': model, x, data.y, edge_index = train_model(log=True) explainer = GNNExplainer(model, epochs=1000, num_hops=1) node_idx = 549 node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index) scores = stats.zscore(edge_mask.cpu().numpy()) idxs = scores > 10 print(sum(idxs)) ax, G = explainer.visualize_subgraph(node_idx, edge_index.T[idxs].T, edge_mask[idxs], y=data.y) plt.savefig('explain/node%d.png' % node_idx, dpi=300)
from sklearn.metrics import roc_auc_score auc = 0 auc_gnn_exp = 0 pbar = range(x.shape[0]) done = 0 for n in pbar: try: k = 3 sharp = 1e-12 splines = 6 explainer = BayesianExplainer(model, n, k, x, edge_index, sharp, splines) avgs = explainer.train(epochs=3000, lr=5, lambd=5e-11, window=500, p = 1.1, log=False) edge_mask = explainer.edge_mask() edges = explainer.edge_index_adj labs = edge_labels[explainer.subset, :][:, explainer.subset][edges[0, :], edges[1, :]] sub_idx = (labs.long().cpu().detach().numpy() == 1) itr_auc = roc_auc_score(labs.long().cpu().detach().numpy()[sub_idx], edge_mask.cpu().detach().numpy()[sub_idx]) auc += itr_auc e_subset = explainer.edge_mask_hard explainer = GNNExplainer(model.to(device), epochs=1000, log=False) _, edge_mask = explainer.explain_node(n, x.to(device), edge_index.to(device)) auc_gnn_exp += roc_auc_score(labs.long().cpu().detach().numpy()[sub_idx], edge_mask[e_subset].cpu().detach().numpy()[sub_idx]) done += 1 if n % 10 == 0: print('EPOCH: %d | AUC: %.3f | AUC GNN_EXP: %.3f | ITR AUC: %.3f' % (n, auc/done, auc_gnn_exp/done, itr_auc)) except: pass print('FINAL | AUC: %.3f | AUC GNN_EXP: %.3f | ITR AUC: %.3f' % (auc/done, auc_gnn_exp/done, itr_auc))
def test_gnn_explainer(model): explainer = GNNExplainer(model, log=False) assert explainer.__repr__() == 'GNNExplainer()' x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) y = torch.randint(0, 6, (8, ), dtype=torch.long) node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index) assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.min() >= 0 and edge_mask.max() <= 1 explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=None) explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=0.5) explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=None) explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=0.5)
super(Net, self).__init__() self.lin = Sequential(Linear(10, 10)) self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) x, edge_index = data.x, data.edge_index for epoch in range(1, 201): model.train() optimizer.zero_grad() log_logits = model(x, edge_index) loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() explainer = GNNExplainer(model, epochs=200) node_idx = 10 node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index) plt = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y) # plt.show()
# equal to the number of samples. Therefore, we use unsqueeze(0). captum_model = to_captum(model, mask_type='edge', output_idx=output_idx) edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device) ig = IntegratedGradients(captum_model) ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=target, additional_forward_args=(data.x, data.edge_index), internal_batch_size=1) # Scale attributions to [0, 1]: ig_attr_edge = ig_attr_edge.squeeze(0).abs() ig_attr_edge /= ig_attr_edge.max() # Visualize absolute values of attributions with GNNExplainer visualizer explainer = GNNExplainer(model) # TODO: Change to general Explainer visualizer ax, G = explainer.visualize_subgraph(output_idx, data.edge_index, ig_attr_edge) plt.show() # Node explainability # =================== captum_model = to_captum(model, mask_type='node', output_idx=output_idx) ig = IntegratedGradients(captum_model) ig_attr_node = ig.attribute(data.x.unsqueeze(0), target=target, additional_forward_args=(data.edge_index), internal_batch_size=1) # Scale attributions to [0, 1]:
test_correct = int((pred[test_idx] == data.y[test_idx]).sum()) test_acc = test_correct / test_idx.size(0) return train_acc, test_acc for epoch in range(1, 2001): loss = train() if epoch % 200 == 0: train_acc, test_acc = test() print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, ' f'Train: {train_acc:.4f}, Test: {test_acc:.4f}') model.eval() targets, preds = [], [] expl = GNNExplainer(model, epochs=300, return_type='raw', log=False) # Explanation ROC AUC over all test nodes: self_loop_mask = data.edge_index[0] != data.edge_index[1] for node_idx in tqdm(data.expl_mask.nonzero(as_tuple=False).view(-1).tolist()): _, expl_edge_mask = expl.explain_node(node_idx, data.x, data.edge_index, edge_weight=data.edge_weight) subgraph = k_hop_subgraph(node_idx, num_hops=3, edge_index=data.edge_index) expl_edge_mask = expl_edge_mask[self_loop_mask] subgraph_edge_mask = subgraph[3][self_loop_mask] targets.append(data.edge_label[subgraph_edge_mask].cpu()) preds.append(expl_edge_mask[subgraph_edge_mask].cpu()) auc = roc_auc_score(torch.cat(targets), torch.cat(preds))